feat(config): harden file watcher hot-reload flow and add test coverage

refactor FileWatcher to subscription-based dispatch with path/change filters
add callback timeout, failure cooldown, auto-retry loop, and runtime stats
strengthen ConfigManager hot-reload with throttling, timeout guard, and watcher stats logging
add pytest suites for watcher behavior and config hot-reload edge cases
This commit is contained in:
DrSmoothl
2026-03-04 21:39:26 +08:00
parent 5cccdf6715
commit b3a81754e6
4 changed files with 419 additions and 20 deletions

View File

@@ -0,0 +1,104 @@
from pathlib import Path
from watchfiles import Change
import asyncio
import pytest
from src.config.config import ConfigManager
from src.config.file_watcher import FileChange, FileWatcherStats
@pytest.mark.asyncio
async def test_handle_file_changes_throttles_reload():
manager = ConfigManager()
manager._hot_reload_min_interval_s = 100.0
called = 0
async def reload_stub() -> bool:
nonlocal called
called += 1
return True
manager.reload_config = reload_stub # type: ignore[method-assign]
changes = [FileChange(change_type=Change.modified, path=Path("/tmp/bot_config.toml"))]
await manager._handle_file_changes(changes)
await manager._handle_file_changes(changes)
assert called == 1
@pytest.mark.asyncio
async def test_handle_file_changes_timeout_logged(caplog):
manager = ConfigManager()
manager._hot_reload_min_interval_s = 0.0
manager._hot_reload_timeout_s = 0.01
async def reload_stub() -> bool:
await asyncio.sleep(0.05)
return True
manager.reload_config = reload_stub # type: ignore[method-assign]
changes = [FileChange(change_type=Change.modified, path=Path("/tmp/model_config.toml"))]
with caplog.at_level("ERROR"):
await manager._handle_file_changes(changes)
assert "配置热重载超时" in caplog.text
@pytest.mark.asyncio
async def test_handle_file_changes_empty_skips_reload():
manager = ConfigManager()
called = 0
async def reload_stub() -> bool:
nonlocal called
called += 1
return True
manager.reload_config = reload_stub # type: ignore[method-assign]
await manager._handle_file_changes([])
assert called == 0
class _FakeWatcher:
def __init__(self):
self.unsubscribe_called_with: str | None = None
self.stop_called = False
self.stats = FileWatcherStats(
batches_seen=1,
changes_seen=2,
callbacks_succeeded=3,
callbacks_failed=4,
callbacks_timed_out=5,
callbacks_skipped_cooldown=6,
restart_count=7,
)
def unsubscribe(self, subscription_id: str) -> bool:
self.unsubscribe_called_with = subscription_id
return True
async def stop(self) -> None:
self.stop_called = True
@pytest.mark.asyncio
async def test_stop_file_watcher_cleans_state():
manager = ConfigManager()
fake_watcher = _FakeWatcher()
manager._file_watcher = fake_watcher # type: ignore[assignment]
manager._file_watcher_subscription_id = "sub-1"
await manager.stop_file_watcher()
assert fake_watcher.unsubscribe_called_with == "sub-1"
assert fake_watcher.stop_called is True
assert manager._file_watcher is None
assert manager._file_watcher_subscription_id is None

View File

