Ruff Format
This commit is contained in:
@@ -15,16 +15,16 @@ logger = get_logger("webui.rate_limiter")
|
||||
class RateLimiter:
|
||||
"""
|
||||
简单的内存请求频率限制器
|
||||
|
||||
|
||||
使用滑动窗口算法实现
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
# 存储格式: {key: [(timestamp, count), ...]}
|
||||
self._requests: Dict[str, list] = defaultdict(list)
|
||||
# 被封禁的 IP: {ip: unblock_timestamp}
|
||||
self._blocked: Dict[str, float] = {}
|
||||
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""获取客户端 IP 地址"""
|
||||
# 检查代理头
|
||||
@@ -32,26 +32,23 @@ class RateLimiter:
|
||||
if forwarded:
|
||||
# 取第一个 IP(最原始的客户端)
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
|
||||
# 直接连接的客户端
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _cleanup_old_requests(self, key: str, window_seconds: int):
|
||||
"""清理过期的请求记录"""
|
||||
now = time.time()
|
||||
cutoff = now - window_seconds
|
||||
self._requests[key] = [
|
||||
(ts, count) for ts, count in self._requests[key]
|
||||
if ts > cutoff
|
||||
]
|
||||
|
||||
self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff]
|
||||
|
||||
def _cleanup_expired_blocks(self):
|
||||
"""清理过期的封禁"""
|
||||
now = time.time()
|
||||
@@ -59,65 +56,61 @@ class RateLimiter:
|
||||
for ip in expired:
|
||||
del self._blocked[ip]
|
||||
logger.info(f"🔓 IP {ip} 封禁已解除")
|
||||
|
||||
|
||||
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
|
||||
"""
|
||||
检查 IP 是否被封禁
|
||||
|
||||
|
||||
Returns:
|
||||
(是否被封禁, 剩余封禁秒数)
|
||||
"""
|
||||
self._cleanup_expired_blocks()
|
||||
ip = self._get_client_ip(request)
|
||||
|
||||
|
||||
if ip in self._blocked:
|
||||
remaining = int(self._blocked[ip] - time.time())
|
||||
return True, max(0, remaining)
|
||||
|
||||
|
||||
return False, None
|
||||
|
||||
|
||||
def check_rate_limit(
|
||||
self,
|
||||
request: Request,
|
||||
max_requests: int,
|
||||
window_seconds: int,
|
||||
key_suffix: str = ""
|
||||
self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = ""
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
检查请求是否超过频率限制
|
||||
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
max_requests: 窗口期内允许的最大请求数
|
||||
window_seconds: 窗口时间(秒)
|
||||
key_suffix: 键后缀,用于区分不同的限制规则
|
||||
|
||||
|
||||
Returns:
|
||||
(是否允许, 剩余请求数)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:{key_suffix}" if key_suffix else ip
|
||||
|
||||
|
||||
# 清理过期记录
|
||||
self._cleanup_old_requests(key, window_seconds)
|
||||
|
||||
|
||||
# 计算当前窗口内的请求数
|
||||
current_count = sum(count for _, count in self._requests[key])
|
||||
|
||||
|
||||
if current_count >= max_requests:
|
||||
return False, 0
|
||||
|
||||
|
||||
# 记录新请求
|
||||
now = time.time()
|
||||
self._requests[key].append((now, 1))
|
||||
|
||||
|
||||
remaining = max_requests - current_count - 1
|
||||
return True, remaining
|
||||
|
||||
|
||||
def block_ip(self, request: Request, duration_seconds: int):
|
||||
"""
|
||||
封禁 IP
|
||||
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
duration_seconds: 封禁时长(秒)
|
||||
@@ -125,55 +118,51 @@ class RateLimiter:
|
||||
ip = self._get_client_ip(request)
|
||||
self._blocked[ip] = time.time() + duration_seconds
|
||||
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds} 秒")
|
||||
|
||||
|
||||
def record_failed_attempt(
|
||||
self,
|
||||
request: Request,
|
||||
max_failures: int = 5,
|
||||
window_seconds: int = 300,
|
||||
block_duration: int = 600
|
||||
self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
记录失败尝试(如登录失败)
|
||||
|
||||
|
||||
如果在窗口期内失败次数过多,自动封禁 IP
|
||||
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
max_failures: 允许的最大失败次数
|
||||
window_seconds: 统计窗口(秒)
|
||||
block_duration: 封禁时长(秒)
|
||||
|
||||
|
||||
Returns:
|
||||
(是否被封禁, 剩余尝试次数)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:auth_failures"
|
||||
|
||||
|
||||
# 清理过期记录
|
||||
self._cleanup_old_requests(key, window_seconds)
|
||||
|
||||
|
||||
# 计算当前失败次数
|
||||
current_failures = sum(count for _, count in self._requests[key])
|
||||
|
||||
|
||||
# 记录本次失败
|
||||
now = time.time()
|
||||
self._requests[key].append((now, 1))
|
||||
current_failures += 1
|
||||
|
||||
|
||||
remaining = max_failures - current_failures
|
||||
|
||||
|
||||
# 检查是否需要封禁
|
||||
if current_failures >= max_failures:
|
||||
self.block_ip(request, block_duration)
|
||||
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
|
||||
return True, 0
|
||||
|
||||
|
||||
if current_failures >= max_failures - 2:
|
||||
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures} 次")
|
||||
|
||||
|
||||
return False, max(0, remaining)
|
||||
|
||||
|
||||
def reset_failures(self, request: Request):
|
||||
"""
|
||||
重置失败计数(认证成功后调用)
|
||||
@@ -199,66 +188,58 @@ def get_rate_limiter() -> RateLimiter:
|
||||
async def check_auth_rate_limit(request: Request):
|
||||
"""
|
||||
认证接口的频率限制依赖
|
||||
|
||||
|
||||
规则:
|
||||
- 每个 IP 每分钟最多 10 次认证请求
|
||||
- 连续失败 5 次后封禁 10 分钟
|
||||
"""
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
|
||||
# 检查是否被封禁
|
||||
blocked, remaining_block = limiter.is_blocked(request)
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||
headers={"Retry-After": str(remaining_block)}
|
||||
headers={"Retry-After": str(remaining_block)},
|
||||
)
|
||||
|
||||
|
||||
# 检查频率限制
|
||||
allowed, remaining = limiter.check_rate_limit(
|
||||
request,
|
||||
request,
|
||||
max_requests=10, # 每分钟 10 次
|
||||
window_seconds=60,
|
||||
key_suffix="auth"
|
||||
key_suffix="auth",
|
||||
)
|
||||
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="认证请求过于频繁,请稍后重试",
|
||||
headers={"Retry-After": "60"}
|
||||
)
|
||||
raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
|
||||
|
||||
|
||||
async def check_api_rate_limit(request: Request):
|
||||
"""
|
||||
普通 API 的频率限制依赖
|
||||
|
||||
|
||||
规则:每个 IP 每分钟最多 100 次请求
|
||||
"""
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
|
||||
# 检查是否被封禁
|
||||
blocked, remaining_block = limiter.is_blocked(request)
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||
headers={"Retry-After": str(remaining_block)}
|
||||
headers={"Retry-After": str(remaining_block)},
|
||||
)
|
||||
|
||||
|
||||
# 检查频率限制
|
||||
allowed, _ = limiter.check_rate_limit(
|
||||
request,
|
||||
request,
|
||||
max_requests=100, # 每分钟 100 次
|
||||
window_seconds=60,
|
||||
key_suffix="api"
|
||||
key_suffix="api",
|
||||
)
|
||||
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="请求过于频繁,请稍后重试",
|
||||
headers={"Retry-After": "60"}
|
||||
)
|
||||
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
|
||||
|
||||
Reference in New Issue
Block a user