解决ConfigBase问题,更严格测试,实际测试

This commit is contained in:
UnCLAS-Prommer
2026-01-15 17:05:23 +08:00
parent fd46d8a302
commit 9186d14100
11 changed files with 871 additions and 1139 deletions

View File

@@ -2,27 +2,35 @@ import ast
import inspect
import types
from dataclasses import dataclass, field
from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field
from typing import Union, get_args, get_origin, Tuple, Any, List, Dict, Set
from typing import Union, get_args, get_origin, Tuple, Any, List, Dict, Set, Literal
__all__ = ["ConfigBase", "Field"]
__all__ = ["ConfigBase", "Field", "AttributeData"]
from src.common.logger import get_logger
logger = get_logger("ConfigBase")
@dataclass
class AttributeData:
missing_attributes: list[str] = field(default_factory=list)
"""缺失的属性列表"""
redundant_attributes: list[str] = field(default_factory=list)
"""多余的属性列表"""
class AttrDocBase:
"""解析字段说明的基类"""
field_docs: dict[str, str] = {}
def __post_init__(self):
self.field_docs = self._get_field_docs() # 全局仅获取一次并保留
def __post_init__(self, allow_extra_methods: bool = False):
self.field_docs = self._get_field_docs(allow_extra_methods) # 全局仅获取一次并保留
@classmethod
def _get_field_docs(cls) -> dict[str, str]:
def _get_field_docs(self, allow_extra_methods: bool) -> dict[str, str]:
"""
获取字段的说明字符串
@@ -30,11 +38,11 @@ class AttrDocBase:
:return: 字段说明字典,键为字段名,值为说明字符串
"""
# 获取类的源代码文本
class_source = cls._get_class_source()
class_source = self._get_class_source()
# 解析源代码,找到对应的类定义节点
class_node = cls._find_class_node(class_source)
class_node = self._find_class_node(class_source)
# 从类定义节点中提取字段文档
return cls._extract_field_docs(class_node)
return self._extract_field_docs(class_node, allow_extra_methods)
@classmethod
def _get_class_source(cls) -> str:
@@ -57,21 +65,22 @@ class AttrDocBase:
# 如果没有找到匹配的类定义,抛出异常
raise AttributeError(f"Class {cls.__name__} not found in source.")
@classmethod
def _extract_field_docs(cls, class_node: ast.ClassDef) -> dict[str, str]:
def _extract_field_docs(self, class_node: ast.ClassDef, allow_extra_methods: bool) -> dict[str, str]:
"""从类的 AST 节点中提取字段的文档字符串"""
# sourcery skip: merge-nested-ifs
doc_dict: dict[str, str] = {}
class_body = class_node.body # 类属性节点列表
for i in range(len(class_body)):
body_item = class_body[i]
# 检查是否有非 model_post_init 的方法定义,如果有则抛出异常
# 这个限制确保 AttrDocBase 子类只包含字段定义和 model_post_init 方法
if isinstance(body_item, ast.FunctionDef) and body_item.name != "model_post_init":
"""检验ConfigBase子类中是否有除model_post_init以外的方法规范配置类的定义"""
raise AttributeError(
f"Methods are not allowed in AttrDocBase subclasses except model_post_init, found {str(body_item.name)}"
) from None
if not allow_extra_methods:
# 检查是否有非 model_post_init 方法定义,如果有则抛出异常
# 这个限制确保 AttrDocBase 子类只包含字段定义和 model_post_init 方法
if isinstance(body_item, ast.FunctionDef) and body_item.name != "model_post_init":
"""检验ConfigBase子类中是否有除model_post_init以外的方法规范配置类的定义"""
raise AttributeError(
f"Methods are not allowed in AttrDocBase subclasses except model_post_init, found {str(body_item.name)}"
) from None
# 检查当前语句是否为带注解的赋值语句 (类型注解的字段定义)
# 并且下一个语句存在
@@ -110,13 +119,53 @@ class AttrDocBase:
class ConfigBase(BaseModel, AttrDocBase):
model_config = ConfigDict(validate_assignment=True, extra="forbid")
_validate_any: bool = True # 是否验证 Any 类型的使用,默认为 True
suppress_any_warning: bool = False # 是否抑制 Any 类型使用的警告,默认为 False仅仅在_validate_any 为 False 时生效
@classmethod
def from_dict(cls, attribute_data: AttributeData, data: dict[str, Any]):
"""从字典创建配置对象,并收集缺失和多余的属性信息"""
class_fields = set(cls.model_fields.keys())
class_fields.remove("field_docs") # 忽略 field_docs 字段
if "_validate_any" in class_fields:
class_fields.remove("_validate_any") # 忽略 _validate_any 字段
if "suppress_any_warning" in class_fields:
class_fields.remove("suppress_any_warning") # 忽略 suppress_any_warning 字
for class_field in class_fields:
if class_field not in data:
attribute_data.missing_attributes.append(class_field) # 记录缺失的属性
cleaned_data_list: list[str] = []
for data_field in data:
if data_field not in class_fields:
cleaned_data_list.append(data_field)
attribute_data.redundant_attributes.append(data_field) # 记录多余的属性
for redundant_field in cleaned_data_list:
data.pop(redundant_field) # 移除多余的属性
# 对于是ConfigBase子类的字段递归调用from_dict
class_field_infos = dict(cls.model_fields.items())
for field_data in data:
if info := class_field_infos.get(field_data):
field_type = info.annotation
if inspect.isclass(field_type) and issubclass(field_type, ConfigBase):
data[field_data] = field_type.from_dict(attribute_data, data[field_data])
if get_origin(field_type) in {list, List}:
elem_type = get_args(field_type)[0]
if inspect.isclass(elem_type) and issubclass(elem_type, ConfigBase):
data[field_data] = [elem_type.from_dict(attribute_data, item) for item in data[field_data]]
# 没有set因为ConfigBase is not Hashable
if get_origin(field_type) in {dict, Dict}:
val_type = get_args(field_type)[1]
if inspect.isclass(val_type) and issubclass(val_type, ConfigBase):
data[field_data] = {
key: val_type.from_dict(attribute_data, val) for key, val in data[field_data].items()
}
return cls(**data)
def _discourage_any_usage(self, field_name: str) -> None:
"""警告使用 Any 类型的字段可被suppress"""
if self._validate_any:
raise TypeError(f"字段'{field_name}'中不允许使用 Any 类型注解")
else:
logger.warning(f"字段'{field_name}'中使用了 Any 类型注解,建议避免使用。")
if not self.suppress_any_warning:
logger.warning(f"字段'{field_name}'中使用了 Any 类型注解,建议使用更具体的类型注解以提高类型安全性")
def _get_real_type(self, annotation: type[Any] | Any | None):
"""获取真实类型,处理 dict 等没有参数的情况"""
@@ -157,10 +206,14 @@ class ConfigBase(BaseModel, AttrDocBase):
elem = args[0]
if elem is Any:
self._discourage_any_usage(field_name)
if get_origin(elem) is not None:
elif get_origin(elem) is not None:
raise TypeError(
f"'{type(self).__name__}'字段'{field_name}'中不允许嵌套泛型类型: {annotation},请使用自定义类代替。"
)
elif inspect.isclass(elem) and issubclass(elem, ConfigBase) and origin in (set, Set):
raise TypeError(
f"'{type(self).__name__}'字段'{field_name}'中不允许使用 ConfigBase 子类作为 set 元素类型: {annotation}。ConfigBase is not Hashable。"
)
def _validate_dict_type(self, annotation: Any | None, field_name: str):
"""验证 dict 类型的使用"""
@@ -215,7 +268,7 @@ class ConfigBase(BaseModel, AttrDocBase):
if inspect.isclass(origin_type) and issubclass(origin_type, ConfigBase): # type: ignore
continue
# 只允许 list, set, dict 三类泛型
if origin_type not in (list, set, dict, List, Set, Dict):
if origin_type not in (list, set, dict, List, Set, Dict, Literal):
raise TypeError(
f"仅允许使用list, set, dict三种泛型类型注解'{type(self).__name__}'字段'{field_name}'中使用了: {annotation}"
)