@@ -0,0 +1,105 @@
from pathlib import Path
from watchfiles import Change
import asyncio
import pytest
from src.config.file_watcher import FileChange, FileWatcher
@pytest.mark.asyncio
async def test_dispatch_changes_with_path_and_change_type_filters(tmp_path: Path):
watcher = FileWatcher(paths=[tmp_path])
target_file = (tmp_path / "bot_config.toml").resolve()
received: list[list[FileChange]] = []
async def callback(changes):
received.append(list(changes))
watcher.subscribe(callback, paths=[target_file], change_types=[Change.modified])
await watcher._dispatch_changes(
[
FileChange(change_type=Change.added, path=target_file),
FileChange(change_type=Change.modified, path=target_file),
FileChange(change_type=Change.modified, path=(tmp_path / "other.toml").resolve()),
]
)
assert len(received) == 1
assert len(received[0]) == 1
assert received[0][0].change_type == Change.modified
assert received[0][0].path == target_file
@pytest.mark.asyncio
async def test_sync_callback_supported(tmp_path: Path):
watcher = FileWatcher(paths=[tmp_path])
target_file = (tmp_path / "model_config.toml").resolve()
received_paths: list[Path] = []
def sync_callback(changes):
received_paths.extend(change.path for change in changes)
watcher.subscribe(sync_callback, paths=[target_file])
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
assert received_paths == [target_file]
@pytest.mark.asyncio
async def test_callback_timeout_and_cooldown(tmp_path: Path):
watcher = FileWatcher(
paths=[tmp_path],
callback_timeout_s=0.05,
callback_failure_threshold=2,
callback_cooldown_s=0.2,
)
target_file = (tmp_path / "bot_config.toml").resolve()
async def slow_callback(changes):
await asyncio.sleep(0.2)
watcher.subscribe(slow_callback, paths=[target_file])
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
stats_after_failures = watcher.stats
assert stats_after_failures.callbacks_timed_out == 2
assert stats_after_failures.callbacks_failed == 2
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
stats_after_cooldown_skip = watcher.stats
assert stats_after_cooldown_skip.callbacks_skipped_cooldown >= 1
@pytest.mark.asyncio
async def test_start_requires_subscription(tmp_path: Path):
watcher = FileWatcher(paths=[tmp_path])
with pytest.raises(RuntimeError):
await watcher.start()
@pytest.mark.asyncio
async def test_unsubscribe_stops_dispatch(tmp_path: Path):
watcher = FileWatcher(paths=[tmp_path])
target_file = (tmp_path / "bot_config.toml").resolve()
calls = 0
async def callback(changes):
nonlocal calls
calls += 1
subscription_id = watcher.subscribe(callback, paths=[target_file])
assert watcher.unsubscribe(subscription_id) is True
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
assert calls == 0

View File

@@ -182,6 +182,10 @@ class ConfigManager:
self._reload_lock: asyncio.Lock = asyncio.Lock()
self._reload_callbacks: list[Callable[[], object]] = []
self._file_watcher: FileWatcher | None = None
self._file_watcher_subscription_id: str | None = None
self._hot_reload_min_interval_s: float = 1.0
self._hot_reload_timeout_s: float = 20.0
self._last_hot_reload_monotonic: float = 0.0
def initialize(self):
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
@@ -261,21 +265,53 @@ class ConfigManager:
async def start_file_watcher(self) -> None:
if self._file_watcher is not None and self._file_watcher.running:
return
self._file_watcher = FileWatcher(paths=[self.bot_config_path, self.model_config_path])
await self._file_watcher.start(self._handle_file_changes)
self._file_watcher = FileWatcher(
paths=[self.bot_config_path, self.model_config_path],
debounce_ms=600,
callback_timeout_s=15.0,
callback_failure_threshold=3,
callback_cooldown_s=30.0,
)
self._file_watcher_subscription_id = self._file_watcher.subscribe(
self._handle_file_changes,
paths=[self.bot_config_path, self.model_config_path],
)
await self._file_watcher.start()
logger.info("配置文件监视器已启动")
async def stop_file_watcher(self) -> None:
if self._file_watcher is None:
return
if self._file_watcher_subscription_id is not None:
self._file_watcher.unsubscribe(self._file_watcher_subscription_id)
self._file_watcher_subscription_id = None
watcher_stats = self._file_watcher.stats
logger.info(
"配置文件监视器停止统计: "
f"batches={watcher_stats.batches_seen}, "
f"changes={watcher_stats.changes_seen}, "
f"ok={watcher_stats.callbacks_succeeded}, "
f"failed={watcher_stats.callbacks_failed}, "
f"timeout={watcher_stats.callbacks_timed_out}, "
f"cooldown_skip={watcher_stats.callbacks_skipped_cooldown}, "
f"restart={watcher_stats.restart_count}"
)
await self._file_watcher.stop()
self._file_watcher = None
async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None:
if not changes:
return
now_monotonic = asyncio.get_running_loop().time()
if now_monotonic - self._last_hot_reload_monotonic < self._hot_reload_min_interval_s:
logger.debug("文件变更触发过于频繁,已跳过本次重载")
return
self._last_hot_reload_monotonic = now_monotonic
logger.info("检测到配置文件变更,触发热重载")
await self.reload_config()
try:
await asyncio.wait_for(self.reload_config(), timeout=self._hot_reload_timeout_s)
except asyncio.TimeoutError:
logger.error(f"配置热重载超时(>{self._hot_reload_timeout_s}s")
def generate_new_config_file(config_class: type[T], config_path: Path, inner_config_version: str) -> None:

