feat: 增加网络安全功能,验证公共 URL 和适配器配置路径

This commit is contained in:
DrSmoothl
2026-03-14 22:55:51 +08:00
parent 1978b097e3
commit 292f0a1d7a
12 changed files with 288 additions and 65 deletions

View File

@@ -4,6 +4,7 @@
import copy
import os
from pathlib import Path
from typing import Any, Annotated, Optional
import tomlkit
@@ -420,6 +421,32 @@ def _normalize_adapter_path(path: str) -> str:
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
def _get_allowed_adapter_config_roots() -> tuple[Path, ...]:
project_root = Path(PROJECT_ROOT).resolve()
return (
project_root,
(project_root.parent / "MaiBot-Napcat-Adapter").resolve(),
Path("/MaiMBot/adapters-config").resolve(),
)
def _resolve_safe_adapter_config_path(path: str) -> Path:
normalized_path = _normalize_adapter_path(path)
candidate_path = Path(normalized_path).expanduser().resolve()
if candidate_path.suffix.lower() != ".toml":
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
for allowed_root in _get_allowed_adapter_config_roots():
try:
candidate_path.relative_to(allowed_root)
return candidate_path
except ValueError:
continue
raise HTTPException(status_code=400, detail="适配器配置路径超出允许范围")
def _to_relative_path(path: str) -> str:
"""尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径"""
if not path or not os.path.isabs(path):
@@ -457,8 +484,11 @@ async def get_adapter_config_path():
if not adapter_config_path:
return {"success": True, "path": None}
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(adapter_config_path)
try:
abs_path = str(_resolve_safe_adapter_config_path(adapter_config_path))
except HTTPException:
logger.warning(f"已忽略不安全的适配器配置路径: {adapter_config_path}")
return {"success": True, "path": None}
# 检查文件是否存在并返回最后修改时间
if os.path.exists(abs_path):
@@ -497,8 +527,7 @@ async def save_adapter_config_path(data: PathBody):
else:
webui_data = {}
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
abs_path = str(_resolve_safe_adapter_config_path(path))
# 尝试转换为相对路径保存(如果文件在项目目录内)
save_path = _to_relative_path(abs_path)
@@ -528,17 +557,12 @@ async def get_adapter_config(path: str):
if not path:
raise HTTPException(status_code=400, detail="路径参数不能为空")
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
abs_path = str(_resolve_safe_adapter_config_path(path))
# 检查文件是否存在
if not os.path.exists(abs_path):
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
# 检查文件扩展名
if not abs_path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
# 读取文件内容
with open(abs_path, "r", encoding="utf-8") as f:
content = f.read()
@@ -565,12 +589,7 @@ async def save_adapter_config(data: PathBody):
if content is None:
raise HTTPException(status_code=400, detail="配置内容不能为空")
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
# 检查文件扩展名
if not abs_path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
abs_path = str(_resolve_safe_adapter_config_path(path))
# 验证 TOML 格式
try:

View File

