177 lines
6.3 KiB
Python
177 lines
6.3 KiB
Python
"""NapCat 群禁言状态跟踪服务。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import time
|
|
|
|
from .ban_state_store import NapCatBanRecord, NapCatBanStateStore
|
|
from .query_service import NapCatQueryService
|
|
|
|
|
|
class NapCatBanTracker:
|
|
"""NapCat 群禁言状态跟踪器。"""
|
|
|
|
def __init__(
|
|
self,
|
|
logger: Any,
|
|
query_service: NapCatQueryService,
|
|
on_natural_lift: Callable[[Dict[str, Any]], Awaitable[None]],
|
|
state_store: NapCatBanStateStore,
|
|
) -> None:
|
|
"""初始化群禁言状态跟踪器。
|
|
|
|
Args:
|
|
logger: 插件日志对象。
|
|
query_service: NapCat 查询服务。
|
|
on_natural_lift: 检测到自然解除禁言后的回调。
|
|
state_store: 禁言状态存储仓库。
|
|
"""
|
|
self._logger = logger
|
|
self._query_service = query_service
|
|
self._on_natural_lift = on_natural_lift
|
|
self._state_store = state_store
|
|
self._poll_task: Optional[asyncio.Task[None]] = None
|
|
|
|
async def start(self) -> None:
|
|
"""启动禁言状态跟踪。"""
|
|
await self._state_store.load()
|
|
await self._refresh_records_from_remote()
|
|
if self._poll_task is None or self._poll_task.done():
|
|
self._poll_task = asyncio.create_task(self._poll_loop(), name="napcat_adapter.ban_tracker")
|
|
|
|
async def stop(self) -> None:
|
|
"""停止禁言状态跟踪并落盘当前记录。"""
|
|
poll_task = self._poll_task
|
|
self._poll_task = None
|
|
if poll_task is not None:
|
|
poll_task.cancel()
|
|
with contextlib.suppress(asyncio.CancelledError):
|
|
await poll_task
|
|
await self._state_store.persist()
|
|
|
|
async def record_notice(self, payload: Mapping[str, Any]) -> None:
|
|
"""根据实际 notice 事件更新禁言状态。
|
|
|
|
Args:
|
|
payload: NapCat 推送的原始通知事件。
|
|
"""
|
|
notice_type = str(payload.get("notice_type") or "").strip()
|
|
if notice_type != "group_ban":
|
|
return
|
|
|
|
sub_type = str(payload.get("sub_type") or "").strip()
|
|
group_id = str(payload.get("group_id") or "").strip()
|
|
user_id = str(payload.get("user_id") or "0").strip() or "0"
|
|
if not group_id:
|
|
return
|
|
|
|
if sub_type == "ban":
|
|
duration = self._normalize_int(payload.get("duration"), default=-1)
|
|
lift_time = -1 if user_id == "0" or duration <= 0 else int(time.time()) + duration
|
|
await self._state_store.upsert(NapCatBanRecord(group_id=group_id, user_id=user_id, lift_time=lift_time))
|
|
return
|
|
|
|
if sub_type in {"lift_ban", "whole_lift_ban"}:
|
|
await self._state_store.remove(group_id=group_id, user_id=user_id)
|
|
|
|
async def _refresh_records_from_remote(self) -> None:
|
|
"""基于当前 QQ 平台状态校正本地禁言记录。"""
|
|
for record in await self._state_store.snapshot():
|
|
if record.user_id == "0":
|
|
await self._refresh_whole_ban_record(record)
|
|
continue
|
|
await self._refresh_member_ban_record(record)
|
|
|
|
async def _refresh_whole_ban_record(self, record: NapCatBanRecord) -> None:
|
|
"""刷新全体禁言记录。
|
|
|
|
Args:
|
|
record: 待刷新的禁言记录。
|
|
"""
|
|
group_info = await self._query_service.get_group_info(record.group_id)
|
|
if group_info is None:
|
|
await self._emit_natural_lift(record)
|
|
return
|
|
|
|
group_all_shut = self._normalize_int(group_info.get("group_all_shut"), default=0)
|
|
if group_all_shut == 0:
|
|
await self._emit_natural_lift(record)
|
|
|
|
async def _refresh_member_ban_record(self, record: NapCatBanRecord) -> None:
|
|
"""刷新成员禁言记录。
|
|
|
|
Args:
|
|
record: 待刷新的禁言记录。
|
|
"""
|
|
member_info = await self._query_service.get_group_member_info(record.group_id, record.user_id, no_cache=True)
|
|
if member_info is None:
|
|
await self._emit_natural_lift(record)
|
|
return
|
|
|
|
shut_up_timestamp = self._normalize_int(member_info.get("shut_up_timestamp"), default=0)
|
|
if shut_up_timestamp == 0:
|
|
await self._emit_natural_lift(record)
|
|
return
|
|
|
|
if shut_up_timestamp != record.lift_time:
|
|
await self._state_store.upsert(
|
|
NapCatBanRecord(group_id=record.group_id, user_id=record.user_id, lift_time=shut_up_timestamp)
|
|
)
|
|
|
|
async def _poll_loop(self) -> None:
|
|
"""后台轮询自然解除禁言。"""
|
|
while True:
|
|
await asyncio.sleep(5.0)
|
|
current_timestamp = int(time.time())
|
|
for record in await self._state_store.snapshot():
|
|
if record.user_id == "0":
|
|
await self._refresh_whole_ban_record(record)
|
|
continue
|
|
if record.lift_time != -1 and record.lift_time <= current_timestamp:
|
|
await self._emit_natural_lift(record)
|
|
|
|
async def _emit_natural_lift(self, record: NapCatBanRecord) -> None:
|
|
"""上报自然解除禁言事件。
|
|
|
|
Args:
|
|
record: 已解除的禁言记录。
|
|
"""
|
|
removed_record = await self._state_store.pop(record.record_key)
|
|
if removed_record is None:
|
|
return
|
|
|
|
payload: Dict[str, Any] = {
|
|
"post_type": "notice",
|
|
"notice_type": "group_ban",
|
|
"sub_type": "whole_lift_ban" if record.user_id == "0" else "lift_ban",
|
|
"group_id": record.group_id,
|
|
"user_id": record.user_id,
|
|
"operator_id": None,
|
|
"time": time.time(),
|
|
"is_natural_lift": True,
|
|
}
|
|
try:
|
|
await self._on_natural_lift(payload)
|
|
except Exception as exc:
|
|
self._logger.warning(f"NapCat 自然解除禁言回调失败: {exc}")
|
|
|
|
@staticmethod
|
|
def _normalize_int(value: Any, default: int) -> int:
|
|
"""将任意值规范化为整数。
|
|
|
|
Args:
|
|
value: 待规范化的值。
|
|
default: 转换失败时的默认值。
|
|
|
|
Returns:
|
|
int: 规范化后的整数结果。
|
|
"""
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return default
|