feat: 增加网络安全功能,验证公共 URL 和适配器配置路径

This commit is contained in:
DrSmoothl
2026-03-14 22:55:51 +08:00
parent 1978b097e3
commit 292f0a1d7a
12 changed files with 288 additions and 65 deletions

View File

@@ -0,0 +1,8 @@
{
"hash": "1b5cd9d5",
"configHash": "027a635a",
"lockfileHash": "36800971",
"browserHash": "e1e062e5",
"optimized": {},
"chunks": {}
}

View File

@@ -0,0 +1,3 @@
{
"type": "module"
}

View File

@@ -82,7 +82,6 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
支持三种认证方式(按优先级):
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/logs?token=xxx
"""
@@ -102,13 +101,6 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
is_authenticated = True
logger.debug("WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")

View File

@@ -4,6 +4,7 @@
import copy
import os
from pathlib import Path
from typing import Any, Annotated, Optional
import tomlkit
@@ -420,6 +421,32 @@ def _normalize_adapter_path(path: str) -> str:
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
def _get_allowed_adapter_config_roots() -> tuple[Path, ...]:
project_root = Path(PROJECT_ROOT).resolve()
return (
project_root,
(project_root.parent / "MaiBot-Napcat-Adapter").resolve(),
Path("/MaiMBot/adapters-config").resolve(),
)
def _resolve_safe_adapter_config_path(path: str) -> Path:
normalized_path = _normalize_adapter_path(path)
candidate_path = Path(normalized_path).expanduser().resolve()
if candidate_path.suffix.lower() != ".toml":
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
for allowed_root in _get_allowed_adapter_config_roots():
try:
candidate_path.relative_to(allowed_root)
return candidate_path
except ValueError:
continue
raise HTTPException(status_code=400, detail="适配器配置路径超出允许范围")
def _to_relative_path(path: str) -> str:
"""尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径"""
if not path or not os.path.isabs(path):
@@ -457,8 +484,11 @@ async def get_adapter_config_path():
if not adapter_config_path:
return {"success": True, "path": None}
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(adapter_config_path)
try:
abs_path = str(_resolve_safe_adapter_config_path(adapter_config_path))
except HTTPException:
logger.warning(f"已忽略不安全的适配器配置路径: {adapter_config_path}")
return {"success": True, "path": None}
# 检查文件是否存在并返回最后修改时间
if os.path.exists(abs_path):
@@ -497,8 +527,7 @@ async def save_adapter_config_path(data: PathBody):
else:
webui_data = {}
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
abs_path = str(_resolve_safe_adapter_config_path(path))
# 尝试转换为相对路径保存(如果文件在项目目录内)
save_path = _to_relative_path(abs_path)
@@ -528,17 +557,12 @@ async def get_adapter_config(path: str):
if not path:
raise HTTPException(status_code=400, detail="路径参数不能为空")
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
abs_path = str(_resolve_safe_adapter_config_path(path))
# 检查文件是否存在
if not os.path.exists(abs_path):
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
# 检查文件扩展名
if not abs_path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
# 读取文件内容
with open(abs_path, "r", encoding="utf-8") as f:
content = f.read()
@@ -565,12 +589,7 @@ async def save_adapter_config(data: PathBody):
if content is None:
raise HTTPException(status_code=400, detail="配置内容不能为空")
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
# 检查文件扩展名
if not abs_path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
abs_path = str(_resolve_safe_adapter_config_path(path))
# 验证 TOML 格式
try:

View File

@@ -14,6 +14,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query
from src.common.logger import get_logger
from src.config.config import CONFIG_DIR
from src.webui.dependencies import require_auth
from src.webui.utils.network_security import validate_public_url
logger = get_logger("webui")
@@ -102,7 +103,12 @@ async def _fetch_models_from_provider(
Returns:
模型列表
"""
url = f"{_normalize_url(base_url)}{endpoint}"
try:
base_url = validate_public_url(_normalize_url(base_url))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
url = f"{base_url}{endpoint}"
# 根据客户端类型设置请求头
headers = {}
@@ -266,6 +272,11 @@ async def test_provider_connection(
if not base_url:
raise HTTPException(status_code=400, detail="base_url 不能为空")
try:
base_url = validate_public_url(base_url)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
result = {
"network_ok": False,
"api_key_valid": None,

View File

@@ -106,6 +106,8 @@ async def update_mirror(
if mirror is None:
raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}")
return _mirror_to_response(mirror)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except HTTPException:
raise
except Exception as e:

View File

@@ -16,6 +16,7 @@ from .support import (
find_plugin_path_by_id,
normalize_dotted_keys,
require_plugin_token,
resolve_plugin_file_path,
)
logger = get_logger("webui.plugin_routes")
@@ -133,7 +134,7 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
schema_json_path = plugin_path / "config_schema.json"
schema_json_path = resolve_plugin_file_path(plugin_path, "config_schema.json")
if schema_json_path.exists():
try:
with open(schema_json_path, "r", encoding="utf-8") as file_obj:
@@ -142,7 +143,7 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
logger.warning(f"读取 config_schema.json 失败,回退到自动推断: {e}")
current_config: Any = {}
config_path = plugin_path / "config.toml"
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as file_obj:
current_config = tomlkit.load(file_obj)
@@ -165,7 +166,7 @@ async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] =
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
if not config_path.exists():
return {"success": True, "config": "", "message": "配置文件不存在"}
@@ -192,7 +193,7 @@ async def update_plugin_config_raw(
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
try:
tomlkit.loads(request.config)
except Exception as e:
@@ -224,7 +225,7 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
if not config_path.exists():
return {"success": True, "config": {}, "message": "配置文件不存在"}
@@ -259,7 +260,7 @@ async def update_plugin_config(
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
backup_path = backup_file(config_path, "backup")
if backup_path is not None:
logger.info(f"已备份配置文件: {backup_path}")
@@ -284,7 +285,7 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
if not config_path.exists():
return {"success": True, "message": "配置文件不存在,无需重置"}
@@ -308,7 +309,7 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
config = tomlkit.document()
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as file_obj:

View File

@@ -13,11 +13,12 @@ from .schemas import InstallPluginRequest, UninstallPluginRequest, UpdatePluginR
from .support import (
find_plugin_path_by_id,
get_plugin_candidate_paths,
get_plugins_dir,
iter_plugin_directories,
load_manifest_json,
parse_repository_url,
remove_tree,
require_plugin_token,
resolve_plugin_file_path,
resolve_installed_plugin_path,
validate_plugin_id,
)
@@ -56,7 +57,8 @@ def _infer_plugin_id(folder_name: str, manifest: dict[str, Any], manifest_path:
logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}")
manifest["id"] = plugin_id
try:
with open(manifest_path, "w", encoding="utf-8") as file_obj:
safe_manifest_path = resolve_plugin_file_path(manifest_path.parent, "_manifest.json")
with open(safe_manifest_path, "w", encoding="utf-8") as file_obj:
json.dump(manifest, file_obj, ensure_ascii=False, indent=2)
except Exception as write_error:
logger.warning(f"无法写入 ID 到 manifest: {write_error}")
@@ -91,10 +93,10 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
if not result.get("success"):
error_msg = str(result.get("error", "克隆失败"))
await update_progress(stage="error", progress=0, message="克隆仓库失败", operation="install", plugin_id=plugin_id, error=error_msg)
raise HTTPException(status_code=500, detail=error_msg)
raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg)
await update_progress(stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id)
manifest_path = target_path / "_manifest.json"
manifest_path = resolve_plugin_file_path(target_path, "_manifest.json")
if not manifest_path.exists():
remove_tree(target_path)
await update_progress(stage="error", progress=0, message="插件缺少 _manifest.json", operation="install", plugin_id=plugin_id, error="无效的插件格式")
@@ -140,7 +142,7 @@ async def uninstall_plugin(request: UninstallPluginRequest, maibot_session: Opti
raise HTTPException(status_code=404, detail="插件未安装")
await update_progress(stage="loading", progress=30, message=f"正在删除插件文件: {plugin_path}", operation="uninstall", plugin_id=plugin_id)
manifest = load_manifest_json(plugin_path / "_manifest.json")
manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json"))
plugin_name = str(manifest.get("name", plugin_id)) if manifest is not None else plugin_id
await update_progress(stage="loading", progress=50, message=f"正在删除 {plugin_name}...", operation="uninstall", plugin_id=plugin_id)
remove_tree(plugin_path)
@@ -173,7 +175,7 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
await update_progress(stage="error", progress=0, message="插件不存在", operation="update", plugin_id=plugin_id, error="插件未安装,请先安装")
raise HTTPException(status_code=404, detail="插件未安装")
manifest = load_manifest_json(plugin_path / "_manifest.json")
manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json"))
old_version = str(manifest.get("version", "unknown")) if manifest is not None else "unknown"
await update_progress(stage="loading", progress=10, message=f"当前版本: {old_version},准备更新...", operation="update", plugin_id=plugin_id)
await update_progress(stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id)
@@ -190,10 +192,10 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
if not result.get("success"):
error_msg = str(result.get("error", "克隆失败"))
await update_progress(stage="error", progress=0, message="下载新版本失败", operation="update", plugin_id=plugin_id, error=error_msg)
raise HTTPException(status_code=500, detail=error_msg)
raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg)
await update_progress(stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id)
new_manifest_path = plugin_path / "_manifest.json"
new_manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json")
if not new_manifest_path.exists():
remove_tree(plugin_path)
await update_progress(stage="error", progress=0, message="新版本缺少 _manifest.json", operation="update", plugin_id=plugin_id, error="无效的插件格式")
@@ -225,23 +227,22 @@ async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) ->
logger.info("收到获取已安装插件列表请求")
try:
plugins_dir = get_plugins_dir()
installed_plugins: list[dict[str, Any]] = []
for plugin_path in plugins_dir.iterdir():
if not plugin_path.is_dir():
continue
for plugin_path in iter_plugin_directories():
folder_name = plugin_path.name
if folder_name.startswith(".") or folder_name.startswith("__"):
continue
manifest_path = plugin_path / "_manifest.json"
manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json")
if not manifest_path.exists():
logger.warning(f"插件文件夹 {folder_name} 缺少 _manifest.json跳过")
continue
try:
with open(manifest_path, "r", encoding="utf-8") as file_obj:
manifest = json.load(file_obj)
manifest = load_manifest_json(manifest_path)
if manifest is None:
logger.warning(f"插件文件夹 {folder_name} 的 _manifest.json 不安全或无效,跳过")
continue
if "name" not in manifest or "version" not in manifest:
logger.warning(f"插件文件夹 {folder_name} 的 _manifest.json 格式无效,跳过")
continue
@@ -286,7 +287,7 @@ async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str]
return {"success": False, "error": "插件未安装"}
for readme_name in ["README.md", "readme.md", "Readme.md", "README.MD"]:
readme_path = plugin_path / readme_name
readme_path = resolve_plugin_file_path(plugin_path, readme_name)
if readme_path.exists():
try:
with open(readme_path, "r", encoding="utf-8") as file_obj:

