添加 WebSocket 认证模块,支持临时 token 认证机制,增强安全性并解决 Cookie 不可用问题
This commit is contained in:
@@ -3,6 +3,7 @@ from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Dict, Any, get_origin
|
||||
from pathlib import Path
|
||||
import json
|
||||
import re
|
||||
from src.common.logger import get_logger
|
||||
from src.common.toml_utils import save_toml_with_format
|
||||
from src.config.config import MMC_VERSION
|
||||
@@ -34,6 +35,85 @@ def get_token_from_cookie_or_header(
|
||||
return None
|
||||
|
||||
|
||||
def validate_safe_path(user_path: str, base_path: Path) -> Path:
|
||||
"""
|
||||
验证用户提供的路径是否安全,防止路径遍历攻击
|
||||
|
||||
Args:
|
||||
user_path: 用户输入的路径(相对路径)
|
||||
base_path: 允许的基础目录
|
||||
|
||||
Returns:
|
||||
安全的绝对路径
|
||||
|
||||
Raises:
|
||||
HTTPException: 如果检测到路径遍历攻击
|
||||
"""
|
||||
# 规范化基础路径
|
||||
base_resolved = base_path.resolve()
|
||||
|
||||
# 检查用户路径是否包含可疑字符
|
||||
# 禁止: .., 绝对路径开头, 空字节等
|
||||
if any(pattern in user_path for pattern in ["..", "\x00"]):
|
||||
logger.warning(f"检测到可疑路径: {user_path}")
|
||||
raise HTTPException(status_code=400, detail="路径包含非法字符")
|
||||
|
||||
# 检查是否为绝对路径(Windows 和 Unix)
|
||||
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
|
||||
logger.warning(f"检测到绝对路径: {user_path}")
|
||||
raise HTTPException(status_code=400, detail="不允许使用绝对路径")
|
||||
|
||||
# 构建目标路径并解析
|
||||
target_path = (base_path / user_path).resolve()
|
||||
|
||||
# 验证解析后的路径仍在基础目录内
|
||||
try:
|
||||
target_path.relative_to(base_resolved)
|
||||
except ValueError as e:
|
||||
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
|
||||
raise HTTPException(status_code=400, detail="路径超出允许范围") from e
|
||||
|
||||
return target_path
|
||||
|
||||
|
||||
def validate_plugin_id(plugin_id: str) -> str:
|
||||
"""
|
||||
验证插件 ID 格式是否安全
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID (支持 author.name 格式,允许中文)
|
||||
|
||||
Returns:
|
||||
验证通过的插件 ID
|
||||
|
||||
Raises:
|
||||
HTTPException: 如果插件 ID 格式不安全
|
||||
"""
|
||||
# 禁止空字符串
|
||||
if not plugin_id or not plugin_id.strip():
|
||||
logger.warning("非法插件 ID: 空字符串")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能为空")
|
||||
|
||||
# 禁止危险字符: 路径分隔符、空字节、控制字符等
|
||||
dangerous_patterns = ["/", "\\", "\x00", "..", "\n", "\r", "\t"]
|
||||
for pattern in dangerous_patterns:
|
||||
if pattern in plugin_id:
|
||||
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
|
||||
|
||||
# 禁止以点开头或结尾(防止隐藏文件和路径问题)
|
||||
if plugin_id.startswith(".") or plugin_id.endswith("."):
|
||||
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
|
||||
|
||||
# 禁止特殊名称
|
||||
if plugin_id in (".", ".."):
|
||||
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
|
||||
|
||||
return plugin_id
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||
"""
|
||||
解析版本号字符串
|
||||
@@ -468,17 +548,16 @@ async def fetch_raw_file(
|
||||
|
||||
支持多镜像源自动切换和错误重试
|
||||
|
||||
注意:此接口可公开访问,用于获取插件仓库等公开资源
|
||||
需要认证才能访问,防止被滥用作为 SSRF 跳板
|
||||
"""
|
||||
# Token 验证(可选,用于日志记录)
|
||||
# Token 验证(强制)
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
is_authenticated = token and token_manager.verify_token(token)
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
|
||||
# 对于公开仓库的访问,不强制要求认证
|
||||
# 只在日志中记录是否认证
|
||||
logger.info(
|
||||
f"收到获取 Raw 文件请求 (认证: {is_authenticated}): "
|
||||
f"收到获取 Raw 文件请求: "
|
||||
f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}"
|
||||
)
|
||||
|
||||
@@ -564,10 +643,10 @@ async def clone_repository(
|
||||
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
|
||||
|
||||
try:
|
||||
# TODO: 验证 target_path 的安全性,防止路径遍历攻击
|
||||
# TODO: 确定实际的插件目录基路径
|
||||
base_plugin_path = Path("./plugins") # 临时路径
|
||||
target_path = base_plugin_path / request.target_path
|
||||
# 验证 target_path 的安全性,防止路径遍历攻击
|
||||
base_plugin_path = Path("./plugins").resolve()
|
||||
base_plugin_path.mkdir(exist_ok=True)
|
||||
target_path = validate_safe_path(request.target_path, base_plugin_path)
|
||||
|
||||
service = get_git_mirror_service()
|
||||
result = await service.clone_repository(
|
||||
@@ -607,13 +686,16 @@ async def install_plugin(
|
||||
logger.info(f"收到安装插件请求: {request.plugin_id}")
|
||||
|
||||
try:
|
||||
# 验证插件 ID 格式安全性
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
|
||||
# 推送进度:开始安装
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=5,
|
||||
message=f"开始安装插件: {request.plugin_id}",
|
||||
message=f"开始安装插件: {plugin_id}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 1. 解析仓库 URL
|
||||
@@ -634,27 +716,28 @@ async def install_plugin(
|
||||
progress=10,
|
||||
message=f"解析仓库信息: {owner}/{repo}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 2. 确定插件安装路径
|
||||
plugins_dir = Path("plugins")
|
||||
plugins_dir = Path("plugins").resolve()
|
||||
plugins_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 将插件 ID 中的点替换为下划线作为文件夹名称(避免文件系统问题)
|
||||
# 例如: SengokuCola.Mute-Plugin -> SengokuCola_Mute-Plugin
|
||||
folder_name = request.plugin_id.replace(".", "_")
|
||||
target_path = plugins_dir / folder_name
|
||||
folder_name = plugin_id.replace(".", "_")
|
||||
# 使用安全路径验证,防止路径遍历
|
||||
target_path = validate_safe_path(folder_name, plugins_dir)
|
||||
|
||||
# 检查插件是否已安装(需要检查两种格式:新格式下划线和旧格式点)
|
||||
old_format_path = plugins_dir / request.plugin_id
|
||||
old_format_path = plugins_dir / plugin_id
|
||||
if target_path.exists() or old_format_path.exists():
|
||||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message="插件已存在",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="插件已安装,请先卸载",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="插件已安装")
|
||||
@@ -664,7 +747,7 @@ async def install_plugin(
|
||||
progress=15,
|
||||
message=f"准备克隆到: {target_path}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
|
||||
@@ -693,14 +776,14 @@ async def install_plugin(
|
||||
progress=0,
|
||||
message="克隆仓库失败",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
# 4. 验证插件完整性
|
||||
await update_progress(
|
||||
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=request.plugin_id
|
||||
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id
|
||||
)
|
||||
|
||||
manifest_path = target_path / "_manifest.json"
|
||||
@@ -715,14 +798,14 @@ async def install_plugin(
|
||||
progress=0,
|
||||
message="插件缺少 _manifest.json",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="无效的插件格式",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||
|
||||
# 5. 读取并验证 manifest
|
||||
await update_progress(
|
||||
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=request.plugin_id
|
||||
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -739,7 +822,7 @@ async def install_plugin(
|
||||
|
||||
# 将插件 ID 写入 manifest(用于后续准确识别)
|
||||
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
|
||||
manifest["id"] = request.plugin_id
|
||||
manifest["id"] = plugin_id
|
||||
with open(manifest_path, "w", encoding="utf-8") as f:
|
||||
json_module.dump(manifest, f, ensure_ascii=False, indent=2)
|
||||
|
||||
@@ -754,7 +837,7 @@ async def install_plugin(
|
||||
progress=0,
|
||||
message="_manifest.json 无效",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=str(e),
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||
@@ -765,13 +848,13 @@ async def install_plugin(
|
||||
progress=100,
|
||||
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "插件安装成功",
|
||||
"plugin_id": request.plugin_id,
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": manifest["name"],
|
||||
"version": manifest["version"],
|
||||
"path": str(target_path),
|
||||
@@ -787,7 +870,7 @@ async def install_plugin(
|
||||
progress=0,
|
||||
message="安装失败",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
@@ -814,22 +897,26 @@ async def uninstall_plugin(
|
||||
logger.info(f"收到卸载插件请求: {request.plugin_id}")
|
||||
|
||||
try:
|
||||
# 验证插件 ID 格式安全性
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
|
||||
# 推送进度:开始卸载
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=10,
|
||||
message=f"开始卸载插件: {request.plugin_id}",
|
||||
message=f"开始卸载插件: {plugin_id}",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 1. 检查插件是否存在(支持新旧两种格式)
|
||||
plugins_dir = Path("plugins")
|
||||
plugins_dir = Path("plugins").resolve()
|
||||
# 新格式:下划线
|
||||
folder_name = request.plugin_id.replace(".", "_")
|
||||
plugin_path = plugins_dir / folder_name
|
||||
folder_name = plugin_id.replace(".", "_")
|
||||
# 使用安全路径验证
|
||||
plugin_path = validate_safe_path(folder_name, plugins_dir)
|
||||
# 旧格式:点
|
||||
old_format_path = plugins_dir / request.plugin_id
|
||||
old_format_path = validate_safe_path(plugin_id, plugins_dir)
|
||||
|
||||
# 优先使用新格式,如果不存在则尝试旧格式
|
||||
if not plugin_path.exists():
|
||||
@@ -841,7 +928,7 @@ async def uninstall_plugin(
|
||||
progress=0,
|
||||
message="插件不存在",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="插件未安装或已被删除",
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="插件未安装")
|
||||
@@ -851,12 +938,12 @@ async def uninstall_plugin(
|
||||
progress=30,
|
||||
message=f"正在删除插件文件: {plugin_path}",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 2. 读取插件信息(用于日志)
|
||||
manifest_path = plugin_path / "_manifest.json"
|
||||
plugin_name = request.plugin_id
|
||||
plugin_name = plugin_id
|
||||
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
@@ -864,7 +951,7 @@ async def uninstall_plugin(
|
||||
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json_module.load(f)
|
||||
plugin_name = manifest.get("name", request.plugin_id)
|
||||
plugin_name = manifest.get("name", plugin_id)
|
||||
except Exception:
|
||||
pass # 如果读取失败,使用插件 ID 作为名称
|
||||
|
||||
@@ -873,7 +960,7 @@ async def uninstall_plugin(
|
||||
progress=50,
|
||||
message=f"正在删除 {plugin_name}...",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 3. 删除插件目录
|
||||
@@ -889,7 +976,7 @@ async def uninstall_plugin(
|
||||
|
||||
shutil.rmtree(plugin_path, onerror=remove_readonly)
|
||||
|
||||
logger.info(f"成功卸载插件: {request.plugin_id} ({plugin_name})")
|
||||
logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})")
|
||||
|
||||
# 4. 推送成功状态
|
||||
await update_progress(
|
||||
@@ -897,10 +984,10 @@ async def uninstall_plugin(
|
||||
progress=100,
|
||||
message=f"成功卸载插件: {plugin_name}",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
return {"success": True, "message": "插件卸载成功", "plugin_id": request.plugin_id, "plugin_name": plugin_name}
|
||||
return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -912,7 +999,7 @@ async def uninstall_plugin(
|
||||
progress=0,
|
||||
message="卸载失败",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="权限不足,无法删除插件文件",
|
||||
)
|
||||
|
||||
@@ -925,7 +1012,7 @@ async def uninstall_plugin(
|
||||
progress=0,
|
||||
message="卸载失败",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
@@ -952,22 +1039,26 @@ async def update_plugin(
|
||||
logger.info(f"收到更新插件请求: {request.plugin_id}")
|
||||
|
||||
try:
|
||||
# 验证插件 ID 格式安全性
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
|
||||
# 推送进度:开始更新
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=5,
|
||||
message=f"开始更新插件: {request.plugin_id}",
|
||||
message=f"开始更新插件: {plugin_id}",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 1. 检查插件是否已安装(支持新旧两种格式)
|
||||
plugins_dir = Path("plugins")
|
||||
plugins_dir = Path("plugins").resolve()
|
||||
# 新格式:下划线
|
||||
folder_name = request.plugin_id.replace(".", "_")
|
||||
plugin_path = plugins_dir / folder_name
|
||||
folder_name = plugin_id.replace(".", "_")
|
||||
# 使用安全路径验证
|
||||
plugin_path = validate_safe_path(folder_name, plugins_dir)
|
||||
# 旧格式:点
|
||||
old_format_path = plugins_dir / request.plugin_id
|
||||
old_format_path = validate_safe_path(plugin_id, plugins_dir)
|
||||
|
||||
# 优先使用新格式,如果不存在则尝试旧格式
|
||||
if not plugin_path.exists():
|
||||
@@ -979,7 +1070,7 @@ async def update_plugin(
|
||||
progress=0,
|
||||
message="插件不存在",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="插件未安装,请先安装",
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="插件未安装")
|
||||
@@ -1003,12 +1094,12 @@ async def update_plugin(
|
||||
progress=10,
|
||||
message=f"当前版本: {old_version},准备更新...",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 3. 删除旧版本
|
||||
await update_progress(
|
||||
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=request.plugin_id
|
||||
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id
|
||||
)
|
||||
|
||||
import shutil
|
||||
@@ -1023,7 +1114,7 @@ async def update_plugin(
|
||||
|
||||
shutil.rmtree(plugin_path, onerror=remove_readonly)
|
||||
|
||||
logger.info(f"已删除旧版本: {request.plugin_id} v{old_version}")
|
||||
logger.info(f"已删除旧版本: {plugin_id} v{old_version}")
|
||||
|
||||
# 4. 解析仓库 URL
|
||||
await update_progress(
|
||||
@@ -1031,7 +1122,7 @@ async def update_plugin(
|
||||
progress=30,
|
||||
message="正在准备下载新版本...",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
repo_url = request.repository_url.rstrip("/")
|
||||
@@ -1069,14 +1160,14 @@ async def update_plugin(
|
||||
progress=0,
|
||||
message="下载新版本失败",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
# 6. 验证新版本
|
||||
await update_progress(
|
||||
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=request.plugin_id
|
||||
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id
|
||||
)
|
||||
|
||||
new_manifest_path = plugin_path / "_manifest.json"
|
||||
@@ -1096,7 +1187,7 @@ async def update_plugin(
|
||||
progress=0,
|
||||
message="新版本缺少 _manifest.json",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="无效的插件格式",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||
@@ -1107,9 +1198,9 @@ async def update_plugin(
|
||||
new_manifest = json_module.load(f)
|
||||
|
||||
new_version = new_manifest.get("version", "unknown")
|
||||
new_name = new_manifest.get("name", request.plugin_id)
|
||||
new_name = new_manifest.get("name", plugin_id)
|
||||
|
||||
logger.info(f"成功更新插件: {request.plugin_id} {old_version} → {new_version}")
|
||||
logger.info(f"成功更新插件: {plugin_id} {old_version} → {new_version}")
|
||||
|
||||
# 8. 推送成功状态
|
||||
await update_progress(
|
||||
@@ -1117,13 +1208,13 @@ async def update_plugin(
|
||||
progress=100,
|
||||
message=f"成功更新 {new_name}: {old_version} → {new_version}",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "插件更新成功",
|
||||
"plugin_id": request.plugin_id,
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": new_name,
|
||||
"old_version": old_version,
|
||||
"new_version": new_version,
|
||||
@@ -1138,7 +1229,7 @@ async def update_plugin(
|
||||
progress=0,
|
||||
message="_manifest.json 无效",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=str(e),
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||
@@ -1149,7 +1240,7 @@ async def update_plugin(
|
||||
logger.error(f"更新插件失败: {e}", exc_info=True)
|
||||
|
||||
await update_progress(
|
||||
stage="error", progress=0, message="更新失败", operation="update", plugin_id=request.plugin_id, error=str(e)
|
||||
stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e)
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
Reference in New Issue
Block a user