diff --git a/dashboard/.vite/deps/_metadata.json b/dashboard/.vite/deps/_metadata.json new file mode 100644 index 00000000..49c6edbc --- /dev/null +++ b/dashboard/.vite/deps/_metadata.json @@ -0,0 +1,8 @@ +{ + "hash": "1b5cd9d5", + "configHash": "027a635a", + "lockfileHash": "36800971", + "browserHash": "e1e062e5", + "optimized": {}, + "chunks": {} +} \ No newline at end of file diff --git a/dashboard/.vite/deps/package.json b/dashboard/.vite/deps/package.json new file mode 100644 index 00000000..3dbc1ca5 --- /dev/null +++ b/dashboard/.vite/deps/package.json @@ -0,0 +1,3 @@ +{ + "type": "module" +} diff --git a/src/webui/logs_ws.py b/src/webui/logs_ws.py index 1d43f306..0e707803 100644 --- a/src/webui/logs_ws.py +++ b/src/webui/logs_ws.py @@ -82,7 +82,6 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None 支持三种认证方式(按优先级): 1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token) 2. Cookie 中的 maibot_session - 3. 直接使用 session token(兼容) 示例:ws://host/ws/logs?token=xxx """ @@ -102,13 +101,6 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None is_authenticated = True logger.debug("WebSocket 使用 Cookie 认证成功") - # 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式) - if not is_authenticated and token: - token_manager = get_token_manager() - if token_manager.verify_token(token): - is_authenticated = True - logger.debug("WebSocket 使用 session token 认证成功") - if not is_authenticated: logger.warning("WebSocket 连接被拒绝:认证失败") await websocket.close(code=4001, reason="认证失败,请重新登录") diff --git a/src/webui/routers/config.py b/src/webui/routers/config.py index fa196af3..78d2f1d2 100644 --- a/src/webui/routers/config.py +++ b/src/webui/routers/config.py @@ -4,6 +4,7 @@ import copy import os +from pathlib import Path from typing import Any, Annotated, Optional import tomlkit @@ -420,6 +421,32 @@ def _normalize_adapter_path(path: str) -> str: return os.path.normpath(os.path.join(PROJECT_ROOT, path)) +def _get_allowed_adapter_config_roots() -> tuple[Path, ...]: + project_root = Path(PROJECT_ROOT).resolve() + return ( + project_root, + (project_root.parent / "MaiBot-Napcat-Adapter").resolve(), + Path("/MaiMBot/adapters-config").resolve(), + ) + + +def _resolve_safe_adapter_config_path(path: str) -> Path: + normalized_path = _normalize_adapter_path(path) + candidate_path = Path(normalized_path).expanduser().resolve() + + if candidate_path.suffix.lower() != ".toml": + raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件") + + for allowed_root in _get_allowed_adapter_config_roots(): + try: + candidate_path.relative_to(allowed_root) + return candidate_path + except ValueError: + continue + + raise HTTPException(status_code=400, detail="适配器配置路径超出允许范围") + + def _to_relative_path(path: str) -> str: """尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径""" if not path or not os.path.isabs(path): @@ -457,8 +484,11 @@ async def get_adapter_config_path(): if not adapter_config_path: return {"success": True, "path": None} - # 将路径规范化为绝对路径 - abs_path = _normalize_adapter_path(adapter_config_path) + try: + abs_path = str(_resolve_safe_adapter_config_path(adapter_config_path)) + except HTTPException: + logger.warning(f"已忽略不安全的适配器配置路径: {adapter_config_path}") + return {"success": True, "path": None} # 检查文件是否存在并返回最后修改时间 if os.path.exists(abs_path): @@ -497,8 +527,7 @@ async def save_adapter_config_path(data: PathBody): else: webui_data = {} - # 将路径规范化为绝对路径 - abs_path = _normalize_adapter_path(path) + abs_path = str(_resolve_safe_adapter_config_path(path)) # 尝试转换为相对路径保存(如果文件在项目目录内) save_path = _to_relative_path(abs_path) @@ -528,17 +557,12 @@ async def get_adapter_config(path: str): if not path: raise HTTPException(status_code=400, detail="路径参数不能为空") - # 将路径规范化为绝对路径 - abs_path = _normalize_adapter_path(path) + abs_path = str(_resolve_safe_adapter_config_path(path)) # 检查文件是否存在 if not os.path.exists(abs_path): raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}") - # 检查文件扩展名 - if not abs_path.endswith(".toml"): - raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件") - # 读取文件内容 with open(abs_path, "r", encoding="utf-8") as f: content = f.read() @@ -565,12 +589,7 @@ async def save_adapter_config(data: PathBody): if content is None: raise HTTPException(status_code=400, detail="配置内容不能为空") - # 将路径规范化为绝对路径 - abs_path = _normalize_adapter_path(path) - - # 检查文件扩展名 - if not abs_path.endswith(".toml"): - raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件") + abs_path = str(_resolve_safe_adapter_config_path(path)) # 验证 TOML 格式 try: diff --git a/src/webui/routers/model.py b/src/webui/routers/model.py index a1382a3a..1e16ac4d 100644 --- a/src/webui/routers/model.py +++ b/src/webui/routers/model.py @@ -14,6 +14,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query from src.common.logger import get_logger from src.config.config import CONFIG_DIR from src.webui.dependencies import require_auth +from src.webui.utils.network_security import validate_public_url logger = get_logger("webui") @@ -102,7 +103,12 @@ async def _fetch_models_from_provider( Returns: 模型列表 """ - url = f"{_normalize_url(base_url)}{endpoint}" + try: + base_url = validate_public_url(_normalize_url(base_url)) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + + url = f"{base_url}{endpoint}" # 根据客户端类型设置请求头 headers = {} @@ -266,6 +272,11 @@ async def test_provider_connection( if not base_url: raise HTTPException(status_code=400, detail="base_url 不能为空") + try: + base_url = validate_public_url(base_url) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + result = { "network_ok": False, "api_key_valid": None, diff --git a/src/webui/routers/plugin/catalog.py b/src/webui/routers/plugin/catalog.py index 124c6fac..4a49d048 100644 --- a/src/webui/routers/plugin/catalog.py +++ b/src/webui/routers/plugin/catalog.py @@ -106,6 +106,8 @@ async def update_mirror( if mirror is None: raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}") return _mirror_to_response(mirror) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e except HTTPException: raise except Exception as e: diff --git a/src/webui/routers/plugin/config_routes.py b/src/webui/routers/plugin/config_routes.py index b2f4e254..b9c32f2e 100644 --- a/src/webui/routers/plugin/config_routes.py +++ b/src/webui/routers/plugin/config_routes.py @@ -16,6 +16,7 @@ from .support import ( find_plugin_path_by_id, normalize_dotted_keys, require_plugin_token, + resolve_plugin_file_path, ) logger = get_logger("webui.plugin_routes") @@ -133,7 +134,7 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] if plugin_path is None: raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - schema_json_path = plugin_path / "config_schema.json" + schema_json_path = resolve_plugin_file_path(plugin_path, "config_schema.json") if schema_json_path.exists(): try: with open(schema_json_path, "r", encoding="utf-8") as file_obj: @@ -142,7 +143,7 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] logger.warning(f"读取 config_schema.json 失败,回退到自动推断: {e}") current_config: Any = {} - config_path = plugin_path / "config.toml" + config_path = resolve_plugin_file_path(plugin_path, "config.toml") if config_path.exists(): with open(config_path, "r", encoding="utf-8") as file_obj: current_config = tomlkit.load(file_obj) @@ -165,7 +166,7 @@ async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] = if plugin_path is None: raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - config_path = plugin_path / "config.toml" + config_path = resolve_plugin_file_path(plugin_path, "config.toml") if not config_path.exists(): return {"success": True, "config": "", "message": "配置文件不存在"} @@ -192,7 +193,7 @@ async def update_plugin_config_raw( if plugin_path is None: raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - config_path = plugin_path / "config.toml" + config_path = resolve_plugin_file_path(plugin_path, "config.toml") try: tomlkit.loads(request.config) except Exception as e: @@ -224,7 +225,7 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook if plugin_path is None: raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - config_path = plugin_path / "config.toml" + config_path = resolve_plugin_file_path(plugin_path, "config.toml") if not config_path.exists(): return {"success": True, "config": {}, "message": "配置文件不存在"} @@ -259,7 +260,7 @@ async def update_plugin_config( if plugin_path is None: raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - config_path = plugin_path / "config.toml" + config_path = resolve_plugin_file_path(plugin_path, "config.toml") backup_path = backup_file(config_path, "backup") if backup_path is not None: logger.info(f"已备份配置文件: {backup_path}") @@ -284,7 +285,7 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co if plugin_path is None: raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - config_path = plugin_path / "config.toml" + config_path = resolve_plugin_file_path(plugin_path, "config.toml") if not config_path.exists(): return {"success": True, "message": "配置文件不存在,无需重置"} @@ -308,7 +309,7 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N if plugin_path is None: raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - config_path = plugin_path / "config.toml" + config_path = resolve_plugin_file_path(plugin_path, "config.toml") config = tomlkit.document() if config_path.exists(): with open(config_path, "r", encoding="utf-8") as file_obj: diff --git a/src/webui/routers/plugin/management.py b/src/webui/routers/plugin/management.py index cbf2920b..7a725c94 100644 --- a/src/webui/routers/plugin/management.py +++ b/src/webui/routers/plugin/management.py @@ -13,11 +13,12 @@ from .schemas import InstallPluginRequest, UninstallPluginRequest, UpdatePluginR from .support import ( find_plugin_path_by_id, get_plugin_candidate_paths, - get_plugins_dir, + iter_plugin_directories, load_manifest_json, parse_repository_url, remove_tree, require_plugin_token, + resolve_plugin_file_path, resolve_installed_plugin_path, validate_plugin_id, ) @@ -56,7 +57,8 @@ def _infer_plugin_id(folder_name: str, manifest: dict[str, Any], manifest_path: logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}") manifest["id"] = plugin_id try: - with open(manifest_path, "w", encoding="utf-8") as file_obj: + safe_manifest_path = resolve_plugin_file_path(manifest_path.parent, "_manifest.json") + with open(safe_manifest_path, "w", encoding="utf-8") as file_obj: json.dump(manifest, file_obj, ensure_ascii=False, indent=2) except Exception as write_error: logger.warning(f"无法写入 ID 到 manifest: {write_error}") @@ -91,10 +93,10 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional if not result.get("success"): error_msg = str(result.get("error", "克隆失败")) await update_progress(stage="error", progress=0, message="克隆仓库失败", operation="install", plugin_id=plugin_id, error=error_msg) - raise HTTPException(status_code=500, detail=error_msg) + raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg) await update_progress(stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id) - manifest_path = target_path / "_manifest.json" + manifest_path = resolve_plugin_file_path(target_path, "_manifest.json") if not manifest_path.exists(): remove_tree(target_path) await update_progress(stage="error", progress=0, message="插件缺少 _manifest.json", operation="install", plugin_id=plugin_id, error="无效的插件格式") @@ -140,7 +142,7 @@ async def uninstall_plugin(request: UninstallPluginRequest, maibot_session: Opti raise HTTPException(status_code=404, detail="插件未安装") await update_progress(stage="loading", progress=30, message=f"正在删除插件文件: {plugin_path}", operation="uninstall", plugin_id=plugin_id) - manifest = load_manifest_json(plugin_path / "_manifest.json") + manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json")) plugin_name = str(manifest.get("name", plugin_id)) if manifest is not None else plugin_id await update_progress(stage="loading", progress=50, message=f"正在删除 {plugin_name}...", operation="uninstall", plugin_id=plugin_id) remove_tree(plugin_path) @@ -173,7 +175,7 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s await update_progress(stage="error", progress=0, message="插件不存在", operation="update", plugin_id=plugin_id, error="插件未安装,请先安装") raise HTTPException(status_code=404, detail="插件未安装") - manifest = load_manifest_json(plugin_path / "_manifest.json") + manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json")) old_version = str(manifest.get("version", "unknown")) if manifest is not None else "unknown" await update_progress(stage="loading", progress=10, message=f"当前版本: {old_version},准备更新...", operation="update", plugin_id=plugin_id) await update_progress(stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id) @@ -190,10 +192,10 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s if not result.get("success"): error_msg = str(result.get("error", "克隆失败")) await update_progress(stage="error", progress=0, message="下载新版本失败", operation="update", plugin_id=plugin_id, error=error_msg) - raise HTTPException(status_code=500, detail=error_msg) + raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg) await update_progress(stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id) - new_manifest_path = plugin_path / "_manifest.json" + new_manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json") if not new_manifest_path.exists(): remove_tree(plugin_path) await update_progress(stage="error", progress=0, message="新版本缺少 _manifest.json", operation="update", plugin_id=plugin_id, error="无效的插件格式") @@ -225,23 +227,22 @@ async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) -> logger.info("收到获取已安装插件列表请求") try: - plugins_dir = get_plugins_dir() installed_plugins: list[dict[str, Any]] = [] - for plugin_path in plugins_dir.iterdir(): - if not plugin_path.is_dir(): - continue + for plugin_path in iter_plugin_directories(): folder_name = plugin_path.name if folder_name.startswith(".") or folder_name.startswith("__"): continue - manifest_path = plugin_path / "_manifest.json" + manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json") if not manifest_path.exists(): logger.warning(f"插件文件夹 {folder_name} 缺少 _manifest.json,跳过") continue try: - with open(manifest_path, "r", encoding="utf-8") as file_obj: - manifest = json.load(file_obj) + manifest = load_manifest_json(manifest_path) + if manifest is None: + logger.warning(f"插件文件夹 {folder_name} 的 _manifest.json 不安全或无效,跳过") + continue if "name" not in manifest or "version" not in manifest: logger.warning(f"插件文件夹 {folder_name} 的 _manifest.json 格式无效,跳过") continue @@ -286,7 +287,7 @@ async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str] return {"success": False, "error": "插件未安装"} for readme_name in ["README.md", "readme.md", "Readme.md", "README.MD"]: - readme_path = plugin_path / readme_name + readme_path = resolve_plugin_file_path(plugin_path, readme_name) if readme_path.exists(): try: with open(readme_path, "r", encoding="utf-8") as file_obj: diff --git a/src/webui/routers/plugin/progress.py b/src/webui/routers/plugin/progress.py index 2d945df5..f12406b7 100644 --- a/src/webui/routers/plugin/progress.py +++ b/src/webui/routers/plugin/progress.py @@ -89,12 +89,6 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = is_authenticated = True logger.debug("插件进度 WebSocket 使用 Cookie 认证成功") - if not is_authenticated and token: - token_manager = get_token_manager() - if token_manager.verify_token(token): - is_authenticated = True - logger.debug("插件进度 WebSocket 使用 session token 认证成功") - if not is_authenticated: logger.warning("插件进度 WebSocket 连接被拒绝:认证失败") await websocket.close(code=4001, reason="认证失败,请重新登录") diff --git a/src/webui/routers/plugin/support.py b/src/webui/routers/plugin/support.py index f0078736..46814709 100644 --- a/src/webui/routers/plugin/support.py +++ b/src/webui/routers/plugin/support.py @@ -44,6 +44,50 @@ def validate_safe_path(user_path: str, base_path: Path) -> Path: return target_path +def _resolve_safe_plugin_directory(plugin_path: Path, plugins_dir: Path, strict: bool) -> Optional[Path]: + try: + if plugin_path.is_symlink(): + raise HTTPException(status_code=400, detail="插件目录不能是符号链接") + + resolved_plugins_dir = plugins_dir.resolve() + resolved_plugin_path = plugin_path.resolve() + resolved_plugin_path.relative_to(resolved_plugins_dir) + + if not resolved_plugin_path.is_dir(): + return None + + return resolved_plugin_path + except HTTPException: + if strict: + raise + logger.warning(f"已跳过不安全的插件目录: {plugin_path}") + return None + except (OSError, RuntimeError, ValueError): + if strict: + raise HTTPException(status_code=400, detail="插件目录超出允许范围") + logger.warning(f"已跳过越界的插件目录: {plugin_path}") + return None + + +def resolve_plugin_file_path(plugin_path: Path, relative_path: str, allow_missing: bool = True) -> Path: + plugin_root = plugin_path.resolve() + target_path = plugin_root / relative_path + + if target_path.exists() and target_path.is_symlink(): + raise HTTPException(status_code=400, detail=f"插件文件不能是符号链接: {relative_path}") + + try: + resolved_target_path = target_path.resolve() + resolved_target_path.relative_to(plugin_root) + except (OSError, RuntimeError, ValueError) as e: + raise HTTPException(status_code=400, detail=f"插件文件超出允许范围: {relative_path}") from e + + if not allow_missing and not resolved_target_path.exists(): + raise HTTPException(status_code=404, detail=f"插件文件不存在: {relative_path}") + + return resolved_target_path + + def validate_plugin_id(plugin_id: str) -> str: if not plugin_id or not plugin_id.strip(): logger.warning("非法插件 ID: 空字符串") @@ -164,9 +208,13 @@ def get_plugin_candidate_paths(plugin_id: str) -> tuple[Path, Path]: def resolve_installed_plugin_path(plugin_id: str) -> Optional[Path]: new_format_path, old_format_path = get_plugin_candidate_paths(plugin_id) + plugins_dir = get_plugins_dir() + if new_format_path.exists(): - return new_format_path - return old_format_path if old_format_path.exists() else None + return _resolve_safe_plugin_directory(new_format_path, plugins_dir, strict=True) + if old_format_path.exists(): + return _resolve_safe_plugin_directory(old_format_path, plugins_dir, strict=True) + return None def parse_repository_url(repository_url: str) -> tuple[str, str, str]: @@ -181,6 +229,16 @@ def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]: if not manifest_path.exists(): return None + if manifest_path.is_symlink(): + logger.warning(f"已拒绝读取符号链接 manifest: {manifest_path}") + return None + + try: + manifest_path.resolve().relative_to(manifest_path.parent.resolve()) + except (OSError, RuntimeError, ValueError): + logger.warning(f"已拒绝读取越界 manifest: {manifest_path}") + return None + try: with open(manifest_path, "r", encoding="utf-8") as file_obj: return cast(dict[str, Any], json.load(file_obj)) @@ -189,12 +247,19 @@ def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]: def iter_plugin_directories() -> list[Path]: - return [path for path in get_plugins_dir().iterdir() if path.is_dir()] + plugins_dir = get_plugins_dir() + plugin_directories: list[Path] = [] + for path in plugins_dir.iterdir(): + safe_path = _resolve_safe_plugin_directory(path, plugins_dir, strict=False) + if safe_path is not None: + plugin_directories.append(safe_path) + return plugin_directories def find_plugin_path_by_id(plugin_id: str) -> Optional[Path]: for plugin_path in iter_plugin_directories(): - manifest = load_manifest_json(plugin_path / "_manifest.json") + manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json") + manifest = load_manifest_json(manifest_path) if manifest is not None and (manifest.get("id") == plugin_id or plugin_path.name == plugin_id): return plugin_path return None @@ -214,6 +279,9 @@ def backup_file(file_path: Path, action: str, move_file: bool = False) -> Option def remove_tree(path: Path) -> None: + if path.is_symlink(): + raise ValueError(f"拒绝删除符号链接路径: {path}") + def remove_readonly(func: Any, target_path: str, _: Any) -> None: os.chmod(target_path, stat.S_IWRITE) func(target_path) diff --git a/src/webui/services/git_mirror_service.py b/src/webui/services/git_mirror_service.py index a6a9b1bc..5052869d 100644 --- a/src/webui/services/git_mirror_service.py +++ b/src/webui/services/git_mirror_service.py @@ -10,6 +10,7 @@ import shutil from pathlib import Path from datetime import datetime from src.common.logger import get_logger +from src.webui.utils.network_security import validate_public_url logger = get_logger("webui.git_mirror") @@ -17,6 +18,20 @@ logger = get_logger("webui.git_mirror") _update_progress = None +def _validate_mirror_prefix(url: str, field_name: str) -> str: + try: + return validate_public_url(url) + except ValueError as e: + raise ValueError(f"{field_name} 非法: {e}") from e + + +def _validate_custom_outbound_url(url: str) -> str: + try: + return validate_public_url(url) + except ValueError as e: + raise ValueError(f"目标 URL 非法: {e}") from e + + def set_update_progress_callback(callback): """设置进度更新回调函数""" global _update_progress @@ -200,6 +215,9 @@ class GitMirrorConfig: if self.get_mirror_by_id(mirror_id): raise ValueError(f"镜像源 ID 已存在: {mirror_id}") + raw_prefix = _validate_mirror_prefix(raw_prefix, "Raw 前缀") + clone_prefix = _validate_mirror_prefix(clone_prefix, "克隆前缀") + # 如果未指定优先级,使用最大优先级 + 1 if priority is None: max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0) @@ -241,8 +259,10 @@ class GitMirrorConfig: if name is not None: mirror["name"] = name if raw_prefix is not None: + raw_prefix = _validate_mirror_prefix(raw_prefix, "Raw 前缀") mirror["raw_prefix"] = raw_prefix if clone_prefix is not None: + clone_prefix = _validate_mirror_prefix(clone_prefix, "克隆前缀") mirror["clone_prefix"] = clone_prefix if enabled is not None: mirror["enabled"] = enabled @@ -372,7 +392,18 @@ class GitMirrorService: logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}") if custom_url: - # 使用自定义 URL + try: + custom_url = _validate_custom_outbound_url(custom_url) + except ValueError as e: + return { + "success": False, + "error": str(e), + "mirror_used": "custom", + "attempts": 0, + "url": custom_url, + "status_code": 400, + } + return await self._fetch_with_url(custom_url, "custom") # 确定要使用的镜像源列表 @@ -443,8 +474,11 @@ class GitMirrorService: self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any] ) -> Dict[str, Any]: """从指定镜像源获取文件""" - # 构建 URL - raw_prefix = mirror["raw_prefix"] + try: + raw_prefix = _validate_mirror_prefix(mirror["raw_prefix"], "镜像 Raw 前缀") + except ValueError as e: + return {"success": False, "error": str(e), "mirror_used": mirror.get("id"), "attempts": 0, "status_code": 400} + url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}" return await self._fetch_with_url(url, mirror["id"]) @@ -515,7 +549,18 @@ class GitMirrorService: logger.info(f"开始克隆仓库: {owner}/{repo} 到 {target_path}") if custom_url: - # 使用自定义 URL + try: + custom_url = _validate_custom_outbound_url(custom_url) + except ValueError as e: + return { + "success": False, + "error": str(e), + "mirror_used": "custom", + "attempts": 0, + "url": custom_url, + "status_code": 400, + } + return await self._clone_with_url(custom_url, target_path, branch, depth, "custom") # 确定要使用的镜像源列表 @@ -549,8 +594,11 @@ class GitMirrorService: mirror: Dict[str, Any], ) -> Dict[str, Any]: """从指定镜像源克隆仓库""" - # 构建克隆 URL - clone_prefix = mirror["clone_prefix"] + try: + clone_prefix = _validate_mirror_prefix(mirror["clone_prefix"], "镜像克隆前缀") + except ValueError as e: + return {"success": False, "error": str(e), "mirror_used": mirror.get("id"), "attempts": 0, "status_code": 400} + url = f"{clone_prefix}/{owner}/{repo}.git" return await self._clone_with_url(url, target_path, branch, depth, mirror["id"]) diff --git a/src/webui/utils/network_security.py b/src/webui/utils/network_security.py new file mode 100644 index 00000000..1cb7e4ea --- /dev/null +++ b/src/webui/utils/network_security.py @@ -0,0 +1,76 @@ +from typing import Iterable + +import ipaddress +import socket + +from urllib.parse import urlparse + + +def _resolve_ip_addresses(hostname: str, port: int) -> set[ipaddress.IPv4Address | ipaddress.IPv6Address]: + try: + address_infos = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM) + except socket.gaierror as exc: + raise ValueError(f"无法解析主机名: {hostname}") from exc + + resolved_addresses: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set() + for _, _, _, _, sockaddr in address_infos: + host_address = sockaddr[0] + if not isinstance(host_address, str): + continue + + raw_ip = host_address.split("%", 1)[0] + resolved_addresses.add(ipaddress.ip_address(raw_ip)) + + if not resolved_addresses: + raise ValueError(f"无法解析主机名: {hostname}") + + return resolved_addresses + + +def _is_forbidden_ip_address(address: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: + return any( + ( + address.is_loopback, + address.is_link_local, + address.is_multicast, + address.is_private, + address.is_reserved, + address.is_unspecified, + getattr(address, "is_site_local", False), + ) + ) + + +def validate_public_url(url: str, allowed_schemes: Iterable[str] = ("https",)) -> str: + normalized_url = url.strip() + if not normalized_url: + raise ValueError("URL 不能为空") + + parsed = urlparse(normalized_url) + allowed_scheme_set = {scheme.lower() for scheme in allowed_schemes} + if parsed.scheme.lower() not in allowed_scheme_set: + allowed = ", ".join(sorted(allowed_scheme_set)) + raise ValueError(f"仅允许以下协议: {allowed}") + + if not parsed.hostname or not parsed.netloc: + raise ValueError("URL 缺少有效的主机名") + + if parsed.username or parsed.password: + raise ValueError("URL 不允许内嵌认证信息") + + if parsed.fragment: + raise ValueError("URL 不允许包含片段") + + if parsed.hostname.lower() in {"localhost", "localhost.localdomain"}: + raise ValueError("不允许访问本地主机") + + try: + port = parsed.port or (443 if parsed.scheme.lower() == "https" else 80) + except ValueError as exc: + raise ValueError("URL 端口非法") from exc + + for address in _resolve_ip_addresses(parsed.hostname, port): + if _is_forbidden_ip_address(address): + raise ValueError(f"禁止访问非公网地址: {address}") + + return normalized_url \ No newline at end of file