View File

@@ -5,6 +5,7 @@ from typing import Awaitable, Callable, Iterable, Sequence
from watchfiles import Change, awatch
import asyncio
import uuid
from src.common.logger import get_logger
@@ -18,25 +19,105 @@ class FileChange:
path: Path
ChangeCallback = Callable[[Sequence[FileChange]], Awaitable[None]]
ChangeCallback = Callable[[Sequence[FileChange]], Awaitable[None] | None]
@dataclass(frozen=True)
class FileWatchSubscription:
subscription_id: str
callback: ChangeCallback
paths: tuple[Path, ...]
change_types: frozenset[Change] | None
@dataclass
class SubscriptionState:
consecutive_failures: int = 0
cooldown_until_monotonic: float = 0.0
@dataclass
class FileWatcherStats:
batches_seen: int = 0
changes_seen: int = 0
callbacks_succeeded: int = 0
callbacks_failed: int = 0
callbacks_timed_out: int = 0
callbacks_skipped_cooldown: int = 0
restart_count: int = 0
class FileWatcher:
def __init__(self, paths: Iterable[Path], debounce_ms: int = 600) -> None:
def __init__(
self,
paths: Iterable[Path],
debounce_ms: int = 600,
callback_timeout_s: float = 10.0,
callback_failure_threshold: int = 3,
callback_cooldown_s: float = 30.0,
) -> None:
self._paths = [path.resolve() for path in paths]
self._debounce_ms = debounce_ms
self._callback_timeout_s = callback_timeout_s
self._callback_failure_threshold = callback_failure_threshold
self._callback_cooldown_s = callback_cooldown_s
self._running = False
self._task: asyncio.Task[None] | None = None
self._subscriptions: dict[str, FileWatchSubscription] = {}
self._subscription_states: dict[str, SubscriptionState] = {}
self._stats = FileWatcherStats()
@property
def running(self) -> bool:
return self._running
async def start(self, callback: ChangeCallback) -> None:
@property
def stats(self) -> FileWatcherStats:
return FileWatcherStats(
batches_seen=self._stats.batches_seen,
changes_seen=self._stats.changes_seen,
callbacks_succeeded=self._stats.callbacks_succeeded,
callbacks_failed=self._stats.callbacks_failed,
callbacks_timed_out=self._stats.callbacks_timed_out,
callbacks_skipped_cooldown=self._stats.callbacks_skipped_cooldown,
restart_count=self._stats.restart_count,
)
def subscribe(
self,
callback: ChangeCallback,
*,
paths: Iterable[Path] | None = None,
change_types: Iterable[Change] | None = None,
) -> str:
if not callable(callback):
raise TypeError("callback 必须是可调用对象")
normalized_paths = tuple(path.resolve() for path in paths) if paths is not None else ()
normalized_change_types = frozenset(change_types) if change_types is not None else None
subscription_id = str(uuid.uuid4())
self._subscriptions[subscription_id] = FileWatchSubscription(
subscription_id=subscription_id,
callback=callback,
paths=normalized_paths,
change_types=normalized_change_types,
)
self._subscription_states[subscription_id] = SubscriptionState()
return subscription_id
def unsubscribe(self, subscription_id: str) -> bool:
removed = self._subscriptions.pop(subscription_id, None) is not None
self._subscription_states.pop(subscription_id, None)
return removed
async def start(self) -> None:
if self._running:
return
if not self._subscriptions:
raise RuntimeError("启动文件监视器前必须至少注册一个订阅")
self._running = True
self._task = asyncio.create_task(self._run(callback))
self._task = asyncio.create_task(self._run())
async def stop(self) -> None:
if not self._running:
@@ -49,20 +130,93 @@ class FileWatcher:
await self._task
except asyncio.CancelledError:
return
finally:
self._task = None
async def _run(self, callback: ChangeCallback) -> None:
try:
async for changes in awatch(*self._paths, debounce=self._debounce_ms):
if not self._running:
break
try:
await callback(self._normalize_changes(changes))
except Exception as exc:
logger.warning(f"文件变更回调执行失败: {exc}")
except asyncio.CancelledError:
async def _run(self) -> None:
while self._running:
try:
async for changes in awatch(*self._paths, debounce=self._debounce_ms):
if not self._running:
break
normalized_changes = self._normalize_changes(changes)
if not normalized_changes:
continue
self._stats.batches_seen += 1
self._stats.changes_seen += len(normalized_changes)
try:
await self._dispatch_changes(normalized_changes)
except Exception as exc:
logger.warning(f"文件变更分发失败: {exc}")
except asyncio.CancelledError:
return
except Exception as exc:
self._stats.restart_count += 1
logger.error(f"文件监视器运行异常将在1秒后重试: {exc}")
if self._running:
await asyncio.sleep(1.0)
async def _dispatch_changes(self, changes: Sequence[FileChange]) -> None:
for subscription in list(self._subscriptions.values()):
matched_changes = self._match_changes(changes, subscription)
if not matched_changes:
continue
state = self._subscription_states.get(subscription.subscription_id)
if state is None:
continue
now_monotonic = asyncio.get_running_loop().time()
if state.cooldown_until_monotonic > now_monotonic:
self._stats.callbacks_skipped_cooldown += 1
continue
try:
await asyncio.wait_for(self._invoke_callback(subscription.callback, matched_changes), timeout=self._callback_timeout_s)
state.consecutive_failures = 0
self._stats.callbacks_succeeded += 1
except asyncio.TimeoutError:
self._stats.callbacks_timed_out += 1
self._stats.callbacks_failed += 1
self._mark_callback_failure(subscription.subscription_id)
logger.warning(
f"文件变更回调执行超时subscription_id={subscription.subscription_id}, timeout={self._callback_timeout_s}s"
)
except Exception as exc:
self._stats.callbacks_failed += 1
self._mark_callback_failure(subscription.subscription_id)
logger.warning(f"文件变更回调执行失败subscription_id={subscription.subscription_id}: {exc}")
async def _invoke_callback(self, callback: ChangeCallback, changes: Sequence[FileChange]) -> None:
if asyncio.iscoroutinefunction(callback):
await callback(changes)
return
except Exception as exc:
logger.error(f"文件监视器运行异常: {exc}")
await asyncio.to_thread(callback, changes)
def _mark_callback_failure(self, subscription_id: str) -> None:
state = self._subscription_states.get(subscription_id)
if state is None:
return
state.consecutive_failures += 1
if state.consecutive_failures >= self._callback_failure_threshold:
now_monotonic = asyncio.get_running_loop().time()
state.cooldown_until_monotonic = now_monotonic + self._callback_cooldown_s
state.consecutive_failures = 0
logger.warning(
f"文件变更回调进入冷却subscription_id={subscription_id}, cooldown={self._callback_cooldown_s}s"
)
def _match_changes(self, changes: Sequence[FileChange], subscription: FileWatchSubscription) -> list[FileChange]:
matched: list[FileChange] = []
for change in changes:
if subscription.change_types is not None and change.change_type not in subscription.change_types:
continue
if subscription.paths and not any(self._path_matches(change.path, path) for path in subscription.paths):
continue
matched.append(change)
return matched
def _path_matches(self, changed_path: Path, subscribed_path: Path) -> bool:
if subscribed_path.is_dir():
return changed_path == subscribed_path or changed_path.is_relative_to(subscribed_path)
return changed_path == subscribed_path
def _normalize_changes(self, changes: set[tuple[Change, str]]) -> list[FileChange]:
return [FileChange(change_type=change, path=Path(path)) for change, path in changes]
return [FileChange(change_type=change, path=Path(path).resolve()) for change, path in changes]