Files
mai-bot/plugin-templates/MaiBot-Napcat-Adapter/services/history_recovery_store.py

424 lines
14 KiB
Python

"""NapCat 历史补拉状态持久化仓库。"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, List, Optional, TypeVar
import asyncio
import sqlite3
import time
from ..constants import DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC
_PROJECT_ROOT = Path(__file__).resolve().parents[3]
_DEFAULT_STORAGE_PATH = _PROJECT_ROOT / "data" / "napcat_adapter" / "history_recovery.sqlite3"
_SCHEMA_STATEMENTS = (
"""
CREATE TABLE IF NOT EXISTS napcat_chat_checkpoint (
account_id TEXT NOT NULL,
scope TEXT NOT NULL,
chat_type TEXT NOT NULL,
chat_id TEXT NOT NULL,
last_message_id TEXT NOT NULL,
last_message_time REAL NOT NULL,
last_message_seq INTEGER,
updated_at REAL NOT NULL,
PRIMARY KEY (account_id, scope, chat_type, chat_id)
)
""",
"""
CREATE INDEX IF NOT EXISTS ix_napcat_chat_checkpoint_updated_at
ON napcat_chat_checkpoint (updated_at DESC)
""",
"""
CREATE TABLE IF NOT EXISTS napcat_recovery_seen (
account_id TEXT NOT NULL,
scope TEXT NOT NULL,
chat_type TEXT NOT NULL,
chat_id TEXT NOT NULL,
external_message_id TEXT NOT NULL,
seen_at REAL NOT NULL,
PRIMARY KEY (account_id, scope, chat_type, chat_id, external_message_id)
)
""",
"""
CREATE INDEX IF NOT EXISTS ix_napcat_recovery_seen_seen_at
ON napcat_recovery_seen (seen_at DESC)
""",
)
T = TypeVar("T")
@dataclass(frozen=True)
class NapCatChatCheckpoint:
"""描述一个会话的最近入站锚点。"""
account_id: str
scope: str
chat_type: str
chat_id: str
last_message_id: str
last_message_time: float
last_message_seq: int | None
updated_at: float
@classmethod
def from_row(cls, row: sqlite3.Row) -> "NapCatChatCheckpoint":
"""从 SQLite 行对象恢复 checkpoint。"""
last_message_seq = row["last_message_seq"]
normalized_seq = int(last_message_seq) if isinstance(last_message_seq, int) else None
return cls(
account_id=str(row["account_id"] or "").strip(),
scope=str(row["scope"] or "").strip(),
chat_type=str(row["chat_type"] or "").strip(),
chat_id=str(row["chat_id"] or "").strip(),
last_message_id=str(row["last_message_id"] or "").strip(),
last_message_time=float(row["last_message_time"] or 0.0),
last_message_seq=normalized_seq,
updated_at=float(row["updated_at"] or 0.0),
)
class NapCatHistoryRecoveryStore:
"""负责持久化历史补拉所需的会话状态与去重状态。"""
def __init__(self, logger: Any, storage_path: Path = _DEFAULT_STORAGE_PATH) -> None:
"""初始化历史补拉状态仓库。"""
self._logger = logger
self._storage_path = storage_path
self._store_lock = asyncio.Lock()
self._schema_ready = False
async def load(self) -> None:
"""初始化 SQLite 文件并清理过期去重记录。"""
await self._execute_locked(self._ensure_schema)
pruned_count = await self.prune_recovery_seen(DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC)
if pruned_count > 0:
self._logger.debug(f"NapCat 历史补拉去重表已清理 {pruned_count} 条过期记录")
async def list_checkpoints(self, account_id: str, scope: str = "", limit: int = 50) -> List[NapCatChatCheckpoint]:
"""列出指定账号与作用域下的最近会话 checkpoint。"""
normalized_account_id = str(account_id or "").strip()
if not normalized_account_id:
return []
normalized_scope = self._normalize_scope(scope)
normalized_limit = max(1, int(limit))
def _operation(conn: sqlite3.Connection) -> List[NapCatChatCheckpoint]:
cursor = conn.execute(
"""
SELECT
account_id,
scope,
chat_type,
chat_id,
last_message_id,
last_message_time,
last_message_seq,
updated_at
FROM napcat_chat_checkpoint
WHERE account_id = ? AND scope = ?
ORDER BY updated_at DESC
LIMIT ?
""",
(normalized_account_id, normalized_scope, normalized_limit),
)
return [NapCatChatCheckpoint.from_row(row) for row in cursor.fetchall()]
return await self._execute_locked(_operation)
async def record_checkpoint(
self,
*,
account_id: str,
scope: str = "",
chat_type: str,
chat_id: str,
message_id: str,
message_time: float,
message_seq: int | None = None,
) -> None:
"""记录一条已被 Host 接受的最新入站消息锚点。"""
normalized_account_id = str(account_id or "").strip()
normalized_scope = self._normalize_scope(scope)
normalized_chat_type = str(chat_type or "").strip()
normalized_chat_id = str(chat_id or "").strip()
normalized_message_id = str(message_id or "").strip()
if not (
normalized_account_id
and normalized_chat_type
and normalized_chat_id
and normalized_message_id
):
return
normalized_message_time = float(message_time or 0.0)
normalized_message_seq = self._normalize_message_seq(message_seq)
updated_at = time.time()
def _operation(conn: sqlite3.Connection) -> None:
cursor = conn.execute(
"""
SELECT last_message_id, last_message_time, last_message_seq
FROM napcat_chat_checkpoint
WHERE account_id = ? AND scope = ? AND chat_type = ? AND chat_id = ?
""",
(
normalized_account_id,
normalized_scope,
normalized_chat_type,
normalized_chat_id,
),
)
existing_row = cursor.fetchone()
if existing_row is not None and not self._should_advance_checkpoint(
existing_row=existing_row,
message_id=normalized_message_id,
message_time=normalized_message_time,
message_seq=normalized_message_seq,
):
return
conn.execute(
"""
INSERT INTO napcat_chat_checkpoint (
account_id,
scope,
chat_type,
chat_id,
last_message_id,
last_message_time,
last_message_seq,
updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(account_id, scope, chat_type, chat_id) DO UPDATE SET
last_message_id = excluded.last_message_id,
last_message_time = excluded.last_message_time,
last_message_seq = excluded.last_message_seq,
updated_at = excluded.updated_at
""",
(
normalized_account_id,
normalized_scope,
normalized_chat_type,
normalized_chat_id,
normalized_message_id,
normalized_message_time,
normalized_message_seq,
updated_at,
),
)
await self._execute_locked(_operation)
async def has_recovered_message_seen(
self,
*,
account_id: str,
scope: str = "",
chat_type: str,
chat_id: str,
external_message_id: str,
) -> bool:
"""判断某条历史补拉消息是否已经被当前仓库记录过。"""
normalized_account_id = str(account_id or "").strip()
normalized_scope = self._normalize_scope(scope)
normalized_chat_type = str(chat_type or "").strip()
normalized_chat_id = str(chat_id or "").strip()
normalized_external_message_id = str(external_message_id or "").strip()
if not (
normalized_account_id
and normalized_chat_type
and normalized_chat_id
and normalized_external_message_id
):
return False
def _operation(conn: sqlite3.Connection) -> bool:
cursor = conn.execute(
"""
SELECT 1
FROM napcat_recovery_seen
WHERE account_id = ?
AND scope = ?
AND chat_type = ?
AND chat_id = ?
AND external_message_id = ?
LIMIT 1
""",
(
normalized_account_id,
normalized_scope,
normalized_chat_type,
normalized_chat_id,
normalized_external_message_id,
),
)
return cursor.fetchone() is not None
return await self._execute_locked(_operation)
async def mark_recovered_message_seen(
self,
*,
account_id: str,
scope: str = "",
chat_type: str,
chat_id: str,
external_message_id: str,
) -> None:
"""将一条历史补拉消息标记为已尝试处理。"""
normalized_account_id = str(account_id or "").strip()
normalized_scope = self._normalize_scope(scope)
normalized_chat_type = str(chat_type or "").strip()
normalized_chat_id = str(chat_id or "").strip()
normalized_external_message_id = str(external_message_id or "").strip()
if not (
normalized_account_id
and normalized_chat_type
and normalized_chat_id
and normalized_external_message_id
):
return
def _operation(conn: sqlite3.Connection) -> None:
conn.execute(
"""
INSERT OR REPLACE INTO napcat_recovery_seen (
account_id,
scope,
chat_type,
chat_id,
external_message_id,
seen_at
) VALUES (?, ?, ?, ?, ?, ?)
""",
(
normalized_account_id,
normalized_scope,
normalized_chat_type,
normalized_chat_id,
normalized_external_message_id,
time.time(),
),
)
await self._execute_locked(_operation)
async def prune_recovery_seen(self, ttl_seconds: float) -> int:
"""删除超过保留期的历史补拉去重记录。"""
normalized_ttl_seconds = max(0.0, float(ttl_seconds or 0.0))
if normalized_ttl_seconds <= 0.0:
return 0
cutoff_timestamp = time.time() - normalized_ttl_seconds
def _operation(conn: sqlite3.Connection) -> int:
cursor = conn.execute(
"DELETE FROM napcat_recovery_seen WHERE seen_at < ?",
(cutoff_timestamp,),
)
return int(cursor.rowcount or 0)
return await self._execute_locked(_operation)
async def _execute_locked(self, operation: Callable[[sqlite3.Connection], T]) -> T:
"""在锁保护下打开 SQLite 并执行一次原子操作。"""
async with self._store_lock:
conn = self._open_connection()
try:
self._ensure_schema(conn)
result = operation(conn)
conn.commit()
return result
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _open_connection(self) -> sqlite3.Connection:
"""打开一个带 Row 工厂的 SQLite 连接。"""
self._storage_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(self._storage_path))
conn.row_factory = sqlite3.Row
return conn
def _ensure_schema(self, conn: sqlite3.Connection) -> None:
"""确保 SQLite 表结构已经准备完成。"""
if self._schema_ready:
return
for statement in _SCHEMA_STATEMENTS:
conn.execute(statement)
self._schema_ready = True
@staticmethod
def _normalize_scope(scope: str | None) -> str:
"""将空作用域统一折叠为空字符串。"""
return str(scope or "").strip()
@staticmethod
def _normalize_message_seq(message_seq: object) -> int | None:
"""将消息序号规范化为可选整数。"""
try:
if message_seq is None or str(message_seq).strip() == "":
return None
return int(message_seq)
except (TypeError, ValueError):
return None
@classmethod
def _should_advance_checkpoint(
cls,
*,
existing_row: sqlite3.Row,
message_id: str,
message_time: float,
message_seq: int | None,
) -> bool:
"""判断新的锚点是否应覆盖旧锚点。"""
existing_message_id = str(existing_row["last_message_id"] or "").strip()
existing_message_time = float(existing_row["last_message_time"] or 0.0)
existing_message_seq = cls._normalize_message_seq(existing_row["last_message_seq"])
if message_seq is not None and existing_message_seq is not None:
if message_seq != existing_message_seq:
return message_seq > existing_message_seq
if message_id == existing_message_id:
return False
return message_time >= existing_message_time
if message_time != existing_message_time:
return message_time > existing_message_time
if message_id == existing_message_id:
return False
if message_seq is not None and existing_message_seq is None:
return True
return True