424 lines
14 KiB
Python
424 lines
14 KiB
Python
"""NapCat 历史补拉状态持久化仓库。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Callable, List, Optional, TypeVar
|
|
|
|
import asyncio
|
|
import sqlite3
|
|
import time
|
|
|
|
from ..constants import DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC
|
|
|
|
_PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
|
_DEFAULT_STORAGE_PATH = _PROJECT_ROOT / "data" / "napcat_adapter" / "history_recovery.sqlite3"
|
|
|
|
_SCHEMA_STATEMENTS = (
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS napcat_chat_checkpoint (
|
|
account_id TEXT NOT NULL,
|
|
scope TEXT NOT NULL,
|
|
chat_type TEXT NOT NULL,
|
|
chat_id TEXT NOT NULL,
|
|
last_message_id TEXT NOT NULL,
|
|
last_message_time REAL NOT NULL,
|
|
last_message_seq INTEGER,
|
|
updated_at REAL NOT NULL,
|
|
PRIMARY KEY (account_id, scope, chat_type, chat_id)
|
|
)
|
|
""",
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS ix_napcat_chat_checkpoint_updated_at
|
|
ON napcat_chat_checkpoint (updated_at DESC)
|
|
""",
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS napcat_recovery_seen (
|
|
account_id TEXT NOT NULL,
|
|
scope TEXT NOT NULL,
|
|
chat_type TEXT NOT NULL,
|
|
chat_id TEXT NOT NULL,
|
|
external_message_id TEXT NOT NULL,
|
|
seen_at REAL NOT NULL,
|
|
PRIMARY KEY (account_id, scope, chat_type, chat_id, external_message_id)
|
|
)
|
|
""",
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS ix_napcat_recovery_seen_seen_at
|
|
ON napcat_recovery_seen (seen_at DESC)
|
|
""",
|
|
)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class NapCatChatCheckpoint:
|
|
"""描述一个会话的最近入站锚点。"""
|
|
|
|
account_id: str
|
|
scope: str
|
|
chat_type: str
|
|
chat_id: str
|
|
last_message_id: str
|
|
last_message_time: float
|
|
last_message_seq: int | None
|
|
updated_at: float
|
|
|
|
@classmethod
|
|
def from_row(cls, row: sqlite3.Row) -> "NapCatChatCheckpoint":
|
|
"""从 SQLite 行对象恢复 checkpoint。"""
|
|
|
|
last_message_seq = row["last_message_seq"]
|
|
normalized_seq = int(last_message_seq) if isinstance(last_message_seq, int) else None
|
|
return cls(
|
|
account_id=str(row["account_id"] or "").strip(),
|
|
scope=str(row["scope"] or "").strip(),
|
|
chat_type=str(row["chat_type"] or "").strip(),
|
|
chat_id=str(row["chat_id"] or "").strip(),
|
|
last_message_id=str(row["last_message_id"] or "").strip(),
|
|
last_message_time=float(row["last_message_time"] or 0.0),
|
|
last_message_seq=normalized_seq,
|
|
updated_at=float(row["updated_at"] or 0.0),
|
|
)
|
|
|
|
|
|
class NapCatHistoryRecoveryStore:
|
|
"""负责持久化历史补拉所需的会话状态与去重状态。"""
|
|
|
|
def __init__(self, logger: Any, storage_path: Path = _DEFAULT_STORAGE_PATH) -> None:
|
|
"""初始化历史补拉状态仓库。"""
|
|
|
|
self._logger = logger
|
|
self._storage_path = storage_path
|
|
self._store_lock = asyncio.Lock()
|
|
self._schema_ready = False
|
|
|
|
async def load(self) -> None:
|
|
"""初始化 SQLite 文件并清理过期去重记录。"""
|
|
|
|
await self._execute_locked(self._ensure_schema)
|
|
pruned_count = await self.prune_recovery_seen(DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC)
|
|
if pruned_count > 0:
|
|
self._logger.debug(f"NapCat 历史补拉去重表已清理 {pruned_count} 条过期记录")
|
|
|
|
async def list_checkpoints(self, account_id: str, scope: str = "", limit: int = 50) -> List[NapCatChatCheckpoint]:
|
|
"""列出指定账号与作用域下的最近会话 checkpoint。"""
|
|
|
|
normalized_account_id = str(account_id or "").strip()
|
|
if not normalized_account_id:
|
|
return []
|
|
|
|
normalized_scope = self._normalize_scope(scope)
|
|
normalized_limit = max(1, int(limit))
|
|
|
|
def _operation(conn: sqlite3.Connection) -> List[NapCatChatCheckpoint]:
|
|
cursor = conn.execute(
|
|
"""
|
|
SELECT
|
|
account_id,
|
|
scope,
|
|
chat_type,
|
|
chat_id,
|
|
last_message_id,
|
|
last_message_time,
|
|
last_message_seq,
|
|
updated_at
|
|
FROM napcat_chat_checkpoint
|
|
WHERE account_id = ? AND scope = ?
|
|
ORDER BY updated_at DESC
|
|
LIMIT ?
|
|
""",
|
|
(normalized_account_id, normalized_scope, normalized_limit),
|
|
)
|
|
return [NapCatChatCheckpoint.from_row(row) for row in cursor.fetchall()]
|
|
|
|
return await self._execute_locked(_operation)
|
|
|
|
async def record_checkpoint(
|
|
self,
|
|
*,
|
|
account_id: str,
|
|
scope: str = "",
|
|
chat_type: str,
|
|
chat_id: str,
|
|
message_id: str,
|
|
message_time: float,
|
|
message_seq: int | None = None,
|
|
) -> None:
|
|
"""记录一条已被 Host 接受的最新入站消息锚点。"""
|
|
|
|
normalized_account_id = str(account_id or "").strip()
|
|
normalized_scope = self._normalize_scope(scope)
|
|
normalized_chat_type = str(chat_type or "").strip()
|
|
normalized_chat_id = str(chat_id or "").strip()
|
|
normalized_message_id = str(message_id or "").strip()
|
|
|
|
if not (
|
|
normalized_account_id
|
|
and normalized_chat_type
|
|
and normalized_chat_id
|
|
and normalized_message_id
|
|
):
|
|
return
|
|
|
|
normalized_message_time = float(message_time or 0.0)
|
|
normalized_message_seq = self._normalize_message_seq(message_seq)
|
|
updated_at = time.time()
|
|
|
|
def _operation(conn: sqlite3.Connection) -> None:
|
|
cursor = conn.execute(
|
|
"""
|
|
SELECT last_message_id, last_message_time, last_message_seq
|
|
FROM napcat_chat_checkpoint
|
|
WHERE account_id = ? AND scope = ? AND chat_type = ? AND chat_id = ?
|
|
""",
|
|
(
|
|
normalized_account_id,
|
|
normalized_scope,
|
|
normalized_chat_type,
|
|
normalized_chat_id,
|
|
),
|
|
)
|
|
existing_row = cursor.fetchone()
|
|
if existing_row is not None and not self._should_advance_checkpoint(
|
|
existing_row=existing_row,
|
|
message_id=normalized_message_id,
|
|
message_time=normalized_message_time,
|
|
message_seq=normalized_message_seq,
|
|
):
|
|
return
|
|
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO napcat_chat_checkpoint (
|
|
account_id,
|
|
scope,
|
|
chat_type,
|
|
chat_id,
|
|
last_message_id,
|
|
last_message_time,
|
|
last_message_seq,
|
|
updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(account_id, scope, chat_type, chat_id) DO UPDATE SET
|
|
last_message_id = excluded.last_message_id,
|
|
last_message_time = excluded.last_message_time,
|
|
last_message_seq = excluded.last_message_seq,
|
|
updated_at = excluded.updated_at
|
|
""",
|
|
(
|
|
normalized_account_id,
|
|
normalized_scope,
|
|
normalized_chat_type,
|
|
normalized_chat_id,
|
|
normalized_message_id,
|
|
normalized_message_time,
|
|
normalized_message_seq,
|
|
updated_at,
|
|
),
|
|
)
|
|
|
|
await self._execute_locked(_operation)
|
|
|
|
async def has_recovered_message_seen(
|
|
self,
|
|
*,
|
|
account_id: str,
|
|
scope: str = "",
|
|
chat_type: str,
|
|
chat_id: str,
|
|
external_message_id: str,
|
|
) -> bool:
|
|
"""判断某条历史补拉消息是否已经被当前仓库记录过。"""
|
|
|
|
normalized_account_id = str(account_id or "").strip()
|
|
normalized_scope = self._normalize_scope(scope)
|
|
normalized_chat_type = str(chat_type or "").strip()
|
|
normalized_chat_id = str(chat_id or "").strip()
|
|
normalized_external_message_id = str(external_message_id or "").strip()
|
|
|
|
if not (
|
|
normalized_account_id
|
|
and normalized_chat_type
|
|
and normalized_chat_id
|
|
and normalized_external_message_id
|
|
):
|
|
return False
|
|
|
|
def _operation(conn: sqlite3.Connection) -> bool:
|
|
cursor = conn.execute(
|
|
"""
|
|
SELECT 1
|
|
FROM napcat_recovery_seen
|
|
WHERE account_id = ?
|
|
AND scope = ?
|
|
AND chat_type = ?
|
|
AND chat_id = ?
|
|
AND external_message_id = ?
|
|
LIMIT 1
|
|
""",
|
|
(
|
|
normalized_account_id,
|
|
normalized_scope,
|
|
normalized_chat_type,
|
|
normalized_chat_id,
|
|
normalized_external_message_id,
|
|
),
|
|
)
|
|
return cursor.fetchone() is not None
|
|
|
|
return await self._execute_locked(_operation)
|
|
|
|
async def mark_recovered_message_seen(
|
|
self,
|
|
*,
|
|
account_id: str,
|
|
scope: str = "",
|
|
chat_type: str,
|
|
chat_id: str,
|
|
external_message_id: str,
|
|
) -> None:
|
|
"""将一条历史补拉消息标记为已尝试处理。"""
|
|
|
|
normalized_account_id = str(account_id or "").strip()
|
|
normalized_scope = self._normalize_scope(scope)
|
|
normalized_chat_type = str(chat_type or "").strip()
|
|
normalized_chat_id = str(chat_id or "").strip()
|
|
normalized_external_message_id = str(external_message_id or "").strip()
|
|
|
|
if not (
|
|
normalized_account_id
|
|
and normalized_chat_type
|
|
and normalized_chat_id
|
|
and normalized_external_message_id
|
|
):
|
|
return
|
|
|
|
def _operation(conn: sqlite3.Connection) -> None:
|
|
conn.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO napcat_recovery_seen (
|
|
account_id,
|
|
scope,
|
|
chat_type,
|
|
chat_id,
|
|
external_message_id,
|
|
seen_at
|
|
) VALUES (?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
normalized_account_id,
|
|
normalized_scope,
|
|
normalized_chat_type,
|
|
normalized_chat_id,
|
|
normalized_external_message_id,
|
|
time.time(),
|
|
),
|
|
)
|
|
|
|
await self._execute_locked(_operation)
|
|
|
|
async def prune_recovery_seen(self, ttl_seconds: float) -> int:
|
|
"""删除超过保留期的历史补拉去重记录。"""
|
|
|
|
normalized_ttl_seconds = max(0.0, float(ttl_seconds or 0.0))
|
|
if normalized_ttl_seconds <= 0.0:
|
|
return 0
|
|
|
|
cutoff_timestamp = time.time() - normalized_ttl_seconds
|
|
|
|
def _operation(conn: sqlite3.Connection) -> int:
|
|
cursor = conn.execute(
|
|
"DELETE FROM napcat_recovery_seen WHERE seen_at < ?",
|
|
(cutoff_timestamp,),
|
|
)
|
|
return int(cursor.rowcount or 0)
|
|
|
|
return await self._execute_locked(_operation)
|
|
|
|
async def _execute_locked(self, operation: Callable[[sqlite3.Connection], T]) -> T:
|
|
"""在锁保护下打开 SQLite 并执行一次原子操作。"""
|
|
|
|
async with self._store_lock:
|
|
conn = self._open_connection()
|
|
try:
|
|
self._ensure_schema(conn)
|
|
result = operation(conn)
|
|
conn.commit()
|
|
return result
|
|
except Exception:
|
|
conn.rollback()
|
|
raise
|
|
finally:
|
|
conn.close()
|
|
|
|
def _open_connection(self) -> sqlite3.Connection:
|
|
"""打开一个带 Row 工厂的 SQLite 连接。"""
|
|
|
|
self._storage_path.parent.mkdir(parents=True, exist_ok=True)
|
|
conn = sqlite3.connect(str(self._storage_path))
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
def _ensure_schema(self, conn: sqlite3.Connection) -> None:
|
|
"""确保 SQLite 表结构已经准备完成。"""
|
|
|
|
if self._schema_ready:
|
|
return
|
|
|
|
for statement in _SCHEMA_STATEMENTS:
|
|
conn.execute(statement)
|
|
self._schema_ready = True
|
|
|
|
@staticmethod
|
|
def _normalize_scope(scope: str | None) -> str:
|
|
"""将空作用域统一折叠为空字符串。"""
|
|
|
|
return str(scope or "").strip()
|
|
|
|
@staticmethod
|
|
def _normalize_message_seq(message_seq: object) -> int | None:
|
|
"""将消息序号规范化为可选整数。"""
|
|
|
|
try:
|
|
if message_seq is None or str(message_seq).strip() == "":
|
|
return None
|
|
return int(message_seq)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
@classmethod
|
|
def _should_advance_checkpoint(
|
|
cls,
|
|
*,
|
|
existing_row: sqlite3.Row,
|
|
message_id: str,
|
|
message_time: float,
|
|
message_seq: int | None,
|
|
) -> bool:
|
|
"""判断新的锚点是否应覆盖旧锚点。"""
|
|
|
|
existing_message_id = str(existing_row["last_message_id"] or "").strip()
|
|
existing_message_time = float(existing_row["last_message_time"] or 0.0)
|
|
existing_message_seq = cls._normalize_message_seq(existing_row["last_message_seq"])
|
|
|
|
if message_seq is not None and existing_message_seq is not None:
|
|
if message_seq != existing_message_seq:
|
|
return message_seq > existing_message_seq
|
|
if message_id == existing_message_id:
|
|
return False
|
|
return message_time >= existing_message_time
|
|
|
|
if message_time != existing_message_time:
|
|
return message_time > existing_message_time
|
|
|
|
if message_id == existing_message_id:
|
|
return False
|
|
|
|
if message_seq is not None and existing_message_seq is None:
|
|
return True
|
|
|
|
return True
|