From b3a81754e68ba704f1ca67ce9a8642387f0b3c80 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Wed, 4 Mar 2026 21:39:26 +0800 Subject: [PATCH] feat(config): harden file watcher hot-reload flow and add test coverage refactor FileWatcher to subscription-based dispatch with path/change filters add callback timeout, failure cooldown, auto-retry loop, and runtime stats strengthen ConfigManager hot-reload with throttling, timeout guard, and watcher stats logging add pytest suites for watcher behavior and config hot-reload edge cases --- .../test_config_manager_hot_reload.py | 104 ++++++++++ pytests/config_test/test_file_watcher.py | 105 ++++++++++ src/config/config.py | 42 +++- src/config/file_watcher.py | 188 ++++++++++++++++-- 4 files changed, 419 insertions(+), 20 deletions(-) create mode 100644 pytests/config_test/test_config_manager_hot_reload.py create mode 100644 pytests/config_test/test_file_watcher.py diff --git a/pytests/config_test/test_config_manager_hot_reload.py b/pytests/config_test/test_config_manager_hot_reload.py new file mode 100644 index 00000000..ab2dd898 --- /dev/null +++ b/pytests/config_test/test_config_manager_hot_reload.py @@ -0,0 +1,104 @@ +from pathlib import Path + +from watchfiles import Change + +import asyncio +import pytest + +from src.config.config import ConfigManager +from src.config.file_watcher import FileChange, FileWatcherStats + + +@pytest.mark.asyncio +async def test_handle_file_changes_throttles_reload(): + manager = ConfigManager() + manager._hot_reload_min_interval_s = 100.0 + + called = 0 + + async def reload_stub() -> bool: + nonlocal called + called += 1 + return True + + manager.reload_config = reload_stub # type: ignore[method-assign] + changes = [FileChange(change_type=Change.modified, path=Path("/tmp/bot_config.toml"))] + + await manager._handle_file_changes(changes) + await manager._handle_file_changes(changes) + + assert called == 1 + + +@pytest.mark.asyncio +async def test_handle_file_changes_timeout_logged(caplog): + manager = ConfigManager() + manager._hot_reload_min_interval_s = 0.0 + manager._hot_reload_timeout_s = 0.01 + + async def reload_stub() -> bool: + await asyncio.sleep(0.05) + return True + + manager.reload_config = reload_stub # type: ignore[method-assign] + changes = [FileChange(change_type=Change.modified, path=Path("/tmp/model_config.toml"))] + + with caplog.at_level("ERROR"): + await manager._handle_file_changes(changes) + + assert "配置热重载超时" in caplog.text + + +@pytest.mark.asyncio +async def test_handle_file_changes_empty_skips_reload(): + manager = ConfigManager() + + called = 0 + + async def reload_stub() -> bool: + nonlocal called + called += 1 + return True + + manager.reload_config = reload_stub # type: ignore[method-assign] + + await manager._handle_file_changes([]) + + assert called == 0 + + +class _FakeWatcher: + def __init__(self): + self.unsubscribe_called_with: str | None = None + self.stop_called = False + self.stats = FileWatcherStats( + batches_seen=1, + changes_seen=2, + callbacks_succeeded=3, + callbacks_failed=4, + callbacks_timed_out=5, + callbacks_skipped_cooldown=6, + restart_count=7, + ) + + def unsubscribe(self, subscription_id: str) -> bool: + self.unsubscribe_called_with = subscription_id + return True + + async def stop(self) -> None: + self.stop_called = True + + +@pytest.mark.asyncio +async def test_stop_file_watcher_cleans_state(): + manager = ConfigManager() + fake_watcher = _FakeWatcher() + manager._file_watcher = fake_watcher # type: ignore[assignment] + manager._file_watcher_subscription_id = "sub-1" + + await manager.stop_file_watcher() + + assert fake_watcher.unsubscribe_called_with == "sub-1" + assert fake_watcher.stop_called is True + assert manager._file_watcher is None + assert manager._file_watcher_subscription_id is None diff --git a/pytests/config_test/test_file_watcher.py b/pytests/config_test/test_file_watcher.py new file mode 100644 index 00000000..6fa7693b --- /dev/null +++ b/pytests/config_test/test_file_watcher.py @@ -0,0 +1,105 @@ +from pathlib import Path + +from watchfiles import Change + +import asyncio +import pytest + +from src.config.file_watcher import FileChange, FileWatcher + + +@pytest.mark.asyncio +async def test_dispatch_changes_with_path_and_change_type_filters(tmp_path: Path): + watcher = FileWatcher(paths=[tmp_path]) + target_file = (tmp_path / "bot_config.toml").resolve() + + received: list[list[FileChange]] = [] + + async def callback(changes): + received.append(list(changes)) + + watcher.subscribe(callback, paths=[target_file], change_types=[Change.modified]) + + await watcher._dispatch_changes( + [ + FileChange(change_type=Change.added, path=target_file), + FileChange(change_type=Change.modified, path=target_file), + FileChange(change_type=Change.modified, path=(tmp_path / "other.toml").resolve()), + ] + ) + + assert len(received) == 1 + assert len(received[0]) == 1 + assert received[0][0].change_type == Change.modified + assert received[0][0].path == target_file + + +@pytest.mark.asyncio +async def test_sync_callback_supported(tmp_path: Path): + watcher = FileWatcher(paths=[tmp_path]) + target_file = (tmp_path / "model_config.toml").resolve() + + received_paths: list[Path] = [] + + def sync_callback(changes): + received_paths.extend(change.path for change in changes) + + watcher.subscribe(sync_callback, paths=[target_file]) + + await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)]) + + assert received_paths == [target_file] + + +@pytest.mark.asyncio +async def test_callback_timeout_and_cooldown(tmp_path: Path): + watcher = FileWatcher( + paths=[tmp_path], + callback_timeout_s=0.05, + callback_failure_threshold=2, + callback_cooldown_s=0.2, + ) + target_file = (tmp_path / "bot_config.toml").resolve() + + async def slow_callback(changes): + await asyncio.sleep(0.2) + + watcher.subscribe(slow_callback, paths=[target_file]) + + await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)]) + await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)]) + + stats_after_failures = watcher.stats + assert stats_after_failures.callbacks_timed_out == 2 + assert stats_after_failures.callbacks_failed == 2 + + await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)]) + stats_after_cooldown_skip = watcher.stats + assert stats_after_cooldown_skip.callbacks_skipped_cooldown >= 1 + + +@pytest.mark.asyncio +async def test_start_requires_subscription(tmp_path: Path): + watcher = FileWatcher(paths=[tmp_path]) + + with pytest.raises(RuntimeError): + await watcher.start() + + +@pytest.mark.asyncio +async def test_unsubscribe_stops_dispatch(tmp_path: Path): + watcher = FileWatcher(paths=[tmp_path]) + target_file = (tmp_path / "bot_config.toml").resolve() + + calls = 0 + + async def callback(changes): + nonlocal calls + calls += 1 + + subscription_id = watcher.subscribe(callback, paths=[target_file]) + assert watcher.unsubscribe(subscription_id) is True + + await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)]) + + assert calls == 0 diff --git a/src/config/config.py b/src/config/config.py index c2ebd345..a30fa778 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -182,6 +182,10 @@ class ConfigManager: self._reload_lock: asyncio.Lock = asyncio.Lock() self._reload_callbacks: list[Callable[[], object]] = [] self._file_watcher: FileWatcher | None = None + self._file_watcher_subscription_id: str | None = None + self._hot_reload_min_interval_s: float = 1.0 + self._hot_reload_timeout_s: float = 20.0 + self._last_hot_reload_monotonic: float = 0.0 def initialize(self): logger.info(f"MaiCore当前版本: {MMC_VERSION}") @@ -261,21 +265,53 @@ class ConfigManager: async def start_file_watcher(self) -> None: if self._file_watcher is not None and self._file_watcher.running: return - self._file_watcher = FileWatcher(paths=[self.bot_config_path, self.model_config_path]) - await self._file_watcher.start(self._handle_file_changes) + self._file_watcher = FileWatcher( + paths=[self.bot_config_path, self.model_config_path], + debounce_ms=600, + callback_timeout_s=15.0, + callback_failure_threshold=3, + callback_cooldown_s=30.0, + ) + self._file_watcher_subscription_id = self._file_watcher.subscribe( + self._handle_file_changes, + paths=[self.bot_config_path, self.model_config_path], + ) + await self._file_watcher.start() logger.info("配置文件监视器已启动") async def stop_file_watcher(self) -> None: if self._file_watcher is None: return + if self._file_watcher_subscription_id is not None: + self._file_watcher.unsubscribe(self._file_watcher_subscription_id) + self._file_watcher_subscription_id = None + watcher_stats = self._file_watcher.stats + logger.info( + "配置文件监视器停止统计: " + f"batches={watcher_stats.batches_seen}, " + f"changes={watcher_stats.changes_seen}, " + f"ok={watcher_stats.callbacks_succeeded}, " + f"failed={watcher_stats.callbacks_failed}, " + f"timeout={watcher_stats.callbacks_timed_out}, " + f"cooldown_skip={watcher_stats.callbacks_skipped_cooldown}, " + f"restart={watcher_stats.restart_count}" + ) await self._file_watcher.stop() self._file_watcher = None async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None: if not changes: return + now_monotonic = asyncio.get_running_loop().time() + if now_monotonic - self._last_hot_reload_monotonic < self._hot_reload_min_interval_s: + logger.debug("文件变更触发过于频繁,已跳过本次重载") + return + self._last_hot_reload_monotonic = now_monotonic logger.info("检测到配置文件变更,触发热重载") - await self.reload_config() + try: + await asyncio.wait_for(self.reload_config(), timeout=self._hot_reload_timeout_s) + except asyncio.TimeoutError: + logger.error(f"配置热重载超时(>{self._hot_reload_timeout_s}s)") def generate_new_config_file(config_class: type[T], config_path: Path, inner_config_version: str) -> None: diff --git a/src/config/file_watcher.py b/src/config/file_watcher.py index fd0dc0f1..0ec81e4a 100644 --- a/src/config/file_watcher.py +++ b/src/config/file_watcher.py @@ -5,6 +5,7 @@ from typing import Awaitable, Callable, Iterable, Sequence from watchfiles import Change, awatch import asyncio +import uuid from src.common.logger import get_logger @@ -18,25 +19,105 @@ class FileChange: path: Path -ChangeCallback = Callable[[Sequence[FileChange]], Awaitable[None]] +ChangeCallback = Callable[[Sequence[FileChange]], Awaitable[None] | None] + + +@dataclass(frozen=True) +class FileWatchSubscription: + subscription_id: str + callback: ChangeCallback + paths: tuple[Path, ...] + change_types: frozenset[Change] | None + + +@dataclass +class SubscriptionState: + consecutive_failures: int = 0 + cooldown_until_monotonic: float = 0.0 + + +@dataclass +class FileWatcherStats: + batches_seen: int = 0 + changes_seen: int = 0 + callbacks_succeeded: int = 0 + callbacks_failed: int = 0 + callbacks_timed_out: int = 0 + callbacks_skipped_cooldown: int = 0 + restart_count: int = 0 class FileWatcher: - def __init__(self, paths: Iterable[Path], debounce_ms: int = 600) -> None: + def __init__( + self, + paths: Iterable[Path], + debounce_ms: int = 600, + callback_timeout_s: float = 10.0, + callback_failure_threshold: int = 3, + callback_cooldown_s: float = 30.0, + ) -> None: self._paths = [path.resolve() for path in paths] self._debounce_ms = debounce_ms + self._callback_timeout_s = callback_timeout_s + self._callback_failure_threshold = callback_failure_threshold + self._callback_cooldown_s = callback_cooldown_s self._running = False self._task: asyncio.Task[None] | None = None + self._subscriptions: dict[str, FileWatchSubscription] = {} + self._subscription_states: dict[str, SubscriptionState] = {} + self._stats = FileWatcherStats() @property def running(self) -> bool: return self._running - async def start(self, callback: ChangeCallback) -> None: + @property + def stats(self) -> FileWatcherStats: + return FileWatcherStats( + batches_seen=self._stats.batches_seen, + changes_seen=self._stats.changes_seen, + callbacks_succeeded=self._stats.callbacks_succeeded, + callbacks_failed=self._stats.callbacks_failed, + callbacks_timed_out=self._stats.callbacks_timed_out, + callbacks_skipped_cooldown=self._stats.callbacks_skipped_cooldown, + restart_count=self._stats.restart_count, + ) + + def subscribe( + self, + callback: ChangeCallback, + *, + paths: Iterable[Path] | None = None, + change_types: Iterable[Change] | None = None, + ) -> str: + if not callable(callback): + raise TypeError("callback 必须是可调用对象") + + normalized_paths = tuple(path.resolve() for path in paths) if paths is not None else () + normalized_change_types = frozenset(change_types) if change_types is not None else None + + subscription_id = str(uuid.uuid4()) + self._subscriptions[subscription_id] = FileWatchSubscription( + subscription_id=subscription_id, + callback=callback, + paths=normalized_paths, + change_types=normalized_change_types, + ) + self._subscription_states[subscription_id] = SubscriptionState() + return subscription_id + + def unsubscribe(self, subscription_id: str) -> bool: + removed = self._subscriptions.pop(subscription_id, None) is not None + self._subscription_states.pop(subscription_id, None) + return removed + + async def start(self) -> None: if self._running: return + if not self._subscriptions: + raise RuntimeError("启动文件监视器前必须至少注册一个订阅") self._running = True - self._task = asyncio.create_task(self._run(callback)) + self._task = asyncio.create_task(self._run()) async def stop(self) -> None: if not self._running: @@ -49,20 +130,93 @@ class FileWatcher: await self._task except asyncio.CancelledError: return + finally: + self._task = None - async def _run(self, callback: ChangeCallback) -> None: - try: - async for changes in awatch(*self._paths, debounce=self._debounce_ms): - if not self._running: - break - try: - await callback(self._normalize_changes(changes)) - except Exception as exc: - logger.warning(f"文件变更回调执行失败: {exc}") - except asyncio.CancelledError: + async def _run(self) -> None: + while self._running: + try: + async for changes in awatch(*self._paths, debounce=self._debounce_ms): + if not self._running: + break + normalized_changes = self._normalize_changes(changes) + if not normalized_changes: + continue + self._stats.batches_seen += 1 + self._stats.changes_seen += len(normalized_changes) + try: + await self._dispatch_changes(normalized_changes) + except Exception as exc: + logger.warning(f"文件变更分发失败: {exc}") + except asyncio.CancelledError: + return + except Exception as exc: + self._stats.restart_count += 1 + logger.error(f"文件监视器运行异常,将在1秒后重试: {exc}") + if self._running: + await asyncio.sleep(1.0) + + async def _dispatch_changes(self, changes: Sequence[FileChange]) -> None: + for subscription in list(self._subscriptions.values()): + matched_changes = self._match_changes(changes, subscription) + if not matched_changes: + continue + state = self._subscription_states.get(subscription.subscription_id) + if state is None: + continue + now_monotonic = asyncio.get_running_loop().time() + if state.cooldown_until_monotonic > now_monotonic: + self._stats.callbacks_skipped_cooldown += 1 + continue + try: + await asyncio.wait_for(self._invoke_callback(subscription.callback, matched_changes), timeout=self._callback_timeout_s) + state.consecutive_failures = 0 + self._stats.callbacks_succeeded += 1 + except asyncio.TimeoutError: + self._stats.callbacks_timed_out += 1 + self._stats.callbacks_failed += 1 + self._mark_callback_failure(subscription.subscription_id) + logger.warning( + f"文件变更回调执行超时(subscription_id={subscription.subscription_id}, timeout={self._callback_timeout_s}s)" + ) + except Exception as exc: + self._stats.callbacks_failed += 1 + self._mark_callback_failure(subscription.subscription_id) + logger.warning(f"文件变更回调执行失败(subscription_id={subscription.subscription_id}): {exc}") + + async def _invoke_callback(self, callback: ChangeCallback, changes: Sequence[FileChange]) -> None: + if asyncio.iscoroutinefunction(callback): + await callback(changes) return - except Exception as exc: - logger.error(f"文件监视器运行异常: {exc}") + await asyncio.to_thread(callback, changes) + + def _mark_callback_failure(self, subscription_id: str) -> None: + state = self._subscription_states.get(subscription_id) + if state is None: + return + state.consecutive_failures += 1 + if state.consecutive_failures >= self._callback_failure_threshold: + now_monotonic = asyncio.get_running_loop().time() + state.cooldown_until_monotonic = now_monotonic + self._callback_cooldown_s + state.consecutive_failures = 0 + logger.warning( + f"文件变更回调进入冷却(subscription_id={subscription_id}, cooldown={self._callback_cooldown_s}s)" + ) + + def _match_changes(self, changes: Sequence[FileChange], subscription: FileWatchSubscription) -> list[FileChange]: + matched: list[FileChange] = [] + for change in changes: + if subscription.change_types is not None and change.change_type not in subscription.change_types: + continue + if subscription.paths and not any(self._path_matches(change.path, path) for path in subscription.paths): + continue + matched.append(change) + return matched + + def _path_matches(self, changed_path: Path, subscribed_path: Path) -> bool: + if subscribed_path.is_dir(): + return changed_path == subscribed_path or changed_path.is_relative_to(subscribed_path) + return changed_path == subscribed_path def _normalize_changes(self, changes: set[tuple[Change, str]]) -> list[FileChange]: - return [FileChange(change_type=change, path=Path(path)) for change, path in changes] + return [FileChange(change_type=change, path=Path(path).resolve()) for change, path in changes]