WebUI 后端类型注解补全,使用全 typing 库类型注解

This commit is contained in:
DrSmoothl
2026-03-16 13:09:12 +08:00
parent df088205dd
commit e7ac064a80
47 changed files with 572 additions and 365 deletions

View File

@@ -5,51 +5,51 @@
import copy
import os
from pathlib import Path
from typing import Any, Annotated, Optional
from typing import Annotated, Any, Dict, List, Tuple
import tomlkit
from fastapi import APIRouter, Body, Depends, HTTPException
from src.common.logger import get_logger
from src.webui.dependencies import require_auth
from src.webui.utils.toml_utils import save_toml_with_format, _update_toml_doc
from src.config.config import Config, ModelConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.config import CONFIG_DIR, PROJECT_ROOT, Config, ModelConfig
from src.config.config_base import AttributeData
from src.config.model_configs import (
APIProvider,
ModelInfo,
ModelTaskConfig,
)
from src.config.official_configs import (
BotConfig,
PersonalityConfig,
RelationshipConfig,
ChatConfig,
MessageReceiveConfig,
ChineseTypoConfig,
DebugConfig,
EmojiConfig,
ExperimentalConfig,
ExpressionConfig,
KeywordReactionConfig,
ChineseTypoConfig,
LPMMKnowledgeConfig,
MaimMessageConfig,
MemoryConfig,
MessageReceiveConfig,
PersonalityConfig,
RelationshipConfig,
ResponsePostProcessConfig,
ResponseSplitterConfig,
TelemetryConfig,
ExperimentalConfig,
MaimMessageConfig,
LPMMKnowledgeConfig,
ToolConfig,
MemoryConfig,
DebugConfig,
VoiceConfig,
)
from src.config.model_configs import (
ModelTaskConfig,
ModelInfo,
APIProvider,
)
from src.webui.config_schema import ConfigSchemaGenerator
from src.webui.dependencies import require_auth
from src.webui.utils.toml_utils import _update_toml_doc, save_toml_with_format
logger = get_logger("webui")
# 模块级别的类型别名(解决 B008 ruff 错误)
ConfigBody = Annotated[dict[str, Any], Body()]
ConfigBody = Annotated[Dict[str, Any], Body()]
SectionBody = Annotated[Any, Body()]
RawContentBody = Annotated[str, Body(embed=True)]
PathBody = Annotated[dict[str, str], Body()]
PathBody = Annotated[Dict[str, str], Body()]
router = APIRouter(prefix="/config", tags=["config"], dependencies=[Depends(require_auth)])
@@ -61,6 +61,8 @@ def _toml_to_plain_dict(obj: Any) -> Any:
if isinstance(obj, list):
return [_toml_to_plain_dict(v) for v in obj]
return obj
# ===== 架构获取接口 =====
@@ -385,8 +387,12 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
if section_name == "api_providers" and "api_provider" in str(e):
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
models = config_data.get("models", [])
orphaned_models = [
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
orphaned_models: List[str] = [
str(model_name)
for m in models
if isinstance(m, dict)
and m.get("api_provider") not in provider_names
and (model_name := m.get("name")) is not None
]
if orphaned_models:
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
@@ -421,7 +427,7 @@ def _normalize_adapter_path(path: str) -> str:
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
def _get_allowed_adapter_config_roots() -> tuple[Path, ...]:
def _get_allowed_adapter_config_roots() -> Tuple[Path, ...]:
project_root = Path(PROJECT_ROOT).resolve()
return (
project_root,