Version: 0.2.1.dev.260210

feat: 🚦 新增基于 Redis 令牌桶的限流中间件

- 使用 Redis 实现令牌桶算法进行限流 🪣
- 覆盖除登录、注册、刷新 token 以外的所有接口 🔒

fix: 🐛 修复任务块添加到日程接口可修改已安排任务时间的问题

- 禁止通过该接口直接修改已安排任务块的时间
- 修正不合理的业务逻辑,保证数据一致性 
This commit is contained in:
LoveLosita
2026-02-10 20:52:06 +08:00
parent d07234e183
commit d5f0b8da63
7 changed files with 143 additions and 9 deletions

View File

@@ -7,6 +7,7 @@ import (
"github.com/LoveLosita/smartflow/backend/api" "github.com/LoveLosita/smartflow/backend/api"
"github.com/LoveLosita/smartflow/backend/dao" "github.com/LoveLosita/smartflow/backend/dao"
"github.com/LoveLosita/smartflow/backend/inits" "github.com/LoveLosita/smartflow/backend/inits"
"github.com/LoveLosita/smartflow/backend/pkg"
"github.com/LoveLosita/smartflow/backend/routers" "github.com/LoveLosita/smartflow/backend/routers"
"github.com/LoveLosita/smartflow/backend/service" "github.com/LoveLosita/smartflow/backend/service"
"github.com/spf13/viper" "github.com/spf13/viper"
@@ -39,6 +40,8 @@ func Start() {
log.Fatalf("Failed to connect to database: %v", err) log.Fatalf("Failed to connect to database: %v", err)
} }
rdb := inits.InitRedis() rdb := inits.InitRedis()
//工具包
limiter := pkg.NewRateLimiter(rdb)
//dao 层 //dao 层
userRepo := dao.NewUserDAO(db) userRepo := dao.NewUserDAO(db)
cacheRepo := dao.NewCacheDAO(rdb) cacheRepo := dao.NewCacheDAO(rdb)
@@ -59,7 +62,6 @@ func Start() {
courseApi := api.NewCourseHandler(courseService) courseApi := api.NewCourseHandler(courseService)
taskClassApi := api.NewTaskClassHandler(taskClassService) taskClassApi := api.NewTaskClassHandler(taskClassService)
scheduleApi := api.NewScheduleAPI(scheduleService) scheduleApi := api.NewScheduleAPI(scheduleService)
handlers := &api.ApiHandlers{ handlers := &api.ApiHandlers{
UserHandler: userApi, UserHandler: userApi,
TaskHandler: taskApi, TaskHandler: taskApi,
@@ -67,6 +69,6 @@ func Start() {
CourseHandler: courseApi, CourseHandler: courseApi,
ScheduleHandler: scheduleApi, ScheduleHandler: scheduleApi,
} }
r := routers.RegisterRouters(handlers, cacheRepo) r := routers.RegisterRouters(handlers, cacheRepo, limiter)
routers.StartEngine(r) routers.StartEngine(r)
} }

View File

@@ -195,3 +195,15 @@ func (dao *TaskClassDAO) DeleteTaskClassItemEmbeddedTime(ctx context.Context, ta
Update("embedded_time", nil).Error Update("embedded_time", nil).Error
return err return err
} }
func (dao *TaskClassDAO) IfTaskClassItemArranged(ctx context.Context, taskID int) (bool, error) {
var item model.TaskClassItem
err := dao.db.WithContext(ctx).
Select("embedded_time").
Where("id = ?", taskID).
First(&item).Error
if err != nil {
return false, err
}
return item.EmbeddedTime != nil, nil
}

View File

