Version: 0.2.1.dev.260210
feat: 🚦 新增基于 Redis 令牌桶的限流中间件 - 使用 Redis 实现令牌桶算法进行限流 🪣 - 覆盖除登录、注册、刷新 token 以外的所有接口 🔒 fix: 🐛 修复任务块添加到日程接口可修改已安排任务时间的问题 - 禁止通过该接口直接修改已安排任务块的时间 - 修正不合理的业务逻辑,保证数据一致性 ✅
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/LoveLosita/smartflow/backend/api"
|
||||
"github.com/LoveLosita/smartflow/backend/dao"
|
||||
"github.com/LoveLosita/smartflow/backend/inits"
|
||||
"github.com/LoveLosita/smartflow/backend/pkg"
|
||||
"github.com/LoveLosita/smartflow/backend/routers"
|
||||
"github.com/LoveLosita/smartflow/backend/service"
|
||||
"github.com/spf13/viper"
|
||||
@@ -39,6 +40,8 @@ func Start() {
|
||||
log.Fatalf("Failed to connect to database: %v", err)
|
||||
}
|
||||
rdb := inits.InitRedis()
|
||||
//工具包
|
||||
limiter := pkg.NewRateLimiter(rdb)
|
||||
//dao 层
|
||||
userRepo := dao.NewUserDAO(db)
|
||||
cacheRepo := dao.NewCacheDAO(rdb)
|
||||
@@ -59,7 +62,6 @@ func Start() {
|
||||
courseApi := api.NewCourseHandler(courseService)
|
||||
taskClassApi := api.NewTaskClassHandler(taskClassService)
|
||||
scheduleApi := api.NewScheduleAPI(scheduleService)
|
||||
|
||||
handlers := &api.ApiHandlers{
|
||||
UserHandler: userApi,
|
||||
TaskHandler: taskApi,
|
||||
@@ -67,6 +69,6 @@ func Start() {
|
||||
CourseHandler: courseApi,
|
||||
ScheduleHandler: scheduleApi,
|
||||
}
|
||||
r := routers.RegisterRouters(handlers, cacheRepo)
|
||||
r := routers.RegisterRouters(handlers, cacheRepo, limiter)
|
||||
routers.StartEngine(r)
|
||||
}
|
||||
|
||||
@@ -195,3 +195,15 @@ func (dao *TaskClassDAO) DeleteTaskClassItemEmbeddedTime(ctx context.Context, ta
|
||||
Update("embedded_time", nil).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (dao *TaskClassDAO) IfTaskClassItemArranged(ctx context.Context, taskID int) (bool, error) {
|
||||
var item model.TaskClassItem
|
||||
err := dao.db.WithContext(ctx).
|
||||
Select("embedded_time").
|
||||
Where("id = ?", taskID).
|
||||
First(&item).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return item.EmbeddedTime != nil, nil
|
||||
}
|
||||
|
||||
38
backend/middleware/rate_limiter.go
Normal file
38
backend/middleware/rate_limiter.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/pkg"
|
||||
"github.com/LoveLosita/smartflow/backend/respond"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RateLimitMiddleware(limiter *pkg.RateLimiter, capacity, rate int) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 1. 确定限流对象:可以用 UserID,也可以用 IP
|
||||
// 这里建议用 UserID,防止某个用户换 IP 疯狂刷
|
||||
userID := c.GetInt("user_id") // 假设你之前的 JWT 已经塞进去了
|
||||
key := fmt.Sprintf("rate_limit:user:%d", userID)
|
||||
|
||||
// 2. 执行限流检查
|
||||
allowed, err := limiter.Allow(c.Request.Context(), key, capacity, rate)
|
||||
if err != nil {
|
||||
// 如果 Redis 挂了,为了保证业务可用,通常选择“放行”并记录日志
|
||||
log.Printf("Redis limiter error: %v", err)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
// 3. 触发限流:直接调你写好的 DealWithError
|
||||
// 可以在 respond 里定义一个新错误:TooManyRequests
|
||||
respond.DealWithError(c, respond.TooManyRequests)
|
||||
c.Abort() // 拦截,不执行后续 Handler
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
64
backend/pkg/rate_limiter.go
Normal file
64
backend/pkg/rate_limiter.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package pkg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
var tokenBucketScript = redis.NewScript(`-- KEYS[1]: 限流标识 (如 rate_limit:user_123)
|
||||
-- ARGV[1]: 令牌桶最大容量 (Capacity)
|
||||
-- ARGV[2]: 令牌填充速率 (Tokens per second)
|
||||
-- ARGV[3]: 当前时间戳 (Current Unix timestamp in seconds)
|
||||
-- ARGV[4]: 请求需要的令牌数 (通常为 1)
|
||||
|
||||
local bucket_info = redis.call("HMGET", KEYS[1], "last_tokens", "last_refreshed")
|
||||
local last_tokens = tonumber(bucket_info[1])
|
||||
local last_refreshed = tonumber(bucket_info[2])
|
||||
|
||||
local capacity = tonumber(ARGV[1])
|
||||
local rate = tonumber(ARGV[2])
|
||||
local now = tonumber(ARGV[3])
|
||||
local requested = tonumber(ARGV[4])
|
||||
|
||||
-- 如果是首次访问,初始化桶
|
||||
if last_tokens == nil then
|
||||
last_tokens = capacity
|
||||
last_refreshed = now
|
||||
end
|
||||
|
||||
-- 💡 核心逻辑:计算这段时间新产生的令牌
|
||||
local delta = math.max(0, now - last_refreshed)
|
||||
local new_tokens = math.min(capacity, last_tokens + (delta * rate))
|
||||
|
||||
local allowed = false
|
||||
if new_tokens >= requested then
|
||||
new_tokens = new_tokens - requested
|
||||
allowed = true
|
||||
end
|
||||
|
||||
-- 更新 Redis 状态
|
||||
redis.call("HMSET", KEYS[1], "last_tokens", new_tokens, "last_refreshed", now)
|
||||
-- 设置过期时间(比如 1 小时没人访问就删掉,省内存)
|
||||
redis.call("EXPIRE", KEYS[1], 3600)
|
||||
|
||||
return allowed and 1 or 0`)
|
||||
|
||||
type RateLimiter struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
func NewRateLimiter(client *redis.Client) *RateLimiter {
|
||||
return &RateLimiter{client: client}
|
||||
}
|
||||
|
||||
func (r *RateLimiter) Allow(ctx context.Context, key string, capacity, rate int) (bool, error) {
|
||||
// 传参:Key, 容量, 速率, 当前时间, 请求数
|
||||
res, err := tokenBucketScript.Run(ctx, r.client, []string{key},
|
||||
capacity, rate, time.Now().Unix(), 1).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return res == 1, nil
|
||||
}
|
||||
@@ -221,4 +221,13 @@ var ( //请求相关的响应
|
||||
Status: "40032",
|
||||
Info: "target schedule does not have embedded task",
|
||||
}
|
||||
TooManyRequests = Response{ //请求过多
|
||||
Status: "40033",
|
||||
Info: "too many requests",
|
||||
}
|
||||
|
||||
TaskClassItemAlreadyArranged = Response{ //任务类项目已安排
|
||||
Status: "40034",
|
||||
Info: "task class item already arranged",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/LoveLosita/smartflow/backend/api"
|
||||
"github.com/LoveLosita/smartflow/backend/dao"
|
||||
"github.com/LoveLosita/smartflow/backend/middleware"
|
||||
"github.com/LoveLosita/smartflow/backend/pkg"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@@ -27,7 +28,7 @@ func StartEngine(r *gin.Engine) {
|
||||
}
|
||||
}
|
||||
|
||||
func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO) *gin.Engine {
|
||||
func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO, limiter *pkg.RateLimiter) *gin.Engine {
|
||||
// 初始化Gin引擎
|
||||
r := gin.Default()
|
||||
// 在这里注册所有的路由和路由组
|
||||
@@ -46,23 +47,23 @@ func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO) *gin.Engine
|
||||
userGroup.POST("/register", handlers.UserHandler.UserRegister)
|
||||
userGroup.POST("/login", handlers.UserHandler.UserLogin)
|
||||
userGroup.POST("/refresh-token", handlers.UserHandler.RefreshTokenHandler)
|
||||
userGroup.POST("/logout", middleware.JWTTokenAuth(cache), handlers.UserHandler.UserLogout)
|
||||
userGroup.POST("/logout", middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1), handlers.UserHandler.UserLogout)
|
||||
}
|
||||
taskGroup := apiGroup.Group("/task")
|
||||
{
|
||||
taskGroup.Use(middleware.JWTTokenAuth(cache))
|
||||
taskGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1))
|
||||
taskGroup.POST("/create", handlers.TaskHandler.AddTask)
|
||||
taskGroup.GET("/get", handlers.TaskHandler.GetUserTasks)
|
||||
}
|
||||
courseGroup := apiGroup.Group("/course")
|
||||
{
|
||||
courseGroup.Use(middleware.JWTTokenAuth(cache))
|
||||
courseGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1))
|
||||
courseGroup.POST("/validate", handlers.CourseHandler.CheckUserCourse)
|
||||
courseGroup.POST("/import", handlers.CourseHandler.AddUserCourses)
|
||||
}
|
||||
taskClassGroup := apiGroup.Group("/task-class")
|
||||
{
|
||||
taskClassGroup.Use(middleware.JWTTokenAuth(cache))
|
||||
taskClassGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1))
|
||||
taskClassGroup.POST("/add", handlers.TaskClassHandler.UserAddTaskClass)
|
||||
taskClassGroup.GET("/list", handlers.TaskClassHandler.UserGetTaskClassInfos)
|
||||
taskClassGroup.GET("/get", handlers.TaskClassHandler.UserGetCompleteTaskClass)
|
||||
@@ -71,7 +72,7 @@ func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO) *gin.Engine
|
||||
}
|
||||
scheduleGroup := apiGroup.Group("/schedule")
|
||||
{
|
||||
scheduleGroup.Use(middleware.JWTTokenAuth(cache))
|
||||
scheduleGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1))
|
||||
scheduleGroup.GET("/today", handlers.ScheduleHandler.GetUserTodaySchedule)
|
||||
scheduleGroup.GET("/week", handlers.ScheduleHandler.GetUserWeeklySchedule)
|
||||
scheduleGroup.DELETE("/delete", handlers.ScheduleHandler.DeleteScheduleEvent)
|
||||
|
||||
@@ -115,7 +115,15 @@ func (sv *TaskClassService) AddTaskClassItemIntoSchedule(ctx context.Context, re
|
||||
if ownerID != userID {
|
||||
return respond.TaskClassItemNotBelongToUser
|
||||
}
|
||||
//2.取出任务块信息
|
||||
//2.再检查任务块本身是否已经被安排
|
||||
result, err := sv.taskClassRepo.IfTaskClassItemArranged(ctx, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if result {
|
||||
return respond.TaskClassItemAlreadyArranged
|
||||
}
|
||||
//3.取出任务块信息
|
||||
taskItem, err := sv.taskClassRepo.GetTaskClassItemByID(ctx, taskID) //通过任务块ID获取任务块信息
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user