Ruff Format

This commit is contained in:
墨梓柒
2025-12-14 21:39:09 +08:00
parent ee4cb3dc67
commit 74a2f4346a
11 changed files with 176 additions and 197 deletions

View File

@@ -15,16 +15,16 @@ logger = get_logger("webui.rate_limiter")
class RateLimiter:
"""
简单的内存请求频率限制器
使用滑动窗口算法实现
"""
def __init__(self):
# 存储格式: {key: [(timestamp, count), ...]}
self._requests: Dict[str, list] = defaultdict(list)
# 被封禁的 IP: {ip: unblock_timestamp}
self._blocked: Dict[str, float] = {}
def _get_client_ip(self, request: Request) -> str:
"""获取客户端 IP 地址"""
# 检查代理头
@@ -32,26 +32,23 @@ class RateLimiter:
if forwarded:
# 取第一个 IP最原始的客户端
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# 直接连接的客户端
if request.client:
return request.client.host
return "unknown"
def _cleanup_old_requests(self, key: str, window_seconds: int):
"""清理过期的请求记录"""
now = time.time()
cutoff = now - window_seconds
self._requests[key] = [
(ts, count) for ts, count in self._requests[key]
if ts > cutoff
]
self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff]
def _cleanup_expired_blocks(self):
"""清理过期的封禁"""
now = time.time()
@@ -59,65 +56,61 @@ class RateLimiter:
for ip in expired:
del self._blocked[ip]
logger.info(f"🔓 IP {ip} 封禁已解除")
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
"""
检查 IP 是否被封禁
Returns:
(是否被封禁, 剩余封禁秒数)
"""
self._cleanup_expired_blocks()
ip = self._get_client_ip(request)
if ip in self._blocked:
remaining = int(self._blocked[ip] - time.time())
return True, max(0, remaining)
return False, None
def check_rate_limit(
self,
request: Request,
max_requests: int,
window_seconds: int,
key_suffix: str = ""
self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = ""
) -> Tuple[bool, int]:
"""
检查请求是否超过频率限制
Args:
request: FastAPI Request 对象
max_requests: 窗口期内允许的最大请求数
window_seconds: 窗口时间(秒)
key_suffix: 键后缀,用于区分不同的限制规则
Returns:
(是否允许, 剩余请求数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:{key_suffix}" if key_suffix else ip
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前窗口内的请求数
current_count = sum(count for _, count in self._requests[key])
if current_count >= max_requests:
return False, 0
# 记录新请求
now = time.time()
self._requests[key].append((now, 1))
remaining = max_requests - current_count - 1
return True, remaining
def block_ip(self, request: Request, duration_seconds: int):
"""
封禁 IP
Args:
request: FastAPI Request 对象
duration_seconds: 封禁时长(秒)
@@ -125,55 +118,51 @@ class RateLimiter:
ip = self._get_client_ip(request)
self._blocked[ip] = time.time() + duration_seconds
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds}")
def record_failed_attempt(
self,
request: Request,
max_failures: int = 5,
window_seconds: int = 300,
block_duration: int = 600
self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600
) -> Tuple[bool, int]:
"""
记录失败尝试(如登录失败)
如果在窗口期内失败次数过多,自动封禁 IP
Args:
request: FastAPI Request 对象
max_failures: 允许的最大失败次数
window_seconds: 统计窗口(秒)
block_duration: 封禁时长(秒)
Returns:
(是否被封禁, 剩余尝试次数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:auth_failures"
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前失败次数
current_failures = sum(count for _, count in self._requests[key])
# 记录本次失败
now = time.time()
self._requests[key].append((now, 1))
current_failures += 1
remaining = max_failures - current_failures
# 检查是否需要封禁
if current_failures >= max_failures:
self.block_ip(request, block_duration)
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
return True, 0
if current_failures >= max_failures - 2:
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures}")
return False, max(0, remaining)
def reset_failures(self, request: Request):
"""
重置失败计数(认证成功后调用)
@@ -199,66 +188,58 @@ def get_rate_limiter() -> RateLimiter:
async def check_auth_rate_limit(request: Request):
"""
认证接口的频率限制依赖
规则:
- 每个 IP 每分钟最多 10 次认证请求
- 连续失败 5 次后封禁 10 分钟
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)}
headers={"Retry-After": str(remaining_block)},
)
# 检查频率限制
allowed, remaining = limiter.check_rate_limit(
request,
request,
max_requests=10, # 每分钟 10 次
window_seconds=60,
key_suffix="auth"
key_suffix="auth",
)
if not allowed:
raise HTTPException(
status_code=429,
detail="认证请求过于频繁,请稍后重试",
headers={"Retry-After": "60"}
)
raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
async def check_api_rate_limit(request: Request):
"""
普通 API 的频率限制依赖
规则:每个 IP 每分钟最多 100 次请求
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)}
headers={"Retry-After": str(remaining_block)},
)
# 检查频率限制
allowed, _ = limiter.check_rate_limit(
request,
request,
max_requests=100, # 每分钟 100 次
window_seconds=60,
key_suffix="api"
key_suffix="api",
)
if not allowed:
raise HTTPException(
status_code=429,
detail="请求过于频繁,请稍后重试",
headers={"Retry-After": "60"}
)
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"})