"""NapCat 历史补拉状态持久化仓库。""" from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, List, Optional, TypeVar import asyncio import sqlite3 import time from ..constants import DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC _PROJECT_ROOT = Path(__file__).resolve().parents[3] _DEFAULT_STORAGE_PATH = _PROJECT_ROOT / "data" / "napcat_adapter" / "history_recovery.sqlite3" _SCHEMA_STATEMENTS = ( """ CREATE TABLE IF NOT EXISTS napcat_chat_checkpoint ( account_id TEXT NOT NULL, scope TEXT NOT NULL, chat_type TEXT NOT NULL, chat_id TEXT NOT NULL, last_message_id TEXT NOT NULL, last_message_time REAL NOT NULL, last_message_seq INTEGER, updated_at REAL NOT NULL, PRIMARY KEY (account_id, scope, chat_type, chat_id) ) """, """ CREATE INDEX IF NOT EXISTS ix_napcat_chat_checkpoint_updated_at ON napcat_chat_checkpoint (updated_at DESC) """, """ CREATE TABLE IF NOT EXISTS napcat_recovery_seen ( account_id TEXT NOT NULL, scope TEXT NOT NULL, chat_type TEXT NOT NULL, chat_id TEXT NOT NULL, external_message_id TEXT NOT NULL, seen_at REAL NOT NULL, PRIMARY KEY (account_id, scope, chat_type, chat_id, external_message_id) ) """, """ CREATE INDEX IF NOT EXISTS ix_napcat_recovery_seen_seen_at ON napcat_recovery_seen (seen_at DESC) """, ) T = TypeVar("T") @dataclass(frozen=True) class NapCatChatCheckpoint: """描述一个会话的最近入站锚点。""" account_id: str scope: str chat_type: str chat_id: str last_message_id: str last_message_time: float last_message_seq: int | None updated_at: float @classmethod def from_row(cls, row: sqlite3.Row) -> "NapCatChatCheckpoint": """从 SQLite 行对象恢复 checkpoint。""" last_message_seq = row["last_message_seq"] normalized_seq = int(last_message_seq) if isinstance(last_message_seq, int) else None return cls( account_id=str(row["account_id"] or "").strip(), scope=str(row["scope"] or "").strip(), chat_type=str(row["chat_type"] or "").strip(), chat_id=str(row["chat_id"] or "").strip(), last_message_id=str(row["last_message_id"] or "").strip(), last_message_time=float(row["last_message_time"] or 0.0), last_message_seq=normalized_seq, updated_at=float(row["updated_at"] or 0.0), ) class NapCatHistoryRecoveryStore: """负责持久化历史补拉所需的会话状态与去重状态。""" def __init__(self, logger: Any, storage_path: Path = _DEFAULT_STORAGE_PATH) -> None: """初始化历史补拉状态仓库。""" self._logger = logger self._storage_path = storage_path self._store_lock = asyncio.Lock() self._schema_ready = False async def load(self) -> None: """初始化 SQLite 文件并清理过期去重记录。""" await self._execute_locked(self._ensure_schema) pruned_count = await self.prune_recovery_seen(DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC) if pruned_count > 0: self._logger.debug(f"NapCat 历史补拉去重表已清理 {pruned_count} 条过期记录") async def list_checkpoints(self, account_id: str, scope: str = "", limit: int = 50) -> List[NapCatChatCheckpoint]: """列出指定账号与作用域下的最近会话 checkpoint。""" normalized_account_id = str(account_id or "").strip() if not normalized_account_id: return [] normalized_scope = self._normalize_scope(scope) normalized_limit = max(1, int(limit)) def _operation(conn: sqlite3.Connection) -> List[NapCatChatCheckpoint]: cursor = conn.execute( """ SELECT account_id, scope, chat_type, chat_id, last_message_id, last_message_time, last_message_seq, updated_at FROM napcat_chat_checkpoint WHERE account_id = ? AND scope = ? ORDER BY updated_at DESC LIMIT ? """, (normalized_account_id, normalized_scope, normalized_limit), ) return [NapCatChatCheckpoint.from_row(row) for row in cursor.fetchall()] return await self._execute_locked(_operation) async def record_checkpoint( self, *, account_id: str, scope: str = "", chat_type: str, chat_id: str, message_id: str, message_time: float, message_seq: int | None = None, ) -> None: """记录一条已被 Host 接受的最新入站消息锚点。""" normalized_account_id = str(account_id or "").strip() normalized_scope = self._normalize_scope(scope) normalized_chat_type = str(chat_type or "").strip() normalized_chat_id = str(chat_id or "").strip() normalized_message_id = str(message_id or "").strip() if not ( normalized_account_id and normalized_chat_type and normalized_chat_id and normalized_message_id ): return normalized_message_time = float(message_time or 0.0) normalized_message_seq = self._normalize_message_seq(message_seq) updated_at = time.time() def _operation(conn: sqlite3.Connection) -> None: cursor = conn.execute( """ SELECT last_message_id, last_message_time, last_message_seq FROM napcat_chat_checkpoint WHERE account_id = ? AND scope = ? AND chat_type = ? AND chat_id = ? """, ( normalized_account_id, normalized_scope, normalized_chat_type, normalized_chat_id, ), ) existing_row = cursor.fetchone() if existing_row is not None and not self._should_advance_checkpoint( existing_row=existing_row, message_id=normalized_message_id, message_time=normalized_message_time, message_seq=normalized_message_seq, ): return conn.execute( """ INSERT INTO napcat_chat_checkpoint ( account_id, scope, chat_type, chat_id, last_message_id, last_message_time, last_message_seq, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(account_id, scope, chat_type, chat_id) DO UPDATE SET last_message_id = excluded.last_message_id, last_message_time = excluded.last_message_time, last_message_seq = excluded.last_message_seq, updated_at = excluded.updated_at """, ( normalized_account_id, normalized_scope, normalized_chat_type, normalized_chat_id, normalized_message_id, normalized_message_time, normalized_message_seq, updated_at, ), ) await self._execute_locked(_operation) async def has_recovered_message_seen( self, *, account_id: str, scope: str = "", chat_type: str, chat_id: str, external_message_id: str, ) -> bool: """判断某条历史补拉消息是否已经被当前仓库记录过。""" normalized_account_id = str(account_id or "").strip() normalized_scope = self._normalize_scope(scope) normalized_chat_type = str(chat_type or "").strip() normalized_chat_id = str(chat_id or "").strip() normalized_external_message_id = str(external_message_id or "").strip() if not ( normalized_account_id and normalized_chat_type and normalized_chat_id and normalized_external_message_id ): return False def _operation(conn: sqlite3.Connection) -> bool: cursor = conn.execute( """ SELECT 1 FROM napcat_recovery_seen WHERE account_id = ? AND scope = ? AND chat_type = ? AND chat_id = ? AND external_message_id = ? LIMIT 1 """, ( normalized_account_id, normalized_scope, normalized_chat_type, normalized_chat_id, normalized_external_message_id, ), ) return cursor.fetchone() is not None return await self._execute_locked(_operation) async def mark_recovered_message_seen( self, *, account_id: str, scope: str = "", chat_type: str, chat_id: str, external_message_id: str, ) -> None: """将一条历史补拉消息标记为已尝试处理。""" normalized_account_id = str(account_id or "").strip() normalized_scope = self._normalize_scope(scope) normalized_chat_type = str(chat_type or "").strip() normalized_chat_id = str(chat_id or "").strip() normalized_external_message_id = str(external_message_id or "").strip() if not ( normalized_account_id and normalized_chat_type and normalized_chat_id and normalized_external_message_id ): return def _operation(conn: sqlite3.Connection) -> None: conn.execute( """ INSERT OR REPLACE INTO napcat_recovery_seen ( account_id, scope, chat_type, chat_id, external_message_id, seen_at ) VALUES (?, ?, ?, ?, ?, ?) """, ( normalized_account_id, normalized_scope, normalized_chat_type, normalized_chat_id, normalized_external_message_id, time.time(), ), ) await self._execute_locked(_operation) async def prune_recovery_seen(self, ttl_seconds: float) -> int: """删除超过保留期的历史补拉去重记录。""" normalized_ttl_seconds = max(0.0, float(ttl_seconds or 0.0)) if normalized_ttl_seconds <= 0.0: return 0 cutoff_timestamp = time.time() - normalized_ttl_seconds def _operation(conn: sqlite3.Connection) -> int: cursor = conn.execute( "DELETE FROM napcat_recovery_seen WHERE seen_at < ?", (cutoff_timestamp,), ) return int(cursor.rowcount or 0) return await self._execute_locked(_operation) async def _execute_locked(self, operation: Callable[[sqlite3.Connection], T]) -> T: """在锁保护下打开 SQLite 并执行一次原子操作。""" async with self._store_lock: conn = self._open_connection() try: self._ensure_schema(conn) result = operation(conn) conn.commit() return result except Exception: conn.rollback() raise finally: conn.close() def _open_connection(self) -> sqlite3.Connection: """打开一个带 Row 工厂的 SQLite 连接。""" self._storage_path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(str(self._storage_path)) conn.row_factory = sqlite3.Row return conn def _ensure_schema(self, conn: sqlite3.Connection) -> None: """确保 SQLite 表结构已经准备完成。""" if self._schema_ready: return for statement in _SCHEMA_STATEMENTS: conn.execute(statement) self._schema_ready = True @staticmethod def _normalize_scope(scope: str | None) -> str: """将空作用域统一折叠为空字符串。""" return str(scope or "").strip() @staticmethod def _normalize_message_seq(message_seq: object) -> int | None: """将消息序号规范化为可选整数。""" try: if message_seq is None or str(message_seq).strip() == "": return None return int(message_seq) except (TypeError, ValueError): return None @classmethod def _should_advance_checkpoint( cls, *, existing_row: sqlite3.Row, message_id: str, message_time: float, message_seq: int | None, ) -> bool: """判断新的锚点是否应覆盖旧锚点。""" existing_message_id = str(existing_row["last_message_id"] or "").strip() existing_message_time = float(existing_row["last_message_time"] or 0.0) existing_message_seq = cls._normalize_message_seq(existing_row["last_message_seq"]) if message_seq is not None and existing_message_seq is not None: if message_seq != existing_message_seq: return message_seq > existing_message_seq if message_id == existing_message_id: return False return message_time >= existing_message_time if message_time != existing_message_time: return message_time > existing_message_time if message_id == existing_message_id: return False if message_seq is not None and existing_message_seq is None: return True return True