Version: 0.6.7.dev.260317
✨ feat(agent): 新增 Token 配额门禁中间件(Redis 快照 + 封禁键 + 7 天懒重置) - 🚪 在 `POST /api/v1/agent/chat` 挂载 `TokenQuotaGuard`,在请求进入业务逻辑前完成额度校验 - ⚡ 新增 Redis 配额快照与封禁键机制:超额用户命中封禁键后可快速拦截,降低重复查库带来的开销 - 🗃️ 新增用户配额 DAO 能力:按需读取 `token_limit`、`token_usage`、`last_reset_at`,并支持基于“到期条件更新”的懒重置 - 🔄 实现 7 天懒重置策略:用户访问时若检测到配额周期已到期,则重置 `token_usage` 并清理封禁状态 - 🚫 新增超额响应码 `40051`,用于标识 `token usage exceeds limit`
This commit is contained in:
@@ -116,6 +116,6 @@ func Start() {
|
||||
AgentHandler: agentApi,
|
||||
}
|
||||
|
||||
r := routers.RegisterRouters(handlers, cacheRepo, limiter)
|
||||
r := routers.RegisterRouters(handlers, cacheRepo, userRepo, limiter)
|
||||
routers.StartEngine(r)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,17 @@ type CacheDAO struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
// UserTokenQuotaSnapshot 是“用户额度判断”的 Redis 快照结构。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 只保留额度判断必要字段,避免把 users 全字段塞进缓存;
|
||||
// 2. 该结构仅用于“快速门禁判断”,权威账本仍以 MySQL 为准。
|
||||
type UserTokenQuotaSnapshot struct {
|
||||
TokenLimit int `json:"token_limit"`
|
||||
TokenUsage int `json:"token_usage"`
|
||||
LastResetAt time.Time `json:"last_reset_at"`
|
||||
}
|
||||
|
||||
func NewCacheDAO(client *redis.Client) *CacheDAO {
|
||||
return &CacheDAO{client: client}
|
||||
}
|
||||
@@ -266,3 +277,79 @@ func (d *CacheDAO) DeleteUserOngoingScheduleFromCache(ctx context.Context, userI
|
||||
key := fmt.Sprintf("smartflow:ongoing_schedule:%d", userID)
|
||||
return d.client.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func userTokenQuotaSnapshotKey(userID int) string {
|
||||
return fmt.Sprintf("smartflow:user_token_quota_snapshot:%d", userID)
|
||||
}
|
||||
|
||||
func userTokenBlockedKey(userID int) string {
|
||||
return fmt.Sprintf("smartflow:user_token_blocked:%d", userID)
|
||||
}
|
||||
|
||||
// GetUserTokenQuotaSnapshot 读取用户 token 配额快照。
|
||||
//
|
||||
// 输入输出语义:
|
||||
// 1. 命中返回 (*UserTokenQuotaSnapshot, true, nil);
|
||||
// 2. 未命中返回 (nil, false, nil);
|
||||
// 3. Redis/反序列化错误返回 (nil, false, err)。
|
||||
func (d *CacheDAO) GetUserTokenQuotaSnapshot(ctx context.Context, userID int) (*UserTokenQuotaSnapshot, bool, error) {
|
||||
key := userTokenQuotaSnapshotKey(userID)
|
||||
val, err := d.client.Get(ctx, key).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
var snapshot UserTokenQuotaSnapshot
|
||||
if err = json.Unmarshal([]byte(val), &snapshot); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return &snapshot, true, nil
|
||||
}
|
||||
|
||||
// SetUserTokenQuotaSnapshot 写入用户 token 配额快照。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只做缓存写入,不做额度判断;
|
||||
// 2. ttl 由上层策略控制,便于按场景调优“性能 vs 一致性”。
|
||||
func (d *CacheDAO) SetUserTokenQuotaSnapshot(ctx context.Context, userID int, snapshot UserTokenQuotaSnapshot, ttl time.Duration) error {
|
||||
key := userTokenQuotaSnapshotKey(userID)
|
||||
data, err := json.Marshal(snapshot)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return d.client.Set(ctx, key, data, ttl).Err()
|
||||
}
|
||||
|
||||
// DeleteUserTokenQuotaSnapshot 删除用户 token 快照缓存。
|
||||
func (d *CacheDAO) DeleteUserTokenQuotaSnapshot(ctx context.Context, userID int) error {
|
||||
return d.client.Del(ctx, userTokenQuotaSnapshotKey(userID)).Err()
|
||||
}
|
||||
|
||||
// IsUserTokenBlocked 检查用户是否被“额度封禁键”命中。
|
||||
func (d *CacheDAO) IsUserTokenBlocked(ctx context.Context, userID int) (bool, error) {
|
||||
result, err := d.client.Get(ctx, userTokenBlockedKey(userID)).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == "1", nil
|
||||
}
|
||||
|
||||
// SetUserTokenBlocked 设置用户“额度封禁键”。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 该键是快速拦截层,不是权威账本;
|
||||
// 2. ttl 建议设置到“下一次重置时间”,到期自动解封。
|
||||
func (d *CacheDAO) SetUserTokenBlocked(ctx context.Context, userID int, ttl time.Duration) error {
|
||||
return d.client.Set(ctx, userTokenBlockedKey(userID), "1", ttl).Err()
|
||||
}
|
||||
|
||||
// DeleteUserTokenBlocked 清理用户“额度封禁键”。
|
||||
func (d *CacheDAO) DeleteUserTokenBlocked(ctx context.Context, userID int) error {
|
||||
return d.client.Del(ctx, userTokenBlockedKey(userID)).Err()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
@@ -85,3 +86,47 @@ func (r *UserDAO) GetUserByID(id int) (*model.User, error) {
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserTokenQuotaByID 查询用户 token 配额快照(仅查询配额相关字段)。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只返回 token_limit / token_usage / last_reset_at 等“额度判断必需字段”;
|
||||
// 2. 不负责做超额判断与重置判断(由中间件统一决策);
|
||||
// 3. 不返回密码等敏感字段,避免把无关信息带入鉴权链路。
|
||||
func (r *UserDAO) GetUserTokenQuotaByID(ctx context.Context, id int) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.WithContext(ctx).
|
||||
Select("id", "token_limit", "token_usage", "last_reset_at").
|
||||
Where("id = ?", id).
|
||||
First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// ResetUserTokenUsageIfDue 在“已到重置窗口”时执行懒重置。
|
||||
//
|
||||
// 输入输出语义:
|
||||
// 1. dueBefore:判定“到期可重置”的截止时间(通常是 now-7d);
|
||||
// 2. resetAt:本次重置写入的时间戳;
|
||||
// 3. 返回值 bool:
|
||||
// - true 表示本次调用实际执行了重置;
|
||||
// - false 表示条件未命中(尚未到期或记录不存在)。
|
||||
//
|
||||
// 并发与幂等说明:
|
||||
// 1. 使用条件更新(WHERE last_reset_at <= dueBefore)保证并发下最多一次成功重置;
|
||||
// 2. 重复调用是安全的,未命中条件时不会破坏现有统计。
|
||||
func (r *UserDAO) ResetUserTokenUsageIfDue(ctx context.Context, id int, dueBefore time.Time, resetAt time.Time) (bool, error) {
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&model.User{}).
|
||||
Where("id = ? AND (last_reset_at IS NULL OR last_reset_at <= ?)", id, dueBefore).
|
||||
Updates(map[string]interface{}{
|
||||
"token_usage": 0,
|
||||
"last_reset_at": resetAt,
|
||||
})
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
}
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
184
backend/middleware/token_quota_guard.go
Normal file
184
backend/middleware/token_quota_guard.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/dao"
|
||||
"github.com/LoveLosita/smartflow/backend/respond"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// userTokenResetInterval 是“用户 token 周期重置”的窗口长度。
|
||||
// 当前按需求设置为 7 天。
|
||||
userTokenResetInterval = 7 * 24 * time.Hour
|
||||
|
||||
// userTokenQuotaSnapshotTTL 是额度快照缓存时长。
|
||||
// 说明:该值越大,读 DB 越少;但“超额生效”的最坏延迟会变长。
|
||||
userTokenQuotaSnapshotTTL = 60 * time.Second
|
||||
|
||||
// minUserTokenBlockTTL 是封禁键的最小 TTL,避免出现 0/负数导致“刚封就失效”。
|
||||
minUserTokenBlockTTL = 30 * time.Second
|
||||
)
|
||||
|
||||
// TokenQuotaGuard 在请求入口做“token 额度门禁 + 懒重置”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责在进入业务 Handler 前判断“该用户是否还能继续消费 token”;
|
||||
// 2. 负责按 7 天窗口执行懒重置(只有访问时才判断是否重置);
|
||||
// 3. 负责维护 Redis 快照与封禁键,降低每次请求都查库的成本;
|
||||
// 4. 不负责 token 累加记账(记账由聊天持久化链路负责)。
|
||||
func TokenQuotaGuard(cache *dao.CacheDAO, userRepo *dao.UserDAO) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 1. 基础依赖判空:
|
||||
// 1.1 若中间件依赖未初始化,直接返回 500,避免出现“无门禁放行”的安全漏洞;
|
||||
// 1.2 这里选择 fail-close(拒绝),因为该中间件是额度治理主入口。
|
||||
if cache == nil || userRepo == nil {
|
||||
c.JSON(http.StatusInternalServerError, respond.InternalError(errors.New("token quota guard dependencies not initialized")))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 从 JWT 中间件上下文获取 user_id:
|
||||
// 2.1 若 user_id 非法,说明鉴权链路异常,直接按未授权拦截;
|
||||
// 2.2 这里不尝试兜底查 token,避免重复实现鉴权逻辑。
|
||||
userID := c.GetInt("user_id")
|
||||
if userID <= 0 {
|
||||
c.JSON(http.StatusUnauthorized, respond.ErrUnauthorized)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
now := time.Now()
|
||||
|
||||
// 3. 快速封禁检查(Redis):
|
||||
// 3.1 命中封禁键直接拒绝,避免每次都查 DB;
|
||||
// 3.2 Redis 查询失败时不立即放行,而是继续走 DB 严格校验,保证安全性。
|
||||
blocked, blockedErr := cache.IsUserTokenBlocked(ctx, userID)
|
||||
if blockedErr != nil {
|
||||
log.Printf("TokenQuotaGuard: 查询封禁键失败 user_id=%d err=%v,回退 DB 校验", userID, blockedErr)
|
||||
} else if blocked {
|
||||
c.JSON(http.StatusBadRequest, respond.TokenUsageExceedsLimit)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 优先尝试走快照快速路径:
|
||||
// 4.1 命中快照且未到重置窗口时,直接用快照判断;
|
||||
// 4.2 快照未命中/已到重置窗口/读取失败,则回源 DB 做权威判断。
|
||||
snapshot, hit, snapshotErr := cache.GetUserTokenQuotaSnapshot(ctx, userID)
|
||||
if snapshotErr != nil {
|
||||
log.Printf("TokenQuotaGuard: 读取额度快照失败 user_id=%d err=%v,回退 DB 校验", userID, snapshotErr)
|
||||
}
|
||||
if hit && snapshot != nil && !isResetDue(snapshot.LastResetAt, now) {
|
||||
if snapshot.TokenUsage > snapshot.TokenLimit {
|
||||
// 4.3 快照判断超额时,顺手写入封禁键,后续请求可 O(1) 拦截。
|
||||
ttl := calcBlockTTL(snapshot.LastResetAt, now)
|
||||
if err := cache.SetUserTokenBlocked(ctx, userID, ttl); err != nil {
|
||||
log.Printf("TokenQuotaGuard: 写入封禁键失败 user_id=%d err=%v", userID, err)
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, respond.TokenUsageExceedsLimit)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 4.4 快照命中且未超额,直接放行,避免本次请求访问 DB。
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 5. 回源 DB(权威判断路径):
|
||||
// 5.1 先读取用户额度字段;
|
||||
// 5.2 若已到重置窗口,执行条件更新懒重置,再回读最新值;
|
||||
// 5.3 最后依据“token_usage > token_limit”判断是否拦截。
|
||||
quota, err := userRepo.GetUserTokenQuotaByID(ctx, userID)
|
||||
if err != nil {
|
||||
log.Printf("TokenQuotaGuard: 查询用户额度失败 user_id=%d err=%v", userID, err)
|
||||
c.JSON(http.StatusInternalServerError, respond.InternalError(err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if isResetDue(quota.LastResetAt, now) {
|
||||
_, resetErr := userRepo.ResetUserTokenUsageIfDue(ctx, userID, now.Add(-userTokenResetInterval), now)
|
||||
if resetErr != nil {
|
||||
log.Printf("TokenQuotaGuard: 懒重置失败 user_id=%d err=%v", userID, resetErr)
|
||||
c.JSON(http.StatusInternalServerError, respond.InternalError(resetErr))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 5.2.1 重置后回读一次最新额度,避免使用旧值继续判断;
|
||||
// 5.2.2 同时主动清理封禁键,防止“已重置仍被封”的残留状态。
|
||||
quota, err = userRepo.GetUserTokenQuotaByID(ctx, userID)
|
||||
if err != nil {
|
||||
log.Printf("TokenQuotaGuard: 重置后回读失败 user_id=%d err=%v", userID, err)
|
||||
c.JSON(http.StatusInternalServerError, respond.InternalError(err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if delErr := cache.DeleteUserTokenBlocked(ctx, userID); delErr != nil {
|
||||
log.Printf("TokenQuotaGuard: 清理封禁键失败 user_id=%d err=%v", userID, delErr)
|
||||
}
|
||||
}
|
||||
|
||||
// 6. 把最新权威值回填快照:
|
||||
// 6.1 回填失败不影响主流程(仅影响性能,不影响正确性);
|
||||
// 6.2 这样后续同用户短时间内请求可直接走快照快速路径。
|
||||
if setErr := cache.SetUserTokenQuotaSnapshot(ctx, userID, dao.UserTokenQuotaSnapshot{
|
||||
TokenLimit: quota.TokenLimit,
|
||||
TokenUsage: quota.TokenUsage,
|
||||
LastResetAt: quota.LastResetAt,
|
||||
}, userTokenQuotaSnapshotTTL); setErr != nil {
|
||||
log.Printf("TokenQuotaGuard: 回填额度快照失败 user_id=%d err=%v", userID, setErr)
|
||||
}
|
||||
|
||||
// 7. 最终判定:
|
||||
// 7.1 按你的规则使用“>”判断超额(等于不拦截);
|
||||
// 7.2 超额时写封禁键并拒绝;未超额则继续放行。
|
||||
if quota.TokenUsage > quota.TokenLimit {
|
||||
ttl := calcBlockTTL(quota.LastResetAt, now)
|
||||
if err = cache.SetUserTokenBlocked(ctx, userID, ttl); err != nil {
|
||||
log.Printf("TokenQuotaGuard: 写入封禁键失败 user_id=%d err=%v", userID, err)
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, respond.TokenUsageExceedsLimit)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// isResetDue 判断“是否到达 7 天懒重置窗口”。
|
||||
//
|
||||
// 说明:
|
||||
// 1. lastResetAt 为零值时,视为到期(首次迁移数据时兜底);
|
||||
// 2. now - lastResetAt >= 7 天 时返回 true。
|
||||
func isResetDue(lastResetAt time.Time, now time.Time) bool {
|
||||
if lastResetAt.IsZero() {
|
||||
return true
|
||||
}
|
||||
return !lastResetAt.Add(userTokenResetInterval).After(now)
|
||||
}
|
||||
|
||||
// calcBlockTTL 计算封禁键 TTL。
|
||||
//
|
||||
// 规则:
|
||||
// 1. 目标是封到“下一次重置时间”;
|
||||
// 2. 若计算结果非正数,回退到最小 TTL,避免封禁键瞬时失效。
|
||||
func calcBlockTTL(lastResetAt time.Time, now time.Time) time.Duration {
|
||||
if lastResetAt.IsZero() {
|
||||
return minUserTokenBlockTTL
|
||||
}
|
||||
nextResetAt := lastResetAt.Add(userTokenResetInterval)
|
||||
ttl := nextResetAt.Sub(now)
|
||||
if ttl <= 0 {
|
||||
return minUserTokenBlockTTL
|
||||
}
|
||||
return ttl
|
||||
}
|
||||
@@ -318,4 +318,9 @@ var ( //请求相关的响应
|
||||
Status: "40050",
|
||||
Info: "wrong task id",
|
||||
}
|
||||
|
||||
TokenUsageExceedsLimit = Response{ //token 使用量超过限额
|
||||
Status: "40051",
|
||||
Info: "token usage exceeds limit",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ func StartEngine(r *gin.Engine) {
|
||||
}
|
||||
}
|
||||
|
||||
func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO, limiter *pkg.RateLimiter) *gin.Engine {
|
||||
func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO, userRepo *dao.UserDAO, limiter *pkg.RateLimiter) *gin.Engine {
|
||||
// 初始化Gin引擎
|
||||
r := gin.Default()
|
||||
// 在这里注册所有的路由和路由组
|
||||
@@ -88,7 +88,7 @@ func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO, limiter *pk
|
||||
agentGroup := apiGroup.Group("/agent")
|
||||
{
|
||||
agentGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1))
|
||||
agentGroup.POST("/chat", handlers.AgentHandler.ChatAgent)
|
||||
agentGroup.POST("/chat", middleware.TokenQuotaGuard(cache, userRepo), handlers.AgentHandler.ChatAgent)
|
||||
agentGroup.GET("/conversation-meta", handlers.AgentHandler.GetConversationMeta)
|
||||
agentGroup.GET("/conversation-list", handlers.AgentHandler.GetConversationList)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user