feat:新增napcat断线后重连重新拉取历史消息的机制
This commit is contained in:
@@ -8,3 +8,6 @@ DEFAULT_RECONNECT_DELAY_SEC = 5.0
|
|||||||
DEFAULT_HEARTBEAT_INTERVAL_SEC = 30.0
|
DEFAULT_HEARTBEAT_INTERVAL_SEC = 30.0
|
||||||
DEFAULT_ACTION_TIMEOUT_SEC = 15.0
|
DEFAULT_ACTION_TIMEOUT_SEC = 15.0
|
||||||
DEFAULT_CHAT_LIST_TYPE = "whitelist"
|
DEFAULT_CHAT_LIST_TYPE = "whitelist"
|
||||||
|
DEFAULT_HISTORY_RECOVERY_BATCH_SIZE = 20
|
||||||
|
DEFAULT_HISTORY_RECOVERY_CHECKPOINT_LIMIT = 50
|
||||||
|
DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC = 86400.0 * 7
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from ..services import (
|
|||||||
NapCatActionService,
|
NapCatActionService,
|
||||||
NapCatBanStateStore,
|
NapCatBanStateStore,
|
||||||
NapCatBanTracker,
|
NapCatBanTracker,
|
||||||
|
NapCatHistoryRecoveryStore,
|
||||||
NapCatOfficialBotGuard,
|
NapCatOfficialBotGuard,
|
||||||
NapCatQueryService,
|
NapCatQueryService,
|
||||||
)
|
)
|
||||||
@@ -66,6 +67,7 @@ class NapCatRuntimeBuilder:
|
|||||||
action_service = NapCatActionService(self._logger, transport)
|
action_service = NapCatActionService(self._logger, transport)
|
||||||
query_service = NapCatQueryService(action_service, self._logger)
|
query_service = NapCatQueryService(action_service, self._logger)
|
||||||
ban_state_store = NapCatBanStateStore(self._logger)
|
ban_state_store = NapCatBanStateStore(self._logger)
|
||||||
|
history_recovery_store = NapCatHistoryRecoveryStore(self._logger)
|
||||||
inbound_codec = NapCatInboundCodec(self._logger, query_service)
|
inbound_codec = NapCatInboundCodec(self._logger, query_service)
|
||||||
notice_codec = NapCatNoticeCodec(self._logger, query_service)
|
notice_codec = NapCatNoticeCodec(self._logger, query_service)
|
||||||
runtime_state = NapCatRuntimeStateManager(
|
runtime_state = NapCatRuntimeStateManager(
|
||||||
@@ -92,6 +94,7 @@ class NapCatRuntimeBuilder:
|
|||||||
ban_tracker=ban_tracker,
|
ban_tracker=ban_tracker,
|
||||||
chat_filter=chat_filter,
|
chat_filter=chat_filter,
|
||||||
heartbeat_monitor=heartbeat_monitor,
|
heartbeat_monitor=heartbeat_monitor,
|
||||||
|
history_recovery_store=history_recovery_store,
|
||||||
inbound_codec=inbound_codec,
|
inbound_codec=inbound_codec,
|
||||||
notice_codec=notice_codec,
|
notice_codec=notice_codec,
|
||||||
official_bot_guard=official_bot_guard,
|
official_bot_guard=official_bot_guard,
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from ..services import (
|
|||||||
NapCatActionService,
|
NapCatActionService,
|
||||||
NapCatBanStateStore,
|
NapCatBanStateStore,
|
||||||
NapCatBanTracker,
|
NapCatBanTracker,
|
||||||
|
NapCatHistoryRecoveryStore,
|
||||||
NapCatOfficialBotGuard,
|
NapCatOfficialBotGuard,
|
||||||
NapCatQueryService,
|
NapCatQueryService,
|
||||||
)
|
)
|
||||||
@@ -29,6 +30,7 @@ class NapCatRuntimeBundle:
|
|||||||
ban_tracker: NapCatBanTracker
|
ban_tracker: NapCatBanTracker
|
||||||
chat_filter: NapCatChatFilter
|
chat_filter: NapCatChatFilter
|
||||||
heartbeat_monitor: NapCatHeartbeatMonitor
|
heartbeat_monitor: NapCatHeartbeatMonitor
|
||||||
|
history_recovery_store: NapCatHistoryRecoveryStore
|
||||||
inbound_codec: NapCatInboundCodec
|
inbound_codec: NapCatInboundCodec
|
||||||
notice_codec: NapCatNoticeCodec
|
notice_codec: NapCatNoticeCodec
|
||||||
official_bot_guard: NapCatOfficialBotGuard
|
official_bot_guard: NapCatOfficialBotGuard
|
||||||
|
|||||||
@@ -2,11 +2,14 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, Mapping, Optional, Protocol
|
from typing import Any, Callable, Dict, Mapping, Optional, Protocol
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from ..config import NapCatPluginSettings
|
from ..config import NapCatPluginSettings
|
||||||
|
from ..constants import DEFAULT_HISTORY_RECOVERY_BATCH_SIZE, DEFAULT_HISTORY_RECOVERY_CHECKPOINT_LIMIT
|
||||||
|
from ..services import NapCatChatCheckpoint
|
||||||
from ..types import NapCatPayloadDict
|
from ..types import NapCatPayloadDict
|
||||||
from .bundle import NapCatRuntimeBundle
|
from .bundle import NapCatRuntimeBundle
|
||||||
|
|
||||||
@@ -27,6 +30,14 @@ class _GatewayCapabilityProtocol(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _NapCatChatIdentity:
|
||||||
|
"""描述一条 NapCat 消息所属的会话身份。"""
|
||||||
|
|
||||||
|
chat_type: str
|
||||||
|
chat_id: str
|
||||||
|
|
||||||
|
|
||||||
class NapCatEventRouter:
|
class NapCatEventRouter:
|
||||||
"""协调 NapCat 运行时组件处理各类平台事件。"""
|
"""协调 NapCat 运行时组件处理各类平台事件。"""
|
||||||
|
|
||||||
@@ -50,6 +61,7 @@ class NapCatEventRouter:
|
|||||||
self._gateway_name = gateway_name
|
self._gateway_name = gateway_name
|
||||||
self._load_settings = load_settings
|
self._load_settings = load_settings
|
||||||
self._runtime: Optional[NapCatRuntimeBundle] = None
|
self._runtime: Optional[NapCatRuntimeBundle] = None
|
||||||
|
self._recovery_task: Optional[asyncio.Task[None]] = None
|
||||||
|
|
||||||
def bind_runtime(self, runtime: NapCatRuntimeBundle) -> None:
|
def bind_runtime(self, runtime: NapCatRuntimeBundle) -> None:
|
||||||
"""绑定当前路由器使用的运行时依赖。
|
"""绑定当前路由器使用的运行时依赖。
|
||||||
@@ -64,6 +76,7 @@ class NapCatEventRouter:
|
|||||||
runtime = self._runtime
|
runtime = self._runtime
|
||||||
if runtime is None:
|
if runtime is None:
|
||||||
return
|
return
|
||||||
|
self._cancel_recovery_task()
|
||||||
runtime.official_bot_guard.clear_cache()
|
runtime.official_bot_guard.clear_cache()
|
||||||
|
|
||||||
async def handle_transport_payload(self, payload: NapCatPayloadDict) -> None:
|
async def handle_transport_payload(self, payload: NapCatPayloadDict) -> None:
|
||||||
@@ -82,7 +95,7 @@ class NapCatEventRouter:
|
|||||||
if post_type == "meta_event":
|
if post_type == "meta_event":
|
||||||
await self.handle_meta_event(payload)
|
await self.handle_meta_event(payload)
|
||||||
|
|
||||||
async def handle_inbound_message(self, payload: NapCatPayloadDict) -> None:
|
async def handle_inbound_message(self, payload: NapCatPayloadDict) -> bool:
|
||||||
"""处理单条 NapCat 入站消息并注入 Host。
|
"""处理单条 NapCat 入站消息并注入 Host。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -101,25 +114,25 @@ class NapCatEventRouter:
|
|||||||
|
|
||||||
sender_user_id = str(payload.get("user_id") or sender.get("user_id") or "").strip()
|
sender_user_id = str(payload.get("user_id") or sender.get("user_id") or "").strip()
|
||||||
if not sender_user_id:
|
if not sender_user_id:
|
||||||
return
|
return False
|
||||||
|
|
||||||
group_id = str(payload.get("group_id") or "").strip()
|
group_id = str(payload.get("group_id") or "").strip()
|
||||||
if self_id and sender_user_id == self_id and settings.filters.ignore_self_message:
|
if self_id and sender_user_id == self_id and settings.filters.ignore_self_message:
|
||||||
return
|
return False
|
||||||
if not runtime.chat_filter.is_inbound_chat_allowed(sender_user_id, group_id, settings.chat):
|
if not runtime.chat_filter.is_inbound_chat_allowed(sender_user_id, group_id, settings.chat):
|
||||||
return
|
return False
|
||||||
if await runtime.official_bot_guard.should_reject(
|
if await runtime.official_bot_guard.should_reject(
|
||||||
sender_user_id=sender_user_id,
|
sender_user_id=sender_user_id,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
ban_qq_bot=settings.chat.ban_qq_bot,
|
ban_qq_bot=settings.chat.ban_qq_bot,
|
||||||
):
|
):
|
||||||
return
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
message_dict = await runtime.inbound_codec.build_message_dict(payload, self_id, sender_user_id, sender)
|
message_dict = await runtime.inbound_codec.build_message_dict(payload, self_id, sender_user_id, sender)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
self._logger.warning(f"NapCat 入站消息格式不受支持,已丢弃: {exc}")
|
self._logger.warning(f"NapCat 入站消息格式不受支持,已丢弃: {exc}")
|
||||||
return
|
return False
|
||||||
|
|
||||||
route_metadata = self._build_route_metadata(self_id, settings.napcat_server.connection_id)
|
route_metadata = self._build_route_metadata(self_id, settings.napcat_server.connection_id)
|
||||||
external_message_id = str(payload.get("message_id") or "").strip()
|
external_message_id = str(payload.get("message_id") or "").strip()
|
||||||
@@ -132,6 +145,15 @@ class NapCatEventRouter:
|
|||||||
)
|
)
|
||||||
if not accepted:
|
if not accepted:
|
||||||
self._logger.debug(f"Host 丢弃了 NapCat 入站消息: {external_message_id or '无消息 ID'}")
|
self._logger.debug(f"Host 丢弃了 NapCat 入站消息: {external_message_id or '无消息 ID'}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self._record_inbound_checkpoint(
|
||||||
|
payload=payload,
|
||||||
|
self_id=self_id,
|
||||||
|
external_message_id=external_message_id or str(message_dict.get("message_id") or "").strip(),
|
||||||
|
scope=settings.napcat_server.connection_id,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
async def handle_notice_event(self, payload: NapCatPayloadDict) -> None:
|
async def handle_notice_event(self, payload: NapCatPayloadDict) -> None:
|
||||||
"""处理 NapCat ``notice`` 事件并注入 Host。
|
"""处理 NapCat ``notice`` 事件并注入 Host。
|
||||||
@@ -232,6 +254,8 @@ class NapCatEventRouter:
|
|||||||
await runtime.runtime_state.report_connected(self_id, settings.napcat_server)
|
await runtime.runtime_state.report_connected(self_id, settings.napcat_server)
|
||||||
await runtime.heartbeat_monitor.start(self_id, settings.napcat_server.heartbeat_interval)
|
await runtime.heartbeat_monitor.start(self_id, settings.napcat_server.heartbeat_interval)
|
||||||
await runtime.ban_tracker.start()
|
await runtime.ban_tracker.start()
|
||||||
|
await runtime.history_recovery_store.load()
|
||||||
|
self._schedule_history_recovery(self_id=self_id, scope=settings.napcat_server.connection_id)
|
||||||
return
|
return
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise
|
raise
|
||||||
@@ -279,6 +303,274 @@ class NapCatEventRouter:
|
|||||||
raise RuntimeError("NapCat 运行时尚未初始化")
|
raise RuntimeError("NapCat 运行时尚未初始化")
|
||||||
return runtime
|
return runtime
|
||||||
|
|
||||||
|
def _schedule_history_recovery(self, self_id: str, scope: str) -> None:
|
||||||
|
"""在连接恢复后调度一次历史补拉任务。"""
|
||||||
|
|
||||||
|
self._cancel_recovery_task()
|
||||||
|
runtime = self._runtime
|
||||||
|
if runtime is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._recovery_task = asyncio.create_task(
|
||||||
|
self._recover_recent_history(self_id=self_id, scope=scope),
|
||||||
|
name="napcat_adapter.history_recovery",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _cancel_recovery_task(self) -> None:
|
||||||
|
"""取消当前仍在运行的历史补拉任务。"""
|
||||||
|
|
||||||
|
recovery_task = self._recovery_task
|
||||||
|
self._recovery_task = None
|
||||||
|
if recovery_task is not None and not recovery_task.done():
|
||||||
|
recovery_task.cancel()
|
||||||
|
|
||||||
|
async def _recover_recent_history(self, *, self_id: str, scope: str) -> None:
|
||||||
|
"""按 checkpoint 列表逐个尝试补拉断线期间遗漏的消息。"""
|
||||||
|
|
||||||
|
runtime = self._require_runtime()
|
||||||
|
checkpoints = await runtime.history_recovery_store.list_checkpoints(
|
||||||
|
self_id,
|
||||||
|
scope=scope,
|
||||||
|
limit=DEFAULT_HISTORY_RECOVERY_CHECKPOINT_LIMIT,
|
||||||
|
)
|
||||||
|
if not checkpoints:
|
||||||
|
return
|
||||||
|
|
||||||
|
recovered_count = 0
|
||||||
|
for checkpoint in checkpoints:
|
||||||
|
recovered_count += await self._recover_chat_history_from_checkpoint(
|
||||||
|
self_id=self_id,
|
||||||
|
scope=scope,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
if recovered_count > 0:
|
||||||
|
self._logger.info(f"NapCat 历史补拉完成,共补回 {recovered_count} 条消息")
|
||||||
|
|
||||||
|
async def _recover_chat_history_from_checkpoint(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
self_id: str,
|
||||||
|
scope: str,
|
||||||
|
checkpoint: NapCatChatCheckpoint,
|
||||||
|
) -> int:
|
||||||
|
"""针对单个会话执行一次小批量历史补拉。"""
|
||||||
|
|
||||||
|
runtime = self._require_runtime()
|
||||||
|
history_messages = await self._query_history_messages(checkpoint, limit=DEFAULT_HISTORY_RECOVERY_BATCH_SIZE)
|
||||||
|
if not history_messages:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
ordered_messages = sorted(
|
||||||
|
history_messages,
|
||||||
|
key=lambda item: (
|
||||||
|
self._extract_message_timestamp(item),
|
||||||
|
self._extract_message_seq(item),
|
||||||
|
str(item.get("message_id") or "").strip(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
recovered_count = 0
|
||||||
|
for history_payload in ordered_messages:
|
||||||
|
external_message_id = str(history_payload.get("message_id") or "").strip()
|
||||||
|
if not external_message_id:
|
||||||
|
continue
|
||||||
|
if external_message_id == checkpoint.last_message_id:
|
||||||
|
continue
|
||||||
|
if await runtime.history_recovery_store.has_recovered_message_seen(
|
||||||
|
account_id=self_id,
|
||||||
|
scope=scope,
|
||||||
|
chat_type=checkpoint.chat_type,
|
||||||
|
chat_id=checkpoint.chat_id,
|
||||||
|
external_message_id=external_message_id,
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
if not self._is_message_after_checkpoint(history_payload, checkpoint):
|
||||||
|
continue
|
||||||
|
accepted = await self._reinject_history_payload(history_payload, self_id=self_id)
|
||||||
|
if not accepted:
|
||||||
|
continue
|
||||||
|
await runtime.history_recovery_store.mark_recovered_message_seen(
|
||||||
|
account_id=self_id,
|
||||||
|
scope=scope,
|
||||||
|
chat_type=checkpoint.chat_type,
|
||||||
|
chat_id=checkpoint.chat_id,
|
||||||
|
external_message_id=external_message_id,
|
||||||
|
)
|
||||||
|
recovered_count += 1
|
||||||
|
|
||||||
|
return recovered_count
|
||||||
|
|
||||||
|
async def _query_history_messages(
|
||||||
|
self,
|
||||||
|
checkpoint: NapCatChatCheckpoint,
|
||||||
|
*,
|
||||||
|
limit: int,
|
||||||
|
) -> list[NapCatPayloadDict]:
|
||||||
|
"""查询某个会话在 checkpoint 之后的一小批历史消息。"""
|
||||||
|
|
||||||
|
runtime = self._require_runtime()
|
||||||
|
payload_collections: list[list[NapCatPayloadDict]] = []
|
||||||
|
if checkpoint.last_message_seq is not None:
|
||||||
|
payload_collections.append(
|
||||||
|
await self._fetch_history_messages(
|
||||||
|
chat_type=checkpoint.chat_type,
|
||||||
|
chat_id=checkpoint.chat_id,
|
||||||
|
message_seq=checkpoint.last_message_seq,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
payload_collections.append(
|
||||||
|
await self._fetch_history_messages(
|
||||||
|
chat_type=checkpoint.chat_type,
|
||||||
|
chat_id=checkpoint.chat_id,
|
||||||
|
message_seq=None,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
merged_payloads: list[NapCatPayloadDict] = []
|
||||||
|
seen_message_ids: set[str] = set()
|
||||||
|
for payloads in payload_collections:
|
||||||
|
for payload in payloads:
|
||||||
|
external_message_id = str(payload.get("message_id") or "").strip()
|
||||||
|
dedupe_key = external_message_id or repr(sorted(payload.items()))
|
||||||
|
if dedupe_key in seen_message_ids:
|
||||||
|
continue
|
||||||
|
seen_message_ids.add(dedupe_key)
|
||||||
|
merged_payloads.append(payload)
|
||||||
|
return merged_payloads
|
||||||
|
|
||||||
|
async def _fetch_history_messages(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
chat_type: str,
|
||||||
|
chat_id: str,
|
||||||
|
message_seq: int | None,
|
||||||
|
limit: int,
|
||||||
|
) -> list[NapCatPayloadDict]:
|
||||||
|
"""调用查询服务获取一批历史消息。"""
|
||||||
|
|
||||||
|
runtime = self._require_runtime()
|
||||||
|
if chat_type == "group":
|
||||||
|
history_payloads = await runtime.query_service.get_group_message_history(
|
||||||
|
chat_id,
|
||||||
|
message_seq=message_seq,
|
||||||
|
count=limit,
|
||||||
|
reverse_order=False,
|
||||||
|
)
|
||||||
|
elif chat_type == "private":
|
||||||
|
history_payloads = await runtime.query_service.get_friend_message_history(
|
||||||
|
chat_id,
|
||||||
|
message_seq=message_seq,
|
||||||
|
count=limit,
|
||||||
|
reverse_order=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if history_payloads is None:
|
||||||
|
return []
|
||||||
|
return [dict(payload) for payload in history_payloads if isinstance(payload, Mapping)]
|
||||||
|
|
||||||
|
async def _reinject_history_payload(self, payload: NapCatPayloadDict, *, self_id: str) -> bool:
|
||||||
|
"""将补拉到的历史消息重新送回实时入站路径。"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
normalized_payload = dict(payload)
|
||||||
|
if self_id and not str(normalized_payload.get("self_id") or "").strip():
|
||||||
|
normalized_payload["self_id"] = self_id
|
||||||
|
return await self.handle_inbound_message(normalized_payload)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
external_message_id = str(payload.get("message_id") or "").strip() or "unknown"
|
||||||
|
self._logger.warning(f"NapCat 历史消息补拉注入失败: message_id={external_message_id} error={exc}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _record_inbound_checkpoint(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
payload: NapCatPayloadDict,
|
||||||
|
self_id: str,
|
||||||
|
external_message_id: str,
|
||||||
|
scope: str,
|
||||||
|
) -> None:
|
||||||
|
"""在消息被 Host 接受后更新该会话的最新 checkpoint。"""
|
||||||
|
|
||||||
|
runtime = self._require_runtime()
|
||||||
|
chat_identity = self._extract_chat_identity(payload)
|
||||||
|
if chat_identity is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
await runtime.history_recovery_store.record_checkpoint(
|
||||||
|
account_id=self_id,
|
||||||
|
scope=scope,
|
||||||
|
chat_type=chat_identity.chat_type,
|
||||||
|
chat_id=chat_identity.chat_id,
|
||||||
|
message_id=external_message_id,
|
||||||
|
message_time=self._extract_message_timestamp(payload),
|
||||||
|
message_seq=self._extract_message_seq(payload),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_chat_identity(payload: Mapping[str, Any]) -> _NapCatChatIdentity | None:
|
||||||
|
"""从 NapCat 载荷中提取会话身份。"""
|
||||||
|
|
||||||
|
group_id = str(payload.get("group_id") or "").strip()
|
||||||
|
user_id = str(payload.get("user_id") or "").strip()
|
||||||
|
|
||||||
|
if group_id:
|
||||||
|
return _NapCatChatIdentity(chat_type="group", chat_id=group_id)
|
||||||
|
if user_id:
|
||||||
|
return _NapCatChatIdentity(chat_type="private", chat_id=user_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_message_seq(payload: Mapping[str, Any]) -> int | None:
|
||||||
|
"""从 NapCat 载荷中提取历史接口可复用的消息序号。"""
|
||||||
|
|
||||||
|
for field_name in ("message_seq", "messageSeq", "msg_seq"):
|
||||||
|
raw_value = payload.get(field_name)
|
||||||
|
if raw_value is None or str(raw_value).strip() == "":
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
return int(raw_value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_message_timestamp(payload: Mapping[str, Any]) -> float:
|
||||||
|
"""从 NapCat 载荷中提取消息时间戳。"""
|
||||||
|
|
||||||
|
raw_timestamp = payload.get("time")
|
||||||
|
if isinstance(raw_timestamp, (int, float)):
|
||||||
|
return float(raw_timestamp)
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _is_message_after_checkpoint(
|
||||||
|
cls,
|
||||||
|
payload: Mapping[str, Any],
|
||||||
|
checkpoint: NapCatChatCheckpoint,
|
||||||
|
) -> bool:
|
||||||
|
"""判断历史消息是否位于 checkpoint 之后。"""
|
||||||
|
|
||||||
|
payload_message_id = str(payload.get("message_id") or "").strip()
|
||||||
|
if payload_message_id == checkpoint.last_message_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
payload_message_seq = cls._extract_message_seq(payload)
|
||||||
|
if payload_message_seq is not None and checkpoint.last_message_seq is not None:
|
||||||
|
return payload_message_seq > checkpoint.last_message_seq
|
||||||
|
|
||||||
|
payload_timestamp = cls._extract_message_timestamp(payload)
|
||||||
|
if payload_timestamp != checkpoint.last_message_time:
|
||||||
|
return payload_timestamp > checkpoint.last_message_time
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_route_metadata(self_id: str, connection_id: str) -> Dict[str, Any]:
|
def _build_route_metadata(self_id: str, connection_id: str) -> Dict[str, Any]:
|
||||||
"""构造注入 Host 时使用的路由元数据。
|
"""构造注入 Host 时使用的路由元数据。
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from .action_service import NapCatActionService
|
from .action_service import NapCatActionService
|
||||||
from .ban_tracker import NapCatBanTracker
|
from .ban_tracker import NapCatBanTracker
|
||||||
from .ban_state_store import NapCatBanRecord, NapCatBanStateStore
|
from .ban_state_store import NapCatBanRecord, NapCatBanStateStore
|
||||||
|
from .history_recovery_store import NapCatChatCheckpoint, NapCatHistoryRecoveryStore
|
||||||
from .official_bot_guard import NapCatOfficialBotGuard
|
from .official_bot_guard import NapCatOfficialBotGuard
|
||||||
from .query_service import NapCatQueryService
|
from .query_service import NapCatQueryService
|
||||||
|
|
||||||
@@ -11,6 +12,8 @@ __all__ = [
|
|||||||
"NapCatBanRecord",
|
"NapCatBanRecord",
|
||||||
"NapCatBanStateStore",
|
"NapCatBanStateStore",
|
||||||
"NapCatBanTracker",
|
"NapCatBanTracker",
|
||||||
|
"NapCatChatCheckpoint",
|
||||||
|
"NapCatHistoryRecoveryStore",
|
||||||
"NapCatOfficialBotGuard",
|
"NapCatOfficialBotGuard",
|
||||||
"NapCatQueryService",
|
"NapCatQueryService",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,423 @@
|
|||||||
|
"""NapCat 历史补拉状态持久化仓库。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, List, Optional, TypeVar
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
|
||||||
|
from ..constants import DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC
|
||||||
|
|
||||||
|
_PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||||
|
_DEFAULT_STORAGE_PATH = _PROJECT_ROOT / "data" / "napcat_adapter" / "history_recovery.sqlite3"
|
||||||
|
|
||||||
|
_SCHEMA_STATEMENTS = (
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS napcat_chat_checkpoint (
|
||||||
|
account_id TEXT NOT NULL,
|
||||||
|
scope TEXT NOT NULL,
|
||||||
|
chat_type TEXT NOT NULL,
|
||||||
|
chat_id TEXT NOT NULL,
|
||||||
|
last_message_id TEXT NOT NULL,
|
||||||
|
last_message_time REAL NOT NULL,
|
||||||
|
last_message_seq INTEGER,
|
||||||
|
updated_at REAL NOT NULL,
|
||||||
|
PRIMARY KEY (account_id, scope, chat_type, chat_id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS ix_napcat_chat_checkpoint_updated_at
|
||||||
|
ON napcat_chat_checkpoint (updated_at DESC)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS napcat_recovery_seen (
|
||||||
|
account_id TEXT NOT NULL,
|
||||||
|
scope TEXT NOT NULL,
|
||||||
|
chat_type TEXT NOT NULL,
|
||||||
|
chat_id TEXT NOT NULL,
|
||||||
|
external_message_id TEXT NOT NULL,
|
||||||
|
seen_at REAL NOT NULL,
|
||||||
|
PRIMARY KEY (account_id, scope, chat_type, chat_id, external_message_id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS ix_napcat_recovery_seen_seen_at
|
||||||
|
ON napcat_recovery_seen (seen_at DESC)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class NapCatChatCheckpoint:
|
||||||
|
"""描述一个会话的最近入站锚点。"""
|
||||||
|
|
||||||
|
account_id: str
|
||||||
|
scope: str
|
||||||
|
chat_type: str
|
||||||
|
chat_id: str
|
||||||
|
last_message_id: str
|
||||||
|
last_message_time: float
|
||||||
|
last_message_seq: int | None
|
||||||
|
updated_at: float
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_row(cls, row: sqlite3.Row) -> "NapCatChatCheckpoint":
|
||||||
|
"""从 SQLite 行对象恢复 checkpoint。"""
|
||||||
|
|
||||||
|
last_message_seq = row["last_message_seq"]
|
||||||
|
normalized_seq = int(last_message_seq) if isinstance(last_message_seq, int) else None
|
||||||
|
return cls(
|
||||||
|
account_id=str(row["account_id"] or "").strip(),
|
||||||
|
scope=str(row["scope"] or "").strip(),
|
||||||
|
chat_type=str(row["chat_type"] or "").strip(),
|
||||||
|
chat_id=str(row["chat_id"] or "").strip(),
|
||||||
|
last_message_id=str(row["last_message_id"] or "").strip(),
|
||||||
|
last_message_time=float(row["last_message_time"] or 0.0),
|
||||||
|
last_message_seq=normalized_seq,
|
||||||
|
updated_at=float(row["updated_at"] or 0.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NapCatHistoryRecoveryStore:
|
||||||
|
"""负责持久化历史补拉所需的会话状态与去重状态。"""
|
||||||
|
|
||||||
|
def __init__(self, logger: Any, storage_path: Path = _DEFAULT_STORAGE_PATH) -> None:
|
||||||
|
"""初始化历史补拉状态仓库。"""
|
||||||
|
|
||||||
|
self._logger = logger
|
||||||
|
self._storage_path = storage_path
|
||||||
|
self._store_lock = asyncio.Lock()
|
||||||
|
self._schema_ready = False
|
||||||
|
|
||||||
|
async def load(self) -> None:
|
||||||
|
"""初始化 SQLite 文件并清理过期去重记录。"""
|
||||||
|
|
||||||
|
await self._execute_locked(self._ensure_schema)
|
||||||
|
pruned_count = await self.prune_recovery_seen(DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC)
|
||||||
|
if pruned_count > 0:
|
||||||
|
self._logger.debug(f"NapCat 历史补拉去重表已清理 {pruned_count} 条过期记录")
|
||||||
|
|
||||||
|
async def list_checkpoints(self, account_id: str, scope: str = "", limit: int = 50) -> List[NapCatChatCheckpoint]:
|
||||||
|
"""列出指定账号与作用域下的最近会话 checkpoint。"""
|
||||||
|
|
||||||
|
normalized_account_id = str(account_id or "").strip()
|
||||||
|
if not normalized_account_id:
|
||||||
|
return []
|
||||||
|
|
||||||
|
normalized_scope = self._normalize_scope(scope)
|
||||||
|
normalized_limit = max(1, int(limit))
|
||||||
|
|
||||||
|
def _operation(conn: sqlite3.Connection) -> List[NapCatChatCheckpoint]:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
account_id,
|
||||||
|
scope,
|
||||||
|
chat_type,
|
||||||
|
chat_id,
|
||||||
|
last_message_id,
|
||||||
|
last_message_time,
|
||||||
|
last_message_seq,
|
||||||
|
updated_at
|
||||||
|
FROM napcat_chat_checkpoint
|
||||||
|
WHERE account_id = ? AND scope = ?
|
||||||
|
ORDER BY updated_at DESC
|
||||||
|
LIMIT ?
|
||||||
|
""",
|
||||||
|
(normalized_account_id, normalized_scope, normalized_limit),
|
||||||
|
)
|
||||||
|
return [NapCatChatCheckpoint.from_row(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
return await self._execute_locked(_operation)
|
||||||
|
|
||||||
|
async def record_checkpoint(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
account_id: str,
|
||||||
|
scope: str = "",
|
||||||
|
chat_type: str,
|
||||||
|
chat_id: str,
|
||||||
|
message_id: str,
|
||||||
|
message_time: float,
|
||||||
|
message_seq: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""记录一条已被 Host 接受的最新入站消息锚点。"""
|
||||||
|
|
||||||
|
normalized_account_id = str(account_id or "").strip()
|
||||||
|
normalized_scope = self._normalize_scope(scope)
|
||||||
|
normalized_chat_type = str(chat_type or "").strip()
|
||||||
|
normalized_chat_id = str(chat_id or "").strip()
|
||||||
|
normalized_message_id = str(message_id or "").strip()
|
||||||
|
|
||||||
|
if not (
|
||||||
|
normalized_account_id
|
||||||
|
and normalized_chat_type
|
||||||
|
and normalized_chat_id
|
||||||
|
and normalized_message_id
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
normalized_message_time = float(message_time or 0.0)
|
||||||
|
normalized_message_seq = self._normalize_message_seq(message_seq)
|
||||||
|
updated_at = time.time()
|
||||||
|
|
||||||
|
def _operation(conn: sqlite3.Connection) -> None:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT last_message_id, last_message_time, last_message_seq
|
||||||
|
FROM napcat_chat_checkpoint
|
||||||
|
WHERE account_id = ? AND scope = ? AND chat_type = ? AND chat_id = ?
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
normalized_account_id,
|
||||||
|
normalized_scope,
|
||||||
|
normalized_chat_type,
|
||||||
|
normalized_chat_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
existing_row = cursor.fetchone()
|
||||||
|
if existing_row is not None and not self._should_advance_checkpoint(
|
||||||
|
existing_row=existing_row,
|
||||||
|
message_id=normalized_message_id,
|
||||||
|
message_time=normalized_message_time,
|
||||||
|
message_seq=normalized_message_seq,
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO napcat_chat_checkpoint (
|
||||||
|
account_id,
|
||||||
|
scope,
|
||||||
|
chat_type,
|
||||||
|
chat_id,
|
||||||
|
last_message_id,
|
||||||
|
last_message_time,
|
||||||
|
last_message_seq,
|
||||||
|
updated_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(account_id, scope, chat_type, chat_id) DO UPDATE SET
|
||||||
|
last_message_id = excluded.last_message_id,
|
||||||
|
last_message_time = excluded.last_message_time,
|
||||||
|
last_message_seq = excluded.last_message_seq,
|
||||||
|
updated_at = excluded.updated_at
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
normalized_account_id,
|
||||||
|
normalized_scope,
|
||||||
|
normalized_chat_type,
|
||||||
|
normalized_chat_id,
|
||||||
|
normalized_message_id,
|
||||||
|
normalized_message_time,
|
||||||
|
normalized_message_seq,
|
||||||
|
updated_at,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._execute_locked(_operation)
|
||||||
|
|
||||||
|
async def has_recovered_message_seen(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
account_id: str,
|
||||||
|
scope: str = "",
|
||||||
|
chat_type: str,
|
||||||
|
chat_id: str,
|
||||||
|
external_message_id: str,
|
||||||
|
) -> bool:
|
||||||
|
"""判断某条历史补拉消息是否已经被当前仓库记录过。"""
|
||||||
|
|
||||||
|
normalized_account_id = str(account_id or "").strip()
|
||||||
|
normalized_scope = self._normalize_scope(scope)
|
||||||
|
normalized_chat_type = str(chat_type or "").strip()
|
||||||
|
normalized_chat_id = str(chat_id or "").strip()
|
||||||
|
normalized_external_message_id = str(external_message_id or "").strip()
|
||||||
|
|
||||||
|
if not (
|
||||||
|
normalized_account_id
|
||||||
|
and normalized_chat_type
|
||||||
|
and normalized_chat_id
|
||||||
|
and normalized_external_message_id
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _operation(conn: sqlite3.Connection) -> bool:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT 1
|
||||||
|
FROM napcat_recovery_seen
|
||||||
|
WHERE account_id = ?
|
||||||
|
AND scope = ?
|
||||||
|
AND chat_type = ?
|
||||||
|
AND chat_id = ?
|
||||||
|
AND external_message_id = ?
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
normalized_account_id,
|
||||||
|
normalized_scope,
|
||||||
|
normalized_chat_type,
|
||||||
|
normalized_chat_id,
|
||||||
|
normalized_external_message_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return cursor.fetchone() is not None
|
||||||
|
|
||||||
|
return await self._execute_locked(_operation)
|
||||||
|
|
||||||
|
async def mark_recovered_message_seen(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
account_id: str,
|
||||||
|
scope: str = "",
|
||||||
|
chat_type: str,
|
||||||
|
chat_id: str,
|
||||||
|
external_message_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""将一条历史补拉消息标记为已尝试处理。"""
|
||||||
|
|
||||||
|
normalized_account_id = str(account_id or "").strip()
|
||||||
|
normalized_scope = self._normalize_scope(scope)
|
||||||
|
normalized_chat_type = str(chat_type or "").strip()
|
||||||
|
normalized_chat_id = str(chat_id or "").strip()
|
||||||
|
normalized_external_message_id = str(external_message_id or "").strip()
|
||||||
|
|
||||||
|
if not (
|
||||||
|
normalized_account_id
|
||||||
|
and normalized_chat_type
|
||||||
|
and normalized_chat_id
|
||||||
|
and normalized_external_message_id
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
def _operation(conn: sqlite3.Connection) -> None:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO napcat_recovery_seen (
|
||||||
|
account_id,
|
||||||
|
scope,
|
||||||
|
chat_type,
|
||||||
|
chat_id,
|
||||||
|
external_message_id,
|
||||||
|
seen_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
normalized_account_id,
|
||||||
|
normalized_scope,
|
||||||
|
normalized_chat_type,
|
||||||
|
normalized_chat_id,
|
||||||
|
normalized_external_message_id,
|
||||||
|
time.time(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._execute_locked(_operation)
|
||||||
|
|
||||||
|
async def prune_recovery_seen(self, ttl_seconds: float) -> int:
|
||||||
|
"""删除超过保留期的历史补拉去重记录。"""
|
||||||
|
|
||||||
|
normalized_ttl_seconds = max(0.0, float(ttl_seconds or 0.0))
|
||||||
|
if normalized_ttl_seconds <= 0.0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
cutoff_timestamp = time.time() - normalized_ttl_seconds
|
||||||
|
|
||||||
|
def _operation(conn: sqlite3.Connection) -> int:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"DELETE FROM napcat_recovery_seen WHERE seen_at < ?",
|
||||||
|
(cutoff_timestamp,),
|
||||||
|
)
|
||||||
|
return int(cursor.rowcount or 0)
|
||||||
|
|
||||||
|
return await self._execute_locked(_operation)
|
||||||
|
|
||||||
|
async def _execute_locked(self, operation: Callable[[sqlite3.Connection], T]) -> T:
|
||||||
|
"""在锁保护下打开 SQLite 并执行一次原子操作。"""
|
||||||
|
|
||||||
|
async with self._store_lock:
|
||||||
|
conn = self._open_connection()
|
||||||
|
try:
|
||||||
|
self._ensure_schema(conn)
|
||||||
|
result = operation(conn)
|
||||||
|
conn.commit()
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def _open_connection(self) -> sqlite3.Connection:
|
||||||
|
"""打开一个带 Row 工厂的 SQLite 连接。"""
|
||||||
|
|
||||||
|
self._storage_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
conn = sqlite3.connect(str(self._storage_path))
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _ensure_schema(self, conn: sqlite3.Connection) -> None:
|
||||||
|
"""确保 SQLite 表结构已经准备完成。"""
|
||||||
|
|
||||||
|
if self._schema_ready:
|
||||||
|
return
|
||||||
|
|
||||||
|
for statement in _SCHEMA_STATEMENTS:
|
||||||
|
conn.execute(statement)
|
||||||
|
self._schema_ready = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_scope(scope: str | None) -> str:
|
||||||
|
"""将空作用域统一折叠为空字符串。"""
|
||||||
|
|
||||||
|
return str(scope or "").strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_message_seq(message_seq: object) -> int | None:
|
||||||
|
"""将消息序号规范化为可选整数。"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
if message_seq is None or str(message_seq).strip() == "":
|
||||||
|
return None
|
||||||
|
return int(message_seq)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _should_advance_checkpoint(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
existing_row: sqlite3.Row,
|
||||||
|
message_id: str,
|
||||||
|
message_time: float,
|
||||||
|
message_seq: int | None,
|
||||||
|
) -> bool:
|
||||||
|
"""判断新的锚点是否应覆盖旧锚点。"""
|
||||||
|
|
||||||
|
existing_message_id = str(existing_row["last_message_id"] or "").strip()
|
||||||
|
existing_message_time = float(existing_row["last_message_time"] or 0.0)
|
||||||
|
existing_message_seq = cls._normalize_message_seq(existing_row["last_message_seq"])
|
||||||
|
|
||||||
|
if message_seq is not None and existing_message_seq is not None:
|
||||||
|
if message_seq != existing_message_seq:
|
||||||
|
return message_seq > existing_message_seq
|
||||||
|
if message_id == existing_message_id:
|
||||||
|
return False
|
||||||
|
return message_time >= existing_message_time
|
||||||
|
|
||||||
|
if message_time != existing_message_time:
|
||||||
|
return message_time > existing_message_time
|
||||||
|
|
||||||
|
if message_id == existing_message_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if message_seq is not None and existing_message_seq is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return True
|
||||||
@@ -180,6 +180,46 @@ class NapCatQueryService:
|
|||||||
response_data = await self._safe_call_action_data("get_msg", {"message_id": message_id})
|
response_data = await self._safe_call_action_data("get_msg", {"message_id": message_id})
|
||||||
return response_data if isinstance(response_data, dict) else None
|
return response_data if isinstance(response_data, dict) else None
|
||||||
|
|
||||||
|
async def get_friend_message_history(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
*,
|
||||||
|
message_seq: int | None = None,
|
||||||
|
count: int = 20,
|
||||||
|
reverse_order: bool = False,
|
||||||
|
) -> Optional[NapCatPayloadList]:
|
||||||
|
"""获取私聊历史消息列表。"""
|
||||||
|
|
||||||
|
params: NapCatActionResponse = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"count": max(1, int(count)),
|
||||||
|
"reverse_order": bool(reverse_order),
|
||||||
|
}
|
||||||
|
if message_seq is not None:
|
||||||
|
params["message_seq"] = int(message_seq)
|
||||||
|
response_data = await self._safe_call_action_data("get_friend_msg_history", params)
|
||||||
|
return self._normalize_payload_list(response_data, action_name="get_friend_msg_history")
|
||||||
|
|
||||||
|
async def get_group_message_history(
|
||||||
|
self,
|
||||||
|
group_id: str,
|
||||||
|
*,
|
||||||
|
message_seq: int | None = None,
|
||||||
|
count: int = 20,
|
||||||
|
reverse_order: bool = False,
|
||||||
|
) -> Optional[NapCatPayloadList]:
|
||||||
|
"""获取群聊历史消息列表。"""
|
||||||
|
|
||||||
|
params: NapCatActionResponse = {
|
||||||
|
"group_id": group_id,
|
||||||
|
"count": max(1, int(count)),
|
||||||
|
"reverse_order": bool(reverse_order),
|
||||||
|
}
|
||||||
|
if message_seq is not None:
|
||||||
|
params["message_seq"] = int(message_seq)
|
||||||
|
response_data = await self._safe_call_action_data("get_group_msg_history", params)
|
||||||
|
return self._normalize_payload_list(response_data, action_name="get_group_msg_history")
|
||||||
|
|
||||||
async def get_forward_message(
|
async def get_forward_message(
|
||||||
self,
|
self,
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
|
|||||||
@@ -12,8 +12,10 @@ import pytest
|
|||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||||
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
|
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
|
||||||
|
PLUGIN_TEMPLATE_ROOT = PROJECT_ROOT / "plugin-templates"
|
||||||
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
|
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
|
||||||
NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter"
|
NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter"
|
||||||
|
NAPCAT_TEMPLATE_DIR = PLUGIN_TEMPLATE_ROOT / "MaiBot-Napcat-Adapter"
|
||||||
NAPCAT_TEST_MODULE = "_test_napcat_adapter"
|
NAPCAT_TEST_MODULE = "_test_napcat_adapter"
|
||||||
|
|
||||||
for import_path in (str(SDK_ROOT),):
|
for import_path in (str(SDK_ROOT),):
|
||||||
@@ -204,12 +206,14 @@ def _load_napcat_sdk_modules() -> Tuple[Any, Any, Any, Any]:
|
|||||||
依次返回常量模块、配置模块、插件模块和运行时状态模块。
|
依次返回常量模块、配置模块、插件模块和运行时状态模块。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
plugin_dir = NAPCAT_PLUGIN_DIR if NAPCAT_PLUGIN_DIR.is_dir() else NAPCAT_TEMPLATE_DIR
|
||||||
|
|
||||||
if NAPCAT_TEST_MODULE not in sys.modules:
|
if NAPCAT_TEST_MODULE not in sys.modules:
|
||||||
plugin_path = NAPCAT_PLUGIN_DIR / "plugin.py"
|
plugin_path = plugin_dir / "plugin.py"
|
||||||
spec = util.spec_from_file_location(
|
spec = util.spec_from_file_location(
|
||||||
NAPCAT_TEST_MODULE,
|
NAPCAT_TEST_MODULE,
|
||||||
plugin_path,
|
plugin_path,
|
||||||
submodule_search_locations=[str(NAPCAT_PLUGIN_DIR)],
|
submodule_search_locations=[str(plugin_dir)],
|
||||||
)
|
)
|
||||||
if spec is None or spec.loader is None:
|
if spec is None or spec.loader is None:
|
||||||
raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}")
|
raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}")
|
||||||
|
|||||||
376
pytests/test_napcat_history_recovery.py
Normal file
376
pytests/test_napcat_history_recovery.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
"""NapCat 历史补拉与恢复状态测试。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from importlib import import_module, util
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
|
||||||
|
PLUGIN_TEMPLATE_ROOT = PROJECT_ROOT / "plugin-templates"
|
||||||
|
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
|
||||||
|
NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter"
|
||||||
|
NAPCAT_TEMPLATE_DIR = PLUGIN_TEMPLATE_ROOT / "MaiBot-Napcat-Adapter"
|
||||||
|
NAPCAT_TEST_MODULE = "_test_napcat_adapter_history_recovery"
|
||||||
|
|
||||||
|
for import_path in (str(SDK_ROOT),):
|
||||||
|
if import_path not in sys.path:
|
||||||
|
sys.path.insert(0, import_path)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeGatewayCapability:
|
||||||
|
"""用于测试入站注入的网关替身。"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""初始化测试替身。"""
|
||||||
|
|
||||||
|
self.calls: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def route_message(
|
||||||
|
self,
|
||||||
|
gateway_name: str,
|
||||||
|
message: Dict[str, Any],
|
||||||
|
*,
|
||||||
|
route_metadata: Dict[str, Any] | None = None,
|
||||||
|
external_message_id: str = "",
|
||||||
|
dedupe_key: str = "",
|
||||||
|
) -> bool:
|
||||||
|
"""记录入站注入请求并始终模拟成功。"""
|
||||||
|
|
||||||
|
self.calls.append(
|
||||||
|
{
|
||||||
|
"gateway_name": gateway_name,
|
||||||
|
"message": dict(message),
|
||||||
|
"route_metadata": dict(route_metadata or {}),
|
||||||
|
"external_message_id": external_message_id,
|
||||||
|
"dedupe_key": dedupe_key,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_napcat_plugin_dir() -> Path:
|
||||||
|
"""返回当前测试可用的 NapCat 插件目录。"""
|
||||||
|
|
||||||
|
if NAPCAT_PLUGIN_DIR.is_dir():
|
||||||
|
return NAPCAT_PLUGIN_DIR
|
||||||
|
return NAPCAT_TEMPLATE_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def _load_napcat_module(module_suffix: str) -> Any:
|
||||||
|
"""动态加载 NapCat 测试模块。"""
|
||||||
|
|
||||||
|
plugin_dir = _resolve_napcat_plugin_dir()
|
||||||
|
if NAPCAT_TEST_MODULE not in sys.modules:
|
||||||
|
plugin_path = plugin_dir / "plugin.py"
|
||||||
|
spec = util.spec_from_file_location(
|
||||||
|
NAPCAT_TEST_MODULE,
|
||||||
|
plugin_path,
|
||||||
|
submodule_search_locations=[str(plugin_dir)],
|
||||||
|
)
|
||||||
|
if spec is None or spec.loader is None:
|
||||||
|
raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}")
|
||||||
|
|
||||||
|
module = util.module_from_spec(spec)
|
||||||
|
sys.modules[NAPCAT_TEST_MODULE] = module
|
||||||
|
try:
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
except Exception:
|
||||||
|
sys.modules.pop(NAPCAT_TEST_MODULE, None)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return import_module(f"{NAPCAT_TEST_MODULE}.{module_suffix}")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_history_recovery_store_cls() -> Any:
|
||||||
|
"""动态加载历史恢复状态仓库类。"""
|
||||||
|
|
||||||
|
return _load_napcat_module("services.history_recovery_store").NapCatHistoryRecoveryStore
|
||||||
|
|
||||||
|
|
||||||
|
def _load_query_service_cls() -> Any:
|
||||||
|
"""动态加载查询服务类。"""
|
||||||
|
|
||||||
|
return _load_napcat_module("services.query_service").NapCatQueryService
|
||||||
|
|
||||||
|
|
||||||
|
def _load_router_cls() -> Any:
|
||||||
|
"""动态加载事件路由器类。"""
|
||||||
|
|
||||||
|
return _load_napcat_module("runtime.router").NapCatEventRouter
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeActionService:
|
||||||
|
"""用于查询服务的动作服务替身。"""
|
||||||
|
|
||||||
|
def __init__(self, response_data: Any) -> None:
|
||||||
|
"""初始化动作服务替身。"""
|
||||||
|
|
||||||
|
self._response_data = response_data
|
||||||
|
self.action_data_calls: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def safe_call_action_data(self, action_name: str, params: Dict[str, Any]) -> Any:
|
||||||
|
"""记录安全查询动作。"""
|
||||||
|
|
||||||
|
self.action_data_calls.append({"action_name": action_name, "params": dict(params)})
|
||||||
|
return self._response_data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_history_recovery_store_persists_checkpoint_and_seen_state(tmp_path: Path) -> None:
|
||||||
|
"""历史恢复状态仓库应持久化 checkpoint 与已补拉标记。"""
|
||||||
|
|
||||||
|
store_cls = _load_history_recovery_store_cls()
|
||||||
|
store = store_cls(
|
||||||
|
logger=logging.getLogger("test.napcat.history_store"),
|
||||||
|
storage_path=tmp_path / "history.sqlite3",
|
||||||
|
)
|
||||||
|
|
||||||
|
await store.load()
|
||||||
|
await store.record_checkpoint(
|
||||||
|
account_id="10001",
|
||||||
|
scope="primary",
|
||||||
|
chat_type="group",
|
||||||
|
chat_id="20001",
|
||||||
|
message_id="msg-2",
|
||||||
|
message_time=200.0,
|
||||||
|
message_seq=2,
|
||||||
|
)
|
||||||
|
await store.record_checkpoint(
|
||||||
|
account_id="10001",
|
||||||
|
scope="primary",
|
||||||
|
chat_type="group",
|
||||||
|
chat_id="20001",
|
||||||
|
message_id="msg-1",
|
||||||
|
message_time=100.0,
|
||||||
|
message_seq=1,
|
||||||
|
)
|
||||||
|
await store.mark_recovered_message_seen(
|
||||||
|
account_id="10001",
|
||||||
|
scope="primary",
|
||||||
|
chat_type="group",
|
||||||
|
chat_id="20001",
|
||||||
|
external_message_id="history-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoints = await store.list_checkpoints("10001", scope="primary")
|
||||||
|
|
||||||
|
assert len(checkpoints) == 1
|
||||||
|
assert checkpoints[0].last_message_id == "msg-2"
|
||||||
|
assert checkpoints[0].last_message_seq == 2
|
||||||
|
assert (
|
||||||
|
await store.has_recovered_message_seen(
|
||||||
|
account_id="10001",
|
||||||
|
scope="primary",
|
||||||
|
chat_type="group",
|
||||||
|
chat_id="20001",
|
||||||
|
external_message_id="history-1",
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_service_wraps_group_and_friend_history_actions() -> None:
|
||||||
|
"""查询服务应按官方动作名封装历史消息接口。"""
|
||||||
|
|
||||||
|
query_service_cls = _load_query_service_cls()
|
||||||
|
action_service = _FakeActionService([{"message_id": "msg-1"}])
|
||||||
|
query_service = query_service_cls(
|
||||||
|
action_service=action_service,
|
||||||
|
logger=logging.getLogger("test.napcat.history_query"),
|
||||||
|
)
|
||||||
|
|
||||||
|
group_payload = await query_service.get_group_message_history("20001", message_seq=123, count=10)
|
||||||
|
friend_payload = await query_service.get_friend_message_history("30001", count=5, reverse_order=True)
|
||||||
|
|
||||||
|
assert group_payload == [{"message_id": "msg-1"}]
|
||||||
|
assert friend_payload == [{"message_id": "msg-1"}]
|
||||||
|
assert action_service.action_data_calls == [
|
||||||
|
{
|
||||||
|
"action_name": "get_group_msg_history",
|
||||||
|
"params": {"group_id": "20001", "count": 10, "reverse_order": False, "message_seq": 123},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action_name": "get_friend_msg_history",
|
||||||
|
"params": {"user_id": "30001", "count": 5, "reverse_order": True},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_recover_recent_history_reinjects_messages_in_order(tmp_path: Path) -> None:
|
||||||
|
"""重连补拉应按时间顺序将历史消息重新注入原入站路径。"""
|
||||||
|
|
||||||
|
history_store_cls = _load_history_recovery_store_cls()
|
||||||
|
router_cls = _load_router_cls()
|
||||||
|
gateway_capability = _FakeGatewayCapability()
|
||||||
|
router = router_cls(
|
||||||
|
gateway_capability=gateway_capability,
|
||||||
|
logger=logging.getLogger("test.napcat.history_router"),
|
||||||
|
gateway_name="napcat_gateway",
|
||||||
|
load_settings=lambda: SimpleNamespace(
|
||||||
|
napcat_server=SimpleNamespace(connection_id="primary", heartbeat_interval=30.0),
|
||||||
|
filters=SimpleNamespace(ignore_self_message=True),
|
||||||
|
chat=SimpleNamespace(ban_qq_bot=False),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
history_store = history_store_cls(
|
||||||
|
logger=logging.getLogger("test.napcat.history_router.store"),
|
||||||
|
storage_path=tmp_path / "router.sqlite3",
|
||||||
|
)
|
||||||
|
await history_store.load()
|
||||||
|
await history_store.record_checkpoint(
|
||||||
|
account_id="10001",
|
||||||
|
scope="primary",
|
||||||
|
chat_type="group",
|
||||||
|
chat_id="20001",
|
||||||
|
message_id="msg-1",
|
||||||
|
message_time=100.0,
|
||||||
|
message_seq=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
history_calls: List[Dict[str, Any]] = []
|
||||||
|
history_payloads = [
|
||||||
|
{
|
||||||
|
"post_type": "message",
|
||||||
|
"message_type": "group",
|
||||||
|
"self_id": "10001",
|
||||||
|
"group_id": "20001",
|
||||||
|
"user_id": "30002",
|
||||||
|
"message_id": "msg-3",
|
||||||
|
"message_seq": 12,
|
||||||
|
"time": 102,
|
||||||
|
"message": [{"type": "text", "data": {"text": "第三条"}}],
|
||||||
|
"sender": {"user_id": "30002", "nickname": "用户二"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"post_type": "message",
|
||||||
|
"message_type": "group",
|
||||||
|
"self_id": "10001",
|
||||||
|
"group_id": "20001",
|
||||||
|
"user_id": "30001",
|
||||||
|
"message_id": "msg-2",
|
||||||
|
"message_seq": 11,
|
||||||
|
"time": 101,
|
||||||
|
"message": [{"type": "text", "data": {"text": "第二条"}}],
|
||||||
|
"sender": {"user_id": "30001", "nickname": "用户一"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
class _FakeQueryService:
|
||||||
|
async def get_group_message_history(
|
||||||
|
self,
|
||||||
|
group_id: str,
|
||||||
|
*,
|
||||||
|
message_seq: int | None = None,
|
||||||
|
count: int = 20,
|
||||||
|
reverse_order: bool = False,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
history_calls.append(
|
||||||
|
{
|
||||||
|
"group_id": group_id,
|
||||||
|
"message_seq": message_seq,
|
||||||
|
"count": count,
|
||||||
|
"reverse_order": reverse_order,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return list(history_payloads)
|
||||||
|
|
||||||
|
async def get_friend_message_history(self, user_id: str, **kwargs: Any) -> List[Dict[str, Any]]:
|
||||||
|
del user_id
|
||||||
|
del kwargs
|
||||||
|
return []
|
||||||
|
|
||||||
|
class _FakeInboundCodec:
|
||||||
|
@staticmethod
|
||||||
|
async def build_message_dict(
|
||||||
|
payload: Dict[str, Any],
|
||||||
|
self_id: str,
|
||||||
|
sender_user_id: str,
|
||||||
|
sender: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
del self_id
|
||||||
|
del sender_user_id
|
||||||
|
del sender
|
||||||
|
return {
|
||||||
|
"message_id": str(payload["message_id"]),
|
||||||
|
"platform": "qq",
|
||||||
|
"timestamp": str(float(payload["time"])),
|
||||||
|
"message_info": {
|
||||||
|
"user_info": {"user_id": str(payload["user_id"]), "user_nickname": "测试用户"},
|
||||||
|
"group_info": {"group_id": str(payload["group_id"]), "group_name": "测试群"},
|
||||||
|
"additional_config": {},
|
||||||
|
},
|
||||||
|
"raw_message": [{"type": "text", "data": str(payload["message"][0]["data"]["text"])}],
|
||||||
|
"processed_plain_text": str(payload["message"][0]["data"]["text"]),
|
||||||
|
"display_message": str(payload["message"][0]["data"]["text"]),
|
||||||
|
"is_mentioned": False,
|
||||||
|
"is_at": False,
|
||||||
|
"is_emoji": False,
|
||||||
|
"is_picture": False,
|
||||||
|
"is_command": False,
|
||||||
|
"is_notify": False,
|
||||||
|
"session_id": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
router.bind_runtime(
|
||||||
|
SimpleNamespace(
|
||||||
|
runtime_state=SimpleNamespace(report_connected=lambda *args, **kwargs: _noop_async(), report_disconnected=_noop_async),
|
||||||
|
chat_filter=SimpleNamespace(is_inbound_chat_allowed=lambda *args, **kwargs: True),
|
||||||
|
official_bot_guard=SimpleNamespace(
|
||||||
|
should_reject=lambda *args, **kwargs: _return_false_async(),
|
||||||
|
clear_cache=lambda: None,
|
||||||
|
),
|
||||||
|
inbound_codec=_FakeInboundCodec(),
|
||||||
|
history_recovery_store=history_store,
|
||||||
|
query_service=_FakeQueryService(),
|
||||||
|
heartbeat_monitor=SimpleNamespace(start=_noop_async, stop=_noop_async),
|
||||||
|
ban_tracker=SimpleNamespace(start=_noop_async, stop=_noop_async, record_notice=_noop_async),
|
||||||
|
notice_codec=SimpleNamespace(handle_meta_event=_noop_async, build_notice_message_dict=_return_none_async),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await router._recover_recent_history(self_id="10001", scope="primary")
|
||||||
|
|
||||||
|
assert history_calls == [
|
||||||
|
{
|
||||||
|
"group_id": "20001",
|
||||||
|
"message_seq": 10,
|
||||||
|
"count": 20,
|
||||||
|
"reverse_order": False,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
assert [call["external_message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"]
|
||||||
|
assert [call["message"]["message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"]
|
||||||
|
|
||||||
|
|
||||||
|
async def _noop_async(*args: Any, **kwargs: Any) -> None:
|
||||||
|
"""无操作异步函数。"""
|
||||||
|
|
||||||
|
del args
|
||||||
|
del kwargs
|
||||||
|
|
||||||
|
|
||||||
|
async def _return_false_async(*args: Any, **kwargs: Any) -> bool:
|
||||||
|
"""返回 ``False`` 的异步测试替身。"""
|
||||||
|
|
||||||
|
del args
|
||||||
|
del kwargs
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def _return_none_async(*args: Any, **kwargs: Any) -> None:
|
||||||
|
"""返回 ``None`` 的异步测试替身。"""
|
||||||
|
|
||||||
|
del args
|
||||||
|
del kwargs
|
||||||
|
return None
|
||||||
Reference in New Issue
Block a user