167 lines
7.3 KiB
Python
167 lines
7.3 KiB
Python
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Set, Tuple, Union, get_args, get_origin
|
|
|
|
from pydantic.fields import FieldInfo
|
|
from tomlkit import items
|
|
|
|
import tomlkit
|
|
import types
|
|
|
|
from .config_base import ConfigBase
|
|
from src.common.i18n import t
|
|
|
|
if TYPE_CHECKING:
|
|
from .config_base import AttributeData
|
|
|
|
|
|
def recursive_parse_item_to_table(
|
|
config: ConfigBase, is_inline_table: bool = False, override_repr: bool = False
|
|
) -> items.Table | items.InlineTable:
|
|
# sourcery skip: merge-else-if-into-elif, reintroduce-else
|
|
"""递归解析配置项为表格"""
|
|
config_table = tomlkit.table()
|
|
if is_inline_table:
|
|
config_table = tomlkit.inline_table()
|
|
for config_item_name, config_item_info in type(config).model_fields.items():
|
|
if not config_item_info.repr and not override_repr:
|
|
continue
|
|
value = getattr(config, config_item_name)
|
|
if config_item_name in ["field_docs", "_validate_any", "suppress_any_warning"]:
|
|
continue
|
|
if value is None:
|
|
continue
|
|
if isinstance(value, ConfigBase):
|
|
config_table.add(
|
|
config_item_name,
|
|
recursive_parse_item_to_table(
|
|
value,
|
|
is_inline_table=is_inline_table,
|
|
override_repr=override_repr,
|
|
),
|
|
)
|
|
else:
|
|
config_table.add(
|
|
config_item_name, convert_field(config_item_name, config_item_info, value, override_repr=override_repr)
|
|
)
|
|
if not is_inline_table:
|
|
config_table = comment_doc_string(config, config_item_name, config_table)
|
|
return config_table
|
|
|
|
|
|
def comment_doc_string(
|
|
config: ConfigBase, field_name: str, toml_table: items.Table | items.InlineTable
|
|
) -> items.Table | items.InlineTable:
|
|
"""将配置类中的注释加入toml表格中"""
|
|
if doc_string := config.field_docs.get(field_name, ""):
|
|
doc_string_splitted = doc_string.splitlines()
|
|
if len(doc_string_splitted) == 1 and not doc_string_splitted[0].strip().startswith("_wrap_"):
|
|
if isinstance(toml_table[field_name], bool):
|
|
# tomlkit 故意设计的行为,布尔值不能直接添加注释
|
|
value = toml_table[field_name]
|
|
item = tomlkit.item(value)
|
|
item.comment(doc_string_splitted[0])
|
|
toml_table[field_name] = item
|
|
else:
|
|
toml_table[field_name].comment(doc_string_splitted[0])
|
|
else:
|
|
if doc_string_splitted[0].strip().startswith("_wrap_"):
|
|
doc_string_splitted[0] = doc_string_splitted[0].replace("_wrap_", "", 1).strip()
|
|
for line in doc_string_splitted:
|
|
toml_table.add(tomlkit.comment(line))
|
|
toml_table.add(tomlkit.nl())
|
|
return toml_table
|
|
|
|
|
|
def convert_field(config_item_name: str, config_item_info: FieldInfo, value: Any, override_repr: bool = False):
|
|
# sourcery skip: extract-method
|
|
"""将非可直接表达类转换为toml可表达类"""
|
|
field_type_origin = get_origin(config_item_info.annotation)
|
|
field_type_args = get_args(config_item_info.annotation)
|
|
|
|
# 处理 Optional[T] / Union[T, None] / PEP604 的 T | None
|
|
if field_type_origin in (Union, types.UnionType):
|
|
# 只处理 "某类型 + None" 的情况,等价于 Optional[T]
|
|
non_none_args = tuple(a for a in field_type_args if a is not type(None))
|
|
if len(non_none_args) == 1:
|
|
inner = non_none_args[0]
|
|
inner_origin = get_origin(inner)
|
|
inner_args = get_args(inner)
|
|
# Optional[基础类型] 直接按基础类型处理
|
|
if inner_origin is None and isinstance(inner, type) and inner in (int, float, str, bool):
|
|
return value
|
|
# Optional[Literal[...]] 的情况
|
|
if inner_origin is Literal:
|
|
if value not in inner_args:
|
|
raise ValueError(f"Value {value} not in Literal options {inner_args} for {config_item_name}")
|
|
return value
|
|
# 其它 Optional[...],后续按去掉 None 的泛型再走一遍逻辑
|
|
field_type_origin = inner_origin
|
|
field_type_args = inner_args
|
|
else:
|
|
# 复杂 Union 不支持写回,只能报错
|
|
raise TypeError(f"Unsupported Union type for {config_item_name}: {config_item_info.annotation}")
|
|
|
|
if not field_type_origin: # 基础类型 int,bool,str,float 等直接添加
|
|
return value
|
|
elif field_type_origin in {list, set, List, Set}:
|
|
toml_list = tomlkit.array()
|
|
if field_type_args and isinstance(field_type_args[0], type) and issubclass(field_type_args[0], ConfigBase):
|
|
for item in value:
|
|
toml_list.append(recursive_parse_item_to_table(item, True, override_repr))
|
|
else:
|
|
for item in value:
|
|
toml_list.append(item)
|
|
return toml_list
|
|
elif field_type_origin in (tuple, Tuple):
|
|
toml_list = tomlkit.array()
|
|
for field_arg, item in zip(field_type_args, value, strict=True):
|
|
if isinstance(field_arg, type) and issubclass(field_arg, ConfigBase):
|
|
toml_list.append(recursive_parse_item_to_table(item, True, override_repr))
|
|
else:
|
|
toml_list.append(item)
|
|
return toml_list
|
|
elif field_type_origin in (dict, Dict):
|
|
if len(field_type_args) != 2:
|
|
raise TypeError(f"Expected a dictionary with two type arguments for {config_item_name}")
|
|
toml_sub_table = tomlkit.inline_table()
|
|
key_type, value_type = field_type_args
|
|
if key_type is not str:
|
|
raise TypeError(f"TOML only supports string keys for tables, got {key_type} for {config_item_name}")
|
|
for k, v in value.items():
|
|
if isinstance(value_type, type) and issubclass(value_type, ConfigBase):
|
|
toml_sub_table.add(k, recursive_parse_item_to_table(v, True, override_repr))
|
|
else:
|
|
toml_sub_table.add(k, v)
|
|
return toml_sub_table
|
|
elif field_type_origin is Literal:
|
|
if value not in field_type_args:
|
|
raise ValueError(f"Value {value} not in Literal options {field_type_args} for {config_item_name}")
|
|
return value
|
|
else:
|
|
raise TypeError(f"Unsupported field type for {config_item_name}: {config_item_info.annotation}")
|
|
|
|
|
|
def output_config_changes(attr_data: "AttributeData", logger, old_ver: str, new_ver: str, file_name: str):
|
|
"""输出配置变更信息"""
|
|
logger.info(t("config.change_summary_header"))
|
|
logger.info(t("config.added_count", count=len(attr_data.missing_attributes)))
|
|
for attr in attr_data.missing_attributes:
|
|
logger.info(t("config.added_item", attribute=attr))
|
|
logger.info(t("config.removed_count", count=len(attr_data.redundant_attributes)))
|
|
for attr in attr_data.redundant_attributes:
|
|
logger.warning(t("config.removed_item", attribute=attr))
|
|
logger.info(
|
|
t(
|
|
"config.file_updated",
|
|
file_name=file_name,
|
|
new_version=new_ver,
|
|
old_version=old_ver,
|
|
)
|
|
)
|
|
|
|
|
|
def compare_versions(old_ver: str, new_ver: str) -> bool:
|
|
"""比较版本号,返回是否有更新"""
|
|
old_parts = [int(part) for part in old_ver.split(".")]
|
|
new_parts = [int(part) for part in new_ver.split(".")]
|
|
return new_parts > old_parts
|