feat:新增napcat断线后重连重新拉取历史消息的机制
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
from .action_service import NapCatActionService
|
||||
from .ban_tracker import NapCatBanTracker
|
||||
from .ban_state_store import NapCatBanRecord, NapCatBanStateStore
|
||||
from .history_recovery_store import NapCatChatCheckpoint, NapCatHistoryRecoveryStore
|
||||
from .official_bot_guard import NapCatOfficialBotGuard
|
||||
from .query_service import NapCatQueryService
|
||||
|
||||
@@ -11,6 +12,8 @@ __all__ = [
|
||||
"NapCatBanRecord",
|
||||
"NapCatBanStateStore",
|
||||
"NapCatBanTracker",
|
||||
"NapCatChatCheckpoint",
|
||||
"NapCatHistoryRecoveryStore",
|
||||
"NapCatOfficialBotGuard",
|
||||
"NapCatQueryService",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,423 @@
|
||||
"""NapCat 历史补拉状态持久化仓库。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Optional, TypeVar
|
||||
|
||||
import asyncio
|
||||
import sqlite3
|
||||
import time
|
||||
|
||||
from ..constants import DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
_DEFAULT_STORAGE_PATH = _PROJECT_ROOT / "data" / "napcat_adapter" / "history_recovery.sqlite3"
|
||||
|
||||
_SCHEMA_STATEMENTS = (
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS napcat_chat_checkpoint (
|
||||
account_id TEXT NOT NULL,
|
||||
scope TEXT NOT NULL,
|
||||
chat_type TEXT NOT NULL,
|
||||
chat_id TEXT NOT NULL,
|
||||
last_message_id TEXT NOT NULL,
|
||||
last_message_time REAL NOT NULL,
|
||||
last_message_seq INTEGER,
|
||||
updated_at REAL NOT NULL,
|
||||
PRIMARY KEY (account_id, scope, chat_type, chat_id)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS ix_napcat_chat_checkpoint_updated_at
|
||||
ON napcat_chat_checkpoint (updated_at DESC)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS napcat_recovery_seen (
|
||||
account_id TEXT NOT NULL,
|
||||
scope TEXT NOT NULL,
|
||||
chat_type TEXT NOT NULL,
|
||||
chat_id TEXT NOT NULL,
|
||||
external_message_id TEXT NOT NULL,
|
||||
seen_at REAL NOT NULL,
|
||||
PRIMARY KEY (account_id, scope, chat_type, chat_id, external_message_id)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS ix_napcat_recovery_seen_seen_at
|
||||
ON napcat_recovery_seen (seen_at DESC)
|
||||
""",
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NapCatChatCheckpoint:
|
||||
"""描述一个会话的最近入站锚点。"""
|
||||
|
||||
account_id: str
|
||||
scope: str
|
||||
chat_type: str
|
||||
chat_id: str
|
||||
last_message_id: str
|
||||
last_message_time: float
|
||||
last_message_seq: int | None
|
||||
updated_at: float
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: sqlite3.Row) -> "NapCatChatCheckpoint":
|
||||
"""从 SQLite 行对象恢复 checkpoint。"""
|
||||
|
||||
last_message_seq = row["last_message_seq"]
|
||||
normalized_seq = int(last_message_seq) if isinstance(last_message_seq, int) else None
|
||||
return cls(
|
||||
account_id=str(row["account_id"] or "").strip(),
|
||||
scope=str(row["scope"] or "").strip(),
|
||||
chat_type=str(row["chat_type"] or "").strip(),
|
||||
chat_id=str(row["chat_id"] or "").strip(),
|
||||
last_message_id=str(row["last_message_id"] or "").strip(),
|
||||
last_message_time=float(row["last_message_time"] or 0.0),
|
||||
last_message_seq=normalized_seq,
|
||||
updated_at=float(row["updated_at"] or 0.0),
|
||||
)
|
||||
|
||||
|
||||
class NapCatHistoryRecoveryStore:
|
||||
"""负责持久化历史补拉所需的会话状态与去重状态。"""
|
||||
|
||||
def __init__(self, logger: Any, storage_path: Path = _DEFAULT_STORAGE_PATH) -> None:
|
||||
"""初始化历史补拉状态仓库。"""
|
||||
|
||||
self._logger = logger
|
||||
self._storage_path = storage_path
|
||||
self._store_lock = asyncio.Lock()
|
||||
self._schema_ready = False
|
||||
|
||||
async def load(self) -> None:
|
||||
"""初始化 SQLite 文件并清理过期去重记录。"""
|
||||
|
||||
await self._execute_locked(self._ensure_schema)
|
||||
pruned_count = await self.prune_recovery_seen(DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC)
|
||||
if pruned_count > 0:
|
||||
self._logger.debug(f"NapCat 历史补拉去重表已清理 {pruned_count} 条过期记录")
|
||||
|
||||
async def list_checkpoints(self, account_id: str, scope: str = "", limit: int = 50) -> List[NapCatChatCheckpoint]:
|
||||
"""列出指定账号与作用域下的最近会话 checkpoint。"""
|
||||
|
||||
normalized_account_id = str(account_id or "").strip()
|
||||
if not normalized_account_id:
|
||||
return []
|
||||
|
||||
normalized_scope = self._normalize_scope(scope)
|
||||
normalized_limit = max(1, int(limit))
|
||||
|
||||
def _operation(conn: sqlite3.Connection) -> List[NapCatChatCheckpoint]:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
account_id,
|
||||
scope,
|
||||
chat_type,
|
||||
chat_id,
|
||||
last_message_id,
|
||||
last_message_time,
|
||||
last_message_seq,
|
||||
updated_at
|
||||
FROM napcat_chat_checkpoint
|
||||
WHERE account_id = ? AND scope = ?
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(normalized_account_id, normalized_scope, normalized_limit),
|
||||
)
|
||||
return [NapCatChatCheckpoint.from_row(row) for row in cursor.fetchall()]
|
||||
|
||||
return await self._execute_locked(_operation)
|
||||
|
||||
async def record_checkpoint(
|
||||
self,
|
||||
*,
|
||||
account_id: str,
|
||||
scope: str = "",
|
||||
chat_type: str,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
message_time: float,
|
||||
message_seq: int | None = None,
|
||||
) -> None:
|
||||
"""记录一条已被 Host 接受的最新入站消息锚点。"""
|
||||
|
||||
normalized_account_id = str(account_id or "").strip()
|
||||
normalized_scope = self._normalize_scope(scope)
|
||||
normalized_chat_type = str(chat_type or "").strip()
|
||||
normalized_chat_id = str(chat_id or "").strip()
|
||||
normalized_message_id = str(message_id or "").strip()
|
||||
|
||||
if not (
|
||||
normalized_account_id
|
||||
and normalized_chat_type
|
||||
and normalized_chat_id
|
||||
and normalized_message_id
|
||||
):
|
||||
return
|
||||
|
||||
normalized_message_time = float(message_time or 0.0)
|
||||
normalized_message_seq = self._normalize_message_seq(message_seq)
|
||||
updated_at = time.time()
|
||||
|
||||
def _operation(conn: sqlite3.Connection) -> None:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT last_message_id, last_message_time, last_message_seq
|
||||
FROM napcat_chat_checkpoint
|
||||
WHERE account_id = ? AND scope = ? AND chat_type = ? AND chat_id = ?
|
||||
""",
|
||||
(
|
||||
normalized_account_id,
|
||||
normalized_scope,
|
||||
normalized_chat_type,
|
||||
normalized_chat_id,
|
||||
),
|
||||
)
|
||||
existing_row = cursor.fetchone()
|
||||
if existing_row is not None and not self._should_advance_checkpoint(
|
||||
existing_row=existing_row,
|
||||
message_id=normalized_message_id,
|
||||
message_time=normalized_message_time,
|
||||
message_seq=normalized_message_seq,
|
||||
):
|
||||
return
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO napcat_chat_checkpoint (
|
||||
account_id,
|
||||
scope,
|
||||
chat_type,
|
||||
chat_id,
|
||||
last_message_id,
|
||||
last_message_time,
|
||||
last_message_seq,
|
||||
updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(account_id, scope, chat_type, chat_id) DO UPDATE SET
|
||||
last_message_id = excluded.last_message_id,
|
||||
last_message_time = excluded.last_message_time,
|
||||
last_message_seq = excluded.last_message_seq,
|
||||
updated_at = excluded.updated_at
|
||||
""",
|
||||
(
|
||||
normalized_account_id,
|
||||
normalized_scope,
|
||||
normalized_chat_type,
|
||||
normalized_chat_id,
|
||||
normalized_message_id,
|
||||
normalized_message_time,
|
||||
normalized_message_seq,
|
||||
updated_at,
|
||||
),
|
||||
)
|
||||
|
||||
await self._execute_locked(_operation)
|
||||
|
||||
async def has_recovered_message_seen(
|
||||
self,
|
||||
*,
|
||||
account_id: str,
|
||||
scope: str = "",
|
||||
chat_type: str,
|
||||
chat_id: str,
|
||||
external_message_id: str,
|
||||
) -> bool:
|
||||
"""判断某条历史补拉消息是否已经被当前仓库记录过。"""
|
||||
|
||||
normalized_account_id = str(account_id or "").strip()
|
||||
normalized_scope = self._normalize_scope(scope)
|
||||
normalized_chat_type = str(chat_type or "").strip()
|
||||
normalized_chat_id = str(chat_id or "").strip()
|
||||
normalized_external_message_id = str(external_message_id or "").strip()
|
||||
|
||||
if not (
|
||||
normalized_account_id
|
||||
and normalized_chat_type
|
||||
and normalized_chat_id
|
||||
and normalized_external_message_id
|
||||
):
|
||||
return False
|
||||
|
||||
def _operation(conn: sqlite3.Connection) -> bool:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT 1
|
||||
FROM napcat_recovery_seen
|
||||
WHERE account_id = ?
|
||||
AND scope = ?
|
||||
AND chat_type = ?
|
||||
AND chat_id = ?
|
||||
AND external_message_id = ?
|
||||
LIMIT 1
|
||||
""",
|
||||
(
|
||||
normalized_account_id,
|
||||
normalized_scope,
|
||||
normalized_chat_type,
|
||||
normalized_chat_id,
|
||||
normalized_external_message_id,
|
||||
),
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
return await self._execute_locked(_operation)
|
||||
|
||||
async def mark_recovered_message_seen(
|
||||
self,
|
||||
*,
|
||||
account_id: str,
|
||||
scope: str = "",
|
||||
chat_type: str,
|
||||
chat_id: str,
|
||||
external_message_id: str,
|
||||
) -> None:
|
||||
"""将一条历史补拉消息标记为已尝试处理。"""
|
||||
|
||||
normalized_account_id = str(account_id or "").strip()
|
||||
normalized_scope = self._normalize_scope(scope)
|
||||
normalized_chat_type = str(chat_type or "").strip()
|
||||
normalized_chat_id = str(chat_id or "").strip()
|
||||
normalized_external_message_id = str(external_message_id or "").strip()
|
||||
|
||||
if not (
|
||||
normalized_account_id
|
||||
and normalized_chat_type
|
||||
and normalized_chat_id
|
||||
and normalized_external_message_id
|
||||
):
|
||||
return
|
||||
|
||||
def _operation(conn: sqlite3.Connection) -> None:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO napcat_recovery_seen (
|
||||
account_id,
|
||||
scope,
|
||||
chat_type,
|
||||
chat_id,
|
||||
external_message_id,
|
||||
seen_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
normalized_account_id,
|
||||
normalized_scope,
|
||||
normalized_chat_type,
|
||||
normalized_chat_id,
|
||||
normalized_external_message_id,
|
||||
time.time(),
|
||||
),
|
||||
)
|
||||
|
||||
await self._execute_locked(_operation)
|
||||
|
||||
async def prune_recovery_seen(self, ttl_seconds: float) -> int:
|
||||
"""删除超过保留期的历史补拉去重记录。"""
|
||||
|
||||
normalized_ttl_seconds = max(0.0, float(ttl_seconds or 0.0))
|
||||
if normalized_ttl_seconds <= 0.0:
|
||||
return 0
|
||||
|
||||
cutoff_timestamp = time.time() - normalized_ttl_seconds
|
||||
|
||||
def _operation(conn: sqlite3.Connection) -> int:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM napcat_recovery_seen WHERE seen_at < ?",
|
||||
(cutoff_timestamp,),
|
||||
)
|
||||
return int(cursor.rowcount or 0)
|
||||
|
||||
return await self._execute_locked(_operation)
|
||||
|
||||
async def _execute_locked(self, operation: Callable[[sqlite3.Connection], T]) -> T:
|
||||
"""在锁保护下打开 SQLite 并执行一次原子操作。"""
|
||||
|
||||
async with self._store_lock:
|
||||
conn = self._open_connection()
|
||||
try:
|
||||
self._ensure_schema(conn)
|
||||
result = operation(conn)
|
||||
conn.commit()
|
||||
return result
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _open_connection(self) -> sqlite3.Connection:
|
||||
"""打开一个带 Row 工厂的 SQLite 连接。"""
|
||||
|
||||
self._storage_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(self._storage_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def _ensure_schema(self, conn: sqlite3.Connection) -> None:
|
||||
"""确保 SQLite 表结构已经准备完成。"""
|
||||
|
||||
if self._schema_ready:
|
||||
return
|
||||
|
||||
for statement in _SCHEMA_STATEMENTS:
|
||||
conn.execute(statement)
|
||||
self._schema_ready = True
|
||||
|
||||
@staticmethod
|
||||
def _normalize_scope(scope: str | None) -> str:
|
||||
"""将空作用域统一折叠为空字符串。"""
|
||||
|
||||
return str(scope or "").strip()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_message_seq(message_seq: object) -> int | None:
|
||||
"""将消息序号规范化为可选整数。"""
|
||||
|
||||
try:
|
||||
if message_seq is None or str(message_seq).strip() == "":
|
||||
return None
|
||||
return int(message_seq)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _should_advance_checkpoint(
|
||||
cls,
|
||||
*,
|
||||
existing_row: sqlite3.Row,
|
||||
message_id: str,
|
||||
message_time: float,
|
||||
message_seq: int | None,
|
||||
) -> bool:
|
||||
"""判断新的锚点是否应覆盖旧锚点。"""
|
||||
|
||||
existing_message_id = str(existing_row["last_message_id"] or "").strip()
|
||||
existing_message_time = float(existing_row["last_message_time"] or 0.0)
|
||||
existing_message_seq = cls._normalize_message_seq(existing_row["last_message_seq"])
|
||||
|
||||
if message_seq is not None and existing_message_seq is not None:
|
||||
if message_seq != existing_message_seq:
|
||||
return message_seq > existing_message_seq
|
||||
if message_id == existing_message_id:
|
||||
return False
|
||||
return message_time >= existing_message_time
|
||||
|
||||
if message_time != existing_message_time:
|
||||
return message_time > existing_message_time
|
||||
|
||||
if message_id == existing_message_id:
|
||||
return False
|
||||
|
||||
if message_seq is not None and existing_message_seq is None:
|
||||
return True
|
||||
|
||||
return True
|
||||
@@ -180,6 +180,46 @@ class NapCatQueryService:
|
||||
response_data = await self._safe_call_action_data("get_msg", {"message_id": message_id})
|
||||
return response_data if isinstance(response_data, dict) else None
|
||||
|
||||
async def get_friend_message_history(
|
||||
self,
|
||||
user_id: str,
|
||||
*,
|
||||
message_seq: int | None = None,
|
||||
count: int = 20,
|
||||
reverse_order: bool = False,
|
||||
) -> Optional[NapCatPayloadList]:
|
||||
"""获取私聊历史消息列表。"""
|
||||
|
||||
params: NapCatActionResponse = {
|
||||
"user_id": user_id,
|
||||
"count": max(1, int(count)),
|
||||
"reverse_order": bool(reverse_order),
|
||||
}
|
||||
if message_seq is not None:
|
||||
params["message_seq"] = int(message_seq)
|
||||
response_data = await self._safe_call_action_data("get_friend_msg_history", params)
|
||||
return self._normalize_payload_list(response_data, action_name="get_friend_msg_history")
|
||||
|
||||
async def get_group_message_history(
|
||||
self,
|
||||
group_id: str,
|
||||
*,
|
||||
message_seq: int | None = None,
|
||||
count: int = 20,
|
||||
reverse_order: bool = False,
|
||||
) -> Optional[NapCatPayloadList]:
|
||||
"""获取群聊历史消息列表。"""
|
||||
|
||||
params: NapCatActionResponse = {
|
||||
"group_id": group_id,
|
||||
"count": max(1, int(count)),
|
||||
"reverse_order": bool(reverse_order),
|
||||
}
|
||||
if message_seq is not None:
|
||||
params["message_seq"] = int(message_seq)
|
||||
response_data = await self._safe_call_action_data("get_group_msg_history", params)
|
||||
return self._normalize_payload_list(response_data, action_name="get_group_msg_history")
|
||||
|
||||
async def get_forward_message(
|
||||
self,
|
||||
message_id: Optional[str] = None,
|
||||
|
||||
Reference in New Issue
Block a user