WebUI后端整体重构
This commit is contained in:
30
src/webui/core/__init__.py
Normal file
30
src/webui/core/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from .security import TokenManager, get_token_manager
|
||||
from .rate_limiter import (
|
||||
RateLimiter,
|
||||
get_rate_limiter,
|
||||
check_auth_rate_limit,
|
||||
check_api_rate_limit,
|
||||
)
|
||||
from .auth import (
|
||||
COOKIE_NAME,
|
||||
COOKIE_MAX_AGE,
|
||||
get_current_token,
|
||||
set_auth_cookie,
|
||||
clear_auth_cookie,
|
||||
verify_auth_token_from_cookie_or_header,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TokenManager",
|
||||
"get_token_manager",
|
||||
"RateLimiter",
|
||||
"get_rate_limiter",
|
||||
"check_auth_rate_limit",
|
||||
"check_api_rate_limit",
|
||||
"COOKIE_NAME",
|
||||
"COOKIE_MAX_AGE",
|
||||
"get_current_token",
|
||||
"set_auth_cookie",
|
||||
"clear_auth_cookie",
|
||||
"verify_auth_token_from_cookie_or_header",
|
||||
]
|
||||
185
src/webui/core/auth.py
Normal file
185
src/webui/core/auth.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
WebUI 认证模块
|
||||
提供统一的认证依赖,支持 Cookie 和 Header 两种方式
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Cookie, Header, Response, Request
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .security import get_token_manager
|
||||
|
||||
logger = get_logger("webui.auth")
|
||||
|
||||
# Cookie 配置
|
||||
COOKIE_NAME = "maibot_session"
|
||||
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
|
||||
|
||||
|
||||
def _is_secure_environment() -> bool:
|
||||
"""
|
||||
检测是否应该启用安全 Cookie(HTTPS)
|
||||
|
||||
Returns:
|
||||
bool: 如果应该使用 secure cookie 则返回 True
|
||||
"""
|
||||
# 从配置读取
|
||||
if global_config.webui.secure_cookie:
|
||||
logger.info("配置中启用了 secure_cookie")
|
||||
return True
|
||||
|
||||
# 检查是否是生产环境
|
||||
if global_config.webui.mode == "production":
|
||||
logger.info("WebUI运行在生产模式,启用 secure cookie")
|
||||
return True
|
||||
|
||||
# 默认:开发环境不启用(因为通常是 HTTP)
|
||||
logger.debug("WebUI运行在开发模式,禁用 secure cookie")
|
||||
return False
|
||||
|
||||
|
||||
def get_current_token(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> str:
|
||||
"""
|
||||
获取当前请求的 token,优先从 Cookie 获取,其次从 Header 获取
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization Header (Bearer token)
|
||||
|
||||
Returns:
|
||||
验证通过的 token
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def set_auth_cookie(response: Response, token: str, request: Optional[Request] = None) -> None:
|
||||
"""
|
||||
设置认证 Cookie
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
token: 要设置的 token
|
||||
request: FastAPI Request 对象(可选,用于检测协议)
|
||||
"""
|
||||
# 根据环境和实际请求协议决定安全设置
|
||||
is_secure = _is_secure_environment()
|
||||
|
||||
# 如果提供了 request,检测实际使用的协议
|
||||
if request:
|
||||
# 检查 X-Forwarded-Proto header(代理/负载均衡器)
|
||||
forwarded_proto = request.headers.get("x-forwarded-proto", "").lower()
|
||||
if forwarded_proto:
|
||||
is_https = forwarded_proto == "https"
|
||||
logger.debug(f"检测到 X-Forwarded-Proto: {forwarded_proto}, is_https={is_https}")
|
||||
else:
|
||||
# 检查 request.url.scheme
|
||||
is_https = request.url.scheme == "https"
|
||||
logger.debug(f"检测到 scheme: {request.url.scheme}, is_https={is_https}")
|
||||
|
||||
# 如果是 HTTP 连接,强制禁用 secure 标志
|
||||
if not is_https and is_secure:
|
||||
logger.warning("=" * 80)
|
||||
logger.warning("检测到 HTTP 连接但环境配置要求 HTTPS (secure cookie)")
|
||||
logger.warning("已自动禁用 secure 标志以允许登录,但建议修改配置:")
|
||||
logger.warning("1. 在配置文件中设置: webui.secure_cookie = false")
|
||||
logger.warning("2. 如果使用反向代理,请确保正确配置 X-Forwarded-Proto 头")
|
||||
logger.warning("=" * 80)
|
||||
is_secure = False
|
||||
|
||||
# 设置 Cookie
|
||||
response.set_cookie(
|
||||
key=COOKIE_NAME,
|
||||
value=token,
|
||||
max_age=COOKIE_MAX_AGE,
|
||||
httponly=True, # 防止 JS 读取,阻止 XSS 窃取
|
||||
samesite="lax", # 使用 lax 以兼容更多场景(开发和生产)
|
||||
secure=is_secure, # 根据实际协议决定
|
||||
path="/", # 确保 Cookie 在所有路径下可用
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"已设置认证 Cookie: {token[:8]}... (secure={is_secure}, samesite=lax, httponly=True, path=/, max_age={COOKIE_MAX_AGE})"
|
||||
)
|
||||
logger.debug(f"完整 token 前缀: {token[:20]}...")
|
||||
|
||||
|
||||
def clear_auth_cookie(response: Response) -> None:
|
||||
"""
|
||||
清除认证 Cookie
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
"""
|
||||
# 保持与 set_auth_cookie 相同的安全设置
|
||||
is_secure = _is_secure_environment()
|
||||
|
||||
response.delete_cookie(
|
||||
key=COOKIE_NAME,
|
||||
httponly=True,
|
||||
samesite="strict" if is_secure else "lax",
|
||||
secure=is_secure,
|
||||
path="/",
|
||||
)
|
||||
logger.debug("已清除认证 Cookie")
|
||||
|
||||
|
||||
def verify_auth_token_from_cookie_or_header(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
验证认证 Token,支持从 Cookie 或 Header 获取
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
验证成功返回 True
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
245
src/webui/core/rate_limiter.py
Normal file
245
src/webui/core/rate_limiter.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
WebUI 请求频率限制模块
|
||||
防止暴力破解和 API 滥用
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Tuple, Optional
|
||||
from fastapi import Request, HTTPException
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.rate_limiter")
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
简单的内存请求频率限制器
|
||||
|
||||
使用滑动窗口算法实现
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 存储格式: {key: [(timestamp, count), ...]}
|
||||
self._requests: Dict[str, list] = defaultdict(list)
|
||||
# 被封禁的 IP: {ip: unblock_timestamp}
|
||||
self._blocked: Dict[str, float] = {}
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""获取客户端 IP 地址"""
|
||||
# 检查代理头
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
# 取第一个 IP(最原始的客户端)
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# 直接连接的客户端
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _cleanup_old_requests(self, key: str, window_seconds: int):
|
||||
"""清理过期的请求记录"""
|
||||
now = time.time()
|
||||
cutoff = now - window_seconds
|
||||
self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff]
|
||||
|
||||
def _cleanup_expired_blocks(self):
|
||||
"""清理过期的封禁"""
|
||||
now = time.time()
|
||||
expired = [ip for ip, unblock_time in self._blocked.items() if now > unblock_time]
|
||||
for ip in expired:
|
||||
del self._blocked[ip]
|
||||
logger.info(f"🔓 IP {ip} 封禁已解除")
|
||||
|
||||
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
|
||||
"""
|
||||
检查 IP 是否被封禁
|
||||
|
||||
Returns:
|
||||
(是否被封禁, 剩余封禁秒数)
|
||||
"""
|
||||
self._cleanup_expired_blocks()
|
||||
ip = self._get_client_ip(request)
|
||||
|
||||
if ip in self._blocked:
|
||||
remaining = int(self._blocked[ip] - time.time())
|
||||
return True, max(0, remaining)
|
||||
|
||||
return False, None
|
||||
|
||||
def check_rate_limit(
|
||||
self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = ""
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
检查请求是否超过频率限制
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
max_requests: 窗口期内允许的最大请求数
|
||||
window_seconds: 窗口时间(秒)
|
||||
key_suffix: 键后缀,用于区分不同的限制规则
|
||||
|
||||
Returns:
|
||||
(是否允许, 剩余请求数)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:{key_suffix}" if key_suffix else ip
|
||||
|
||||
# 清理过期记录
|
||||
self._cleanup_old_requests(key, window_seconds)
|
||||
|
||||
# 计算当前窗口内的请求数
|
||||
current_count = sum(count for _, count in self._requests[key])
|
||||
|
||||
if current_count >= max_requests:
|
||||
return False, 0
|
||||
|
||||
# 记录新请求
|
||||
now = time.time()
|
||||
self._requests[key].append((now, 1))
|
||||
|
||||
remaining = max_requests - current_count - 1
|
||||
return True, remaining
|
||||
|
||||
def block_ip(self, request: Request, duration_seconds: int):
|
||||
"""
|
||||
封禁 IP
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
duration_seconds: 封禁时长(秒)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
self._blocked[ip] = time.time() + duration_seconds
|
||||
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds} 秒")
|
||||
|
||||
def record_failed_attempt(
|
||||
self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
记录失败尝试(如登录失败)
|
||||
|
||||
如果在窗口期内失败次数过多,自动封禁 IP
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
max_failures: 允许的最大失败次数
|
||||
window_seconds: 统计窗口(秒)
|
||||
block_duration: 封禁时长(秒)
|
||||
|
||||
Returns:
|
||||
(是否被封禁, 剩余尝试次数)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:auth_failures"
|
||||
|
||||
# 清理过期记录
|
||||
self._cleanup_old_requests(key, window_seconds)
|
||||
|
||||
# 计算当前失败次数
|
||||
current_failures = sum(count for _, count in self._requests[key])
|
||||
|
||||
# 记录本次失败
|
||||
now = time.time()
|
||||
self._requests[key].append((now, 1))
|
||||
current_failures += 1
|
||||
|
||||
remaining = max_failures - current_failures
|
||||
|
||||
# 检查是否需要封禁
|
||||
if current_failures >= max_failures:
|
||||
self.block_ip(request, block_duration)
|
||||
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
|
||||
return True, 0
|
||||
|
||||
if current_failures >= max_failures - 2:
|
||||
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures} 次")
|
||||
|
||||
return False, max(0, remaining)
|
||||
|
||||
def reset_failures(self, request: Request):
|
||||
"""
|
||||
重置失败计数(认证成功后调用)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:auth_failures"
|
||||
if key in self._requests:
|
||||
del self._requests[key]
|
||||
|
||||
|
||||
# 全局单例
|
||||
_rate_limiter: Optional[RateLimiter] = None
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""获取 RateLimiter 单例"""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RateLimiter()
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
async def check_auth_rate_limit(request: Request):
|
||||
"""
|
||||
认证接口的频率限制依赖
|
||||
|
||||
规则:
|
||||
- 每个 IP 每分钟最多 10 次认证请求
|
||||
- 连续失败 5 次后封禁 10 分钟
|
||||
"""
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
# 检查是否被封禁
|
||||
blocked, remaining_block = limiter.is_blocked(request)
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||
headers={"Retry-After": str(remaining_block)},
|
||||
)
|
||||
|
||||
# 检查频率限制
|
||||
allowed, remaining = limiter.check_rate_limit(
|
||||
request,
|
||||
max_requests=10, # 每分钟 10 次
|
||||
window_seconds=60,
|
||||
key_suffix="auth",
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
|
||||
|
||||
|
||||
async def check_api_rate_limit(request: Request):
|
||||
"""
|
||||
普通 API 的频率限制依赖
|
||||
|
||||
规则:每个 IP 每分钟最多 100 次请求
|
||||
"""
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
# 检查是否被封禁
|
||||
blocked, remaining_block = limiter.is_blocked(request)
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||
headers={"Retry-After": str(remaining_block)},
|
||||
)
|
||||
|
||||
# 检查频率限制
|
||||
allowed, _ = limiter.check_rate_limit(
|
||||
request,
|
||||
max_requests=100, # 每分钟 100 次
|
||||
window_seconds=60,
|
||||
key_suffix="api",
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
|
||||
309
src/webui/core/security.py
Normal file
309
src/webui/core/security.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
WebUI Token 管理模块
|
||||
负责生成、保存、验证和更新访问令牌
|
||||
"""
|
||||
|
||||
import json
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""Token 管理器"""
|
||||
|
||||
def __init__(self, config_path: Optional[Path] = None):
|
||||
"""
|
||||
初始化 Token 管理器
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径,默认为项目根目录的 data/webui.json
|
||||
"""
|
||||
if config_path is None:
|
||||
# 获取项目根目录 (src/webui/core -> src/webui -> src -> 根目录)
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
config_path = project_root / "data" / "webui.json"
|
||||
|
||||
self.config_path = config_path
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 确保配置文件存在并包含有效的 token
|
||||
self._ensure_config()
|
||||
|
||||
def _ensure_config(self):
|
||||
"""确保配置文件存在且包含有效的 token"""
|
||||
if not self.config_path.exists():
|
||||
logger.info(f"WebUI 配置文件不存在,正在创建: {self.config_path}")
|
||||
self._create_new_token()
|
||||
else:
|
||||
# 验证配置文件格式
|
||||
try:
|
||||
config = self._load_config()
|
||||
if not config.get("access_token"):
|
||||
logger.warning("WebUI 配置文件中缺少 access_token,正在重新生成")
|
||||
self._create_new_token()
|
||||
else:
|
||||
logger.info(f"WebUI Token 已加载: {config['access_token'][:8]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"读取 WebUI 配置文件失败: {e},正在重新创建")
|
||||
self._create_new_token()
|
||||
|
||||
def _load_config(self) -> dict:
|
||||
"""加载配置文件"""
|
||||
try:
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载 WebUI 配置失败: {e}")
|
||||
return {}
|
||||
|
||||
def _save_config(self, config: dict):
|
||||
"""保存配置文件"""
|
||||
try:
|
||||
with open(self.config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"WebUI 配置已保存到: {self.config_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存 WebUI 配置失败: {e}")
|
||||
raise
|
||||
|
||||
def _create_new_token(self) -> str:
|
||||
"""生成新的 64 位随机 token"""
|
||||
# 生成 64 位十六进制字符串 (32 字节 = 64 hex 字符)
|
||||
token = secrets.token_hex(32)
|
||||
|
||||
config = {
|
||||
"access_token": token,
|
||||
"created_at": self._get_current_timestamp(),
|
||||
"updated_at": self._get_current_timestamp(),
|
||||
"first_setup_completed": False, # 标记首次配置未完成
|
||||
}
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"新的 WebUI Token 已生成: {token[:8]}...")
|
||||
|
||||
return token
|
||||
|
||||
def _get_current_timestamp(self) -> str:
|
||||
"""获取当前时间戳字符串"""
|
||||
from datetime import datetime
|
||||
|
||||
return datetime.now().isoformat()
|
||||
|
||||
def get_token(self) -> str:
|
||||
"""获取当前有效的 token"""
|
||||
config = self._load_config()
|
||||
return config.get("access_token", "")
|
||||
|
||||
def verify_token(self, token: str) -> bool:
|
||||
"""
|
||||
验证 token 是否有效
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
Returns:
|
||||
bool: token 是否有效
|
||||
"""
|
||||
if not token:
|
||||
return False
|
||||
|
||||
current_token = self.get_token()
|
||||
if not current_token:
|
||||
logger.error("系统中没有有效的 token")
|
||||
return False
|
||||
|
||||
# 使用 secrets.compare_digest 防止时序攻击
|
||||
is_valid = secrets.compare_digest(token, current_token)
|
||||
|
||||
if is_valid:
|
||||
logger.debug("Token 验证成功")
|
||||
else:
|
||||
logger.warning("Token 验证失败")
|
||||
|
||||
return is_valid
|
||||
|
||||
def update_token(self, new_token: str) -> tuple[bool, str]:
|
||||
"""
|
||||
更新 token
|
||||
|
||||
Args:
|
||||
new_token: 新的 token (最少 10 位,必须包含大小写字母和特殊符号)
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: (是否更新成功, 错误消息)
|
||||
"""
|
||||
# 验证新 token 格式
|
||||
is_valid, error_msg = self._validate_custom_token(new_token)
|
||||
if not is_valid:
|
||||
logger.error(f"Token 格式无效: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
try:
|
||||
config = self._load_config()
|
||||
old_token = config.get("access_token", "")[:8]
|
||||
|
||||
config["access_token"] = new_token
|
||||
config["updated_at"] = self._get_current_timestamp()
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"Token 已更新: {old_token}... -> {new_token[:8]}...")
|
||||
|
||||
return True, "Token 更新成功"
|
||||
except Exception as e:
|
||||
logger.error(f"更新 Token 失败: {e}")
|
||||
return False, f"更新失败: {str(e)}"
|
||||
|
||||
def regenerate_token(self) -> str:
|
||||
"""
|
||||
重新生成 token(保留 first_setup_completed 状态)
|
||||
|
||||
Returns:
|
||||
str: 新生成的 token
|
||||
"""
|
||||
logger.info("正在重新生成 WebUI Token...")
|
||||
|
||||
# 生成新的 64 位十六进制字符串
|
||||
new_token = secrets.token_hex(32)
|
||||
|
||||
# 加载现有配置,保留 first_setup_completed 状态
|
||||
config = self._load_config()
|
||||
old_token = config.get("access_token", "")[:8] if config.get("access_token") else "无"
|
||||
first_setup_completed = config.get("first_setup_completed", True) # 默认为 True,表示已完成配置
|
||||
|
||||
config["access_token"] = new_token
|
||||
config["updated_at"] = self._get_current_timestamp()
|
||||
config["first_setup_completed"] = first_setup_completed # 保留原来的状态
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"WebUI Token 已重新生成: {old_token}... -> {new_token[:8]}...")
|
||||
|
||||
return new_token
|
||||
|
||||
def _validate_token_format(self, token: str) -> bool:
|
||||
"""
|
||||
验证 token 格式是否正确(旧的 64 位十六进制验证,保留用于系统生成的 token)
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
Returns:
|
||||
bool: 格式是否正确
|
||||
"""
|
||||
if not token or not isinstance(token, str):
|
||||
return False
|
||||
|
||||
# 必须是 64 位十六进制字符串
|
||||
if len(token) != 64:
|
||||
return False
|
||||
|
||||
# 验证是否为有效的十六进制字符串
|
||||
try:
|
||||
int(token, 16)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _validate_custom_token(self, token: str) -> tuple[bool, str]:
|
||||
"""
|
||||
验证自定义 token 格式
|
||||
|
||||
要求:
|
||||
- 最少 10 位
|
||||
- 包含大写字母
|
||||
- 包含小写字母
|
||||
- 包含特殊符号
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: (是否有效, 错误消息)
|
||||
"""
|
||||
if not token or not isinstance(token, str):
|
||||
return False, "Token 不能为空"
|
||||
|
||||
# 检查长度
|
||||
if len(token) < 10:
|
||||
return False, "Token 长度至少为 10 位"
|
||||
|
||||
# 检查是否包含大写字母
|
||||
has_upper = any(c.isupper() for c in token)
|
||||
if not has_upper:
|
||||
return False, "Token 必须包含大写字母"
|
||||
|
||||
# 检查是否包含小写字母
|
||||
has_lower = any(c.islower() for c in token)
|
||||
if not has_lower:
|
||||
return False, "Token 必须包含小写字母"
|
||||
|
||||
# 检查是否包含特殊符号
|
||||
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?/"
|
||||
has_special = any(c in special_chars for c in token)
|
||||
if not has_special:
|
||||
return False, f"Token 必须包含特殊符号 ({special_chars})"
|
||||
|
||||
return True, "Token 格式正确"
|
||||
|
||||
def is_first_setup(self) -> bool:
|
||||
"""
|
||||
检查是否为首次配置
|
||||
|
||||
Returns:
|
||||
bool: 是否为首次配置
|
||||
"""
|
||||
config = self._load_config()
|
||||
return not config.get("first_setup_completed", False)
|
||||
|
||||
def mark_setup_completed(self) -> bool:
|
||||
"""
|
||||
标记首次配置已完成
|
||||
|
||||
Returns:
|
||||
bool: 是否标记成功
|
||||
"""
|
||||
try:
|
||||
config = self._load_config()
|
||||
config["first_setup_completed"] = True
|
||||
config["setup_completed_at"] = self._get_current_timestamp()
|
||||
self._save_config(config)
|
||||
logger.info("首次配置已标记为完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"标记首次配置完成失败: {e}")
|
||||
return False
|
||||
|
||||
def reset_setup_status(self) -> bool:
|
||||
"""
|
||||
重置首次配置状态,允许重新进入配置向导
|
||||
|
||||
Returns:
|
||||
bool: 是否重置成功
|
||||
"""
|
||||
try:
|
||||
config = self._load_config()
|
||||
config["first_setup_completed"] = False
|
||||
if "setup_completed_at" in config:
|
||||
del config["setup_completed_at"]
|
||||
self._save_config(config)
|
||||
logger.info("首次配置状态已重置")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"重置首次配置状态失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 全局单例
|
||||
_token_manager_instance: Optional[TokenManager] = None
|
||||
|
||||
|
||||
def get_token_manager() -> TokenManager:
|
||||
"""获取 TokenManager 单例"""
|
||||
global _token_manager_instance
|
||||
if _token_manager_instance is None:
|
||||
_token_manager_instance = TokenManager()
|
||||
return _token_manager_instance
|
||||
Reference in New Issue
Block a user