View File

@@ -89,12 +89,6 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")

View File

@@ -44,6 +44,50 @@ def validate_safe_path(user_path: str, base_path: Path) -> Path:
return target_path
def _resolve_safe_plugin_directory(plugin_path: Path, plugins_dir: Path, strict: bool) -> Optional[Path]:
try:
if plugin_path.is_symlink():
raise HTTPException(status_code=400, detail="插件目录不能是符号链接")
resolved_plugins_dir = plugins_dir.resolve()
resolved_plugin_path = plugin_path.resolve()
resolved_plugin_path.relative_to(resolved_plugins_dir)
if not resolved_plugin_path.is_dir():
return None
return resolved_plugin_path
except HTTPException:
if strict:
raise
logger.warning(f"已跳过不安全的插件目录: {plugin_path}")
return None
except (OSError, RuntimeError, ValueError):
if strict:
raise HTTPException(status_code=400, detail="插件目录超出允许范围")
logger.warning(f"已跳过越界的插件目录: {plugin_path}")
return None
def resolve_plugin_file_path(plugin_path: Path, relative_path: str, allow_missing: bool = True) -> Path:
plugin_root = plugin_path.resolve()
target_path = plugin_root / relative_path
if target_path.exists() and target_path.is_symlink():
raise HTTPException(status_code=400, detail=f"插件文件不能是符号链接: {relative_path}")
try:
resolved_target_path = target_path.resolve()
resolved_target_path.relative_to(plugin_root)
except (OSError, RuntimeError, ValueError) as e:
raise HTTPException(status_code=400, detail=f"插件文件超出允许范围: {relative_path}") from e
if not allow_missing and not resolved_target_path.exists():
raise HTTPException(status_code=404, detail=f"插件文件不存在: {relative_path}")
return resolved_target_path
def validate_plugin_id(plugin_id: str) -> str:
if not plugin_id or not plugin_id.strip():
logger.warning("非法插件 ID: 空字符串")
@@ -164,9 +208,13 @@ def get_plugin_candidate_paths(plugin_id: str) -> tuple[Path, Path]:
def resolve_installed_plugin_path(plugin_id: str) -> Optional[Path]:
new_format_path, old_format_path = get_plugin_candidate_paths(plugin_id)
plugins_dir = get_plugins_dir()
if new_format_path.exists():
return new_format_path
return old_format_path if old_format_path.exists() else None
return _resolve_safe_plugin_directory(new_format_path, plugins_dir, strict=True)
if old_format_path.exists():
return _resolve_safe_plugin_directory(old_format_path, plugins_dir, strict=True)
return None
def parse_repository_url(repository_url: str) -> tuple[str, str, str]:
@@ -181,6 +229,16 @@ def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
if not manifest_path.exists():
return None
if manifest_path.is_symlink():
logger.warning(f"已拒绝读取符号链接 manifest: {manifest_path}")
return None
try:
manifest_path.resolve().relative_to(manifest_path.parent.resolve())
except (OSError, RuntimeError, ValueError):
logger.warning(f"已拒绝读取越界 manifest: {manifest_path}")
return None
try:
with open(manifest_path, "r", encoding="utf-8") as file_obj:
return cast(dict[str, Any], json.load(file_obj))
@@ -189,12 +247,19 @@ def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
def iter_plugin_directories() -> list[Path]:
return [path for path in get_plugins_dir().iterdir() if path.is_dir()]
plugins_dir = get_plugins_dir()
plugin_directories: list[Path] = []
for path in plugins_dir.iterdir():
safe_path = _resolve_safe_plugin_directory(path, plugins_dir, strict=False)
if safe_path is not None:
plugin_directories.append(safe_path)
return plugin_directories
def find_plugin_path_by_id(plugin_id: str) -> Optional[Path]:
for plugin_path in iter_plugin_directories():
manifest = load_manifest_json(plugin_path / "_manifest.json")
manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json")
manifest = load_manifest_json(manifest_path)
if manifest is not None and (manifest.get("id") == plugin_id or plugin_path.name == plugin_id):
return plugin_path
return None
@@ -214,6 +279,9 @@ def backup_file(file_path: Path, action: str, move_file: bool = False) -> Option
def remove_tree(path: Path) -> None:
if path.is_symlink():
raise ValueError(f"拒绝删除符号链接路径: {path}")
def remove_readonly(func: Any, target_path: str, _: Any) -> None:
os.chmod(target_path, stat.S_IWRITE)
func(target_path)

