ConfigBase and Test

This commit is contained in:
UnCLAS-Prommer
2026-01-12 18:20:03 +08:00
parent 3ab0a2c737
commit 207dc460cb
6 changed files with 1076 additions and 141 deletions

0
src/config/__init__.py Normal file
View File

View File

@@ -1,169 +1,228 @@
from dataclasses import dataclass, fields, MISSING
from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Union
import ast
import inspect
import types
T = TypeVar("T", bound="ConfigBase")
from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field
from typing import Union, get_args, get_origin, Tuple, Any, List, Dict, Set
TOML_DICT_TYPE = {
int,
float,
str,
bool,
list,
dict,
}
__all__ = ["ConfigBase", "Field"]
from src.common.logger import get_logger
logger = get_logger("ConfigBase")
@dataclass
class ConfigBase:
"""配置类的基类"""
class AttrDocBase:
"""解析字段说明的基类"""
field_docs: dict[str, str] = {}
def __post_init__(self):
self.field_docs = self._get_field_docs() # 全局仅获取一次并保留
@classmethod
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
"""从字典加载配置字段"""
if not isinstance(data, dict):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
def _get_field_docs(cls) -> dict[str, str]:
"""
获取字段的说明字符串
init_args: dict[str, Any] = {}
for f in fields(cls):
field_name = f.name
if field_name.startswith("_"):
# 跳过以 _ 开头的字段
continue
if field_name not in data:
if f.default is not MISSING or f.default_factory is not MISSING:
# 跳过未提供且有默认值/默认构造方法的字段
continue
else:
raise ValueError(f"Missing required field: '{field_name}'")
value = data[field_name]
field_type = f.type
try:
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
except TypeError as e:
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
except Exception as e:
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
return cls(**init_args)
:param cls: 配置类
:return: 字段说明字典,键为字段名,值为说明字符串
"""
# 获取类的源代码文本
class_source = cls._get_class_source()
# 解析源代码,找到对应的类定义节点
class_node = cls._find_class_node(class_source)
# 从类定义节点中提取字段文档
return cls._extract_field_docs(class_node)
@classmethod
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
"""
转换字段值为指定类型
def _get_class_source(cls) -> str:
"""获取类定义所在文件的完整源代码"""
# 使用 inspect 模块获取类定义所在的文件路径
class_file = inspect.getfile(cls)
# 读取文件内容并以 UTF-8 编码返回
return Path(class_file).read_text(encoding="utf-8")
1. 对于嵌套的 dataclass递归调用相应的 from_dict 方法
2. 对于泛型集合类型list, set, tuple递归转换每个元素
3. 对于基础类型int, str, float, bool直接转换
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
"""
@classmethod
def _find_class_node(cls, class_source: str) -> ast.ClassDef:
"""在源代码中找到类定义的AST节点"""
tree = ast.parse(class_source)
# 遍历 AST 中的所有节点
for node in ast.walk(tree):
# 查找类定义节点,且类名与当前类名匹配
if isinstance(node, ast.ClassDef) and node.name == cls.__name__:
"""类名匹配,返回节点"""
return node
# 如果没有找到匹配的类定义,抛出异常
raise AttributeError(f"Class {cls.__name__} not found in source.")
# 如果是嵌套的 dataclass递归调用 from_dict 方法
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
return field_type.from_dict(value)
@classmethod
def _extract_field_docs(cls, class_node: ast.ClassDef) -> dict[str, str]:
"""从类的 AST 节点中提取字段的文档字符串"""
doc_dict: dict[str, str] = {}
class_body = class_node.body # 类属性节点列表
for i in range(len(class_body)):
body_item = class_body[i]
# 处理泛型集合类型list, set, tuple
field_origin_type = get_origin(field_type)
field_type_args = get_args(field_type)
# 检查是否有非 model_post_init 的方法定义,如果有则抛出异常
# 这个限制确保 AttrDocBase 子类只包含字段定义和 model_post_init 方法
if isinstance(body_item, ast.FunctionDef) and body_item.name != "model_post_init":
"""检验ConfigBase子类中是否有除model_post_init以外的方法规范配置类的定义"""
raise AttributeError(
f"Methods are not allowed in AttrDocBase subclasses except model_post_init, found {str(body_item.name)}"
) from None
if field_origin_type in {list, set, tuple}:
# 检查提供的value是否为list
if not isinstance(value, list):
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
# 检查当前语句是否为带注解的赋值语句 (类型注解的字段定义)
# 并且下一个语句存在
if (
i + 1 < len(class_body)
and isinstance(body_item, ast.AnnAssign) # 例如: field_name: int = 10
and isinstance(body_item.target, ast.Name) # 目标是一个简单的名称
):
"""字段定义后紧跟的字符串表达式即为字段说明"""
expr_item = class_body[i + 1]
if field_origin_type is list:
# 如果列表元素类型是ConfigBase的子类则对每个元素调用from_dict
# 检查下一个语句是否为字符串常量表达式 (文档字符串)
if (
field_type_args
and isinstance(field_type_args[0], type)
and issubclass(field_type_args[0], ConfigBase)
isinstance(expr_item, ast.Expr) # 表达式语句
and isinstance(expr_item.value, ast.Constant) # 常量值
and isinstance(expr_item.value.value, str) # 字符串常量
):
return [field_type_args[0].from_dict(item) for item in value]
return [cls._convert_field(item, field_type_args[0]) for item in value]
elif field_origin_type is set:
return {cls._convert_field(item, field_type_args[0]) for item in value}
elif field_origin_type is tuple:
# 检查提供的value长度是否与类型参数一致
if len(value) != len(field_type_args):
raise TypeError(
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
)
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False))
doc_string = expr_item.value.value.strip() # 获取说明字符串并去除首尾空白
processed_doc_lines = [line.strip() for line in doc_string.splitlines()] # 多行处理
if field_origin_type is dict:
# 检查提供的value是否为dict
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
# 删除开头的所有空行
while processed_doc_lines and not processed_doc_lines[0]:
processed_doc_lines.pop(0)
# 检查字典的键值类型
if len(field_type_args) != 2:
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
key_type, value_type = field_type_args
# 删除结尾的所有空行
while processed_doc_lines and not processed_doc_lines[-1]:
processed_doc_lines.pop()
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 将处理后的行重新组合,并存入字典
# 键是字段名,值是清理后的文档字符串
doc_dict[body_item.target.id] = "\n".join(processed_doc_lines)
# 处理 Union/Optional 类型(包括 float | None 这种 Python 3.10+ 语法)
# 注意:
# - Optional[float] 等价于 Union[float, None]get_origin() 返回 typing.Union
# - float | None 是 types.UnionTypeget_origin() 返回 None
is_union_type = (
field_origin_type is Union # typing.Optional / typing.Union
or isinstance(field_type, types.UnionType) # Python 3.10+ 的 | 语法
)
if is_union_type:
union_args = field_type_args if field_type_args else get_args(field_type)
# 安全检查:只允许 T | None 形式的 Optional 类型,禁止 float | str 这种多类型 Union
non_none_types = [arg for arg in union_args if arg is not type(None)]
if len(non_none_types) > 1:
return doc_dict
class ConfigBase(BaseModel, AttrDocBase):
model_config = ConfigDict(validate_assignment=True, extra="forbid")
_validate_any: bool = True # 是否验证 Any 类型的使用,默认为 True
def _discourage_any_usage(self, field_name: str) -> None:
"""警告使用 Any 类型的字段可被suppress"""
if self._validate_any:
raise TypeError(f"字段'{field_name}'中不允许使用 Any 类型注解")
else:
logger.warning(f"字段'{field_name}'中使用了 Any 类型注解,建议避免使用。")
def _get_real_type(self, annotation: type[Any] | Any | None):
"""获取真实类型,处理 dict 等没有参数的情况"""
origin_type = get_origin(annotation)
args_type = get_args(annotation)
if origin_type is None:
origin_type = annotation
args_type = ()
return origin_type, args_type
def _validate_union_type(self, annotation: type[Any] | Any | None, field_name: str):
"""
验证 Union 类型的使用可被suppress
明确禁止 Union / PEP 604 的 | 表示法
允许 Optional[T](即 Union[T, None]"""
origin, args = self._get_real_type(annotation)
other = annotation
if origin in (Union, types.UnionType):
if len(args) != 2 or all(a is not type(None) for a in args):
raise TypeError(f"'{type(self).__name__}'字段'{field_name}'中不允许使用 Union 类型注解")
# 将注解替换为 Optional 的内部类型,继续后续校验(允许原子或容器类型)
other = args[0] if args[1] is type(None) else args[1]
origin, args = self._get_real_type(other)
if origin in (Union, types.UnionType):
raise TypeError(f"'{type(self).__name__}'字段'{field_name}'中不允许嵌套使用 Union/Optional 类型注解")
return origin, args, other
def _validate_list_set_type(self, annotation: Any | None, field_name: str):
"""验证 list/set 类型的使用"""
origin, args = self._get_real_type(annotation)
if origin in (list, set, List, Set):
if len(args) != 1:
raise TypeError(
f"配置字段不支持多类型 Union如 float | str只支持 Optional 类型(如 float | None"
f"当前类型: {field_type}"
f"'{type(self).__name__}'字段'{field_name}'中必须指定且仅指定一个类型参数,使用了: {annotation}"
)
elem = args[0]
if elem is Any:
self._discourage_any_usage(field_name)
if get_origin(elem) is not None:
raise TypeError(
f"'{type(self).__name__}'字段'{field_name}'中不允许嵌套泛型类型: {annotation},请使用自定义类代替。"
)
# 如果值是 None 且 None 在 Union 中,直接返回
if value is None and type(None) in union_args:
return None
# 尝试转换为非 None 的类型
for arg in union_args:
if arg is not type(None):
try:
return cls._convert_field(value, arg)
except (ValueError, TypeError):
continue
# 如果所有类型都转换失败,抛出异常
raise TypeError("Cannot convert value to any type in Union")
# 处理基础类型,例如 int, str 等
if field_origin_type is type(None) and value is None: # 处理Optional类型
return None
# 处理Literal类型
if field_origin_type is Literal or get_origin(field_type) is Literal:
# 获取Literal的允许值
allowed_values = get_args(field_type)
if value in allowed_values:
return value
def _validate_dict_type(self, annotation: Any | None, field_name: str):
"""验证 dict 类型的使用"""
_, args = self._get_real_type(annotation)
if len(args) != 2:
raise TypeError(f"'{type(self).__name__}'字段'{field_name}'中必须指定键和值的类型参数: {annotation}")
_, val_t = args
if val_t is Any:
self._discourage_any_usage(field_name)
if get_origin(val_t):
origin_type = get_origin(val_t)
if origin_type is None:
return
origin_type, _, anno = self._validate_union_type(val_t, field_name)
if origin_type in (list, set, List, Set):
self._validate_list_set_type(anno, field_name)
elif origin_type is Any:
self._discourage_any_usage(field_name)
else:
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
raise TypeError(
f"'{type(self).__name__}'字段'{field_name}'中不允许嵌套泛型类型: {annotation},请使用自定义类代替。"
)
if field_type is Any or isinstance(value, field_type):
return value
def model_post_init(self, context: Any = None) -> None:
"""验证字段的类型注解
# 其他类型,尝试直接转换
try:
return field_type(value)
except (ValueError, TypeError) as e:
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
规则:
- 允许原子注解(非泛型,且不为 Any
- 允许 list[T], set[T],其中 T 为原子注解
- 允许 dict[K, V],其中 K、V 为原子注解
- 禁止使用 Union不包含 Optional和 tuple及 Tuple
- 禁止嵌套泛型(例如 list[list[int]])和使用 Any
"""
for field_name, field_info in type(self).model_fields.items():
annotation = field_info.annotation
origin_type, _ = self._get_real_type(annotation)
# 处理 Union (含Optional) 类型
origin_type, _, annotation = self._validate_union_type(annotation, field_name)
# 禁止 tuple / Tuple
if origin_type in (tuple, Tuple):
raise TypeError(f"'{type(self).__name__}'字段'{field_name}'中不允许使用 Tuple 类型注解")
# 处理 Any 类型
if origin_type is Any:
self._discourage_any_usage(field_name)
def __str__(self):
"""返回配置类的字符串表示"""
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"
# 非泛型注解视为原子类型,允许
if origin_type in (int, float, str, bool, complex, bytes, type(None), Any):
continue
# 允许嵌套的ConfigBase自定义类
if inspect.isclass(origin_type) and issubclass(origin_type, ConfigBase): # type: ignore
continue
# 只允许 list, set, dict 三类泛型
if origin_type not in (list, set, dict, List, Set, Dict):
raise TypeError(
f"仅允许使用list, set, dict三种泛型类型注解'{type(self).__name__}'字段'{field_name}'中使用了: {annotation}"
)
# list/set: 必须指定且仅指定一个类型参数,且参数为原子类型
if origin_type in (list, set, List, Set):
self._validate_list_set_type(annotation, field_name)
# dict: 必须指定两个类型参数,且 key/value 为原子类型或者set/list类型
if origin_type in (dict, Dict):
self._validate_dict_type(annotation, field_name)
super().model_post_init(context)
super().__post_init__() # 获取字段说明