feat: 增加网络安全功能,验证公共 URL 和适配器配置路径
This commit is contained in:
8
dashboard/.vite/deps/_metadata.json
Normal file
8
dashboard/.vite/deps/_metadata.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"hash": "1b5cd9d5",
|
||||
"configHash": "027a635a",
|
||||
"lockfileHash": "36800971",
|
||||
"browserHash": "e1e062e5",
|
||||
"optimized": {},
|
||||
"chunks": {}
|
||||
}
|
||||
3
dashboard/.vite/deps/package.json
Normal file
3
dashboard/.vite/deps/package.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"type": "module"
|
||||
}
|
||||
@@ -82,7 +82,6 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/ws/logs?token=xxx
|
||||
"""
|
||||
@@ -102,13 +101,6 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import copy
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
import tomlkit
|
||||
@@ -420,6 +421,32 @@ def _normalize_adapter_path(path: str) -> str:
|
||||
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
|
||||
|
||||
|
||||
def _get_allowed_adapter_config_roots() -> tuple[Path, ...]:
|
||||
project_root = Path(PROJECT_ROOT).resolve()
|
||||
return (
|
||||
project_root,
|
||||
(project_root.parent / "MaiBot-Napcat-Adapter").resolve(),
|
||||
Path("/MaiMBot/adapters-config").resolve(),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_safe_adapter_config_path(path: str) -> Path:
|
||||
normalized_path = _normalize_adapter_path(path)
|
||||
candidate_path = Path(normalized_path).expanduser().resolve()
|
||||
|
||||
if candidate_path.suffix.lower() != ".toml":
|
||||
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||
|
||||
for allowed_root in _get_allowed_adapter_config_roots():
|
||||
try:
|
||||
candidate_path.relative_to(allowed_root)
|
||||
return candidate_path
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
raise HTTPException(status_code=400, detail="适配器配置路径超出允许范围")
|
||||
|
||||
|
||||
def _to_relative_path(path: str) -> str:
|
||||
"""尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径"""
|
||||
if not path or not os.path.isabs(path):
|
||||
@@ -457,8 +484,11 @@ async def get_adapter_config_path():
|
||||
if not adapter_config_path:
|
||||
return {"success": True, "path": None}
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(adapter_config_path)
|
||||
try:
|
||||
abs_path = str(_resolve_safe_adapter_config_path(adapter_config_path))
|
||||
except HTTPException:
|
||||
logger.warning(f"已忽略不安全的适配器配置路径: {adapter_config_path}")
|
||||
return {"success": True, "path": None}
|
||||
|
||||
# 检查文件是否存在并返回最后修改时间
|
||||
if os.path.exists(abs_path):
|
||||
@@ -497,8 +527,7 @@ async def save_adapter_config_path(data: PathBody):
|
||||
else:
|
||||
webui_data = {}
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
abs_path = str(_resolve_safe_adapter_config_path(path))
|
||||
|
||||
# 尝试转换为相对路径保存(如果文件在项目目录内)
|
||||
save_path = _to_relative_path(abs_path)
|
||||
@@ -528,17 +557,12 @@ async def get_adapter_config(path: str):
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径参数不能为空")
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
abs_path = str(_resolve_safe_adapter_config_path(path))
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(abs_path):
|
||||
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
|
||||
|
||||
# 检查文件扩展名
|
||||
if not abs_path.endswith(".toml"):
|
||||
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||
|
||||
# 读取文件内容
|
||||
with open(abs_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
@@ -565,12 +589,7 @@ async def save_adapter_config(data: PathBody):
|
||||
if content is None:
|
||||
raise HTTPException(status_code=400, detail="配置内容不能为空")
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
# 检查文件扩展名
|
||||
if not abs_path.endswith(".toml"):
|
||||
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||
abs_path = str(_resolve_safe_adapter_config_path(path))
|
||||
|
||||
# 验证 TOML 格式
|
||||
try:
|
||||
|
||||
@@ -14,6 +14,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import CONFIG_DIR
|
||||
from src.webui.dependencies import require_auth
|
||||
from src.webui.utils.network_security import validate_public_url
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
@@ -102,7 +103,12 @@ async def _fetch_models_from_provider(
|
||||
Returns:
|
||||
模型列表
|
||||
"""
|
||||
url = f"{_normalize_url(base_url)}{endpoint}"
|
||||
try:
|
||||
base_url = validate_public_url(_normalize_url(base_url))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
url = f"{base_url}{endpoint}"
|
||||
|
||||
# 根据客户端类型设置请求头
|
||||
headers = {}
|
||||
@@ -266,6 +272,11 @@ async def test_provider_connection(
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="base_url 不能为空")
|
||||
|
||||
try:
|
||||
base_url = validate_public_url(base_url)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
result = {
|
||||
"network_ok": False,
|
||||
"api_key_valid": None,
|
||||
|
||||
@@ -106,6 +106,8 @@ async def update_mirror(
|
||||
if mirror is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}")
|
||||
return _mirror_to_response(mirror)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -16,6 +16,7 @@ from .support import (
|
||||
find_plugin_path_by_id,
|
||||
normalize_dotted_keys,
|
||||
require_plugin_token,
|
||||
resolve_plugin_file_path,
|
||||
)
|
||||
|
||||
logger = get_logger("webui.plugin_routes")
|
||||
@@ -133,7 +134,7 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
schema_json_path = plugin_path / "config_schema.json"
|
||||
schema_json_path = resolve_plugin_file_path(plugin_path, "config_schema.json")
|
||||
if schema_json_path.exists():
|
||||
try:
|
||||
with open(schema_json_path, "r", encoding="utf-8") as file_obj:
|
||||
@@ -142,7 +143,7 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
|
||||
logger.warning(f"读取 config_schema.json 失败,回退到自动推断: {e}")
|
||||
|
||||
current_config: Any = {}
|
||||
config_path = plugin_path / "config.toml"
|
||||
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
current_config = tomlkit.load(file_obj)
|
||||
@@ -165,7 +166,7 @@ async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] =
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
|
||||
if not config_path.exists():
|
||||
return {"success": True, "config": "", "message": "配置文件不存在"}
|
||||
|
||||
@@ -192,7 +193,7 @@ async def update_plugin_config_raw(
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
|
||||
try:
|
||||
tomlkit.loads(request.config)
|
||||
except Exception as e:
|
||||
@@ -224,7 +225,7 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
|
||||
if not config_path.exists():
|
||||
return {"success": True, "config": {}, "message": "配置文件不存在"}
|
||||
|
||||
@@ -259,7 +260,7 @@ async def update_plugin_config(
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
|
||||
backup_path = backup_file(config_path, "backup")
|
||||
if backup_path is not None:
|
||||
logger.info(f"已备份配置文件: {backup_path}")
|
||||
@@ -284,7 +285,7 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
|
||||
if not config_path.exists():
|
||||
return {"success": True, "message": "配置文件不存在,无需重置"}
|
||||
|
||||
@@ -308,7 +309,7 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
|
||||
config = tomlkit.document()
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
|
||||
@@ -13,11 +13,12 @@ from .schemas import InstallPluginRequest, UninstallPluginRequest, UpdatePluginR
|
||||
from .support import (
|
||||
find_plugin_path_by_id,
|
||||
get_plugin_candidate_paths,
|
||||
get_plugins_dir,
|
||||
iter_plugin_directories,
|
||||
load_manifest_json,
|
||||
parse_repository_url,
|
||||
remove_tree,
|
||||
require_plugin_token,
|
||||
resolve_plugin_file_path,
|
||||
resolve_installed_plugin_path,
|
||||
validate_plugin_id,
|
||||
)
|
||||
@@ -56,7 +57,8 @@ def _infer_plugin_id(folder_name: str, manifest: dict[str, Any], manifest_path:
|
||||
logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}")
|
||||
manifest["id"] = plugin_id
|
||||
try:
|
||||
with open(manifest_path, "w", encoding="utf-8") as file_obj:
|
||||
safe_manifest_path = resolve_plugin_file_path(manifest_path.parent, "_manifest.json")
|
||||
with open(safe_manifest_path, "w", encoding="utf-8") as file_obj:
|
||||
json.dump(manifest, file_obj, ensure_ascii=False, indent=2)
|
||||
except Exception as write_error:
|
||||
logger.warning(f"无法写入 ID 到 manifest: {write_error}")
|
||||
@@ -91,10 +93,10 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
||||
if not result.get("success"):
|
||||
error_msg = str(result.get("error", "克隆失败"))
|
||||
await update_progress(stage="error", progress=0, message="克隆仓库失败", operation="install", plugin_id=plugin_id, error=error_msg)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg)
|
||||
|
||||
await update_progress(stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id)
|
||||
manifest_path = target_path / "_manifest.json"
|
||||
manifest_path = resolve_plugin_file_path(target_path, "_manifest.json")
|
||||
if not manifest_path.exists():
|
||||
remove_tree(target_path)
|
||||
await update_progress(stage="error", progress=0, message="插件缺少 _manifest.json", operation="install", plugin_id=plugin_id, error="无效的插件格式")
|
||||
@@ -140,7 +142,7 @@ async def uninstall_plugin(request: UninstallPluginRequest, maibot_session: Opti
|
||||
raise HTTPException(status_code=404, detail="插件未安装")
|
||||
|
||||
await update_progress(stage="loading", progress=30, message=f"正在删除插件文件: {plugin_path}", operation="uninstall", plugin_id=plugin_id)
|
||||
manifest = load_manifest_json(plugin_path / "_manifest.json")
|
||||
manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json"))
|
||||
plugin_name = str(manifest.get("name", plugin_id)) if manifest is not None else plugin_id
|
||||
await update_progress(stage="loading", progress=50, message=f"正在删除 {plugin_name}...", operation="uninstall", plugin_id=plugin_id)
|
||||
remove_tree(plugin_path)
|
||||
@@ -173,7 +175,7 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
||||
await update_progress(stage="error", progress=0, message="插件不存在", operation="update", plugin_id=plugin_id, error="插件未安装,请先安装")
|
||||
raise HTTPException(status_code=404, detail="插件未安装")
|
||||
|
||||
manifest = load_manifest_json(plugin_path / "_manifest.json")
|
||||
manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json"))
|
||||
old_version = str(manifest.get("version", "unknown")) if manifest is not None else "unknown"
|
||||
await update_progress(stage="loading", progress=10, message=f"当前版本: {old_version},准备更新...", operation="update", plugin_id=plugin_id)
|
||||
await update_progress(stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id)
|
||||
@@ -190,10 +192,10 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
||||
if not result.get("success"):
|
||||
error_msg = str(result.get("error", "克隆失败"))
|
||||
await update_progress(stage="error", progress=0, message="下载新版本失败", operation="update", plugin_id=plugin_id, error=error_msg)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg)
|
||||
|
||||
await update_progress(stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id)
|
||||
new_manifest_path = plugin_path / "_manifest.json"
|
||||
new_manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json")
|
||||
if not new_manifest_path.exists():
|
||||
remove_tree(plugin_path)
|
||||
await update_progress(stage="error", progress=0, message="新版本缺少 _manifest.json", operation="update", plugin_id=plugin_id, error="无效的插件格式")
|
||||
@@ -225,23 +227,22 @@ async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) ->
|
||||
logger.info("收到获取已安装插件列表请求")
|
||||
|
||||
try:
|
||||
plugins_dir = get_plugins_dir()
|
||||
installed_plugins: list[dict[str, Any]] = []
|
||||
for plugin_path in plugins_dir.iterdir():
|
||||
if not plugin_path.is_dir():
|
||||
continue
|
||||
for plugin_path in iter_plugin_directories():
|
||||
folder_name = plugin_path.name
|
||||
if folder_name.startswith(".") or folder_name.startswith("__"):
|
||||
continue
|
||||
|
||||
manifest_path = plugin_path / "_manifest.json"
|
||||
manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json")
|
||||
if not manifest_path.exists():
|
||||
logger.warning(f"插件文件夹 {folder_name} 缺少 _manifest.json,跳过")
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as file_obj:
|
||||
manifest = json.load(file_obj)
|
||||
manifest = load_manifest_json(manifest_path)
|
||||
if manifest is None:
|
||||
logger.warning(f"插件文件夹 {folder_name} 的 _manifest.json 不安全或无效,跳过")
|
||||
continue
|
||||
if "name" not in manifest or "version" not in manifest:
|
||||
logger.warning(f"插件文件夹 {folder_name} 的 _manifest.json 格式无效,跳过")
|
||||
continue
|
||||
@@ -286,7 +287,7 @@ async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str]
|
||||
return {"success": False, "error": "插件未安装"}
|
||||
|
||||
for readme_name in ["README.md", "readme.md", "Readme.md", "README.MD"]:
|
||||
readme_path = plugin_path / readme_name
|
||||
readme_path = resolve_plugin_file_path(plugin_path, readme_name)
|
||||
if readme_path.exists():
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as file_obj:
|
||||
|
||||
@@ -89,12 +89,6 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
|
||||
@@ -44,6 +44,50 @@ def validate_safe_path(user_path: str, base_path: Path) -> Path:
|
||||
return target_path
|
||||
|
||||
|
||||
def _resolve_safe_plugin_directory(plugin_path: Path, plugins_dir: Path, strict: bool) -> Optional[Path]:
|
||||
try:
|
||||
if plugin_path.is_symlink():
|
||||
raise HTTPException(status_code=400, detail="插件目录不能是符号链接")
|
||||
|
||||
resolved_plugins_dir = plugins_dir.resolve()
|
||||
resolved_plugin_path = plugin_path.resolve()
|
||||
resolved_plugin_path.relative_to(resolved_plugins_dir)
|
||||
|
||||
if not resolved_plugin_path.is_dir():
|
||||
return None
|
||||
|
||||
return resolved_plugin_path
|
||||
except HTTPException:
|
||||
if strict:
|
||||
raise
|
||||
logger.warning(f"已跳过不安全的插件目录: {plugin_path}")
|
||||
return None
|
||||
except (OSError, RuntimeError, ValueError):
|
||||
if strict:
|
||||
raise HTTPException(status_code=400, detail="插件目录超出允许范围")
|
||||
logger.warning(f"已跳过越界的插件目录: {plugin_path}")
|
||||
return None
|
||||
|
||||
|
||||
def resolve_plugin_file_path(plugin_path: Path, relative_path: str, allow_missing: bool = True) -> Path:
|
||||
plugin_root = plugin_path.resolve()
|
||||
target_path = plugin_root / relative_path
|
||||
|
||||
if target_path.exists() and target_path.is_symlink():
|
||||
raise HTTPException(status_code=400, detail=f"插件文件不能是符号链接: {relative_path}")
|
||||
|
||||
try:
|
||||
resolved_target_path = target_path.resolve()
|
||||
resolved_target_path.relative_to(plugin_root)
|
||||
except (OSError, RuntimeError, ValueError) as e:
|
||||
raise HTTPException(status_code=400, detail=f"插件文件超出允许范围: {relative_path}") from e
|
||||
|
||||
if not allow_missing and not resolved_target_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"插件文件不存在: {relative_path}")
|
||||
|
||||
return resolved_target_path
|
||||
|
||||
|
||||
def validate_plugin_id(plugin_id: str) -> str:
|
||||
if not plugin_id or not plugin_id.strip():
|
||||
logger.warning("非法插件 ID: 空字符串")
|
||||
@@ -164,9 +208,13 @@ def get_plugin_candidate_paths(plugin_id: str) -> tuple[Path, Path]:
|
||||
|
||||
def resolve_installed_plugin_path(plugin_id: str) -> Optional[Path]:
|
||||
new_format_path, old_format_path = get_plugin_candidate_paths(plugin_id)
|
||||
plugins_dir = get_plugins_dir()
|
||||
|
||||
if new_format_path.exists():
|
||||
return new_format_path
|
||||
return old_format_path if old_format_path.exists() else None
|
||||
return _resolve_safe_plugin_directory(new_format_path, plugins_dir, strict=True)
|
||||
if old_format_path.exists():
|
||||
return _resolve_safe_plugin_directory(old_format_path, plugins_dir, strict=True)
|
||||
return None
|
||||
|
||||
|
||||
def parse_repository_url(repository_url: str) -> tuple[str, str, str]:
|
||||
@@ -181,6 +229,16 @@ def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
|
||||
if not manifest_path.exists():
|
||||
return None
|
||||
|
||||
if manifest_path.is_symlink():
|
||||
logger.warning(f"已拒绝读取符号链接 manifest: {manifest_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
manifest_path.resolve().relative_to(manifest_path.parent.resolve())
|
||||
except (OSError, RuntimeError, ValueError):
|
||||
logger.warning(f"已拒绝读取越界 manifest: {manifest_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as file_obj:
|
||||
return cast(dict[str, Any], json.load(file_obj))
|
||||
@@ -189,12 +247,19 @@ def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
|
||||
|
||||
|
||||
def iter_plugin_directories() -> list[Path]:
|
||||
return [path for path in get_plugins_dir().iterdir() if path.is_dir()]
|
||||
plugins_dir = get_plugins_dir()
|
||||
plugin_directories: list[Path] = []
|
||||
for path in plugins_dir.iterdir():
|
||||
safe_path = _resolve_safe_plugin_directory(path, plugins_dir, strict=False)
|
||||
if safe_path is not None:
|
||||
plugin_directories.append(safe_path)
|
||||
return plugin_directories
|
||||
|
||||
|
||||
def find_plugin_path_by_id(plugin_id: str) -> Optional[Path]:
|
||||
for plugin_path in iter_plugin_directories():
|
||||
manifest = load_manifest_json(plugin_path / "_manifest.json")
|
||||
manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json")
|
||||
manifest = load_manifest_json(manifest_path)
|
||||
if manifest is not None and (manifest.get("id") == plugin_id or plugin_path.name == plugin_id):
|
||||
return plugin_path
|
||||
return None
|
||||
@@ -214,6 +279,9 @@ def backup_file(file_path: Path, action: str, move_file: bool = False) -> Option
|
||||
|
||||
|
||||
def remove_tree(path: Path) -> None:
|
||||
if path.is_symlink():
|
||||
raise ValueError(f"拒绝删除符号链接路径: {path}")
|
||||
|
||||
def remove_readonly(func: Any, target_path: str, _: Any) -> None:
|
||||
os.chmod(target_path, stat.S_IWRITE)
|
||||
func(target_path)
|
||||
|
||||
@@ -10,6 +10,7 @@ import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.utils.network_security import validate_public_url
|
||||
|
||||
logger = get_logger("webui.git_mirror")
|
||||
|
||||
@@ -17,6 +18,20 @@ logger = get_logger("webui.git_mirror")
|
||||
_update_progress = None
|
||||
|
||||
|
||||
def _validate_mirror_prefix(url: str, field_name: str) -> str:
|
||||
try:
|
||||
return validate_public_url(url)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"{field_name} 非法: {e}") from e
|
||||
|
||||
|
||||
def _validate_custom_outbound_url(url: str) -> str:
|
||||
try:
|
||||
return validate_public_url(url)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"目标 URL 非法: {e}") from e
|
||||
|
||||
|
||||
def set_update_progress_callback(callback):
|
||||
"""设置进度更新回调函数"""
|
||||
global _update_progress
|
||||
@@ -200,6 +215,9 @@ class GitMirrorConfig:
|
||||
if self.get_mirror_by_id(mirror_id):
|
||||
raise ValueError(f"镜像源 ID 已存在: {mirror_id}")
|
||||
|
||||
raw_prefix = _validate_mirror_prefix(raw_prefix, "Raw 前缀")
|
||||
clone_prefix = _validate_mirror_prefix(clone_prefix, "克隆前缀")
|
||||
|
||||
# 如果未指定优先级,使用最大优先级 + 1
|
||||
if priority is None:
|
||||
max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0)
|
||||
@@ -241,8 +259,10 @@ class GitMirrorConfig:
|
||||
if name is not None:
|
||||
mirror["name"] = name
|
||||
if raw_prefix is not None:
|
||||
raw_prefix = _validate_mirror_prefix(raw_prefix, "Raw 前缀")
|
||||
mirror["raw_prefix"] = raw_prefix
|
||||
if clone_prefix is not None:
|
||||
clone_prefix = _validate_mirror_prefix(clone_prefix, "克隆前缀")
|
||||
mirror["clone_prefix"] = clone_prefix
|
||||
if enabled is not None:
|
||||
mirror["enabled"] = enabled
|
||||
@@ -372,7 +392,18 @@ class GitMirrorService:
|
||||
logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}")
|
||||
|
||||
if custom_url:
|
||||
# 使用自定义 URL
|
||||
try:
|
||||
custom_url = _validate_custom_outbound_url(custom_url)
|
||||
except ValueError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"mirror_used": "custom",
|
||||
"attempts": 0,
|
||||
"url": custom_url,
|
||||
"status_code": 400,
|
||||
}
|
||||
|
||||
return await self._fetch_with_url(custom_url, "custom")
|
||||
|
||||
# 确定要使用的镜像源列表
|
||||
@@ -443,8 +474,11 @@ class GitMirrorService:
|
||||
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源获取文件"""
|
||||
# 构建 URL
|
||||
raw_prefix = mirror["raw_prefix"]
|
||||
try:
|
||||
raw_prefix = _validate_mirror_prefix(mirror["raw_prefix"], "镜像 Raw 前缀")
|
||||
except ValueError as e:
|
||||
return {"success": False, "error": str(e), "mirror_used": mirror.get("id"), "attempts": 0, "status_code": 400}
|
||||
|
||||
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
|
||||
|
||||
return await self._fetch_with_url(url, mirror["id"])
|
||||
@@ -515,7 +549,18 @@ class GitMirrorService:
|
||||
logger.info(f"开始克隆仓库: {owner}/{repo} 到 {target_path}")
|
||||
|
||||
if custom_url:
|
||||
# 使用自定义 URL
|
||||
try:
|
||||
custom_url = _validate_custom_outbound_url(custom_url)
|
||||
except ValueError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"mirror_used": "custom",
|
||||
"attempts": 0,
|
||||
"url": custom_url,
|
||||
"status_code": 400,
|
||||
}
|
||||
|
||||
return await self._clone_with_url(custom_url, target_path, branch, depth, "custom")
|
||||
|
||||
# 确定要使用的镜像源列表
|
||||
@@ -549,8 +594,11 @@ class GitMirrorService:
|
||||
mirror: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源克隆仓库"""
|
||||
# 构建克隆 URL
|
||||
clone_prefix = mirror["clone_prefix"]
|
||||
try:
|
||||
clone_prefix = _validate_mirror_prefix(mirror["clone_prefix"], "镜像克隆前缀")
|
||||
except ValueError as e:
|
||||
return {"success": False, "error": str(e), "mirror_used": mirror.get("id"), "attempts": 0, "status_code": 400}
|
||||
|
||||
url = f"{clone_prefix}/{owner}/{repo}.git"
|
||||
|
||||
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
|
||||
|
||||
76
src/webui/utils/network_security.py
Normal file
76
src/webui/utils/network_security.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from typing import Iterable
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def _resolve_ip_addresses(hostname: str, port: int) -> set[ipaddress.IPv4Address | ipaddress.IPv6Address]:
|
||||
try:
|
||||
address_infos = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM)
|
||||
except socket.gaierror as exc:
|
||||
raise ValueError(f"无法解析主机名: {hostname}") from exc
|
||||
|
||||
resolved_addresses: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
|
||||
for _, _, _, _, sockaddr in address_infos:
|
||||
host_address = sockaddr[0]
|
||||
if not isinstance(host_address, str):
|
||||
continue
|
||||
|
||||
raw_ip = host_address.split("%", 1)[0]
|
||||
resolved_addresses.add(ipaddress.ip_address(raw_ip))
|
||||
|
||||
if not resolved_addresses:
|
||||
raise ValueError(f"无法解析主机名: {hostname}")
|
||||
|
||||
return resolved_addresses
|
||||
|
||||
|
||||
def _is_forbidden_ip_address(address: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||
return any(
|
||||
(
|
||||
address.is_loopback,
|
||||
address.is_link_local,
|
||||
address.is_multicast,
|
||||
address.is_private,
|
||||
address.is_reserved,
|
||||
address.is_unspecified,
|
||||
getattr(address, "is_site_local", False),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def validate_public_url(url: str, allowed_schemes: Iterable[str] = ("https",)) -> str:
|
||||
normalized_url = url.strip()
|
||||
if not normalized_url:
|
||||
raise ValueError("URL 不能为空")
|
||||
|
||||
parsed = urlparse(normalized_url)
|
||||
allowed_scheme_set = {scheme.lower() for scheme in allowed_schemes}
|
||||
if parsed.scheme.lower() not in allowed_scheme_set:
|
||||
allowed = ", ".join(sorted(allowed_scheme_set))
|
||||
raise ValueError(f"仅允许以下协议: {allowed}")
|
||||
|
||||
if not parsed.hostname or not parsed.netloc:
|
||||
raise ValueError("URL 缺少有效的主机名")
|
||||
|
||||
if parsed.username or parsed.password:
|
||||
raise ValueError("URL 不允许内嵌认证信息")
|
||||
|
||||
if parsed.fragment:
|
||||
raise ValueError("URL 不允许包含片段")
|
||||
|
||||
if parsed.hostname.lower() in {"localhost", "localhost.localdomain"}:
|
||||
raise ValueError("不允许访问本地主机")
|
||||
|
||||
try:
|
||||
port = parsed.port or (443 if parsed.scheme.lower() == "https" else 80)
|
||||
except ValueError as exc:
|
||||
raise ValueError("URL 端口非法") from exc
|
||||
|
||||
for address in _resolve_ip_addresses(parsed.hostname, port):
|
||||
if _is_forbidden_ip_address(address):
|
||||
raise ValueError(f"禁止访问非公网地址: {address}")
|
||||
|
||||
return normalized_url
|
||||
Reference in New Issue
Block a user