View File

@@ -10,6 +10,7 @@ import shutil
from pathlib import Path
from datetime import datetime
from src.common.logger import get_logger
from src.webui.utils.network_security import validate_public_url
logger = get_logger("webui.git_mirror")
@@ -17,6 +18,20 @@ logger = get_logger("webui.git_mirror")
_update_progress = None
def _validate_mirror_prefix(url: str, field_name: str) -> str:
try:
return validate_public_url(url)
except ValueError as e:
raise ValueError(f"{field_name} 非法: {e}") from e
def _validate_custom_outbound_url(url: str) -> str:
try:
return validate_public_url(url)
except ValueError as e:
raise ValueError(f"目标 URL 非法: {e}") from e
def set_update_progress_callback(callback):
"""设置进度更新回调函数"""
global _update_progress
@@ -200,6 +215,9 @@ class GitMirrorConfig:
if self.get_mirror_by_id(mirror_id):
raise ValueError(f"镜像源 ID 已存在: {mirror_id}")
raw_prefix = _validate_mirror_prefix(raw_prefix, "Raw 前缀")
clone_prefix = _validate_mirror_prefix(clone_prefix, "克隆前缀")
# 如果未指定优先级,使用最大优先级 + 1
if priority is None:
max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0)
@@ -241,8 +259,10 @@ class GitMirrorConfig:
if name is not None:
mirror["name"] = name
if raw_prefix is not None:
raw_prefix = _validate_mirror_prefix(raw_prefix, "Raw 前缀")
mirror["raw_prefix"] = raw_prefix
if clone_prefix is not None:
clone_prefix = _validate_mirror_prefix(clone_prefix, "克隆前缀")
mirror["clone_prefix"] = clone_prefix
if enabled is not None:
mirror["enabled"] = enabled
@@ -372,7 +392,18 @@ class GitMirrorService:
logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}")
if custom_url:
# 使用自定义 URL
try:
custom_url = _validate_custom_outbound_url(custom_url)
except ValueError as e:
return {
"success": False,
"error": str(e),
"mirror_used": "custom",
"attempts": 0,
"url": custom_url,
"status_code": 400,
}
return await self._fetch_with_url(custom_url, "custom")
# 确定要使用的镜像源列表
@@ -443,8 +474,11 @@ class GitMirrorService:
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
) -> Dict[str, Any]:
"""从指定镜像源获取文件"""
# 构建 URL
raw_prefix = mirror["raw_prefix"]
try:
raw_prefix = _validate_mirror_prefix(mirror["raw_prefix"], "镜像 Raw 前缀")
except ValueError as e:
return {"success": False, "error": str(e), "mirror_used": mirror.get("id"), "attempts": 0, "status_code": 400}
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
return await self._fetch_with_url(url, mirror["id"])
@@ -515,7 +549,18 @@ class GitMirrorService:
logger.info(f"开始克隆仓库: {owner}/{repo}{target_path}")
if custom_url:
# 使用自定义 URL
try:
custom_url = _validate_custom_outbound_url(custom_url)
except ValueError as e:
return {
"success": False,
"error": str(e),
"mirror_used": "custom",
"attempts": 0,
"url": custom_url,
"status_code": 400,
}
return await self._clone_with_url(custom_url, target_path, branch, depth, "custom")
# 确定要使用的镜像源列表
@@ -549,8 +594,11 @@ class GitMirrorService:
mirror: Dict[str, Any],
) -> Dict[str, Any]:
"""从指定镜像源克隆仓库"""
# 构建克隆 URL
clone_prefix = mirror["clone_prefix"]
try:
clone_prefix = _validate_mirror_prefix(mirror["clone_prefix"], "镜像克隆前缀")
except ValueError as e:
return {"success": False, "error": str(e), "mirror_used": mirror.get("id"), "attempts": 0, "status_code": 400}
url = f"{clone_prefix}/{owner}/{repo}.git"
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])

