393 lines
15 KiB
Python
393 lines
15 KiB
Python
from pathlib import Path
|
||
from typing import Any, Callable, Mapping, Sequence, TypeVar
|
||
from datetime import datetime
|
||
import asyncio
|
||
import copy
|
||
|
||
import tomlkit
|
||
import sys
|
||
|
||
from .legacy_migration import try_migrate_legacy_bot_config_dict
|
||
|
||
from .official_configs import (
|
||
BotConfig,
|
||
PersonalityConfig,
|
||
ExpressionConfig,
|
||
ChatConfig,
|
||
EmojiConfig,
|
||
KeywordReactionConfig,
|
||
ChineseTypoConfig,
|
||
ResponsePostProcessConfig,
|
||
ResponseSplitterConfig,
|
||
TelemetryConfig,
|
||
ExperimentalConfig,
|
||
MessageReceiveConfig,
|
||
MaimMessageConfig,
|
||
LPMMKnowledgeConfig,
|
||
RelationshipConfig,
|
||
ToolConfig,
|
||
VoiceConfig,
|
||
MemoryConfig,
|
||
DebugConfig,
|
||
DreamConfig,
|
||
WebUIConfig,
|
||
DatabaseConfig,
|
||
)
|
||
from .model_configs import ModelInfo, ModelTaskConfig, APIProvider
|
||
from .config_base import ConfigBase, Field, AttributeData
|
||
from .config_utils import recursive_parse_item_to_table, output_config_changes, compare_versions
|
||
|
||
from src.common.logger import get_logger
|
||
from src.config.file_watcher import FileChange, FileWatcher
|
||
|
||
"""
|
||
如果你想要修改配置文件,请递增version的值
|
||
|
||
版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
|
||
主版本号:MMC版本更新
|
||
次版本号:配置文件内容大更新
|
||
修订号:配置文件内容小更新
|
||
"""
|
||
|
||
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
|
||
CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||
MMC_VERSION: str = "1.0.0"
|
||
CONFIG_VERSION: str = "8.0.0"
|
||
MODEL_CONFIG_VERSION: str = "1.12.0"
|
||
|
||
logger = get_logger("config")
|
||
|
||
T = TypeVar("T", bound="ConfigBase")
|
||
|
||
|
||
class Config(ConfigBase):
|
||
"""总配置类"""
|
||
|
||
bot: BotConfig = Field(default_factory=BotConfig)
|
||
"""机器人配置类"""
|
||
|
||
personality: PersonalityConfig = Field(default_factory=PersonalityConfig)
|
||
"""人格配置类"""
|
||
|
||
expression: ExpressionConfig = Field(default_factory=ExpressionConfig)
|
||
"""表达配置类"""
|
||
|
||
chat: ChatConfig = Field(default_factory=ChatConfig)
|
||
"""聊天配置类"""
|
||
|
||
memory: MemoryConfig = Field(default_factory=MemoryConfig)
|
||
"""记忆配置类"""
|
||
|
||
relationship: RelationshipConfig = Field(default_factory=RelationshipConfig)
|
||
"""关系配置类"""
|
||
|
||
message_receive: MessageReceiveConfig = Field(default_factory=MessageReceiveConfig)
|
||
"""消息接收配置类"""
|
||
|
||
dream: DreamConfig = Field(default_factory=DreamConfig)
|
||
"""做梦配置类"""
|
||
|
||
tool: ToolConfig = Field(default_factory=ToolConfig)
|
||
"""工具配置类"""
|
||
|
||
voice: VoiceConfig = Field(default_factory=VoiceConfig)
|
||
"""语音配置类"""
|
||
|
||
emoji: EmojiConfig = Field(default_factory=EmojiConfig)
|
||
"""表情包配置类"""
|
||
|
||
keyword_reaction: KeywordReactionConfig = Field(default_factory=KeywordReactionConfig)
|
||
"""关键词反应配置类"""
|
||
|
||
response_post_process: ResponsePostProcessConfig = Field(default_factory=ResponsePostProcessConfig)
|
||
"""回复后处理配置类"""
|
||
|
||
chinese_typo: ChineseTypoConfig = Field(default_factory=ChineseTypoConfig)
|
||
"""中文错别字生成器配置类"""
|
||
|
||
response_splitter: ResponseSplitterConfig = Field(default_factory=ResponseSplitterConfig)
|
||
"""回复分割器配置类"""
|
||
|
||
telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig)
|
||
"""遥测配置类"""
|
||
|
||
debug: DebugConfig = Field(default_factory=DebugConfig)
|
||
"""调试配置类"""
|
||
|
||
experimental: ExperimentalConfig = Field(default_factory=ExperimentalConfig)
|
||
"""实验性功能配置类"""
|
||
|
||
maim_message: MaimMessageConfig = Field(default_factory=MaimMessageConfig)
|
||
"""maim_message配置类"""
|
||
|
||
lpmm_knowledge: LPMMKnowledgeConfig = Field(default_factory=LPMMKnowledgeConfig)
|
||
"""LPMM知识库配置类"""
|
||
|
||
webui: WebUIConfig = Field(default_factory=WebUIConfig)
|
||
"""WebUI配置类"""
|
||
|
||
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
|
||
"""数据库配置类"""
|
||
|
||
|
||
class ModelConfig(ConfigBase):
|
||
"""模型配置类"""
|
||
|
||
models: list[ModelInfo] = Field(default_factory=list)
|
||
"""模型配置列表"""
|
||
|
||
model_task_config: ModelTaskConfig = Field(default_factory=ModelTaskConfig)
|
||
"""模型任务配置"""
|
||
|
||
api_providers: list[APIProvider] = Field(default_factory=list)
|
||
"""API提供商列表"""
|
||
|
||
def model_post_init(self, context: Any = None):
|
||
if not self.models:
|
||
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
|
||
if not self.api_providers:
|
||
raise ValueError("API提供商列表不能为空,请在配置中设置有效的API提供商列表。")
|
||
|
||
# 检查API提供商名称是否重复
|
||
provider_names = [provider.name for provider in self.api_providers]
|
||
if len(provider_names) != len(set(provider_names)):
|
||
raise ValueError("API提供商名称存在重复,请检查配置文件。")
|
||
|
||
# 检查模型名称是否重复
|
||
model_names = [model.name for model in self.models]
|
||
if len(model_names) != len(set(model_names)):
|
||
raise ValueError("模型名称存在重复,请检查配置文件。")
|
||
|
||
api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||
|
||
for model in self.models:
|
||
if not model.model_identifier:
|
||
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
|
||
if not model.api_provider or model.api_provider not in api_providers_dict:
|
||
raise ValueError(f"模型 '{model.name}' 的 api_provider '{model.api_provider}' 不存在")
|
||
return super().model_post_init(context)
|
||
|
||
|
||
class ConfigManager:
|
||
"""总配置管理类"""
|
||
|
||
def __init__(self):
|
||
self.bot_config_path: Path = BOT_CONFIG_PATH
|
||
self.model_config_path: Path = MODEL_CONFIG_PATH
|
||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||
self.global_config: Config | None = None
|
||
self.model_config: ModelConfig | None = None
|
||
self._reload_lock: asyncio.Lock = asyncio.Lock()
|
||
self._reload_callbacks: list[Callable[[], object]] = []
|
||
self._file_watcher: FileWatcher | None = None
|
||
|
||
def initialize(self):
|
||
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
|
||
logger.info("正在品鉴配置文件...")
|
||
self.global_config = self.load_global_config()
|
||
self.model_config = self.load_model_config()
|
||
logger.info("非常的新鲜,非常的美味!")
|
||
|
||
def load_global_config(self) -> Config:
|
||
config, updated = load_config_from_file(Config, self.bot_config_path, CONFIG_VERSION)
|
||
if updated:
|
||
sys.exit(0) # 先直接退出
|
||
return config
|
||
|
||
def load_model_config(self) -> ModelConfig:
|
||
config, updated = load_config_from_file(ModelConfig, self.model_config_path, MODEL_CONFIG_VERSION, True)
|
||
if updated:
|
||
sys.exit(0) # 先直接退出
|
||
return config
|
||
|
||
def get_global_config(self) -> Config:
|
||
if self.global_config is None:
|
||
raise RuntimeError("global_config 未初始化")
|
||
return self.global_config
|
||
|
||
def get_model_config(self) -> ModelConfig:
|
||
if self.model_config is None:
|
||
raise RuntimeError("model_config 未初始化")
|
||
return self.model_config
|
||
|
||
def register_reload_callback(self, callback: Callable[[], object]) -> None:
|
||
self._reload_callbacks.append(callback)
|
||
|
||
def unregister_reload_callback(self, callback: Callable[[], object]) -> None:
|
||
try:
|
||
self._reload_callbacks.remove(callback)
|
||
except ValueError:
|
||
return
|
||
|
||
async def reload_config(self) -> bool:
|
||
async with self._reload_lock:
|
||
try:
|
||
global_config_new, global_updated = load_config_from_file(
|
||
Config,
|
||
self.bot_config_path,
|
||
CONFIG_VERSION,
|
||
)
|
||
model_config_new, model_updated = load_config_from_file(
|
||
ModelConfig,
|
||
self.model_config_path,
|
||
MODEL_CONFIG_VERSION,
|
||
True,
|
||
)
|
||
except Exception as exc:
|
||
logger.error(f"配置重载失败: {exc}")
|
||
return False
|
||
|
||
if global_updated or model_updated:
|
||
logger.warning("检测到配置版本更新,热重载仅更新内存数据")
|
||
|
||
self.global_config = global_config_new
|
||
self.model_config = model_config_new
|
||
global global_config, model_config
|
||
global_config = global_config_new
|
||
model_config = model_config_new
|
||
logger.info("配置热重载完成")
|
||
|
||
for callback in list(self._reload_callbacks):
|
||
try:
|
||
result = callback()
|
||
if asyncio.iscoroutine(result):
|
||
await result
|
||
except Exception as exc:
|
||
logger.warning(f"配置重载回调执行失败: {exc}")
|
||
return True
|
||
|
||
async def start_file_watcher(self) -> None:
|
||
if self._file_watcher is not None and self._file_watcher.running:
|
||
return
|
||
self._file_watcher = FileWatcher(paths=[self.bot_config_path, self.model_config_path])
|
||
await self._file_watcher.start(self._handle_file_changes)
|
||
logger.info("配置文件监视器已启动")
|
||
|
||
async def stop_file_watcher(self) -> None:
|
||
if self._file_watcher is None:
|
||
return
|
||
await self._file_watcher.stop()
|
||
self._file_watcher = None
|
||
|
||
async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None:
|
||
if not changes:
|
||
return
|
||
logger.info("检测到配置文件变更,触发热重载")
|
||
await self.reload_config()
|
||
|
||
|
||
def generate_new_config_file(config_class: type[T], config_path: Path, inner_config_version: str) -> None:
|
||
"""生成新的配置文件
|
||
|
||
:param config_class: 配置类
|
||
:param config_path: 配置文件路径
|
||
:param inner_config_version: 配置文件版本号
|
||
"""
|
||
config = config_class()
|
||
write_config_to_file(config, config_path, inner_config_version)
|
||
|
||
|
||
def load_config_from_file(
|
||
config_class: type[T], config_path: Path, new_ver: str, override_repr: bool = False
|
||
) -> tuple[T, bool]:
|
||
attribute_data = AttributeData()
|
||
with open(config_path, "r", encoding="utf-8") as f:
|
||
config_data = tomlkit.load(f)
|
||
inner_table = config_data.get("inner")
|
||
if not isinstance(inner_table, Mapping):
|
||
raise TypeError("配置文件缺少 inner 版本信息")
|
||
inner_version = inner_table.get("version")
|
||
if not isinstance(inner_version, str):
|
||
raise TypeError("配置文件 inner.version 类型错误")
|
||
old_ver: str = inner_version
|
||
config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理
|
||
config_data = config_data.unwrap() # 转换为普通字典,方便后续处理
|
||
# 保留一份“干净”的原始数据副本,避免第一次 from_dict 过程中对 dict 的就地修改
|
||
original_data: dict[str, Any] = copy.deepcopy(config_data)
|
||
try:
|
||
updated: bool = False
|
||
try:
|
||
target_config = config_class.from_dict(attribute_data, config_data)
|
||
except TypeError as e:
|
||
# 可拔插的旧配置修复(仅针对 bot_config.toml 的已知结构变更)
|
||
if config_path.name == "bot_config.toml" and config_class.__name__ == "Config":
|
||
# 基于未被部分构造污染的 original_data 做迁移尝试
|
||
mig = try_migrate_legacy_bot_config_dict(original_data)
|
||
if mig.migrated:
|
||
logger.warning(
|
||
f"检测到旧版配置结构,已尝试自动修复: {mig.reason}。建议稍后检查并保存生成的新配置文件。"
|
||
)
|
||
migrated_data = mig.data
|
||
target_config = config_class.from_dict(attribute_data, migrated_data)
|
||
else:
|
||
raise e
|
||
else:
|
||
raise e
|
||
if compare_versions(old_ver, new_ver):
|
||
output_config_changes(attribute_data, logger, old_ver, new_ver, config_path.name)
|
||
write_config_to_file(target_config, config_path, new_ver, override_repr)
|
||
updated = True
|
||
return target_config, updated
|
||
except Exception as e:
|
||
logger.critical(f"配置文件{config_path.name}解析失败")
|
||
raise e
|
||
|
||
|
||
def write_config_to_file(
|
||
config: ConfigBase, config_path: Path, inner_config_version: str, override_repr: bool = False
|
||
) -> None:
|
||
"""将配置写入文件
|
||
|
||
:param config: 配置对象
|
||
:param config_path: 配置文件路径
|
||
"""
|
||
# 创建空TOMLDocument
|
||
full_config_data = tomlkit.document()
|
||
|
||
# 首先写入配置文件版本信息
|
||
version_table = tomlkit.table()
|
||
version_table.add("version", inner_config_version)
|
||
full_config_data.add("inner", version_table)
|
||
|
||
# 递归解析配置项为表格
|
||
for config_item_name, config_item in type(config).model_fields.items():
|
||
if not config_item.repr and not override_repr:
|
||
continue
|
||
if config_item_name in ["field_docs", "_validate_any", "suppress_any_warning"]:
|
||
continue
|
||
config_field = getattr(config, config_item_name)
|
||
if isinstance(config_field, ConfigBase):
|
||
full_config_data.add(
|
||
config_item_name, recursive_parse_item_to_table(config_field, override_repr=override_repr)
|
||
)
|
||
elif isinstance(config_field, list):
|
||
aot = tomlkit.aot()
|
||
for item in config_field:
|
||
if not isinstance(item, ConfigBase):
|
||
raise TypeError("配置写入只支持ConfigBase子类")
|
||
aot.append(recursive_parse_item_to_table(item, override_repr=override_repr))
|
||
full_config_data.add(config_item_name, aot)
|
||
else:
|
||
raise TypeError("配置写入只支持ConfigBase子类")
|
||
|
||
# 备份旧文件
|
||
if config_path.exists():
|
||
backup_root = config_path.parent / "old"
|
||
backup_root.mkdir(parents=True, exist_ok=True)
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
backup_path = backup_root / f"{config_path.stem}_{timestamp}.toml"
|
||
config_path.replace(backup_path)
|
||
|
||
# 写入文件
|
||
with open(config_path, "w", encoding="utf-8") as f:
|
||
tomlkit.dump(full_config_data, f)
|
||
|
||
|
||
# generate_new_config_file(Config, BOT_CONFIG_PATH, CONFIG_VERSION)
|
||
config_manager = ConfigManager()
|
||
config_manager.initialize()
|
||
global_config = config_manager.get_global_config()
|
||
model_config = config_manager.get_model_config()
|