@@ -0,0 +1,38 @@
package middleware
import (
"fmt"
"log"
"github.com/LoveLosita/smartflow/backend/pkg"
"github.com/LoveLosita/smartflow/backend/respond"
"github.com/gin-gonic/gin"
)
func RateLimitMiddleware(limiter *pkg.RateLimiter, capacity, rate int) gin.HandlerFunc {
return func(c *gin.Context) {
// 1. 确定限流对象:可以用 UserID也可以用 IP
// 这里建议用 UserID防止某个用户换 IP 疯狂刷
userID := c.GetInt("user_id") // 假设你之前的 JWT 已经塞进去了
key := fmt.Sprintf("rate_limit:user:%d", userID)
// 2. 执行限流检查
allowed, err := limiter.Allow(c.Request.Context(), key, capacity, rate)
if err != nil {
// 如果 Redis 挂了,为了保证业务可用,通常选择“放行”并记录日志
log.Printf("Redis limiter error: %v", err)
c.Next()
return
}
if !allowed {
// 3. 触发限流:直接调你写好的 DealWithError
// 可以在 respond 里定义一个新错误TooManyRequests
respond.DealWithError(c, respond.TooManyRequests)
c.Abort() // 拦截,不执行后续 Handler
return
}
c.Next()
}
}

View File

@@ -0,0 +1,64 @@
package pkg
import (
"context"
"time"
"github.com/go-redis/redis/v8"
)
var tokenBucketScript = redis.NewScript(`-- KEYS[1]: 限流标识 (如 rate_limit:user_123)
-- ARGV[1]: 令牌桶最大容量 (Capacity)
-- ARGV[2]: 令牌填充速率 (Tokens per second)
-- ARGV[3]: 当前时间戳 (Current Unix timestamp in seconds)
-- ARGV[4]: 请求需要的令牌数 (通常为 1)
local bucket_info = redis.call("HMGET", KEYS[1], "last_tokens", "last_refreshed")
local last_tokens = tonumber(bucket_info[1])
local last_refreshed = tonumber(bucket_info[2])
local capacity = tonumber(ARGV[1])
local rate = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
-- 如果是首次访问,初始化桶
if last_tokens == nil then
last_tokens = capacity
last_refreshed = now
end
-- 💡 核心逻辑:计算这段时间新产生的令牌
local delta = math.max(0, now - last_refreshed)
local new_tokens = math.min(capacity, last_tokens + (delta * rate))
local allowed = false
if new_tokens >= requested then
new_tokens = new_tokens - requested
allowed = true
end
-- 更新 Redis 状态
redis.call("HMSET", KEYS[1], "last_tokens", new_tokens, "last_refreshed", now)
-- 设置过期时间(比如 1 小时没人访问就删掉,省内存)
redis.call("EXPIRE", KEYS[1], 3600)
return allowed and 1 or 0`)
type RateLimiter struct {
client *redis.Client
}
func NewRateLimiter(client *redis.Client) *RateLimiter {
return &RateLimiter{client: client}
}
func (r *RateLimiter) Allow(ctx context.Context, key string, capacity, rate int) (bool, error) {
// 传参Key, 容量, 速率, 当前时间, 请求数
res, err := tokenBucketScript.Run(ctx, r.client, []string{key},
capacity, rate, time.Now().Unix(), 1).Int()
if err != nil {
return false, err
}
return res == 1, nil
}

View File

@@ -221,4 +221,13 @@ var ( //请求相关的响应
Status: "40032", Status: "40032",
Info: "target schedule does not have embedded task", Info: "target schedule does not have embedded task",
} }
TooManyRequests = Response{ //请求过多
Status: "40033",
Info: "too many requests",
}
TaskClassItemAlreadyArranged = Response{ //任务类项目已安排
Status: "40034",
Info: "task class item already arranged",
}
) )

View File