View File

@@ -0,0 +1,76 @@
from typing import Iterable
import ipaddress
import socket
from urllib.parse import urlparse
def _resolve_ip_addresses(hostname: str, port: int) -> set[ipaddress.IPv4Address | ipaddress.IPv6Address]:
try:
address_infos = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM)
except socket.gaierror as exc:
raise ValueError(f"无法解析主机名: {hostname}") from exc
resolved_addresses: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
for _, _, _, _, sockaddr in address_infos:
host_address = sockaddr[0]
if not isinstance(host_address, str):
continue
raw_ip = host_address.split("%", 1)[0]
resolved_addresses.add(ipaddress.ip_address(raw_ip))
if not resolved_addresses:
raise ValueError(f"无法解析主机名: {hostname}")
return resolved_addresses
def _is_forbidden_ip_address(address: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
return any(
(
address.is_loopback,
address.is_link_local,
address.is_multicast,
address.is_private,
address.is_reserved,
address.is_unspecified,
getattr(address, "is_site_local", False),
)
)
def validate_public_url(url: str, allowed_schemes: Iterable[str] = ("https",)) -> str:
normalized_url = url.strip()
if not normalized_url:
raise ValueError("URL 不能为空")
parsed = urlparse(normalized_url)
allowed_scheme_set = {scheme.lower() for scheme in allowed_schemes}
if parsed.scheme.lower() not in allowed_scheme_set:
allowed = ", ".join(sorted(allowed_scheme_set))
raise ValueError(f"仅允许以下协议: {allowed}")
if not parsed.hostname or not parsed.netloc:
raise ValueError("URL 缺少有效的主机名")
if parsed.username or parsed.password:
raise ValueError("URL 不允许内嵌认证信息")
if parsed.fragment:
raise ValueError("URL 不允许包含片段")
if parsed.hostname.lower() in {"localhost", "localhost.localdomain"}:
raise ValueError("不允许访问本地主机")
try:
port = parsed.port or (443 if parsed.scheme.lower() == "https" else 80)
except ValueError as exc:
raise ValueError("URL 端口非法") from exc
for address in _resolve_ip_addresses(parsed.hostname, port):
if _is_forbidden_ip_address(address):
raise ValueError(f"禁止访问非公网地址: {address}")
return normalized_url