fix: 修复保存toml时偶发的空行累计bug和注释丢失问题

This commit is contained in:
Ronifue
2025-12-02 16:00:05 +08:00
parent fa0211c87c
commit 2b7559b8cc
2 changed files with 67 additions and 43 deletions

View File

@@ -4,6 +4,7 @@ TOML 工具函数
提供 TOML 文件的格式化保存功能,确保数组等元素以美观的多行格式输出。 提供 TOML 文件的格式化保存功能,确保数组等元素以美观的多行格式输出。
""" """
import re
from typing import Any from typing import Any
import tomlkit import tomlkit
from tomlkit.items import AoT, Table, Array from tomlkit.items import AoT, Table, Array
@@ -54,14 +55,71 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
return obj return obj
def save_toml_with_format(data: Any, file_path: str, multiline_threshold: int = 1) -> None: def _update_toml_doc(target: Any, source: Any) -> None:
"""格式化 TOML 数据并保存到文件""" """
递归合并字典,将 source 的值更新到 target 中,保留 target 的注释和格式。
- 已存在的键:更新值(递归处理嵌套字典)
- 新增的键:添加到 target
- 跳过 version 字段
"""
if isinstance(source, list) or not isinstance(source, dict) or not isinstance(target, dict):
return
for key, value in source.items():
if key == "version":
continue
if key in target:
# 已存在的键:递归更新或直接赋值
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, dict):
_update_toml_doc(target_value, value)
else:
try:
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
target[key] = value
else:
# 新增的键:添加到 target
try:
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
target[key] = value
def save_toml_with_format(
data: Any, file_path: str, multiline_threshold: int = 1, preserve_comments: bool = True
) -> None:
"""
格式化 TOML 数据并保存到文件。
Args:
data: 要保存的数据dict 或 tomlkit 文档)
file_path: 保存路径
multiline_threshold: 数组多行格式化阈值,-1 表示不格式化
preserve_comments: 是否保留原文件的注释和格式(默认 True
若为 True 且文件已存在且 data 不是 tomlkit 文档,会先读取原文件,再将 data 合并进去
"""
import os
from tomlkit import TOMLDocument
# 如果需要保留注释、文件存在、且 data 不是已有的 tomlkit 文档,先读取原文件再合并
if preserve_comments and os.path.exists(file_path) and not isinstance(data, TOMLDocument):
with open(file_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f)
_update_toml_doc(doc, data)
data = doc
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
output = re.sub(r'\n{3,}', '\n\n', output)
with open(file_path, "w", encoding="utf-8") as f: with open(file_path, "w", encoding="utf-8") as f:
tomlkit.dump(formatted, f) f.write(output)
def format_toml_string(data: Any, multiline_threshold: int = 1) -> str: def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
"""格式化 TOML 数据并返回字符串""" """格式化 TOML 数据并返回字符串"""
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
return tomlkit.dumps(formatted) output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
return re.sub(r'\n{3,}', '\n\n', output)

View File

@@ -8,7 +8,7 @@ from fastapi import APIRouter, HTTPException, Body
from typing import Any, Annotated from typing import Any, Annotated
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.toml_utils import save_toml_with_format from src.common.toml_utils import save_toml_with_format, _update_toml_doc
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.official_configs import ( from src.config.official_configs import (
BotConfig, BotConfig,
@@ -51,40 +51,6 @@ PathBody = Annotated[dict[str, str], Body()]
router = APIRouter(prefix="/config", tags=["config"]) router = APIRouter(prefix="/config", tags=["config"])
# ===== 辅助函数 =====
def _update_dict_preserve_comments(target: Any, source: Any) -> None:
"""
递归合并字典,保留 target 中的注释和格式
将 source 的值更新到 target 中(仅更新已存在的键)
Args:
target: 目标字典tomlkit 对象,包含注释)
source: 源字典(普通 dict 或 list
"""
# 如果 source 是列表,直接替换(数组表没有注释保留的意义)
if isinstance(source, list):
return # 调用者需要直接赋值
# 如果都是字典,递归合并
if isinstance(source, dict) and isinstance(target, dict):
for key, value in source.items():
if key == "version":
continue # 跳过版本号
if key in target:
target_value = target[key]
# 递归处理嵌套字典
if isinstance(value, dict) and isinstance(target_value, dict):
_update_dict_preserve_comments(target_value, value)
else:
# 使用 tomlkit.item 保持类型
try:
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
target[key] = value
# ===== 架构获取接口 ===== # ===== 架构获取接口 =====
@@ -238,7 +204,7 @@ async def update_bot_config(config_data: ConfigBody):
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件(格式化数组为多行 # 保存配置文件(自动保留注释和格式)
config_path = os.path.join(CONFIG_DIR, "bot_config.toml") config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
save_toml_with_format(config_data, config_path) save_toml_with_format(config_data, config_path)
@@ -261,7 +227,7 @@ async def update_model_config(config_data: ConfigBody):
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件(格式化数组为多行 # 保存配置文件(自动保留注释和格式)
config_path = os.path.join(CONFIG_DIR, "model_config.toml") config_path = os.path.join(CONFIG_DIR, "model_config.toml")
save_toml_with_format(config_data, config_path) save_toml_with_format(config_data, config_path)
@@ -300,7 +266,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
config_data[section_name] = section_data config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict): elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并 # 字典递归合并
_update_dict_preserve_comments(config_data[section_name], section_data) _update_toml_doc(config_data[section_name], section_data)
else: else:
# 其他类型直接替换 # 其他类型直接替换
config_data[section_name] = section_data config_data[section_name] = section_data
@@ -398,7 +364,7 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
config_data[section_name] = section_data config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict): elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并 # 字典递归合并
_update_dict_preserve_comments(config_data[section_name], section_data) _update_toml_doc(config_data[section_name], section_data)
else: else:
# 其他类型直接替换 # 其他类型直接替换
config_data[section_name] = section_data config_data[section_name] = section_data