@@ -8,6 +8,7 @@ import (
"github.com/LoveLosita/smartflow/backend/api" "github.com/LoveLosita/smartflow/backend/api"
"github.com/LoveLosita/smartflow/backend/dao" "github.com/LoveLosita/smartflow/backend/dao"
"github.com/LoveLosita/smartflow/backend/middleware" "github.com/LoveLosita/smartflow/backend/middleware"
"github.com/LoveLosita/smartflow/backend/pkg"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
@@ -27,7 +28,7 @@ func StartEngine(r *gin.Engine) {
} }
} }
func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO) *gin.Engine { func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO, limiter *pkg.RateLimiter) *gin.Engine {
// 初始化Gin引擎 // 初始化Gin引擎
r := gin.Default() r := gin.Default()
// 在这里注册所有的路由和路由组 // 在这里注册所有的路由和路由组
@@ -46,23 +47,23 @@ func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO) *gin.Engine
userGroup.POST("/register", handlers.UserHandler.UserRegister) userGroup.POST("/register", handlers.UserHandler.UserRegister)
userGroup.POST("/login", handlers.UserHandler.UserLogin) userGroup.POST("/login", handlers.UserHandler.UserLogin)
userGroup.POST("/refresh-token", handlers.UserHandler.RefreshTokenHandler) userGroup.POST("/refresh-token", handlers.UserHandler.RefreshTokenHandler)
userGroup.POST("/logout", middleware.JWTTokenAuth(cache), handlers.UserHandler.UserLogout) userGroup.POST("/logout", middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1), handlers.UserHandler.UserLogout)
} }
taskGroup := apiGroup.Group("/task") taskGroup := apiGroup.Group("/task")
{ {
taskGroup.Use(middleware.JWTTokenAuth(cache)) taskGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1))
taskGroup.POST("/create", handlers.TaskHandler.AddTask) taskGroup.POST("/create", handlers.TaskHandler.AddTask)
taskGroup.GET("/get", handlers.TaskHandler.GetUserTasks) taskGroup.GET("/get", handlers.TaskHandler.GetUserTasks)
} }
courseGroup := apiGroup.Group("/course") courseGroup := apiGroup.Group("/course")
{ {
courseGroup.Use(middleware.JWTTokenAuth(cache)) courseGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1))
courseGroup.POST("/validate", handlers.CourseHandler.CheckUserCourse) courseGroup.POST("/validate", handlers.CourseHandler.CheckUserCourse)
courseGroup.POST("/import", handlers.CourseHandler.AddUserCourses) courseGroup.POST("/import", handlers.CourseHandler.AddUserCourses)
} }
taskClassGroup := apiGroup.Group("/task-class") taskClassGroup := apiGroup.Group("/task-class")
{ {
taskClassGroup.Use(middleware.JWTTokenAuth(cache)) taskClassGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1))
taskClassGroup.POST("/add", handlers.TaskClassHandler.UserAddTaskClass) taskClassGroup.POST("/add", handlers.TaskClassHandler.UserAddTaskClass)
taskClassGroup.GET("/list", handlers.TaskClassHandler.UserGetTaskClassInfos) taskClassGroup.GET("/list", handlers.TaskClassHandler.UserGetTaskClassInfos)
taskClassGroup.GET("/get", handlers.TaskClassHandler.UserGetCompleteTaskClass) taskClassGroup.GET("/get", handlers.TaskClassHandler.UserGetCompleteTaskClass)
@@ -71,7 +72,7 @@ func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO) *gin.Engine
} }
scheduleGroup := apiGroup.Group("/schedule") scheduleGroup := apiGroup.Group("/schedule")
{ {
scheduleGroup.Use(middleware.JWTTokenAuth(cache)) scheduleGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1))
scheduleGroup.GET("/today", handlers.ScheduleHandler.GetUserTodaySchedule) scheduleGroup.GET("/today", handlers.ScheduleHandler.GetUserTodaySchedule)
scheduleGroup.GET("/week", handlers.ScheduleHandler.GetUserWeeklySchedule) scheduleGroup.GET("/week", handlers.ScheduleHandler.GetUserWeeklySchedule)
scheduleGroup.DELETE("/delete", handlers.ScheduleHandler.DeleteScheduleEvent) scheduleGroup.DELETE("/delete", handlers.ScheduleHandler.DeleteScheduleEvent)

View File

@@ -115,7 +115,15 @@ func (sv *TaskClassService) AddTaskClassItemIntoSchedule(ctx context.Context, re
if ownerID != userID { if ownerID != userID {
return respond.TaskClassItemNotBelongToUser return respond.TaskClassItemNotBelongToUser
} }
//2.取出任务块信息 //2.再检查任务块本身是否已经被安排
result, err := sv.taskClassRepo.IfTaskClassItemArranged(ctx, taskID)
if err != nil {
return err
}
if result {
return respond.TaskClassItemAlreadyArranged
}
//3.取出任务块信息
taskItem, err := sv.taskClassRepo.GetTaskClassItemByID(ctx, taskID) //通过任务块ID获取任务块信息 taskItem, err := sv.taskClassRepo.GetTaskClassItemByID(ctx, taskID) //通过任务块ID获取任务块信息
if err != nil { if err != nil {
return err return err