""" WebUI 请求频率限制模块 防止暴力破解和 API 滥用 """ import time from collections import defaultdict from typing import Dict, Tuple, Optional from fastapi import Request, HTTPException from src.common.logger import get_logger logger = get_logger("webui.rate_limiter") class RateLimiter: """ 简单的内存请求频率限制器 使用滑动窗口算法实现 """ def __init__(self): # 存储格式: {key: [(timestamp, count), ...]} self._requests: Dict[str, list] = defaultdict(list) # 被封禁的 IP: {ip: unblock_timestamp} self._blocked: Dict[str, float] = {} def _get_client_ip(self, request: Request) -> str: """获取客户端 IP 地址""" # 检查代理头 forwarded = request.headers.get("X-Forwarded-For") if forwarded: # 取第一个 IP(最原始的客户端) return forwarded.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip # 直接连接的客户端 if request.client: return request.client.host return "unknown" def _cleanup_old_requests(self, key: str, window_seconds: int): """清理过期的请求记录""" now = time.time() cutoff = now - window_seconds self._requests[key] = [ (ts, count) for ts, count in self._requests[key] if ts > cutoff ] def _cleanup_expired_blocks(self): """清理过期的封禁""" now = time.time() expired = [ip for ip, unblock_time in self._blocked.items() if now > unblock_time] for ip in expired: del self._blocked[ip] logger.info(f"🔓 IP {ip} 封禁已解除") def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]: """ 检查 IP 是否被封禁 Returns: (是否被封禁, 剩余封禁秒数) """ self._cleanup_expired_blocks() ip = self._get_client_ip(request) if ip in self._blocked: remaining = int(self._blocked[ip] - time.time()) return True, max(0, remaining) return False, None def check_rate_limit( self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = "" ) -> Tuple[bool, int]: """ 检查请求是否超过频率限制 Args: request: FastAPI Request 对象 max_requests: 窗口期内允许的最大请求数 window_seconds: 窗口时间(秒) key_suffix: 键后缀,用于区分不同的限制规则 Returns: (是否允许, 剩余请求数) """ ip = self._get_client_ip(request) key = f"{ip}:{key_suffix}" if key_suffix else ip # 清理过期记录 self._cleanup_old_requests(key, window_seconds) # 计算当前窗口内的请求数 current_count = sum(count for _, count in self._requests[key]) if current_count >= max_requests: return False, 0 # 记录新请求 now = time.time() self._requests[key].append((now, 1)) remaining = max_requests - current_count - 1 return True, remaining def block_ip(self, request: Request, duration_seconds: int): """ 封禁 IP Args: request: FastAPI Request 对象 duration_seconds: 封禁时长(秒) """ ip = self._get_client_ip(request) self._blocked[ip] = time.time() + duration_seconds logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds} 秒") def record_failed_attempt( self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600 ) -> Tuple[bool, int]: """ 记录失败尝试(如登录失败) 如果在窗口期内失败次数过多,自动封禁 IP Args: request: FastAPI Request 对象 max_failures: 允许的最大失败次数 window_seconds: 统计窗口(秒) block_duration: 封禁时长(秒) Returns: (是否被封禁, 剩余尝试次数) """ ip = self._get_client_ip(request) key = f"{ip}:auth_failures" # 清理过期记录 self._cleanup_old_requests(key, window_seconds) # 计算当前失败次数 current_failures = sum(count for _, count in self._requests[key]) # 记录本次失败 now = time.time() self._requests[key].append((now, 1)) current_failures += 1 remaining = max_failures - current_failures # 检查是否需要封禁 if current_failures >= max_failures: self.block_ip(request, block_duration) logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁") return True, 0 if current_failures >= max_failures - 2: logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures} 次") return False, max(0, remaining) def reset_failures(self, request: Request): """ 重置失败计数(认证成功后调用) """ ip = self._get_client_ip(request) key = f"{ip}:auth_failures" if key in self._requests: del self._requests[key] # 全局单例 _rate_limiter: Optional[RateLimiter] = None def get_rate_limiter() -> RateLimiter: """获取 RateLimiter 单例""" global _rate_limiter if _rate_limiter is None: _rate_limiter = RateLimiter() return _rate_limiter async def check_auth_rate_limit(request: Request): """ 认证接口的频率限制依赖 规则: - 每个 IP 每分钟最多 10 次认证请求 - 连续失败 5 次后封禁 10 分钟 """ limiter = get_rate_limiter() # 检查是否被封禁 blocked, remaining_block = limiter.is_blocked(request) if blocked: raise HTTPException( status_code=429, detail=f"请求过于频繁,请在 {remaining_block} 秒后重试", headers={"Retry-After": str(remaining_block)} ) # 检查频率限制 allowed, remaining = limiter.check_rate_limit( request, max_requests=10, # 每分钟 10 次 window_seconds=60, key_suffix="auth" ) if not allowed: raise HTTPException( status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"} ) async def check_api_rate_limit(request: Request): """ 普通 API 的频率限制依赖 规则:每个 IP 每分钟最多 100 次请求 """ limiter = get_rate_limiter() # 检查是否被封禁 blocked, remaining_block = limiter.is_blocked(request) if blocked: raise HTTPException( status_code=429, detail=f"请求过于频繁,请在 {remaining_block} 秒后重试", headers={"Retry-After": str(remaining_block)} ) # 检查频率限制 allowed, _ = limiter.check_rate_limit( request, max_requests=100, # 每分钟 100 次 window_seconds=60, key_suffix="api" ) if not allowed: raise HTTPException( status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"} )