Version: 0.2.1.dev.260210
feat: 🚦 新增基于 Redis 令牌桶的限流中间件 - 使用 Redis 实现令牌桶算法进行限流 🪣 - 覆盖除登录、注册、刷新 token 以外的所有接口 🔒 fix: 🐛 修复任务块添加到日程接口可修改已安排任务时间的问题 - 禁止通过该接口直接修改已安排任务块的时间 - 修正不合理的业务逻辑,保证数据一致性 ✅
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/LoveLosita/smartflow/backend/api"
|
"github.com/LoveLosita/smartflow/backend/api"
|
||||||
"github.com/LoveLosita/smartflow/backend/dao"
|
"github.com/LoveLosita/smartflow/backend/dao"
|
||||||
"github.com/LoveLosita/smartflow/backend/inits"
|
"github.com/LoveLosita/smartflow/backend/inits"
|
||||||
|
"github.com/LoveLosita/smartflow/backend/pkg"
|
||||||
"github.com/LoveLosita/smartflow/backend/routers"
|
"github.com/LoveLosita/smartflow/backend/routers"
|
||||||
"github.com/LoveLosita/smartflow/backend/service"
|
"github.com/LoveLosita/smartflow/backend/service"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
@@ -39,6 +40,8 @@ func Start() {
|
|||||||
log.Fatalf("Failed to connect to database: %v", err)
|
log.Fatalf("Failed to connect to database: %v", err)
|
||||||
}
|
}
|
||||||
rdb := inits.InitRedis()
|
rdb := inits.InitRedis()
|
||||||
|
//工具包
|
||||||
|
limiter := pkg.NewRateLimiter(rdb)
|
||||||
//dao 层
|
//dao 层
|
||||||
userRepo := dao.NewUserDAO(db)
|
userRepo := dao.NewUserDAO(db)
|
||||||
cacheRepo := dao.NewCacheDAO(rdb)
|
cacheRepo := dao.NewCacheDAO(rdb)
|
||||||
@@ -59,7 +62,6 @@ func Start() {
|
|||||||
courseApi := api.NewCourseHandler(courseService)
|
courseApi := api.NewCourseHandler(courseService)
|
||||||
taskClassApi := api.NewTaskClassHandler(taskClassService)
|
taskClassApi := api.NewTaskClassHandler(taskClassService)
|
||||||
scheduleApi := api.NewScheduleAPI(scheduleService)
|
scheduleApi := api.NewScheduleAPI(scheduleService)
|
||||||
|
|
||||||
handlers := &api.ApiHandlers{
|
handlers := &api.ApiHandlers{
|
||||||
UserHandler: userApi,
|
UserHandler: userApi,
|
||||||
TaskHandler: taskApi,
|
TaskHandler: taskApi,
|
||||||
@@ -67,6 +69,6 @@ func Start() {
|
|||||||
CourseHandler: courseApi,
|
CourseHandler: courseApi,
|
||||||
ScheduleHandler: scheduleApi,
|
ScheduleHandler: scheduleApi,
|
||||||
}
|
}
|
||||||
r := routers.RegisterRouters(handlers, cacheRepo)
|
r := routers.RegisterRouters(handlers, cacheRepo, limiter)
|
||||||
routers.StartEngine(r)
|
routers.StartEngine(r)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -195,3 +195,15 @@ func (dao *TaskClassDAO) DeleteTaskClassItemEmbeddedTime(ctx context.Context, ta
|
|||||||
Update("embedded_time", nil).Error
|
Update("embedded_time", nil).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dao *TaskClassDAO) IfTaskClassItemArranged(ctx context.Context, taskID int) (bool, error) {
|
||||||
|
var item model.TaskClassItem
|
||||||
|
err := dao.db.WithContext(ctx).
|
||||||
|
Select("embedded_time").
|
||||||
|
Where("id = ?", taskID).
|
||||||
|
First(&item).Error
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return item.EmbeddedTime != nil, nil
|
||||||
|
}
|
||||||
|
|||||||
38
backend/middleware/rate_limiter.go
Normal file
38
backend/middleware/rate_limiter.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/LoveLosita/smartflow/backend/pkg"
|
||||||
|
"github.com/LoveLosita/smartflow/backend/respond"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RateLimitMiddleware(limiter *pkg.RateLimiter, capacity, rate int) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// 1. 确定限流对象:可以用 UserID,也可以用 IP
|
||||||
|
// 这里建议用 UserID,防止某个用户换 IP 疯狂刷
|
||||||
|
userID := c.GetInt("user_id") // 假设你之前的 JWT 已经塞进去了
|
||||||
|
key := fmt.Sprintf("rate_limit:user:%d", userID)
|
||||||
|
|
||||||
|
// 2. 执行限流检查
|
||||||
|
allowed, err := limiter.Allow(c.Request.Context(), key, capacity, rate)
|
||||||
|
if err != nil {
|
||||||
|
// 如果 Redis 挂了,为了保证业务可用,通常选择“放行”并记录日志
|
||||||
|
log.Printf("Redis limiter error: %v", err)
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !allowed {
|
||||||
|
// 3. 触发限流:直接调你写好的 DealWithError
|
||||||
|
// 可以在 respond 里定义一个新错误:TooManyRequests
|
||||||
|
respond.DealWithError(c, respond.TooManyRequests)
|
||||||
|
c.Abort() // 拦截,不执行后续 Handler
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
64
backend/pkg/rate_limiter.go
Normal file
64
backend/pkg/rate_limiter.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package pkg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
)
|
||||||
|
|
||||||
|
var tokenBucketScript = redis.NewScript(`-- KEYS[1]: 限流标识 (如 rate_limit:user_123)
|
||||||
|
-- ARGV[1]: 令牌桶最大容量 (Capacity)
|
||||||
|
-- ARGV[2]: 令牌填充速率 (Tokens per second)
|
||||||
|
-- ARGV[3]: 当前时间戳 (Current Unix timestamp in seconds)
|
||||||
|
-- ARGV[4]: 请求需要的令牌数 (通常为 1)
|
||||||
|
|
||||||
|
local bucket_info = redis.call("HMGET", KEYS[1], "last_tokens", "last_refreshed")
|
||||||
|
local last_tokens = tonumber(bucket_info[1])
|
||||||
|
local last_refreshed = tonumber(bucket_info[2])
|
||||||
|
|
||||||
|
local capacity = tonumber(ARGV[1])
|
||||||
|
local rate = tonumber(ARGV[2])
|
||||||
|
local now = tonumber(ARGV[3])
|
||||||
|
local requested = tonumber(ARGV[4])
|
||||||
|
|
||||||
|
-- 如果是首次访问,初始化桶
|
||||||
|
if last_tokens == nil then
|
||||||
|
last_tokens = capacity
|
||||||
|
last_refreshed = now
|
||||||
|
end
|
||||||
|
|
||||||
|
-- 💡 核心逻辑:计算这段时间新产生的令牌
|
||||||
|
local delta = math.max(0, now - last_refreshed)
|
||||||
|
local new_tokens = math.min(capacity, last_tokens + (delta * rate))
|
||||||
|
|
||||||
|
local allowed = false
|
||||||
|
if new_tokens >= requested then
|
||||||
|
new_tokens = new_tokens - requested
|
||||||
|
allowed = true
|
||||||
|
end
|
||||||
|
|
||||||
|
-- 更新 Redis 状态
|
||||||
|
redis.call("HMSET", KEYS[1], "last_tokens", new_tokens, "last_refreshed", now)
|
||||||
|
-- 设置过期时间(比如 1 小时没人访问就删掉,省内存)
|
||||||
|
redis.call("EXPIRE", KEYS[1], 3600)
|
||||||
|
|
||||||
|
return allowed and 1 or 0`)
|
||||||
|
|
||||||
|
type RateLimiter struct {
|
||||||
|
client *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRateLimiter(client *redis.Client) *RateLimiter {
|
||||||
|
return &RateLimiter{client: client}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RateLimiter) Allow(ctx context.Context, key string, capacity, rate int) (bool, error) {
|
||||||
|
// 传参:Key, 容量, 速率, 当前时间, 请求数
|
||||||
|
res, err := tokenBucketScript.Run(ctx, r.client, []string{key},
|
||||||
|
capacity, rate, time.Now().Unix(), 1).Int()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return res == 1, nil
|
||||||
|
}
|
||||||
@@ -221,4 +221,13 @@ var ( //请求相关的响应
|
|||||||
Status: "40032",
|
Status: "40032",
|
||||||
Info: "target schedule does not have embedded task",
|
Info: "target schedule does not have embedded task",
|
||||||
}
|
}
|
||||||
|
TooManyRequests = Response{ //请求过多
|
||||||
|
Status: "40033",
|
||||||
|
Info: "too many requests",
|
||||||
|
}
|
||||||
|
|
||||||
|
TaskClassItemAlreadyArranged = Response{ //任务类项目已安排
|
||||||
|
Status: "40034",
|
||||||
|
Info: "task class item already arranged",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/LoveLosita/smartflow/backend/api"
|
"github.com/LoveLosita/smartflow/backend/api"
|
||||||
"github.com/LoveLosita/smartflow/backend/dao"
|
"github.com/LoveLosita/smartflow/backend/dao"
|
||||||
"github.com/LoveLosita/smartflow/backend/middleware"
|
"github.com/LoveLosita/smartflow/backend/middleware"
|
||||||
|
"github.com/LoveLosita/smartflow/backend/pkg"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
@@ -27,7 +28,7 @@ func StartEngine(r *gin.Engine) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO) *gin.Engine {
|
func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO, limiter *pkg.RateLimiter) *gin.Engine {
|
||||||
// 初始化Gin引擎
|
// 初始化Gin引擎
|
||||||
r := gin.Default()
|
r := gin.Default()
|
||||||
// 在这里注册所有的路由和路由组
|
// 在这里注册所有的路由和路由组
|
||||||
@@ -46,23 +47,23 @@ func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO) *gin.Engine
|
|||||||
userGroup.POST("/register", handlers.UserHandler.UserRegister)
|
userGroup.POST("/register", handlers.UserHandler.UserRegister)
|
||||||
userGroup.POST("/login", handlers.UserHandler.UserLogin)
|
userGroup.POST("/login", handlers.UserHandler.UserLogin)
|
||||||
userGroup.POST("/refresh-token", handlers.UserHandler.RefreshTokenHandler)
|
userGroup.POST("/refresh-token", handlers.UserHandler.RefreshTokenHandler)
|
||||||
userGroup.POST("/logout", middleware.JWTTokenAuth(cache), handlers.UserHandler.UserLogout)
|
userGroup.POST("/logout", middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1), handlers.UserHandler.UserLogout)
|
||||||
}
|
}
|
||||||
taskGroup := apiGroup.Group("/task")
|
taskGroup := apiGroup.Group("/task")
|
||||||
{
|
{
|
||||||
taskGroup.Use(middleware.JWTTokenAuth(cache))
|
taskGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1))
|
||||||
taskGroup.POST("/create", handlers.TaskHandler.AddTask)
|
taskGroup.POST("/create", handlers.TaskHandler.AddTask)
|
||||||
taskGroup.GET("/get", handlers.TaskHandler.GetUserTasks)
|
taskGroup.GET("/get", handlers.TaskHandler.GetUserTasks)
|
||||||
}
|
}
|
||||||
courseGroup := apiGroup.Group("/course")
|
courseGroup := apiGroup.Group("/course")
|
||||||
{
|
{
|
||||||
courseGroup.Use(middleware.JWTTokenAuth(cache))
|
courseGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1))
|
||||||
courseGroup.POST("/validate", handlers.CourseHandler.CheckUserCourse)
|
courseGroup.POST("/validate", handlers.CourseHandler.CheckUserCourse)
|
||||||
courseGroup.POST("/import", handlers.CourseHandler.AddUserCourses)
|
courseGroup.POST("/import", handlers.CourseHandler.AddUserCourses)
|
||||||
}
|
}
|
||||||
taskClassGroup := apiGroup.Group("/task-class")
|
taskClassGroup := apiGroup.Group("/task-class")
|
||||||
{
|
{
|
||||||
taskClassGroup.Use(middleware.JWTTokenAuth(cache))
|
taskClassGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1))
|
||||||
taskClassGroup.POST("/add", handlers.TaskClassHandler.UserAddTaskClass)
|
taskClassGroup.POST("/add", handlers.TaskClassHandler.UserAddTaskClass)
|
||||||
taskClassGroup.GET("/list", handlers.TaskClassHandler.UserGetTaskClassInfos)
|
taskClassGroup.GET("/list", handlers.TaskClassHandler.UserGetTaskClassInfos)
|
||||||
taskClassGroup.GET("/get", handlers.TaskClassHandler.UserGetCompleteTaskClass)
|
taskClassGroup.GET("/get", handlers.TaskClassHandler.UserGetCompleteTaskClass)
|
||||||
@@ -71,7 +72,7 @@ func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO) *gin.Engine
|
|||||||
}
|
}
|
||||||
scheduleGroup := apiGroup.Group("/schedule")
|
scheduleGroup := apiGroup.Group("/schedule")
|
||||||
{
|
{
|
||||||
scheduleGroup.Use(middleware.JWTTokenAuth(cache))
|
scheduleGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1))
|
||||||
scheduleGroup.GET("/today", handlers.ScheduleHandler.GetUserTodaySchedule)
|
scheduleGroup.GET("/today", handlers.ScheduleHandler.GetUserTodaySchedule)
|
||||||
scheduleGroup.GET("/week", handlers.ScheduleHandler.GetUserWeeklySchedule)
|
scheduleGroup.GET("/week", handlers.ScheduleHandler.GetUserWeeklySchedule)
|
||||||
scheduleGroup.DELETE("/delete", handlers.ScheduleHandler.DeleteScheduleEvent)
|
scheduleGroup.DELETE("/delete", handlers.ScheduleHandler.DeleteScheduleEvent)
|
||||||
|
|||||||
@@ -115,7 +115,15 @@ func (sv *TaskClassService) AddTaskClassItemIntoSchedule(ctx context.Context, re
|
|||||||
if ownerID != userID {
|
if ownerID != userID {
|
||||||
return respond.TaskClassItemNotBelongToUser
|
return respond.TaskClassItemNotBelongToUser
|
||||||
}
|
}
|
||||||
//2.取出任务块信息
|
//2.再检查任务块本身是否已经被安排
|
||||||
|
result, err := sv.taskClassRepo.IfTaskClassItemArranged(ctx, taskID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if result {
|
||||||
|
return respond.TaskClassItemAlreadyArranged
|
||||||
|
}
|
||||||
|
//3.取出任务块信息
|
||||||
taskItem, err := sv.taskClassRepo.GetTaskClassItemByID(ctx, taskID) //通过任务块ID获取任务块信息
|
taskItem, err := sv.taskClassRepo.GetTaskClassItemByID(ctx, taskID) //通过任务块ID获取任务块信息
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
Reference in New Issue
Block a user