@@ -14,6 +14,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query
from src.common.logger import get_logger
from src.config.config import CONFIG_DIR
from src.webui.dependencies import require_auth
from src.webui.utils.network_security import validate_public_url
logger = get_logger("webui")
@@ -102,7 +103,12 @@ async def _fetch_models_from_provider(
Returns:
模型列表
"""
url = f"{_normalize_url(base_url)}{endpoint}"
try:
base_url = validate_public_url(_normalize_url(base_url))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
url = f"{base_url}{endpoint}"
# 根据客户端类型设置请求头
headers = {}
@@ -266,6 +272,11 @@ async def test_provider_connection(
if not base_url:
raise HTTPException(status_code=400, detail="base_url 不能为空")
try:
base_url = validate_public_url(base_url)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
result = {
"network_ok": False,
"api_key_valid": None,

View File

@@ -106,6 +106,8 @@ async def update_mirror(
if mirror is None:
raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}")
return _mirror_to_response(mirror)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except HTTPException:
raise
except Exception as e:

View File

@@ -16,6 +16,7 @@ from .support import (
find_plugin_path_by_id,
normalize_dotted_keys,
require_plugin_token,
resolve_plugin_file_path,
)
logger = get_logger("webui.plugin_routes")
@@ -133,7 +134,7 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
schema_json_path = plugin_path / "config_schema.json"
schema_json_path = resolve_plugin_file_path(plugin_path, "config_schema.json")
if schema_json_path.exists():
try:
with open(schema_json_path, "r", encoding="utf-8") as file_obj:
@@ -142,7 +143,7 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
logger.warning(f"读取 config_schema.json 失败,回退到自动推断: {e}")
current_config: Any = {}
config_path = plugin_path / "config.toml"
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as file_obj:
current_config = tomlkit.load(file_obj)
@@ -165,7 +166,7 @@ async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] =
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
if not config_path.exists():
return {"success": True, "config": "", "message": "配置文件不存在"}
@@ -192,7 +193,7 @@ async def update_plugin_config_raw(
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
try:
tomlkit.loads(request.config)
except Exception as e:
@@ -224,7 +225,7 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
if not config_path.exists():
return {"success": True, "config": {}, "message": "配置文件不存在"}
@@ -259,7 +260,7 @@ async def update_plugin_config(
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
backup_path = backup_file(config_path, "backup")
if backup_path is not None:
logger.info(f"已备份配置文件: {backup_path}")
@@ -284,7 +285,7 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
if not config_path.exists():
return {"success": True, "message": "配置文件不存在,无需重置"}
@@ -308,7 +309,7 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
config = tomlkit.document()
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as file_obj:

View File

@@ -13,11 +13,12 @@ from .schemas import InstallPluginRequest, UninstallPluginRequest, UpdatePluginR
from .support import (
find_plugin_path_by_id,
get_plugin_candidate_paths,
get_plugins_dir,
iter_plugin_directories,
load_manifest_json,
parse_repository_url,
remove_tree,
require_plugin_token,
resolve_plugin_file_path,
resolve_installed_plugin_path,
validate_plugin_id,
)
@@ -56,7 +57,8 @@ def _infer_plugin_id(folder_name: str, manifest: dict[str, Any], manifest_path:
logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}")
manifest["id"] = plugin_id
try:
with open(manifest_path, "w", encoding="utf-8") as file_obj:
safe_manifest_path = resolve_plugin_file_path(manifest_path.parent, "_manifest.json")
with open(safe_manifest_path, "w", encoding="utf-8") as file_obj:
json.dump(manifest, file_obj, ensure_ascii=False, indent=2)
except Exception as write_error:
logger.warning(f"无法写入 ID 到 manifest: {write_error}")
@@ -91,10 +93,10 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
if not result.get("success"):
error_msg = str(result.get("error", "克隆失败"))
await update_progress(stage="error", progress=0, message="克隆仓库失败", operation="install", plugin_id=plugin_id, error=error_msg)
raise HTTPException(status_code=500, detail=error_msg)
raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg)
await update_progress(stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id)
manifest_path = target_path / "_manifest.json"
manifest_path = resolve_plugin_file_path(target_path, "_manifest.json")
if not manifest_path.exists():
remove_tree(target_path)
await update_progress(stage="error", progress=0, message="插件缺少 _manifest.json", operation="install", plugin_id=plugin_id, error="无效的插件格式")
@@ -140,7 +142,7 @@ async def uninstall_plugin(request: UninstallPluginRequest, maibot_session: Opti
raise HTTPException(status_code=404, detail="插件未安装")
await update_progress(stage="loading", progress=30, message=f"正在删除插件文件: {plugin_path}", operation="uninstall", plugin_id=plugin_id)
manifest = load_manifest_json(plugin_path / "_manifest.json")
manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json"))
plugin_name = str(manifest.get("name", plugin_id)) if manifest is not None else plugin_id
await update_progress(stage="loading", progress=50, message=f"正在删除 {plugin_name}...", operation="uninstall", plugin_id=plugin_id)
remove_tree(plugin_path)
@@ -173,7 +175,7 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
await update_progress(stage="error", progress=0, message="插件不存在", operation="update", plugin_id=plugin_id, error="插件未安装,请先安装")
raise HTTPException(status_code=404, detail="插件未安装")
manifest = load_manifest_json(plugin_path / "_manifest.json")
manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json"))
old_version = str(manifest.get("version", "unknown")) if manifest is not None else "unknown"
await update_progress(stage="loading", progress=10, message=f"当前版本: {old_version},准备更新...", operation="update", plugin_id=plugin_id)
await update_progress(stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id)
@@ -190,10 +192,10 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
if not result.get("success"):
error_msg = str(result.get("error", "克隆失败"))
await update_progress(stage="error", progress=0, message="下载新版本失败", operation="update", plugin_id=plugin_id, error=error_msg)
raise HTTPException(status_code=500, detail=error_msg)
raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg)
await update_progress(stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id)
new_manifest_path = plugin_path / "_manifest.json"
new_manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json")
if not new_manifest_path.exists():
remove_tree(plugin_path)
await update_progress(stage="error", progress=0, message="新版本缺少 _manifest.json", operation="update", plugin_id=plugin_id, error="无效的插件格式")
@@ -225,23 +227,22 @@ async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) ->
logger.info("收到获取已安装插件列表请求")
try:
plugins_dir = get_plugins_dir()
installed_plugins: list[dict[str, Any]] = []
for plugin_path in plugins_dir.iterdir():
if not plugin_path.is_dir():
continue
for plugin_path in iter_plugin_directories():
folder_name = plugin_path.name
if folder_name.startswith(".") or folder_name.startswith("__"):
continue
manifest_path = plugin_path / "_manifest.json"
manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json")
if not manifest_path.exists():
logger.warning(f"插件文件夹 {folder_name} 缺少 _manifest.json跳过")
continue
try:
with open(manifest_path, "r", encoding="utf-8") as file_obj:
manifest = json.load(file_obj)
manifest = load_manifest_json(manifest_path)
if manifest is None:
logger.warning(f"插件文件夹 {folder_name} 的 _manifest.json 不安全或无效,跳过")
continue
if "name" not in manifest or "version" not in manifest:
logger.warning(f"插件文件夹 {folder_name} 的 _manifest.json 格式无效,跳过")
continue
@@ -286,7 +287,7 @@ async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str]
return {"success": False, "error": "插件未安装"}
for readme_name in ["README.md", "readme.md", "Readme.md", "README.MD"]:
readme_path = plugin_path / readme_name
readme_path = resolve_plugin_file_path(plugin_path, readme_name)
if readme_path.exists():
try:
with open(readme_path, "r", encoding="utf-8") as file_obj:

View File

@@ -89,12 +89,6 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")

View File

@@ -44,6 +44,50 @@ def validate_safe_path(user_path: str, base_path: Path) -> Path:
return target_path
def _resolve_safe_plugin_directory(plugin_path: Path, plugins_dir: Path, strict: bool) -> Optional[Path]:
try:
if plugin_path.is_symlink():
raise HTTPException(status_code=400, detail="插件目录不能是符号链接")
resolved_plugins_dir = plugins_dir.resolve()
resolved_plugin_path = plugin_path.resolve()
resolved_plugin_path.relative_to(resolved_plugins_dir)
if not resolved_plugin_path.is_dir():
return None
return resolved_plugin_path
except HTTPException:
if strict:
raise
logger.warning(f"已跳过不安全的插件目录: {plugin_path}")
return None
except (OSError, RuntimeError, ValueError):
if strict:
raise HTTPException(status_code=400, detail="插件目录超出允许范围")
logger.warning(f"已跳过越界的插件目录: {plugin_path}")
return None
def resolve_plugin_file_path(plugin_path: Path, relative_path: str, allow_missing: bool = True) -> Path:
plugin_root = plugin_path.resolve()
target_path = plugin_root / relative_path
if target_path.exists() and target_path.is_symlink():
raise HTTPException(status_code=400, detail=f"插件文件不能是符号链接: {relative_path}")
try:
resolved_target_path = target_path.resolve()
resolved_target_path.relative_to(plugin_root)
except (OSError, RuntimeError, ValueError) as e:
raise HTTPException(status_code=400, detail=f"插件文件超出允许范围: {relative_path}") from e
if not allow_missing and not resolved_target_path.exists():
raise HTTPException(status_code=404, detail=f"插件文件不存在: {relative_path}")
return resolved_target_path
def validate_plugin_id(plugin_id: str) -> str:
if not plugin_id or not plugin_id.strip():
logger.warning("非法插件 ID: 空字符串")
@@ -164,9 +208,13 @@ def get_plugin_candidate_paths(plugin_id: str) -> tuple[Path, Path]:
def resolve_installed_plugin_path(plugin_id: str) -> Optional[Path]:
new_format_path, old_format_path = get_plugin_candidate_paths(plugin_id)
plugins_dir = get_plugins_dir()
if new_format_path.exists():
return new_format_path
return old_format_path if old_format_path.exists() else None
return _resolve_safe_plugin_directory(new_format_path, plugins_dir, strict=True)
if old_format_path.exists():
return _resolve_safe_plugin_directory(old_format_path, plugins_dir, strict=True)
return None
def parse_repository_url(repository_url: str) -> tuple[str, str, str]:
@@ -181,6 +229,16 @@ def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
if not manifest_path.exists():
return None
if manifest_path.is_symlink():
logger.warning(f"已拒绝读取符号链接 manifest: {manifest_path}")
return None
try:
manifest_path.resolve().relative_to(manifest_path.parent.resolve())
except (OSError, RuntimeError, ValueError):
logger.warning(f"已拒绝读取越界 manifest: {manifest_path}")
return None
try:
with open(manifest_path, "r", encoding="utf-8") as file_obj:
return cast(dict[str, Any], json.load(file_obj))
@@ -189,12 +247,19 @@ def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
def iter_plugin_directories() -> list[Path]:
return [path for path in get_plugins_dir().iterdir() if path.is_dir()]
plugins_dir = get_plugins_dir()
plugin_directories: list[Path] = []
for path in plugins_dir.iterdir():
safe_path = _resolve_safe_plugin_directory(path, plugins_dir, strict=False)
if safe_path is not None:
plugin_directories.append(safe_path)
return plugin_directories
def find_plugin_path_by_id(plugin_id: str) -> Optional[Path]:
for plugin_path in iter_plugin_directories():
manifest = load_manifest_json(plugin_path / "_manifest.json")
manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json")
manifest = load_manifest_json(manifest_path)
if manifest is not None and (manifest.get("id") == plugin_id or plugin_path.name == plugin_id):
return plugin_path
return None
@@ -214,6 +279,9 @@ def backup_file(file_path: Path, action: str, move_file: bool = False) -> Option
def remove_tree(path: Path) -> None:
if path.is_symlink():
raise ValueError(f"拒绝删除符号链接路径: {path}")
def remove_readonly(func: Any, target_path: str, _: Any) -> None:
os.chmod(target_path, stat.S_IWRITE)
func(target_path)