From d5f0b8da63e8b579883688d24cd67539d9ec9229 Mon Sep 17 00:00:00 2001 From: LoveLosita <2810873701@qq.com> Date: Tue, 10 Feb 2026 20:52:06 +0800 Subject: [PATCH] =?UTF-8?q?Version:=200.2.1.dev.260210=20feat:=20?= =?UTF-8?q?=F0=9F=9A=A6=20=E6=96=B0=E5=A2=9E=E5=9F=BA=E4=BA=8E=20Redis=20?= =?UTF-8?q?=E4=BB=A4=E7=89=8C=E6=A1=B6=E7=9A=84=E9=99=90=E6=B5=81=E4=B8=AD?= =?UTF-8?q?=E9=97=B4=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 使用 Redis 实现令牌桶算法进行限流 🪣 - 覆盖除登录、注册、刷新 token 以外的所有接口 🔒 fix: 🐛 修复任务块添加到日程接口可修改已安排任务时间的问题 - 禁止通过该接口直接修改已安排任务块的时间 - 修正不合理的业务逻辑,保证数据一致性 ✅ --- backend/cmd/start.go | 6 ++- backend/dao/task-class.go | 12 ++++++ backend/middleware/rate_limiter.go | 38 ++++++++++++++++++ backend/pkg/rate_limiter.go | 64 ++++++++++++++++++++++++++++++ backend/respond/respond.go | 9 +++++ backend/routers/routers.go | 13 +++--- backend/service/task-class.go | 10 ++++- 7 files changed, 143 insertions(+), 9 deletions(-) create mode 100644 backend/middleware/rate_limiter.go create mode 100644 backend/pkg/rate_limiter.go diff --git a/backend/cmd/start.go b/backend/cmd/start.go index a5ce38e..0430751 100644 --- a/backend/cmd/start.go +++ b/backend/cmd/start.go @@ -7,6 +7,7 @@ import ( "github.com/LoveLosita/smartflow/backend/api" "github.com/LoveLosita/smartflow/backend/dao" "github.com/LoveLosita/smartflow/backend/inits" + "github.com/LoveLosita/smartflow/backend/pkg" "github.com/LoveLosita/smartflow/backend/routers" "github.com/LoveLosita/smartflow/backend/service" "github.com/spf13/viper" @@ -39,6 +40,8 @@ func Start() { log.Fatalf("Failed to connect to database: %v", err) } rdb := inits.InitRedis() + //工具包 + limiter := pkg.NewRateLimiter(rdb) //dao 层 userRepo := dao.NewUserDAO(db) cacheRepo := dao.NewCacheDAO(rdb) @@ -59,7 +62,6 @@ func Start() { courseApi := api.NewCourseHandler(courseService) taskClassApi := api.NewTaskClassHandler(taskClassService) scheduleApi := api.NewScheduleAPI(scheduleService) - handlers := &api.ApiHandlers{ UserHandler: userApi, TaskHandler: taskApi, @@ -67,6 +69,6 @@ func Start() { CourseHandler: courseApi, ScheduleHandler: scheduleApi, } - r := routers.RegisterRouters(handlers, cacheRepo) + r := routers.RegisterRouters(handlers, cacheRepo, limiter) routers.StartEngine(r) } diff --git a/backend/dao/task-class.go b/backend/dao/task-class.go index c4f65eb..8bd2de0 100644 --- a/backend/dao/task-class.go +++ b/backend/dao/task-class.go @@ -195,3 +195,15 @@ func (dao *TaskClassDAO) DeleteTaskClassItemEmbeddedTime(ctx context.Context, ta Update("embedded_time", nil).Error return err } + +func (dao *TaskClassDAO) IfTaskClassItemArranged(ctx context.Context, taskID int) (bool, error) { + var item model.TaskClassItem + err := dao.db.WithContext(ctx). + Select("embedded_time"). + Where("id = ?", taskID). + First(&item).Error + if err != nil { + return false, err + } + return item.EmbeddedTime != nil, nil +} diff --git a/backend/middleware/rate_limiter.go b/backend/middleware/rate_limiter.go new file mode 100644 index 0000000..dee868f --- /dev/null +++ b/backend/middleware/rate_limiter.go @@ -0,0 +1,38 @@ +package middleware + +import ( + "fmt" + "log" + + "github.com/LoveLosita/smartflow/backend/pkg" + "github.com/LoveLosita/smartflow/backend/respond" + "github.com/gin-gonic/gin" +) + +func RateLimitMiddleware(limiter *pkg.RateLimiter, capacity, rate int) gin.HandlerFunc { + return func(c *gin.Context) { + // 1. 确定限流对象:可以用 UserID,也可以用 IP + // 这里建议用 UserID,防止某个用户换 IP 疯狂刷 + userID := c.GetInt("user_id") // 假设你之前的 JWT 已经塞进去了 + key := fmt.Sprintf("rate_limit:user:%d", userID) + + // 2. 执行限流检查 + allowed, err := limiter.Allow(c.Request.Context(), key, capacity, rate) + if err != nil { + // 如果 Redis 挂了,为了保证业务可用,通常选择“放行”并记录日志 + log.Printf("Redis limiter error: %v", err) + c.Next() + return + } + + if !allowed { + // 3. 触发限流:直接调你写好的 DealWithError + // 可以在 respond 里定义一个新错误:TooManyRequests + respond.DealWithError(c, respond.TooManyRequests) + c.Abort() // 拦截,不执行后续 Handler + return + } + + c.Next() + } +} diff --git a/backend/pkg/rate_limiter.go b/backend/pkg/rate_limiter.go new file mode 100644 index 0000000..2a58dab --- /dev/null +++ b/backend/pkg/rate_limiter.go @@ -0,0 +1,64 @@ +package pkg + +import ( + "context" + "time" + + "github.com/go-redis/redis/v8" +) + +var tokenBucketScript = redis.NewScript(`-- KEYS[1]: 限流标识 (如 rate_limit:user_123) +-- ARGV[1]: 令牌桶最大容量 (Capacity) +-- ARGV[2]: 令牌填充速率 (Tokens per second) +-- ARGV[3]: 当前时间戳 (Current Unix timestamp in seconds) +-- ARGV[4]: 请求需要的令牌数 (通常为 1) + +local bucket_info = redis.call("HMGET", KEYS[1], "last_tokens", "last_refreshed") +local last_tokens = tonumber(bucket_info[1]) +local last_refreshed = tonumber(bucket_info[2]) + +local capacity = tonumber(ARGV[1]) +local rate = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) +local requested = tonumber(ARGV[4]) + +-- 如果是首次访问,初始化桶 +if last_tokens == nil then + last_tokens = capacity + last_refreshed = now +end + +-- 💡 核心逻辑:计算这段时间新产生的令牌 +local delta = math.max(0, now - last_refreshed) +local new_tokens = math.min(capacity, last_tokens + (delta * rate)) + +local allowed = false +if new_tokens >= requested then + new_tokens = new_tokens - requested + allowed = true +end + +-- 更新 Redis 状态 +redis.call("HMSET", KEYS[1], "last_tokens", new_tokens, "last_refreshed", now) +-- 设置过期时间(比如 1 小时没人访问就删掉,省内存) +redis.call("EXPIRE", KEYS[1], 3600) + +return allowed and 1 or 0`) + +type RateLimiter struct { + client *redis.Client +} + +func NewRateLimiter(client *redis.Client) *RateLimiter { + return &RateLimiter{client: client} +} + +func (r *RateLimiter) Allow(ctx context.Context, key string, capacity, rate int) (bool, error) { + // 传参:Key, 容量, 速率, 当前时间, 请求数 + res, err := tokenBucketScript.Run(ctx, r.client, []string{key}, + capacity, rate, time.Now().Unix(), 1).Int() + if err != nil { + return false, err + } + return res == 1, nil +} diff --git a/backend/respond/respond.go b/backend/respond/respond.go index 18c1539..82e1cac 100644 --- a/backend/respond/respond.go +++ b/backend/respond/respond.go @@ -221,4 +221,13 @@ var ( //请求相关的响应 Status: "40032", Info: "target schedule does not have embedded task", } + TooManyRequests = Response{ //请求过多 + Status: "40033", + Info: "too many requests", + } + + TaskClassItemAlreadyArranged = Response{ //任务类项目已安排 + Status: "40034", + Info: "task class item already arranged", + } ) diff --git a/backend/routers/routers.go b/backend/routers/routers.go index c1c3c9e..97d7d32 100644 --- a/backend/routers/routers.go +++ b/backend/routers/routers.go @@ -8,6 +8,7 @@ import ( "github.com/LoveLosita/smartflow/backend/api" "github.com/LoveLosita/smartflow/backend/dao" "github.com/LoveLosita/smartflow/backend/middleware" + "github.com/LoveLosita/smartflow/backend/pkg" "github.com/gin-gonic/gin" "github.com/spf13/viper" ) @@ -27,7 +28,7 @@ func StartEngine(r *gin.Engine) { } } -func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO) *gin.Engine { +func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO, limiter *pkg.RateLimiter) *gin.Engine { // 初始化Gin引擎 r := gin.Default() // 在这里注册所有的路由和路由组 @@ -46,23 +47,23 @@ func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO) *gin.Engine userGroup.POST("/register", handlers.UserHandler.UserRegister) userGroup.POST("/login", handlers.UserHandler.UserLogin) userGroup.POST("/refresh-token", handlers.UserHandler.RefreshTokenHandler) - userGroup.POST("/logout", middleware.JWTTokenAuth(cache), handlers.UserHandler.UserLogout) + userGroup.POST("/logout", middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1), handlers.UserHandler.UserLogout) } taskGroup := apiGroup.Group("/task") { - taskGroup.Use(middleware.JWTTokenAuth(cache)) + taskGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1)) taskGroup.POST("/create", handlers.TaskHandler.AddTask) taskGroup.GET("/get", handlers.TaskHandler.GetUserTasks) } courseGroup := apiGroup.Group("/course") { - courseGroup.Use(middleware.JWTTokenAuth(cache)) + courseGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1)) courseGroup.POST("/validate", handlers.CourseHandler.CheckUserCourse) courseGroup.POST("/import", handlers.CourseHandler.AddUserCourses) } taskClassGroup := apiGroup.Group("/task-class") { - taskClassGroup.Use(middleware.JWTTokenAuth(cache)) + taskClassGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1)) taskClassGroup.POST("/add", handlers.TaskClassHandler.UserAddTaskClass) taskClassGroup.GET("/list", handlers.TaskClassHandler.UserGetTaskClassInfos) taskClassGroup.GET("/get", handlers.TaskClassHandler.UserGetCompleteTaskClass) @@ -71,7 +72,7 @@ func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO) *gin.Engine } scheduleGroup := apiGroup.Group("/schedule") { - scheduleGroup.Use(middleware.JWTTokenAuth(cache)) + scheduleGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1)) scheduleGroup.GET("/today", handlers.ScheduleHandler.GetUserTodaySchedule) scheduleGroup.GET("/week", handlers.ScheduleHandler.GetUserWeeklySchedule) scheduleGroup.DELETE("/delete", handlers.ScheduleHandler.DeleteScheduleEvent) diff --git a/backend/service/task-class.go b/backend/service/task-class.go index 11de787..14f3a0e 100644 --- a/backend/service/task-class.go +++ b/backend/service/task-class.go @@ -115,7 +115,15 @@ func (sv *TaskClassService) AddTaskClassItemIntoSchedule(ctx context.Context, re if ownerID != userID { return respond.TaskClassItemNotBelongToUser } - //2.取出任务块信息 + //2.再检查任务块本身是否已经被安排 + result, err := sv.taskClassRepo.IfTaskClassItemArranged(ctx, taskID) + if err != nil { + return err + } + if result { + return respond.TaskClassItemAlreadyArranged + } + //3.取出任务块信息 taskItem, err := sv.taskClassRepo.GetTaskClassItemByID(ctx, taskID) //通过任务块ID获取任务块信息 if err != nil { return err