from pathlib import Path from typing import Any, Callable, Mapping, Sequence, TypeVar from datetime import datetime import asyncio import copy import tomlkit import sys from .legacy_migration import try_migrate_legacy_bot_config_dict from .official_configs import ( BotConfig, PersonalityConfig, ExpressionConfig, ChatConfig, EmojiConfig, KeywordReactionConfig, ChineseTypoConfig, ResponsePostProcessConfig, ResponseSplitterConfig, TelemetryConfig, ExperimentalConfig, MessageReceiveConfig, MaimMessageConfig, LPMMKnowledgeConfig, RelationshipConfig, ToolConfig, VoiceConfig, MemoryConfig, DebugConfig, DreamConfig, WebUIConfig, DatabaseConfig, ) from .model_configs import ModelInfo, ModelTaskConfig, APIProvider from .config_base import ConfigBase, Field, AttributeData from .config_utils import recursive_parse_item_to_table, output_config_changes, compare_versions from src.common.logger import get_logger from src.config.file_watcher import FileChange, FileWatcher """ 如果你想要修改配置文件,请递增version的值 版本格式:主版本号.次版本号.修订号,版本号递增规则如下: 主版本号:MMC版本更新 次版本号:配置文件内容大更新 修订号:配置文件内容小更新 """ PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve() CONFIG_DIR: Path = PROJECT_ROOT / "config" BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute() MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute() MMC_VERSION: str = "1.0.0" CONFIG_VERSION: str = "8.0.0" MODEL_CONFIG_VERSION: str = "1.12.0" logger = get_logger("config") T = TypeVar("T", bound="ConfigBase") class Config(ConfigBase): """总配置类""" bot: BotConfig = Field(default_factory=BotConfig) """机器人配置类""" personality: PersonalityConfig = Field(default_factory=PersonalityConfig) """人格配置类""" expression: ExpressionConfig = Field(default_factory=ExpressionConfig) """表达配置类""" chat: ChatConfig = Field(default_factory=ChatConfig) """聊天配置类""" memory: MemoryConfig = Field(default_factory=MemoryConfig) """记忆配置类""" relationship: RelationshipConfig = Field(default_factory=RelationshipConfig) """关系配置类""" message_receive: MessageReceiveConfig = Field(default_factory=MessageReceiveConfig) """消息接收配置类""" dream: DreamConfig = Field(default_factory=DreamConfig) """做梦配置类""" tool: ToolConfig = Field(default_factory=ToolConfig) """工具配置类""" voice: VoiceConfig = Field(default_factory=VoiceConfig) """语音配置类""" emoji: EmojiConfig = Field(default_factory=EmojiConfig) """表情包配置类""" keyword_reaction: KeywordReactionConfig = Field(default_factory=KeywordReactionConfig) """关键词反应配置类""" response_post_process: ResponsePostProcessConfig = Field(default_factory=ResponsePostProcessConfig) """回复后处理配置类""" chinese_typo: ChineseTypoConfig = Field(default_factory=ChineseTypoConfig) """中文错别字生成器配置类""" response_splitter: ResponseSplitterConfig = Field(default_factory=ResponseSplitterConfig) """回复分割器配置类""" telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig) """遥测配置类""" debug: DebugConfig = Field(default_factory=DebugConfig) """调试配置类""" experimental: ExperimentalConfig = Field(default_factory=ExperimentalConfig) """实验性功能配置类""" maim_message: MaimMessageConfig = Field(default_factory=MaimMessageConfig) """maim_message配置类""" lpmm_knowledge: LPMMKnowledgeConfig = Field(default_factory=LPMMKnowledgeConfig) """LPMM知识库配置类""" webui: WebUIConfig = Field(default_factory=WebUIConfig) """WebUI配置类""" database: DatabaseConfig = Field(default_factory=DatabaseConfig) """数据库配置类""" class ModelConfig(ConfigBase): """模型配置类""" models: list[ModelInfo] = Field(default_factory=list) """模型配置列表""" model_task_config: ModelTaskConfig = Field(default_factory=ModelTaskConfig) """模型任务配置""" api_providers: list[APIProvider] = Field(default_factory=list) """API提供商列表""" def model_post_init(self, context: Any = None): if not self.models: raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。") if not self.api_providers: raise ValueError("API提供商列表不能为空,请在配置中设置有效的API提供商列表。") # 检查API提供商名称是否重复 provider_names = [provider.name for provider in self.api_providers] if len(provider_names) != len(set(provider_names)): raise ValueError("API提供商名称存在重复,请检查配置文件。") # 检查模型名称是否重复 model_names = [model.name for model in self.models] if len(model_names) != len(set(model_names)): raise ValueError("模型名称存在重复,请检查配置文件。") api_providers_dict = {provider.name: provider for provider in self.api_providers} for model in self.models: if not model.model_identifier: raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空") if not model.api_provider or model.api_provider not in api_providers_dict: raise ValueError(f"模型 '{model.name}' 的 api_provider '{model.api_provider}' 不存在") return super().model_post_init(context) class ConfigManager: """总配置管理类""" def __init__(self): self.bot_config_path: Path = BOT_CONFIG_PATH self.model_config_path: Path = MODEL_CONFIG_PATH CONFIG_DIR.mkdir(parents=True, exist_ok=True) self.global_config: Config | None = None self.model_config: ModelConfig | None = None self._reload_lock: asyncio.Lock = asyncio.Lock() self._reload_callbacks: list[Callable[[], object]] = [] self._file_watcher: FileWatcher | None = None def initialize(self): logger.info(f"MaiCore当前版本: {MMC_VERSION}") logger.info("正在品鉴配置文件...") self.global_config = self.load_global_config() self.model_config = self.load_model_config() logger.info("非常的新鲜,非常的美味!") def load_global_config(self) -> Config: config, updated = load_config_from_file(Config, self.bot_config_path, CONFIG_VERSION) if updated: sys.exit(0) # 先直接退出 return config def load_model_config(self) -> ModelConfig: config, updated = load_config_from_file(ModelConfig, self.model_config_path, MODEL_CONFIG_VERSION, True) if updated: sys.exit(0) # 先直接退出 return config def get_global_config(self) -> Config: if self.global_config is None: raise RuntimeError("global_config 未初始化") return self.global_config def get_model_config(self) -> ModelConfig: if self.model_config is None: raise RuntimeError("model_config 未初始化") return self.model_config def register_reload_callback(self, callback: Callable[[], object]) -> None: self._reload_callbacks.append(callback) def unregister_reload_callback(self, callback: Callable[[], object]) -> None: try: self._reload_callbacks.remove(callback) except ValueError: return async def reload_config(self) -> bool: async with self._reload_lock: try: global_config_new, global_updated = load_config_from_file( Config, self.bot_config_path, CONFIG_VERSION, ) model_config_new, model_updated = load_config_from_file( ModelConfig, self.model_config_path, MODEL_CONFIG_VERSION, True, ) except Exception as exc: logger.error(f"配置重载失败: {exc}") return False if global_updated or model_updated: logger.warning("检测到配置版本更新,热重载仅更新内存数据") self.global_config = global_config_new self.model_config = model_config_new global global_config, model_config global_config = global_config_new model_config = model_config_new logger.info("配置热重载完成") for callback in list(self._reload_callbacks): try: result = callback() if asyncio.iscoroutine(result): await result except Exception as exc: logger.warning(f"配置重载回调执行失败: {exc}") return True async def start_file_watcher(self) -> None: if self._file_watcher is not None and self._file_watcher.running: return self._file_watcher = FileWatcher(paths=[self.bot_config_path, self.model_config_path]) await self._file_watcher.start(self._handle_file_changes) logger.info("配置文件监视器已启动") async def stop_file_watcher(self) -> None: if self._file_watcher is None: return await self._file_watcher.stop() self._file_watcher = None async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None: if not changes: return logger.info("检测到配置文件变更,触发热重载") await self.reload_config() def generate_new_config_file(config_class: type[T], config_path: Path, inner_config_version: str) -> None: """生成新的配置文件 :param config_class: 配置类 :param config_path: 配置文件路径 :param inner_config_version: 配置文件版本号 """ config = config_class() write_config_to_file(config, config_path, inner_config_version) def load_config_from_file( config_class: type[T], config_path: Path, new_ver: str, override_repr: bool = False ) -> tuple[T, bool]: attribute_data = AttributeData() with open(config_path, "r", encoding="utf-8") as f: config_data = tomlkit.load(f) inner_table = config_data.get("inner") if not isinstance(inner_table, Mapping): raise TypeError("配置文件缺少 inner 版本信息") inner_version = inner_table.get("version") if not isinstance(inner_version, str): raise TypeError("配置文件 inner.version 类型错误") old_ver: str = inner_version config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理 config_data = config_data.unwrap() # 转换为普通字典,方便后续处理 # 保留一份“干净”的原始数据副本,避免第一次 from_dict 过程中对 dict 的就地修改 original_data: dict[str, Any] = copy.deepcopy(config_data) try: updated: bool = False try: target_config = config_class.from_dict(attribute_data, config_data) except TypeError as e: # 可拔插的旧配置修复(仅针对 bot_config.toml 的已知结构变更) if config_path.name == "bot_config.toml" and config_class.__name__ == "Config": # 基于未被部分构造污染的 original_data 做迁移尝试 mig = try_migrate_legacy_bot_config_dict(original_data) if mig.migrated: logger.warning( f"检测到旧版配置结构,已尝试自动修复: {mig.reason}。建议稍后检查并保存生成的新配置文件。" ) migrated_data = mig.data target_config = config_class.from_dict(attribute_data, migrated_data) else: raise e else: raise e if compare_versions(old_ver, new_ver): output_config_changes(attribute_data, logger, old_ver, new_ver, config_path.name) write_config_to_file(target_config, config_path, new_ver, override_repr) updated = True return target_config, updated except Exception as e: logger.critical(f"配置文件{config_path.name}解析失败") raise e def write_config_to_file( config: ConfigBase, config_path: Path, inner_config_version: str, override_repr: bool = False ) -> None: """将配置写入文件 :param config: 配置对象 :param config_path: 配置文件路径 """ # 创建空TOMLDocument full_config_data = tomlkit.document() # 首先写入配置文件版本信息 version_table = tomlkit.table() version_table.add("version", inner_config_version) full_config_data.add("inner", version_table) # 递归解析配置项为表格 for config_item_name, config_item in type(config).model_fields.items(): if not config_item.repr and not override_repr: continue if config_item_name in ["field_docs", "_validate_any", "suppress_any_warning"]: continue config_field = getattr(config, config_item_name) if isinstance(config_field, ConfigBase): full_config_data.add( config_item_name, recursive_parse_item_to_table(config_field, override_repr=override_repr) ) elif isinstance(config_field, list): aot = tomlkit.aot() for item in config_field: if not isinstance(item, ConfigBase): raise TypeError("配置写入只支持ConfigBase子类") aot.append(recursive_parse_item_to_table(item, override_repr=override_repr)) full_config_data.add(config_item_name, aot) else: raise TypeError("配置写入只支持ConfigBase子类") # 备份旧文件 if config_path.exists(): backup_root = config_path.parent / "old" backup_root.mkdir(parents=True, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") backup_path = backup_root / f"{config_path.stem}_{timestamp}.toml" config_path.replace(backup_path) # 写入文件 with open(config_path, "w", encoding="utf-8") as f: tomlkit.dump(full_config_data, f) # generate_new_config_file(Config, BOT_CONFIG_PATH, CONFIG_VERSION) config_manager = ConfigManager() config_manager.initialize() global_config = config_manager.get_global_config() model_config = config_manager.get_model_config()