解决ConfigBase问题,更严格测试,实际测试

This commit is contained in:
UnCLAS-Prommer
2026-01-15 17:05:23 +08:00
parent fd46d8a302
commit 9186d14100
11 changed files with 871 additions and 1139 deletions

View File

@@ -1,19 +1,12 @@
import os
from pathlib import Path
from typing import TypeVar
from datetime import datetime
from typing import Any
import tomlkit
import shutil
import sys
from datetime import datetime
from tomlkit import TOMLDocument
from tomlkit.items import Table, KeyType
from dataclasses import field, dataclass
from rich.traceback import install
from typing import List, Optional
from src.common.logger import get_logger
from src.common.toml_utils import format_toml_string
from src.config.config_base import ConfigBase
from src.config.official_configs import (
from .official_configs import (
BotConfig,
PersonalityConfig,
ExpressionConfig,
@@ -34,345 +27,112 @@ from src.config.official_configs import (
MemoryConfig,
DebugConfig,
DreamConfig,
WebUIConfig,
)
from .model_configs import ModelInfo, ModelTaskConfig, APIProvider
from .config_base import ConfigBase, Field, AttributeData
from .config_utils import recursive_parse_item_to_table, output_config_changes, compare_versions
from .api_ada_configs import (
ModelTaskConfig,
ModelInfo,
APIProvider,
)
from src.common.logger import get_logger
"""
如果你想要修改配置文件请递增version的值
install(extra_lines=3)
版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
主版本号MMC版本更新
次版本号:配置文件内容大更新
修订号:配置文件内容小更新
"""
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
CONFIG_DIR: Path = PROJECT_ROOT / "config"
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
MMC_VERSION: str = "0.13.0"
CONFIG_VERSION: str = "8.0.0"
MODEL_CONFIG_VERSION: str = "1.12.0"
# 配置主程序日志格式
logger = get_logger("config")
# 获取当前文件所在目录的父目录的父目录即MaiBot项目根目录
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
CONFIG_DIR = os.path.join(PROJECT_ROOT, "config")
TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.13.0-snapshot.1"
T = TypeVar("T", bound="ConfigBase")
def get_key_comment(toml_table, key):
# 获取key的注释如果有
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
return toml_table.trivia.comment
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
item = toml_table.value.get(key)
if item is not None and hasattr(item, "trivia"):
return item.trivia.comment
if hasattr(toml_table, "keys"):
for k in toml_table.keys():
if isinstance(k, KeyType) and k.key == key: # type: ignore
return k.trivia.comment # type: ignore
return None
def compare_dicts(new, old, path=None, logs=None):
# 递归比较两个dict找出新增和删减项收集注释
if path is None:
path = []
if logs is None:
logs = []
# 新增项
for key in new:
if key == "version":
continue
if key not in old:
comment = get_key_comment(new, key)
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
compare_dicts(new[key], old[key], path + [str(key)], logs)
# 删减项
for key in old:
if key == "version":
continue
if key not in new:
comment = get_key_comment(old, key)
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
return logs
def get_value_by_path(d, path):
for k in path:
if isinstance(d, dict) and k in d:
d = d[k]
else:
return None
return d
def set_value_by_path(d, path, value):
"""设置嵌套字典中指定路径的值"""
for k in path[:-1]:
if k not in d or not isinstance(d[k], dict):
d[k] = {}
d = d[k]
# 使用 tomlkit.item 来保持 TOML 格式
try:
d[path[-1]] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
d[path[-1]] = value
def compare_default_values(new, old, path=None, logs=None, changes=None):
# 递归比较两个dict找出默认值变化项
if path is None:
path = []
if logs is None:
logs = []
if changes is None:
changes = []
for key in new:
if key == "version":
continue
if key in old:
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
elif new[key] != old[key]:
logs.append(f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
changes.append((path + [str(key)], old[key], new[key]))
return logs, changes
def _get_version_from_toml(toml_path) -> Optional[str]:
"""从TOML文件中获取版本号"""
if not os.path.exists(toml_path):
return None
with open(toml_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f)
if "inner" in doc and "version" in doc["inner"]: # type: ignore
return doc["inner"]["version"] # type: ignore
return None
def _version_tuple(v):
"""将版本字符串转换为元组以便比较"""
if v is None:
return (0,)
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
_update_dict(target_value, value)
else:
try:
# 统一使用 tomlkit.item 来保持原生类型与转义,不对列表做字符串化处理
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
def _update_config_generic(config_name: str, template_name: str):
"""
通用的配置文件更新函数
Args:
config_name: 配置文件名(不含扩展名),如 'bot_config''model_config'
template_name: 模板文件名(不含扩展名),如 'bot_config_template''model_config_template'
"""
# 获取根目录路径
old_config_dir = os.path.join(CONFIG_DIR, "old")
compare_dir = os.path.join(TEMPLATE_DIR, "compare")
# 定义文件路径
template_path = os.path.join(TEMPLATE_DIR, f"{template_name}.toml")
old_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
new_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
compare_path = os.path.join(compare_dir, f"{template_name}.toml")
# 创建compare目录如果不存在
os.makedirs(compare_dir, exist_ok=True)
template_version = _get_version_from_toml(template_path)
compare_version = _get_version_from_toml(compare_path)
# 检查配置文件是否存在
if not os.path.exists(old_config_path):
logger.info(f"{config_name}.toml配置文件不存在从模板创建新配置")
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}")
# 新创建配置文件,退出
sys.exit(0)
compare_config = None
new_config = None
old_config = None
# 先读取 compare 下的模板(如果有),用于默认值变动检测
if os.path.exists(compare_path):
with open(compare_path, "r", encoding="utf-8") as f:
compare_config = tomlkit.load(f)
# 读取当前模板
with open(template_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查默认值变化并处理(只有 compare_config 存在时才做)
if compare_config:
# 读取旧配置
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
logs, changes = compare_default_values(new_config, compare_config)
if logs:
logger.info(f"检测到{config_name}模板默认值变动如下:")
for log in logs:
logger.info(log)
# 检查旧配置是否等于旧默认值,如果是则更新为新默认值
config_updated = False
for path, old_default, new_default in changes:
old_value = get_value_by_path(old_config, path)
if old_value == old_default:
set_value_by_path(old_config, path, new_default)
logger.info(
f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
)
config_updated = True
# 如果配置有更新,立即保存到文件
if config_updated:
with open(old_config_path, "w", encoding="utf-8") as f:
f.write(format_toml_string(old_config))
logger.info(f"已保存更新后的{config_name}配置文件")
else:
logger.info(f"未检测到{config_name}模板默认值变动")
# 检查 compare 下没有模板,或新模板版本更高,则复制
if not os.path.exists(compare_path):
shutil.copy2(template_path, compare_path)
logger.info(f"已将{config_name}模板文件复制到: {compare_path}")
elif _version_tuple(template_version) > _version_tuple(compare_version):
shutil.copy2(template_path, compare_path)
logger.info(f"{config_name}模板版本较新已替换compare下的模板: {compare_path}")
else:
logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}")
# 读取旧配置文件和模板文件(如果前面没读过 old_config这里再读一次
if old_config is None:
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
# new_config 已经读取
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
logger.info(f"检测到{config_name}配置文件版本号相同 (v{old_version}),跳过更新")
return
else:
logger.info(
f"\n----------------------------------------\n检测到{config_name}版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
)
else:
logger.info(f"已有{config_name}配置文件未检测到版本号,可能是旧版本。将进行更新")
# 创建old目录如果不存在
os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = os.path.join(old_config_dir, f"{config_name}_{timestamp}.toml")
# 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path)
logger.info(f"已备份旧{config_name}配置文件到: {old_backup_path}")
# 复制模板文件到配置目录
shutil.copy2(template_path, new_config_path)
logger.info(f"已创建新{config_name}配置文件: {new_config_path}")
# 输出新增和删减项及注释
if old_config:
logger.info(f"{config_name}配置项变动如下:\n----------------------------------------")
if logs := compare_dicts(new_config, old_config):
for log in logs:
logger.info(log)
else:
logger.info("无新增或删减项")
# 将旧配置的值更新到新配置中
logger.info(f"开始合并{config_name}新旧配置...")
_update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式,数组多行格式化)
with open(new_config_path, "w", encoding="utf-8") as f:
f.write(format_toml_string(new_config))
logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
def update_config():
"""更新bot_config.toml配置文件"""
_update_config_generic("bot_config", "bot_config_template")
def update_model_config():
"""更新model_config.toml配置文件"""
_update_config_generic("model_config", "model_config_template")
@dataclass
class Config(ConfigBase):
"""总配置类"""
MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
bot: BotConfig = Field(default_factory=BotConfig)
"""机器人配置类"""
bot: BotConfig
personality: PersonalityConfig
relationship: RelationshipConfig
chat: ChatConfig
message_receive: MessageReceiveConfig
emoji: EmojiConfig
expression: ExpressionConfig
keyword_reaction: KeywordReactionConfig
chinese_typo: ChineseTypoConfig
response_post_process: ResponsePostProcessConfig
response_splitter: ResponseSplitterConfig
telemetry: TelemetryConfig
webui: WebUIConfig
experimental: ExperimentalConfig
maim_message: MaimMessageConfig
lpmm_knowledge: LPMMKnowledgeConfig
tool: ToolConfig
memory: MemoryConfig
debug: DebugConfig
voice: VoiceConfig
dream: DreamConfig
personality: PersonalityConfig = Field(default_factory=PersonalityConfig)
"""人格配置类"""
expression: ExpressionConfig = Field(default_factory=ExpressionConfig)
"""表达配置类"""
chat: ChatConfig = Field(default_factory=ChatConfig)
"""聊天配置类"""
memory: MemoryConfig = Field(default_factory=MemoryConfig)
"""记忆配置类"""
relationship: RelationshipConfig = Field(default_factory=RelationshipConfig)
"""关系配置类"""
message_receive: MessageReceiveConfig = Field(default_factory=MessageReceiveConfig)
"""消息接收配置类"""
dream: DreamConfig = Field(default_factory=DreamConfig)
"""做梦配置类"""
tool: ToolConfig = Field(default_factory=ToolConfig)
"""工具配置类"""
voice: VoiceConfig = Field(default_factory=VoiceConfig)
"""语音配置类"""
emoji: EmojiConfig = Field(default_factory=EmojiConfig)
"""表情包配置类"""
keyword_reaction: KeywordReactionConfig = Field(default_factory=KeywordReactionConfig)
"""关键词反应配置类"""
response_post_process: ResponsePostProcessConfig = Field(default_factory=ResponsePostProcessConfig)
"""回复后处理配置类"""
chinese_typo: ChineseTypoConfig = Field(default_factory=ChineseTypoConfig)
"""中文错别字生成器配置类"""
response_splitter: ResponseSplitterConfig = Field(default_factory=ResponseSplitterConfig)
"""回复分割器配置类"""
telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig)
"""遥测配置类"""
debug: DebugConfig = Field(default_factory=DebugConfig)
"""调试配置类"""
experimental: ExperimentalConfig = Field(default_factory=ExperimentalConfig)
"""实验性功能配置类"""
maim_message: MaimMessageConfig = Field(default_factory=MaimMessageConfig)
"""maim_message配置类"""
lpmm_knowledge: LPMMKnowledgeConfig = Field(default_factory=LPMMKnowledgeConfig)
"""LPMM知识库配置类"""
@dataclass
class APIAdapterConfig(ConfigBase):
"""API Adapter配置类"""
class ModelConfig(ConfigBase):
"""模型配置类"""
models: List[ModelInfo]
"""模型列表"""
models: list[ModelInfo] = Field(default_factory=list)
"""模型配置列表"""
model_task_config: ModelTaskConfig
model_task_config: ModelTaskConfig = Field(default_factory=ModelTaskConfig)
"""模型任务配置"""
api_providers: List[APIProvider] = field(default_factory=list)
api_providers: list[APIProvider] = Field(default_factory=list)
"""API提供商列表"""
def __post_init__(self):
def model_post_init(self, context: Any = None):
if not self.models:
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
if not self.api_providers:
@@ -388,78 +148,134 @@ class APIAdapterConfig(ConfigBase):
if len(model_names) != len(set(model_names)):
raise ValueError("模型名称存在重复,请检查配置文件。")
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models}
api_providers_dict = {provider.name: provider for provider in self.api_providers}
for model in self.models:
if not model.model_identifier:
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
if not model.api_provider or model.api_provider not in self.api_providers_dict:
if not model.api_provider or model.api_provider not in api_providers_dict:
raise ValueError(f"模型 '{model.name}' 的 api_provider '{model.api_provider}' 不存在")
def get_model_info(self, model_name: str) -> ModelInfo:
"""根据模型名称获取模型信息"""
if not model_name:
raise ValueError("模型名称不能为空")
if model_name not in self.models_dict:
raise KeyError(f"模型 '{model_name}' 不存在")
return self.models_dict[model_name]
def get_provider(self, provider_name: str) -> APIProvider:
"""根据提供商名称获取API提供商信息"""
if not provider_name:
raise ValueError("API提供商名称不能为空")
if provider_name not in self.api_providers_dict:
raise KeyError(f"API提供商 '{provider_name}' 不存在")
return self.api_providers_dict[provider_name]
return super().model_post_init(context)
def load_config(config_path: str) -> Config:
class ConfigManager:
"""总配置管理类"""
def __init__(self):
self.bot_config_path: Path = BOT_CONFIG_PATH
self.model_config_path: Path = MODEL_CONFIG_PATH
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
def initialize(self):
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
logger.info("正在品鉴配置文件...")
self.global_config: Config = self._load_global_config()
self.model_config: ModelConfig = self._load_model_config()
logger.info("非常的新鲜,非常的美味!")
def _load_global_config(self) -> Config:
config, updated = load_config_from_file(Config, self.bot_config_path, CONFIG_VERSION)
if updated:
sys.exit(0) # 先直接退出
return config
def _load_model_config(self) -> ModelConfig:
config, updated = load_config_from_file(ModelConfig, self.model_config_path, MODEL_CONFIG_VERSION, True)
if updated:
sys.exit(0) # 先直接退出
return config
def get_global_config(self) -> Config:
return self.global_config
def get_model_config(self) -> ModelConfig:
return self.model_config
def generate_new_config_file(config_class: type[T], config_path: Path, inner_config_version: str) -> None:
"""生成新的配置文件
:param config_class: 配置类
:param config_path: 配置文件路径
:param inner_config_version: 配置文件版本号
"""
加载配置文件
Args:
config_path: 配置文件路径
Returns:
Config对象
"""
# 读取配置文件
config = config_class()
write_config_to_file(config, config_path, inner_config_version)
def load_config_from_file(
config_class: type[T], config_path: Path, new_ver: str, override_repr: bool = False
) -> tuple[T, bool]:
attribute_data = AttributeData()
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建Config对象
old_ver: str = config_data["inner"]["version"] # type: ignore
config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理
config_data = config_data.unwrap() # 转换为普通字典,方便后续处理
try:
return Config.from_dict(config_data)
updated: bool = False
target_config = config_class.from_dict(attribute_data, config_data)
if compare_versions(old_ver, new_ver):
output_config_changes(attribute_data, logger, old_ver, new_ver, config_path.name)
write_config_to_file(target_config, config_path, new_ver, override_repr)
updated = True
return target_config, updated
except Exception as e:
logger.critical("配置文件解析失败")
logger.critical(f"配置文件{config_path.name}解析失败")
raise e
def api_ada_load_config(config_path: str) -> APIAdapterConfig:
def write_config_to_file(
config: ConfigBase, config_path: Path, inner_config_version: str, override_repr: bool = False
) -> None:
"""将配置写入文件
:param config: 配置对象
:param config_path: 配置文件路径
"""
加载API适配器配置文件
Args:
config_path: 配置文件路径
Returns:
APIAdapterConfig对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建空TOMLDocument
full_config_data = tomlkit.document()
# 创建APIAdapterConfig对象
try:
return APIAdapterConfig.from_dict(config_data)
except Exception as e:
logger.critical("API适配器配置文件解析失败")
raise e
# 首先写入配置文件版本信息
version_table = tomlkit.table()
version_table.add("version", inner_config_version)
full_config_data.add("inner", version_table)
# 递归解析配置项为表格
for config_item_name, config_item in type(config).model_fields.items():
if not config_item.repr and not override_repr:
continue
if config_item_name in ["field_docs", "_validate_any", "suppress_any_warning"]:
continue
config_field = getattr(config, config_item_name)
if isinstance(config_field, ConfigBase):
full_config_data.add(
config_item_name, recursive_parse_item_to_table(config_field, override_repr=override_repr)
)
elif isinstance(config_field, list):
aot = tomlkit.aot()
for item in config_field:
if not isinstance(item, ConfigBase):
raise TypeError("配置写入只支持ConfigBase子类")
aot.append(recursive_parse_item_to_table(item, override_repr=override_repr))
full_config_data.add(config_item_name, aot)
else:
raise TypeError("配置写入只支持ConfigBase子类")
# 备份旧文件
if config_path.exists():
backup_root = config_path.parent / "old"
backup_root.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = backup_root / f"{config_path.stem}_{timestamp}.toml"
config_path.replace(backup_path)
# 写入文件
with open(config_path, "w", encoding="utf-8") as f:
tomlkit.dump(full_config_data, f)
# 获取配置文件路径
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
update_config()
update_model_config()
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
logger.info("非常的新鲜,非常的美味!")
# generate_new_config_file(Config, BOT_CONFIG_PATH, CONFIG_VERSION)
config_manager = ConfigManager()
config_manager.initialize()
# global_config = config_manager.get_global_config()