From bc56d471a8b449cbbe8f2ac43b581573ef035144 Mon Sep 17 00:00:00 2001 From: LoveLosita <2810873701@qq.com> Date: Tue, 17 Mar 2026 19:46:08 +0800 Subject: [PATCH] =?UTF-8?q?Version:=200.6.7.dev.260317=20=E2=9C=A8=20feat(?= =?UTF-8?q?agent):=20=E6=96=B0=E5=A2=9E=20Token=20=E9=85=8D=E9=A2=9D?= =?UTF-8?q?=E9=97=A8=E7=A6=81=E4=B8=AD=E9=97=B4=E4=BB=B6=EF=BC=88Redis=20?= =?UTF-8?q?=E5=BF=AB=E7=85=A7=20+=20=E5=B0=81=E7=A6=81=E9=94=AE=20+=207=20?= =?UTF-8?q?=E5=A4=A9=E6=87=92=E9=87=8D=E7=BD=AE=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 🚪 在 `POST /api/v1/agent/chat` 挂载 `TokenQuotaGuard`,在请求进入业务逻辑前完成额度校验 - ⚡ 新增 Redis 配额快照与封禁键机制:超额用户命中封禁键后可快速拦截,降低重复查库带来的开销 - 🗃️ 新增用户配额 DAO 能力:按需读取 `token_limit`、`token_usage`、`last_reset_at`,并支持基于“到期条件更新”的懒重置 - 🔄 实现 7 天懒重置策略:用户访问时若检测到配额周期已到期,则重置 `token_usage` 并清理封禁状态 - 🚫 新增超额响应码 `40051`,用于标识 `token usage exceeds limit` --- backend/cmd/start.go | 2 +- backend/dao/cache.go | 87 +++++++++++ backend/dao/user.go | 45 ++++++ backend/middleware/token_quota_guard.go | 184 ++++++++++++++++++++++++ backend/respond/respond.go | 5 + backend/routers/routers.go | 4 +- 6 files changed, 324 insertions(+), 3 deletions(-) create mode 100644 backend/middleware/token_quota_guard.go diff --git a/backend/cmd/start.go b/backend/cmd/start.go index 87b9ee2..ecfda46 100644 --- a/backend/cmd/start.go +++ b/backend/cmd/start.go @@ -116,6 +116,6 @@ func Start() { AgentHandler: agentApi, } - r := routers.RegisterRouters(handlers, cacheRepo, limiter) + r := routers.RegisterRouters(handlers, cacheRepo, userRepo, limiter) routers.StartEngine(r) } diff --git a/backend/dao/cache.go b/backend/dao/cache.go index d4b3fcf..91b834d 100644 --- a/backend/dao/cache.go +++ b/backend/dao/cache.go @@ -15,6 +15,17 @@ type CacheDAO struct { client *redis.Client } +// UserTokenQuotaSnapshot 是“用户额度判断”的 Redis 快照结构。 +// +// 设计说明: +// 1. 只保留额度判断必要字段,避免把 users 全字段塞进缓存; +// 2. 该结构仅用于“快速门禁判断”,权威账本仍以 MySQL 为准。 +type UserTokenQuotaSnapshot struct { + TokenLimit int `json:"token_limit"` + TokenUsage int `json:"token_usage"` + LastResetAt time.Time `json:"last_reset_at"` +} + func NewCacheDAO(client *redis.Client) *CacheDAO { return &CacheDAO{client: client} } @@ -266,3 +277,79 @@ func (d *CacheDAO) DeleteUserOngoingScheduleFromCache(ctx context.Context, userI key := fmt.Sprintf("smartflow:ongoing_schedule:%d", userID) return d.client.Del(ctx, key).Err() } + +func userTokenQuotaSnapshotKey(userID int) string { + return fmt.Sprintf("smartflow:user_token_quota_snapshot:%d", userID) +} + +func userTokenBlockedKey(userID int) string { + return fmt.Sprintf("smartflow:user_token_blocked:%d", userID) +} + +// GetUserTokenQuotaSnapshot 读取用户 token 配额快照。 +// +// 输入输出语义: +// 1. 命中返回 (*UserTokenQuotaSnapshot, true, nil); +// 2. 未命中返回 (nil, false, nil); +// 3. Redis/反序列化错误返回 (nil, false, err)。 +func (d *CacheDAO) GetUserTokenQuotaSnapshot(ctx context.Context, userID int) (*UserTokenQuotaSnapshot, bool, error) { + key := userTokenQuotaSnapshotKey(userID) + val, err := d.client.Get(ctx, key).Result() + if errors.Is(err, redis.Nil) { + return nil, false, nil + } + if err != nil { + return nil, false, err + } + + var snapshot UserTokenQuotaSnapshot + if err = json.Unmarshal([]byte(val), &snapshot); err != nil { + return nil, false, err + } + return &snapshot, true, nil +} + +// SetUserTokenQuotaSnapshot 写入用户 token 配额快照。 +// +// 职责边界: +// 1. 只做缓存写入,不做额度判断; +// 2. ttl 由上层策略控制,便于按场景调优“性能 vs 一致性”。 +func (d *CacheDAO) SetUserTokenQuotaSnapshot(ctx context.Context, userID int, snapshot UserTokenQuotaSnapshot, ttl time.Duration) error { + key := userTokenQuotaSnapshotKey(userID) + data, err := json.Marshal(snapshot) + if err != nil { + return err + } + return d.client.Set(ctx, key, data, ttl).Err() +} + +// DeleteUserTokenQuotaSnapshot 删除用户 token 快照缓存。 +func (d *CacheDAO) DeleteUserTokenQuotaSnapshot(ctx context.Context, userID int) error { + return d.client.Del(ctx, userTokenQuotaSnapshotKey(userID)).Err() +} + +// IsUserTokenBlocked 检查用户是否被“额度封禁键”命中。 +func (d *CacheDAO) IsUserTokenBlocked(ctx context.Context, userID int) (bool, error) { + result, err := d.client.Get(ctx, userTokenBlockedKey(userID)).Result() + if errors.Is(err, redis.Nil) { + return false, nil + } + if err != nil { + return false, err + } + return result == "1", nil +} + +// SetUserTokenBlocked 设置用户“额度封禁键”。 +// +// 说明: +// 1. 该键是快速拦截层,不是权威账本; +// 2. ttl 建议设置到“下一次重置时间”,到期自动解封。 +func (d *CacheDAO) SetUserTokenBlocked(ctx context.Context, userID int, ttl time.Duration) error { + return d.client.Set(ctx, userTokenBlockedKey(userID), "1", ttl).Err() +} + +// DeleteUserTokenBlocked 清理用户“额度封禁键”。 +func (d *CacheDAO) DeleteUserTokenBlocked(ctx context.Context, userID int) error { + return d.client.Del(ctx, userTokenBlockedKey(userID)).Err() +} diff --git a/backend/dao/user.go b/backend/dao/user.go index 31505fe..9e9b88d 100644 --- a/backend/dao/user.go +++ b/backend/dao/user.go @@ -1,6 +1,7 @@ package dao import ( + "context" "errors" "time" @@ -85,3 +86,47 @@ func (r *UserDAO) GetUserByID(id int) (*model.User, error) { } return &user, nil } + +// GetUserTokenQuotaByID 查询用户 token 配额快照(仅查询配额相关字段)。 +// +// 职责边界: +// 1. 只返回 token_limit / token_usage / last_reset_at 等“额度判断必需字段”; +// 2. 不负责做超额判断与重置判断(由中间件统一决策); +// 3. 不返回密码等敏感字段,避免把无关信息带入鉴权链路。 +func (r *UserDAO) GetUserTokenQuotaByID(ctx context.Context, id int) (*model.User, error) { + var user model.User + err := r.db.WithContext(ctx). + Select("id", "token_limit", "token_usage", "last_reset_at"). + Where("id = ?", id). + First(&user).Error + if err != nil { + return nil, err + } + return &user, nil +} + +// ResetUserTokenUsageIfDue 在“已到重置窗口”时执行懒重置。 +// +// 输入输出语义: +// 1. dueBefore:判定“到期可重置”的截止时间(通常是 now-7d); +// 2. resetAt:本次重置写入的时间戳; +// 3. 返回值 bool: +// - true 表示本次调用实际执行了重置; +// - false 表示条件未命中(尚未到期或记录不存在)。 +// +// 并发与幂等说明: +// 1. 使用条件更新(WHERE last_reset_at <= dueBefore)保证并发下最多一次成功重置; +// 2. 重复调用是安全的,未命中条件时不会破坏现有统计。 +func (r *UserDAO) ResetUserTokenUsageIfDue(ctx context.Context, id int, dueBefore time.Time, resetAt time.Time) (bool, error) { + result := r.db.WithContext(ctx). + Model(&model.User{}). + Where("id = ? AND (last_reset_at IS NULL OR last_reset_at <= ?)", id, dueBefore). + Updates(map[string]interface{}{ + "token_usage": 0, + "last_reset_at": resetAt, + }) + if result.Error != nil { + return false, result.Error + } + return result.RowsAffected > 0, nil +} diff --git a/backend/middleware/token_quota_guard.go b/backend/middleware/token_quota_guard.go new file mode 100644 index 0000000..d920766 --- /dev/null +++ b/backend/middleware/token_quota_guard.go @@ -0,0 +1,184 @@ +package middleware + +import ( + "errors" + "log" + "net/http" + "time" + + "github.com/LoveLosita/smartflow/backend/dao" + "github.com/LoveLosita/smartflow/backend/respond" + "github.com/gin-gonic/gin" +) + +const ( + // userTokenResetInterval 是“用户 token 周期重置”的窗口长度。 + // 当前按需求设置为 7 天。 + userTokenResetInterval = 7 * 24 * time.Hour + + // userTokenQuotaSnapshotTTL 是额度快照缓存时长。 + // 说明:该值越大,读 DB 越少;但“超额生效”的最坏延迟会变长。 + userTokenQuotaSnapshotTTL = 60 * time.Second + + // minUserTokenBlockTTL 是封禁键的最小 TTL,避免出现 0/负数导致“刚封就失效”。 + minUserTokenBlockTTL = 30 * time.Second +) + +// TokenQuotaGuard 在请求入口做“token 额度门禁 + 懒重置”。 +// +// 职责边界: +// 1. 负责在进入业务 Handler 前判断“该用户是否还能继续消费 token”; +// 2. 负责按 7 天窗口执行懒重置(只有访问时才判断是否重置); +// 3. 负责维护 Redis 快照与封禁键,降低每次请求都查库的成本; +// 4. 不负责 token 累加记账(记账由聊天持久化链路负责)。 +func TokenQuotaGuard(cache *dao.CacheDAO, userRepo *dao.UserDAO) gin.HandlerFunc { + return func(c *gin.Context) { + // 1. 基础依赖判空: + // 1.1 若中间件依赖未初始化,直接返回 500,避免出现“无门禁放行”的安全漏洞; + // 1.2 这里选择 fail-close(拒绝),因为该中间件是额度治理主入口。 + if cache == nil || userRepo == nil { + c.JSON(http.StatusInternalServerError, respond.InternalError(errors.New("token quota guard dependencies not initialized"))) + c.Abort() + return + } + + // 2. 从 JWT 中间件上下文获取 user_id: + // 2.1 若 user_id 非法,说明鉴权链路异常,直接按未授权拦截; + // 2.2 这里不尝试兜底查 token,避免重复实现鉴权逻辑。 + userID := c.GetInt("user_id") + if userID <= 0 { + c.JSON(http.StatusUnauthorized, respond.ErrUnauthorized) + c.Abort() + return + } + + ctx := c.Request.Context() + now := time.Now() + + // 3. 快速封禁检查(Redis): + // 3.1 命中封禁键直接拒绝,避免每次都查 DB; + // 3.2 Redis 查询失败时不立即放行,而是继续走 DB 严格校验,保证安全性。 + blocked, blockedErr := cache.IsUserTokenBlocked(ctx, userID) + if blockedErr != nil { + log.Printf("TokenQuotaGuard: 查询封禁键失败 user_id=%d err=%v,回退 DB 校验", userID, blockedErr) + } else if blocked { + c.JSON(http.StatusBadRequest, respond.TokenUsageExceedsLimit) + c.Abort() + return + } + + // 4. 优先尝试走快照快速路径: + // 4.1 命中快照且未到重置窗口时,直接用快照判断; + // 4.2 快照未命中/已到重置窗口/读取失败,则回源 DB 做权威判断。 + snapshot, hit, snapshotErr := cache.GetUserTokenQuotaSnapshot(ctx, userID) + if snapshotErr != nil { + log.Printf("TokenQuotaGuard: 读取额度快照失败 user_id=%d err=%v,回退 DB 校验", userID, snapshotErr) + } + if hit && snapshot != nil && !isResetDue(snapshot.LastResetAt, now) { + if snapshot.TokenUsage > snapshot.TokenLimit { + // 4.3 快照判断超额时,顺手写入封禁键,后续请求可 O(1) 拦截。 + ttl := calcBlockTTL(snapshot.LastResetAt, now) + if err := cache.SetUserTokenBlocked(ctx, userID, ttl); err != nil { + log.Printf("TokenQuotaGuard: 写入封禁键失败 user_id=%d err=%v", userID, err) + } + c.JSON(http.StatusBadRequest, respond.TokenUsageExceedsLimit) + c.Abort() + return + } + + // 4.4 快照命中且未超额,直接放行,避免本次请求访问 DB。 + c.Next() + return + } + + // 5. 回源 DB(权威判断路径): + // 5.1 先读取用户额度字段; + // 5.2 若已到重置窗口,执行条件更新懒重置,再回读最新值; + // 5.3 最后依据“token_usage > token_limit”判断是否拦截。 + quota, err := userRepo.GetUserTokenQuotaByID(ctx, userID) + if err != nil { + log.Printf("TokenQuotaGuard: 查询用户额度失败 user_id=%d err=%v", userID, err) + c.JSON(http.StatusInternalServerError, respond.InternalError(err)) + c.Abort() + return + } + + if isResetDue(quota.LastResetAt, now) { + _, resetErr := userRepo.ResetUserTokenUsageIfDue(ctx, userID, now.Add(-userTokenResetInterval), now) + if resetErr != nil { + log.Printf("TokenQuotaGuard: 懒重置失败 user_id=%d err=%v", userID, resetErr) + c.JSON(http.StatusInternalServerError, respond.InternalError(resetErr)) + c.Abort() + return + } + + // 5.2.1 重置后回读一次最新额度,避免使用旧值继续判断; + // 5.2.2 同时主动清理封禁键,防止“已重置仍被封”的残留状态。 + quota, err = userRepo.GetUserTokenQuotaByID(ctx, userID) + if err != nil { + log.Printf("TokenQuotaGuard: 重置后回读失败 user_id=%d err=%v", userID, err) + c.JSON(http.StatusInternalServerError, respond.InternalError(err)) + c.Abort() + return + } + if delErr := cache.DeleteUserTokenBlocked(ctx, userID); delErr != nil { + log.Printf("TokenQuotaGuard: 清理封禁键失败 user_id=%d err=%v", userID, delErr) + } + } + + // 6. 把最新权威值回填快照: + // 6.1 回填失败不影响主流程(仅影响性能,不影响正确性); + // 6.2 这样后续同用户短时间内请求可直接走快照快速路径。 + if setErr := cache.SetUserTokenQuotaSnapshot(ctx, userID, dao.UserTokenQuotaSnapshot{ + TokenLimit: quota.TokenLimit, + TokenUsage: quota.TokenUsage, + LastResetAt: quota.LastResetAt, + }, userTokenQuotaSnapshotTTL); setErr != nil { + log.Printf("TokenQuotaGuard: 回填额度快照失败 user_id=%d err=%v", userID, setErr) + } + + // 7. 最终判定: + // 7.1 按你的规则使用“>”判断超额(等于不拦截); + // 7.2 超额时写封禁键并拒绝;未超额则继续放行。 + if quota.TokenUsage > quota.TokenLimit { + ttl := calcBlockTTL(quota.LastResetAt, now) + if err = cache.SetUserTokenBlocked(ctx, userID, ttl); err != nil { + log.Printf("TokenQuotaGuard: 写入封禁键失败 user_id=%d err=%v", userID, err) + } + c.JSON(http.StatusBadRequest, respond.TokenUsageExceedsLimit) + c.Abort() + return + } + + c.Next() + } +} + +// isResetDue 判断“是否到达 7 天懒重置窗口”。 +// +// 说明: +// 1. lastResetAt 为零值时,视为到期(首次迁移数据时兜底); +// 2. now - lastResetAt >= 7 天 时返回 true。 +func isResetDue(lastResetAt time.Time, now time.Time) bool { + if lastResetAt.IsZero() { + return true + } + return !lastResetAt.Add(userTokenResetInterval).After(now) +} + +// calcBlockTTL 计算封禁键 TTL。 +// +// 规则: +// 1. 目标是封到“下一次重置时间”; +// 2. 若计算结果非正数,回退到最小 TTL,避免封禁键瞬时失效。 +func calcBlockTTL(lastResetAt time.Time, now time.Time) time.Duration { + if lastResetAt.IsZero() { + return minUserTokenBlockTTL + } + nextResetAt := lastResetAt.Add(userTokenResetInterval) + ttl := nextResetAt.Sub(now) + if ttl <= 0 { + return minUserTokenBlockTTL + } + return ttl +} diff --git a/backend/respond/respond.go b/backend/respond/respond.go index da7d34e..b793658 100644 --- a/backend/respond/respond.go +++ b/backend/respond/respond.go @@ -318,4 +318,9 @@ var ( //请求相关的响应 Status: "40050", Info: "wrong task id", } + + TokenUsageExceedsLimit = Response{ //token 使用量超过限额 + Status: "40051", + Info: "token usage exceeds limit", + } ) diff --git a/backend/routers/routers.go b/backend/routers/routers.go index 3ac667c..20ebfa3 100644 --- a/backend/routers/routers.go +++ b/backend/routers/routers.go @@ -28,7 +28,7 @@ func StartEngine(r *gin.Engine) { } } -func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO, limiter *pkg.RateLimiter) *gin.Engine { +func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO, userRepo *dao.UserDAO, limiter *pkg.RateLimiter) *gin.Engine { // 初始化Gin引擎 r := gin.Default() // 在这里注册所有的路由和路由组 @@ -88,7 +88,7 @@ func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO, limiter *pk agentGroup := apiGroup.Group("/agent") { agentGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1)) - agentGroup.POST("/chat", handlers.AgentHandler.ChatAgent) + agentGroup.POST("/chat", middleware.TokenQuotaGuard(cache, userRepo), handlers.AgentHandler.ChatAgent) agentGroup.GET("/conversation-meta", handlers.AgentHandler.GetConversationMeta) agentGroup.GET("/conversation-list", handlers.AgentHandler.GetConversationList) }