Version: 0.9.66.dev.260504

后端:
1. 阶段 2 user/auth 服务边界落地,新增 `cmd/userauth` go-zero zrpc 服务、`services/userauth` 核心实现、gateway user API/zrpc client 与 shared contracts/ports,迁移注册、登录、刷新 token、登出、JWT、黑名单和 token 额度治理
2. gateway 与启动装配切流,`cmd/all` 只保留边缘路由、鉴权和轻量组合,通过 userauth zrpc 访问核心用户能力;拆分 MySQL/Redis 初始化与 AutoMigrate 边界,`userauth` 自迁 `users` 和 token 记账幂等表,`all` 不再迁用户表
3. 清退 Gin 单体旧 user/auth DAO、model、service、router、middleware 和 JWT handler,并同步调整 agent/schedule/cache/outbox 相关调用依赖
4. 补齐 refresh token 防并发重放、MySQL 幂等 token 记账、额度 `>=` 拦截和 RPC 错误映射,避免重复记账与内部错误透出

文档:
1. 新增《学习计划论坛与Token商店PRD》
This commit is contained in:
Losita
2026-05-04 15:20:47 +08:00
parent 9902ca3563
commit b08ee17893
58 changed files with 3754 additions and 1510 deletions

View File

@@ -0,0 +1,130 @@
package dao
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/go-redis/redis/v8"
)
// TokenQuotaSnapshot 是 user/auth 服务内部的额度快照缓存结构。
type TokenQuotaSnapshot struct {
TokenLimit int `json:"token_limit"`
TokenUsage int `json:"token_usage"`
LastResetAt time.Time `json:"last_reset_at"`
}
// CacheDAO 只承载 user/auth 领域需要的 Redis 能力。
type CacheDAO struct {
client *redis.Client
}
func NewCacheDAO(client *redis.Client) *CacheDAO {
return &CacheDAO{client: client}
}
func blacklistKey(jti string) string {
return "blacklist:" + jti
}
func sessionBlacklistKey(sessionID string) string {
return "session_blacklist:" + sessionID
}
func userTokenQuotaSnapshotKey(userID int) string {
return fmt.Sprintf("smartflow:user_token_quota_snapshot:%d", userID)
}
func userTokenBlockedKey(userID int) string {
return fmt.Sprintf("smartflow:user_token_blocked:%d", userID)
}
func (d *CacheDAO) SetBlacklist(jti string, expiration time.Duration) error {
return d.client.Set(context.Background(), blacklistKey(jti), "1", expiration).Err()
}
// SetBlacklistIfAbsent 使用 Redis SET NX 原子抢占某个 JTI。
//
// 职责边界:
// 1. 用于 refresh token 轮转时保证旧 refresh 只能被消费一次;
// 2. 返回 ok=false 表示该 JTI 已经被其它请求消费过;
// 3. 不负责解析 JWT也不负责判断 token 类型。
func (d *CacheDAO) SetBlacklistIfAbsent(jti string, expiration time.Duration) (bool, error) {
return d.client.SetNX(context.Background(), blacklistKey(jti), "1", expiration).Result()
}
func (d *CacheDAO) IsBlacklisted(jti string) (bool, error) {
result, err := d.client.Get(context.Background(), blacklistKey(jti)).Result()
if errors.Is(err, redis.Nil) {
return false, nil
}
if err != nil {
return false, err
}
return result == "1", nil
}
func (d *CacheDAO) SetSessionBlacklist(sessionID string, expiration time.Duration) error {
return d.client.Set(context.Background(), sessionBlacklistKey(sessionID), "1", expiration).Err()
}
func (d *CacheDAO) IsSessionBlacklisted(sessionID string) (bool, error) {
result, err := d.client.Get(context.Background(), sessionBlacklistKey(sessionID)).Result()
if errors.Is(err, redis.Nil) {
return false, nil
}
if err != nil {
return false, err
}
return result == "1", nil
}
func (d *CacheDAO) GetUserTokenQuotaSnapshot(ctx context.Context, userID int) (*TokenQuotaSnapshot, bool, error) {
val, err := d.client.Get(ctx, userTokenQuotaSnapshotKey(userID)).Result()
if errors.Is(err, redis.Nil) {
return nil, false, nil
}
if err != nil {
return nil, false, err
}
var snapshot TokenQuotaSnapshot
if err = json.Unmarshal([]byte(val), &snapshot); err != nil {
return nil, false, err
}
return &snapshot, true, nil
}
func (d *CacheDAO) SetUserTokenQuotaSnapshot(ctx context.Context, userID int, snapshot TokenQuotaSnapshot, ttl time.Duration) error {
data, err := json.Marshal(snapshot)
if err != nil {
return err
}
return d.client.Set(ctx, userTokenQuotaSnapshotKey(userID), data, ttl).Err()
}
func (d *CacheDAO) DeleteUserTokenQuotaSnapshot(ctx context.Context, userID int) error {
return d.client.Del(ctx, userTokenQuotaSnapshotKey(userID)).Err()
}
func (d *CacheDAO) IsUserTokenBlocked(ctx context.Context, userID int) (bool, error) {
result, err := d.client.Get(ctx, userTokenBlockedKey(userID)).Result()
if errors.Is(err, redis.Nil) {
return false, nil
}
if err != nil {
return false, err
}
return result == "1", nil
}
func (d *CacheDAO) SetUserTokenBlocked(ctx context.Context, userID int, ttl time.Duration) error {
return d.client.Set(ctx, userTokenBlockedKey(userID), "1", ttl).Err()
}
func (d *CacheDAO) DeleteUserTokenBlocked(ctx context.Context, userID int) error {
return d.client.Del(ctx, userTokenBlockedKey(userID)).Err()
}

View File

@@ -0,0 +1,55 @@
package dao
import (
"context"
"fmt"
userauthmodel "github.com/LoveLosita/smartflow/backend/services/userauth/model"
"github.com/go-redis/redis/v8"
"github.com/spf13/viper"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
// OpenDBFromConfig 创建 user/auth 服务自己的数据库句柄。
//
// 职责边界:
// 1. 只迁移 users 以及 user/auth 自己拥有的辅助表,避免独立 userauth 进程顺手迁移其它服务表;
// 2. 不负责读取业务配置之外的外部依赖,配置来源仍由 bootstrap.LoadConfig 统一注入;
// 3. 返回 *gorm.DB 供服务内 DAO 复用,调用方负责进程生命周期。
func OpenDBFromConfig() (*gorm.DB, error) {
host := viper.GetString("database.host")
port := viper.GetString("database.port")
user := viper.GetString("database.user")
password := viper.GetString("database.password")
dbname := viper.GetString("database.dbname")
dsn := fmt.Sprintf(
"%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
user, password, host, port, dbname,
)
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
if err != nil {
return nil, err
}
if err = db.AutoMigrate(&userauthmodel.User{}, &userauthmodel.TokenUsageAdjustment{}); err != nil {
return nil, fmt.Errorf("auto migrate userauth tables failed: %w", err)
}
return db, nil
}
// OpenRedisFromConfig 创建 user/auth 服务自己的 Redis 句柄。
//
// 失败时返回 error让独立进程入口 fail-fast避免黑名单和额度门禁静默失效。
func OpenRedisFromConfig() (*redis.Client, error) {
client := redis.NewClient(&redis.Options{
Addr: viper.GetString("redis.host") + ":" + viper.GetString("redis.port"),
Password: viper.GetString("redis.password"),
DB: 0,
})
if _, err := client.Ping(context.Background()).Result(); err != nil {
return nil, err
}
return client, nil
}

View File

@@ -0,0 +1,173 @@
package dao
import (
"context"
"errors"
"strings"
"time"
userauthmodel "github.com/LoveLosita/smartflow/backend/services/userauth/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// UserDAO 是 user/auth 服务内部的 users 表访问层。
// 职责边界:只提供注册、登录和额度治理需要的最小读写能力,不暴露整张 users 表给 gateway。
type UserDAO struct {
db *gorm.DB
}
func NewUserDAO(db *gorm.DB) *UserDAO {
return &UserDAO{db: db}
}
// Create 创建新用户并初始化 token 额度字段。
func (r *UserDAO) Create(ctx context.Context, username, phoneNumber, password string) (*userauthmodel.User, error) {
user := &userauthmodel.User{
Username: username,
PhoneNumber: phoneNumber,
Password: password,
TokenLimit: 100000,
TokenUsage: 0,
LastResetAt: time.Now(),
}
if err := r.db.WithContext(ctx).Create(user).Error; err != nil {
return nil, err
}
return user, nil
}
func (r *UserDAO) IfUsernameExists(ctx context.Context, name string) (bool, error) {
err := r.db.WithContext(ctx).Where("username = ?", name).First(&userauthmodel.User{}).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
return true, err
}
return true, nil
}
func (r *UserDAO) GetUserHashedPasswordByName(ctx context.Context, name string) (string, error) {
var user userauthmodel.User
if err := r.db.WithContext(ctx).Where("username = ?", name).First(&user).Error; err != nil {
return "", err
}
return user.Password, nil
}
func (r *UserDAO) GetUserIDByName(ctx context.Context, name string) (int, error) {
var user userauthmodel.User
if err := r.db.WithContext(ctx).Where("username = ?", name).First(&user).Error; err != nil {
return -1, err
}
return int(user.ID), nil
}
// GetUserTokenQuotaByID 只读取额度判断需要的字段。
func (r *UserDAO) GetUserTokenQuotaByID(ctx context.Context, id int) (*userauthmodel.User, error) {
var user userauthmodel.User
err := r.db.WithContext(ctx).
Select("id", "token_limit", "token_usage", "last_reset_at").
Where("id = ?", id).
First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
// ResetUserTokenUsageIfDue 使用条件更新实现幂等懒重置。
func (r *UserDAO) ResetUserTokenUsageIfDue(ctx context.Context, id int, dueBefore time.Time, resetAt time.Time) (bool, error) {
result := r.db.WithContext(ctx).
Model(&userauthmodel.User{}).
Where("id = ? AND (last_reset_at IS NULL OR last_reset_at <= ?)", id, dueBefore).
Updates(map[string]interface{}{
"token_usage": 0,
"last_reset_at": resetAt,
})
if result.Error != nil {
return false, result.Error
}
return result.RowsAffected > 0, nil
}
// AddTokenUsage 为用户 token 账本做增量累加。
// 职责边界:
// 1. 只做数据库累加,不负责额度判断与缓存刷新;
// 2. delta<=0 视为无操作,直接返回成功;
// 3. 由 service 层决定是否需要先做懒重置和后续 cache 回填。
func (r *UserDAO) AddTokenUsage(ctx context.Context, id int, delta int) (bool, error) {
if delta <= 0 {
return true, nil
}
result := r.db.WithContext(ctx).
Model(&userauthmodel.User{}).
Where("id = ?", id).
Update("token_usage", gorm.Expr("token_usage + ?", delta))
if result.Error != nil {
return false, result.Error
}
return result.RowsAffected > 0, nil
}
// AdjustTokenUsageOnce 在同一个 MySQL 事务里完成“幂等占位 + token 用量增量”。
//
// 职责边界:
// 1. eventID 非空时先写入 user_token_usage_adjustments依赖主键冲突判断是否重复事件
// 2. 只有幂等占位写入成功后才更新 users.token_usage保证并发重放不会重复记账
// 3. 不负责 Redis 快照和封禁键维护,这些缓存语义仍由 service 层在事务成功后刷新。
func (r *UserDAO) AdjustTokenUsageOnce(ctx context.Context, eventID string, id int, delta int, dueBefore time.Time, resetAt time.Time) (*userauthmodel.User, bool, error) {
var quota userauthmodel.User
duplicated := false
trimmedEventID := strings.TrimSpace(eventID)
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if trimmedEventID != "" {
marker := userauthmodel.TokenUsageAdjustment{
EventID: trimmedEventID,
UserID: id,
TokenDelta: delta,
}
result := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&marker)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
duplicated = true
return nil
}
}
resetResult := tx.Model(&userauthmodel.User{}).
Where("id = ? AND (last_reset_at IS NULL OR last_reset_at <= ?)", id, dueBefore).
Updates(map[string]interface{}{
"token_usage": 0,
"last_reset_at": resetAt,
})
if resetResult.Error != nil {
return resetResult.Error
}
updateResult := tx.Model(&userauthmodel.User{}).
Where("id = ?", id).
Update("token_usage", gorm.Expr("token_usage + ?", delta))
if updateResult.Error != nil {
return updateResult.Error
}
if updateResult.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return tx.Select("id", "token_limit", "token_usage", "last_reset_at").
Where("id = ?", id).
First(&quota).Error
})
if err != nil {
return nil, false, err
}
if duplicated {
return nil, true, nil
}
return &quota, false, nil
}

View File

@@ -0,0 +1,330 @@
package auth
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/respond"
contracts "github.com/LoveLosita/smartflow/backend/shared/contracts/userauth"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/spf13/viper"
)
const (
accessSecretConfigKey = "jwt.accessSecret"
refreshSecretConfigKey = "jwt.refreshSecret"
accessExpireConfigKey = "jwt.accessTokenExpire"
refreshExpireConfigKey = "jwt.refreshTokenExpire"
defaultAccessTokenExpire = 15 * time.Minute
defaultRefreshTokenExpire = 7 * 24 * time.Hour
)
type BlacklistReader interface {
IsBlacklisted(jti string) (bool, error)
IsSessionBlacklisted(sessionID string) (bool, error)
}
type runtimeConfig struct {
AccessKey []byte
RefreshKey []byte
AccessExpire time.Duration
RefreshExpire time.Duration
}
// Claims 是 user/auth 服务内部使用的 JWT 声明结构。
type Claims struct {
UserID int `json:"user_id"`
SessionID string `json:"sid"`
TokenType string `json:"token_type"`
JTI string `json:"jti"`
jwt.RegisteredClaims // 标准字段包含 exp/iat 等时间声明。
}
func generateJTI() string {
return uuid.New().String()
}
func loadConfig() (*runtimeConfig, error) {
accessKey, err := readSecret(accessSecretConfigKey)
if err != nil {
return nil, err
}
refreshKey, err := readSecret(refreshSecretConfigKey)
if err != nil {
return nil, err
}
accessExpire, err := readExpireDuration(accessExpireConfigKey, defaultAccessTokenExpire)
if err != nil {
return nil, err
}
refreshExpire, err := readExpireDuration(refreshExpireConfigKey, defaultRefreshTokenExpire)
if err != nil {
return nil, err
}
return &runtimeConfig{
AccessKey: accessKey,
RefreshKey: refreshKey,
AccessExpire: accessExpire,
RefreshExpire: refreshExpire,
}, nil
}
func readSecret(configKey string) ([]byte, error) {
secret := strings.TrimSpace(viper.GetString(configKey))
if secret == "" {
return nil, fmt.Errorf("jwt 配置缺失: %s", configKey)
}
return []byte(secret), nil
}
func readExpireDuration(configKey string, fallback time.Duration) (time.Duration, error) {
raw := strings.TrimSpace(viper.GetString(configKey))
if raw == "" {
return fallback, nil
}
d, err := parseFlexibleDuration(raw)
if err != nil {
return 0, fmt.Errorf("jwt 配置项 %s 非法: %w", configKey, err)
}
if d <= 0 {
return 0, fmt.Errorf("jwt 配置项 %s 必须大于 0", configKey)
}
return d, nil
}
// SessionBlacklistTTL 返回 logout 后会话黑名单需要保留的时长。
//
// 职责边界:
// 1. 只负责从配置推导“会话级黑名单”保留多久;
// 2. 不负责写 Redis也不负责判断具体 token 是否过期;
// 3. 取 access / refresh 中更长的有效期,避免旧 access 在 refresh 轮转后把整段会话放掉。
func SessionBlacklistTTL() (time.Duration, error) {
cfg, err := loadConfig()
if err != nil {
return 0, err
}
if cfg.RefreshExpire >= cfg.AccessExpire {
return cfg.RefreshExpire, nil
}
return cfg.AccessExpire, nil
}
// parseFlexibleDuration 兼容 Go 原生时长和项目历史配置中的 7d / 15min。
func parseFlexibleDuration(raw string) (time.Duration, error) {
normalized := strings.ToLower(strings.TrimSpace(raw))
if normalized == "" {
return 0, errors.New("时长不能为空")
}
if d, err := time.ParseDuration(normalized); err == nil {
return d, nil
}
type unitDef struct {
Suffix string
Multiplier time.Duration
}
unitDefs := []unitDef{
{Suffix: "minutes", Multiplier: time.Minute},
{Suffix: "minute", Multiplier: time.Minute},
{Suffix: "mins", Multiplier: time.Minute},
{Suffix: "min", Multiplier: time.Minute},
{Suffix: "days", Multiplier: 24 * time.Hour},
{Suffix: "day", Multiplier: 24 * time.Hour},
{Suffix: "d", Multiplier: 24 * time.Hour},
{Suffix: "hours", Multiplier: time.Hour},
{Suffix: "hour", Multiplier: time.Hour},
{Suffix: "h", Multiplier: time.Hour},
{Suffix: "seconds", Multiplier: time.Second},
{Suffix: "second", Multiplier: time.Second},
{Suffix: "secs", Multiplier: time.Second},
{Suffix: "sec", Multiplier: time.Second},
{Suffix: "m", Multiplier: time.Minute},
{Suffix: "s", Multiplier: time.Second},
}
for _, unit := range unitDefs {
if !strings.HasSuffix(normalized, unit.Suffix) {
continue
}
numberPart := strings.TrimSpace(strings.TrimSuffix(normalized, unit.Suffix))
value, err := strconv.Atoi(numberPart)
if err != nil {
return 0, fmt.Errorf("时长数值非法: %q", numberPart)
}
if value <= 0 {
return 0, fmt.Errorf("时长数值必须大于 0: %d", value)
}
return time.Duration(value) * unit.Multiplier, nil
}
return 0, fmt.Errorf("不支持的时长格式: %s", raw)
}
// GenerateTokens 签发访问令牌与刷新令牌。
func GenerateTokens(userID int) (*contracts.Tokens, error) {
return GenerateTokensWithSession(userID, "")
}
// GenerateTokensWithSession 为同一个登录会话签发一对 access / refresh token。
//
// 职责边界:
// 1. 负责生成新的会话标识,或复用传入的会话标识;
// 2. access / refresh 各自使用独立 JTI避免 refresh 轮转时误伤新 access
// 3. 不负责黑名单写入,黑名单由 logout / refresh 重放防护链路处理。
func GenerateTokensWithSession(userID int, sessionID string) (*contracts.Tokens, error) {
cfg, err := loadConfig()
if err != nil {
return nil, err
}
now := time.Now()
sessionID = strings.TrimSpace(sessionID)
if sessionID == "" {
sessionID = generateJTI()
}
accessJTI := generateJTI()
refreshJTI := generateJTI()
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{
UserID: userID,
SessionID: sessionID,
TokenType: "access_token",
JTI: accessJTI,
RegisteredClaims: jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(cfg.AccessExpire)),
},
})
accessTokenString, err := accessToken.SignedString(cfg.AccessKey)
if err != nil {
return nil, err
}
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{
UserID: userID,
SessionID: sessionID,
TokenType: "refresh_token",
JTI: refreshJTI,
RegisteredClaims: jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(cfg.RefreshExpire)),
},
})
refreshTokenString, err := refreshToken.SignedString(cfg.RefreshKey)
if err != nil {
return nil, err
}
return &contracts.Tokens{
AccessToken: accessTokenString,
RefreshToken: refreshTokenString,
}, nil
}
func parseToken(tokenString string, signingKey []byte) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, respond.InvalidTokenSingingMethod
}
return signingKey, nil
})
if err != nil || !token.Valid {
return nil, respond.InvalidToken
}
claims, ok := token.Claims.(*Claims)
if !ok || claims.ExpiresAt == nil || claims.UserID <= 0 || strings.TrimSpace(claims.JTI) == "" {
return nil, respond.InvalidClaims
}
return claims, nil
}
func ensureAccessActive(cache BlacklistReader, sessionID, jti string) error {
if cache == nil {
return errors.New("token blacklist dependency is not initialized")
}
sessionID = strings.TrimSpace(sessionID)
if sessionID != "" {
isBlack, err := cache.IsSessionBlacklisted(sessionID)
if err != nil {
return errors.New("无法验证令牌状态")
}
if isBlack {
return respond.UserLoggedOut
}
return nil
}
isBlack, err := cache.IsBlacklisted(jti)
if err != nil {
return errors.New("无法验证令牌状态")
}
if isBlack {
return respond.UserLoggedOut
}
return nil
}
func ensureRefreshActive(cache BlacklistReader, sessionID, jti string) error {
if cache == nil {
return errors.New("token blacklist dependency is not initialized")
}
sessionID = strings.TrimSpace(sessionID)
if sessionID != "" {
isBlack, err := cache.IsSessionBlacklisted(sessionID)
if err != nil {
return errors.New("无法验证令牌状态")
}
if isBlack {
return respond.UserLoggedOut
}
}
isBlack, err := cache.IsBlacklisted(jti)
if err != nil {
return errors.New("无法验证令牌状态")
}
if isBlack {
return respond.InvalidRefreshToken
}
return nil
}
// ValidateAccessToken 校验 access token并统一检查黑名单。
func ValidateAccessToken(tokenString string, cache BlacklistReader) (*Claims, error) {
cfg, err := loadConfig()
if err != nil {
return nil, err
}
claims, err := parseToken(tokenString, cfg.AccessKey)
if err != nil {
return nil, err
}
if claims.TokenType != "access_token" {
return nil, respond.WrongTokenType
}
if err = ensureAccessActive(cache, claims.SessionID, claims.JTI); err != nil {
return nil, err
}
return claims, nil
}
// ValidateRefreshToken 校验 refresh token并统一检查黑名单。
func ValidateRefreshToken(tokenString string, cache BlacklistReader) (*Claims, error) {
cfg, err := loadConfig()
if err != nil {
return nil, err
}
claims, err := parseToken(tokenString, cfg.RefreshKey)
if err != nil {
return nil, respond.InvalidRefreshToken
}
if claims.TokenType != "refresh_token" {
return nil, respond.WrongTokenType
}
if err = ensureRefreshActive(cache, claims.SessionID, claims.JTI); err != nil {
return nil, err
}
return claims, nil
}

View File

@@ -0,0 +1,20 @@
package model
import "time"
// TokenUsageAdjustment 是 user/auth 服务内的 token 账本幂等表。
//
// 职责边界:
// 1. 只记录“某个 outbox/event_id 是否已经调整过 users.token_usage”
// 2. 不保存 agent 会话 token_total那个统计仍属于 agent 领域;
// 3. event_id 作为主键,配合 users.token_usage 更新放在同一个 MySQL 事务里,避免并发重放重复记账。
type TokenUsageAdjustment struct {
EventID string `gorm:"column:event_id;type:varchar(64);primaryKey;comment:来源事件ID"`
UserID int `gorm:"column:user_id;not null;index:idx_userauth_token_adjust_user;comment:用户ID"`
TokenDelta int `gorm:"column:token_delta;not null;comment:本次增加的 token 用量"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime;comment:创建时间"`
}
func (TokenUsageAdjustment) TableName() string {
return "user_token_usage_adjustments"
}

View File

@@ -0,0 +1,19 @@
package model
import "time"
// User 是 user/auth 服务内部拥有的 users 表模型。
// 职责边界:只覆盖 user/auth 需要维护的字段,不承载 gateway 或其他领域规则。
type User struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Username string `gorm:"type:varchar(255);not null;unique" json:"username"`
Password string `gorm:"type:varchar(255);not null" json:"-"`
PhoneNumber string `gorm:"type:varchar(255)" json:"phone_number"`
TokenLimit int `gorm:"default:100000" json:"token_limit"`
TokenUsage int `gorm:"default:0" json:"token_usage"`
LastResetAt time.Time `json:"last_reset_at"`
}
func (User) TableName() string {
return "users"
}

View File

@@ -0,0 +1,86 @@
package rpc
import (
"errors"
"log"
"strings"
"github.com/LoveLosita/smartflow/backend/respond"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const userAuthErrorDomain = "smartflow.userauth"
// grpcErrorFromServiceError 负责把 user/auth 内部错误收口成 gRPC status。
//
// 职责边界:
// 1. 只负责把本服务内部的 respond.Response / 普通 error 转成 gRPC 可传输错误;
// 2. 不负责决定 HTTP 语义,也不负责写回前端响应体;
// 3. 上层 handler 只要直接 return 这个结果,就能让 client 侧按 `res, err :=` 的方式接收。
func grpcErrorFromServiceError(err error) error {
if err == nil {
return nil
}
var resp respond.Response
if errors.As(err, &resp) {
return grpcErrorFromResponse(resp)
}
log.Printf("userauth rpc internal error: %v", err)
return status.Error(codes.Internal, "userauth service internal error")
}
// grpcErrorFromResponse 负责把项目内业务响应映射成 gRPC status。
//
// 职责边界:
// 1. 只处理 user/auth 这组响应码到 gRPC code 的映射;
// 2. 业务码和业务文案通过 ErrorInfo 附带,方便 gateway 再反解回 respond.Response
// 3. 失败时退化为普通 gRPC status不阻断请求链路。
func grpcErrorFromResponse(resp respond.Response) error {
code := grpcCodeFromRespondStatus(resp.Status)
message := strings.TrimSpace(resp.Info)
if message == "" {
message = strings.TrimSpace(resp.Status)
}
st := status.New(code, message)
detail := &errdetails.ErrorInfo{
Domain: userAuthErrorDomain,
Reason: resp.Status,
Metadata: map[string]string{
"info": resp.Info,
},
}
withDetails, err := st.WithDetails(detail)
if err != nil {
return st.Err()
}
return withDetails.Err()
}
func grpcCodeFromRespondStatus(statusValue string) codes.Code {
switch strings.TrimSpace(statusValue) {
case respond.InvalidName.Status:
return codes.AlreadyExists
case respond.WrongName.Status:
return codes.NotFound
case respond.WrongPwd.Status, respond.WrongUsernameOrPwd.Status:
return codes.Unauthenticated
case respond.MissingToken.Status, respond.InvalidTokenSingingMethod.Status, respond.InvalidToken.Status,
respond.InvalidClaims.Status, respond.ErrUnauthorized.Status, respond.InvalidRefreshToken.Status,
respond.WrongTokenType.Status, respond.UserLoggedOut.Status:
return codes.Unauthenticated
case respond.TokenUsageExceedsLimit.Status:
return codes.ResourceExhausted
case respond.MissingParam.Status, respond.WrongParamType.Status, respond.ParamTooLong.Status,
respond.WrongGender.Status, respond.WrongUserID.Status:
return codes.InvalidArgument
}
if strings.HasPrefix(strings.TrimSpace(statusValue), "5") {
return codes.Internal
}
return codes.InvalidArgument
}

View File

@@ -0,0 +1,177 @@
package rpc
import (
"context"
"errors"
"time"
"github.com/LoveLosita/smartflow/backend/respond"
"github.com/LoveLosita/smartflow/backend/services/userauth/rpc/pb"
userauthsv "github.com/LoveLosita/smartflow/backend/services/userauth/sv"
contracts "github.com/LoveLosita/smartflow/backend/shared/contracts/userauth"
)
type Handler struct {
pb.UnimplementedUserAuthServer
svc *userauthsv.Service
}
func NewHandler(svc *userauthsv.Service) *Handler {
return &Handler{svc: svc}
}
// Register 负责把 user/auth 的注册请求从 gRPC 协议转成内部服务调用。
//
// 职责边界:
// 1. 只做 transport -> service 的参数搬运,不碰 DAO/Redis/JWT 细节;
// 2. 业务错误统一转成 gRPC status让 client 侧继续使用 `res, err :=`
// 3. 成功时只回传业务数据,不再在 payload 里塞 status/info。
func (h *Handler) Register(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) {
if h == nil || h.svc == nil {
return nil, grpcErrorFromServiceError(errors.New("userauth service dependency not initialized"))
}
if req == nil {
return nil, grpcErrorFromServiceError(respond.MissingParam)
}
resp, err := h.svc.Register(ctx, contracts.RegisterRequest{
Username: req.Username,
Password: req.Password,
PhoneNumber: req.PhoneNumber,
})
if err != nil {
return nil, grpcErrorFromServiceError(err)
}
return &pb.RegisterResponse{Id: uint64(resp.ID)}, nil
}
func (h *Handler) Login(ctx context.Context, req *pb.LoginRequest) (*pb.TokensResponse, error) {
if h == nil || h.svc == nil {
return nil, grpcErrorFromServiceError(errors.New("userauth service dependency not initialized"))
}
if req == nil {
return nil, grpcErrorFromServiceError(respond.MissingParam)
}
resp, err := h.svc.Login(ctx, contracts.LoginRequest{
Username: req.Username,
Password: req.Password,
})
if err != nil {
return nil, grpcErrorFromServiceError(err)
}
return &pb.TokensResponse{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
}, nil
}
func (h *Handler) RefreshToken(ctx context.Context, req *pb.RefreshTokenRequest) (*pb.TokensResponse, error) {
if h == nil || h.svc == nil {
return nil, grpcErrorFromServiceError(errors.New("userauth service dependency not initialized"))
}
if req == nil {
return nil, grpcErrorFromServiceError(respond.MissingParam)
}
resp, err := h.svc.RefreshToken(ctx, contracts.RefreshTokenRequest{
RefreshToken: req.RefreshToken,
})
if err != nil {
return nil, grpcErrorFromServiceError(err)
}
return &pb.TokensResponse{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
}, nil
}
func (h *Handler) Logout(ctx context.Context, req *pb.LogoutRequest) (*pb.StatusResponse, error) {
if h == nil || h.svc == nil {
return nil, grpcErrorFromServiceError(errors.New("userauth service dependency not initialized"))
}
if req == nil {
return nil, grpcErrorFromServiceError(respond.MissingToken)
}
if err := h.svc.LogoutByAccessToken(ctx, req.AccessToken); err != nil {
return nil, grpcErrorFromServiceError(err)
}
return &pb.StatusResponse{}, nil
}
func (h *Handler) ValidateAccessToken(ctx context.Context, req *pb.ValidateAccessTokenRequest) (*pb.ValidateAccessTokenResponse, error) {
if h == nil || h.svc == nil {
return nil, grpcErrorFromServiceError(errors.New("userauth service dependency not initialized"))
}
if req == nil {
return nil, grpcErrorFromServiceError(respond.MissingToken)
}
resp, err := h.svc.ValidateAccessToken(ctx, contracts.ValidateAccessTokenRequest{
AccessToken: req.AccessToken,
})
if err != nil {
return nil, grpcErrorFromServiceError(err)
}
return &pb.ValidateAccessTokenResponse{
Valid: resp.Valid,
UserId: int64(resp.UserID),
TokenType: resp.TokenType,
Jti: resp.JTI,
ExpiresAtUnixNano: timeToUnixNano(resp.ExpiresAt),
}, nil
}
func (h *Handler) CheckTokenQuota(ctx context.Context, req *pb.CheckTokenQuotaRequest) (*pb.CheckTokenQuotaResponse, error) {
if h == nil || h.svc == nil {
return nil, grpcErrorFromServiceError(errors.New("userauth service dependency not initialized"))
}
if req == nil {
return nil, grpcErrorFromServiceError(respond.ErrUnauthorized)
}
resp, err := h.svc.CheckTokenQuota(ctx, contracts.CheckTokenQuotaRequest{
UserID: int(req.UserId),
})
if err != nil {
return nil, grpcErrorFromServiceError(err)
}
return &pb.CheckTokenQuotaResponse{
Allowed: resp.Allowed,
TokenLimit: int64(resp.TokenLimit),
TokenUsage: int64(resp.TokenUsage),
LastResetAtUnixNano: timeToUnixNano(resp.LastResetAt),
}, nil
}
func (h *Handler) AdjustTokenUsage(ctx context.Context, req *pb.AdjustTokenUsageRequest) (*pb.CheckTokenQuotaResponse, error) {
if h == nil || h.svc == nil {
return nil, grpcErrorFromServiceError(errors.New("userauth service dependency not initialized"))
}
if req == nil {
return nil, grpcErrorFromServiceError(respond.MissingParam)
}
resp, err := h.svc.AdjustTokenUsage(ctx, contracts.AdjustTokenUsageRequest{
EventID: req.EventId,
UserID: int(req.UserId),
TokenDelta: int(req.TokenDelta),
})
if err != nil {
return nil, grpcErrorFromServiceError(err)
}
return &pb.CheckTokenQuotaResponse{
Allowed: resp.Allowed,
TokenLimit: int64(resp.TokenLimit),
TokenUsage: int64(resp.TokenUsage),
LastResetAtUnixNano: timeToUnixNano(resp.LastResetAt),
}, nil
}
func timeToUnixNano(value time.Time) int64 {
if value.IsZero() {
return 0
}
return value.UnixNano()
}

View File

@@ -0,0 +1,151 @@
package pb
import proto "github.com/golang/protobuf/proto"
var _ = proto.Marshal
const _ = proto.ProtoPackageIsVersion3
type RegisterRequest struct {
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"`
PhoneNumber string `protobuf:"bytes,3,opt,name=phone_number,json=phoneNumber,proto3" json:"phone_number,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *RegisterRequest) Reset() { *m = RegisterRequest{} }
func (m *RegisterRequest) String() string { return proto.CompactTextString(m) }
func (*RegisterRequest) ProtoMessage() {}
type RegisterResponse struct {
Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *RegisterResponse) Reset() { *m = RegisterResponse{} }
func (m *RegisterResponse) String() string { return proto.CompactTextString(m) }
func (*RegisterResponse) ProtoMessage() {}
type LoginRequest struct {
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *LoginRequest) Reset() { *m = LoginRequest{} }
func (m *LoginRequest) String() string { return proto.CompactTextString(m) }
func (*LoginRequest) ProtoMessage() {}
type TokensResponse struct {
AccessToken string `protobuf:"bytes,1,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"`
RefreshToken string `protobuf:"bytes,2,opt,name=refresh_token,json=refreshToken,proto3" json:"refresh_token,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *TokensResponse) Reset() { *m = TokensResponse{} }
func (m *TokensResponse) String() string { return proto.CompactTextString(m) }
func (*TokensResponse) ProtoMessage() {}
type RefreshTokenRequest struct {
RefreshToken string `protobuf:"bytes,1,opt,name=refresh_token,json=refreshToken,proto3" json:"refresh_token,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *RefreshTokenRequest) Reset() { *m = RefreshTokenRequest{} }
func (m *RefreshTokenRequest) String() string { return proto.CompactTextString(m) }
func (*RefreshTokenRequest) ProtoMessage() {}
type LogoutRequest struct {
AccessToken string `protobuf:"bytes,1,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *LogoutRequest) Reset() { *m = LogoutRequest{} }
func (m *LogoutRequest) String() string { return proto.CompactTextString(m) }
func (*LogoutRequest) ProtoMessage() {}
type StatusResponse struct {
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *StatusResponse) Reset() { *m = StatusResponse{} }
func (m *StatusResponse) String() string { return proto.CompactTextString(m) }
func (*StatusResponse) ProtoMessage() {}
type ValidateAccessTokenRequest struct {
AccessToken string `protobuf:"bytes,1,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ValidateAccessTokenRequest) Reset() { *m = ValidateAccessTokenRequest{} }
func (m *ValidateAccessTokenRequest) String() string { return proto.CompactTextString(m) }
func (*ValidateAccessTokenRequest) ProtoMessage() {}
type ValidateAccessTokenResponse struct {
Valid bool `protobuf:"varint,1,opt,name=valid,proto3" json:"valid,omitempty"`
UserId int64 `protobuf:"varint,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
TokenType string `protobuf:"bytes,3,opt,name=token_type,json=tokenType,proto3" json:"token_type,omitempty"`
Jti string `protobuf:"bytes,4,opt,name=jti,proto3" json:"jti,omitempty"`
ExpiresAtUnixNano int64 `protobuf:"varint,5,opt,name=expires_at_unix_nano,json=expiresAtUnixNano,proto3" json:"expires_at_unix_nano,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ValidateAccessTokenResponse) Reset() { *m = ValidateAccessTokenResponse{} }
func (m *ValidateAccessTokenResponse) String() string { return proto.CompactTextString(m) }
func (*ValidateAccessTokenResponse) ProtoMessage() {}
type CheckTokenQuotaRequest struct {
UserId int64 `protobuf:"varint,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *CheckTokenQuotaRequest) Reset() { *m = CheckTokenQuotaRequest{} }
func (m *CheckTokenQuotaRequest) String() string { return proto.CompactTextString(m) }
func (*CheckTokenQuotaRequest) ProtoMessage() {}
type AdjustTokenUsageRequest struct {
EventId string `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
UserId int64 `protobuf:"varint,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
TokenDelta int64 `protobuf:"varint,3,opt,name=token_delta,json=tokenDelta,proto3" json:"token_delta,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *AdjustTokenUsageRequest) Reset() { *m = AdjustTokenUsageRequest{} }
func (m *AdjustTokenUsageRequest) String() string { return proto.CompactTextString(m) }
func (*AdjustTokenUsageRequest) ProtoMessage() {}
type CheckTokenQuotaResponse struct {
Allowed bool `protobuf:"varint,1,opt,name=allowed,proto3" json:"allowed,omitempty"`
TokenLimit int64 `protobuf:"varint,2,opt,name=token_limit,json=tokenLimit,proto3" json:"token_limit,omitempty"`
TokenUsage int64 `protobuf:"varint,3,opt,name=token_usage,json=tokenUsage,proto3" json:"token_usage,omitempty"`
LastResetAtUnixNano int64 `protobuf:"varint,4,opt,name=last_reset_at_unix_nano,json=lastResetAtUnixNano,proto3" json:"last_reset_at_unix_nano,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *CheckTokenQuotaResponse) Reset() { *m = CheckTokenQuotaResponse{} }
func (m *CheckTokenQuotaResponse) String() string { return proto.CompactTextString(m) }
func (*CheckTokenQuotaResponse) ProtoMessage() {}

View File

@@ -0,0 +1,307 @@
package pb
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
const (
UserAuth_Register_FullMethodName = "/smartflow.userauth.UserAuth/Register"
UserAuth_Login_FullMethodName = "/smartflow.userauth.UserAuth/Login"
UserAuth_RefreshToken_FullMethodName = "/smartflow.userauth.UserAuth/RefreshToken"
UserAuth_Logout_FullMethodName = "/smartflow.userauth.UserAuth/Logout"
UserAuth_ValidateAccessToken_FullMethodName = "/smartflow.userauth.UserAuth/ValidateAccessToken"
UserAuth_CheckTokenQuota_FullMethodName = "/smartflow.userauth.UserAuth/CheckTokenQuota"
UserAuth_AdjustTokenUsage_FullMethodName = "/smartflow.userauth.UserAuth/AdjustTokenUsage"
)
type UserAuthClient interface {
Register(ctx context.Context, in *RegisterRequest, opts ...grpc.CallOption) (*RegisterResponse, error)
Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*TokensResponse, error)
RefreshToken(ctx context.Context, in *RefreshTokenRequest, opts ...grpc.CallOption) (*TokensResponse, error)
Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*StatusResponse, error)
ValidateAccessToken(ctx context.Context, in *ValidateAccessTokenRequest, opts ...grpc.CallOption) (*ValidateAccessTokenResponse, error)
CheckTokenQuota(ctx context.Context, in *CheckTokenQuotaRequest, opts ...grpc.CallOption) (*CheckTokenQuotaResponse, error)
AdjustTokenUsage(ctx context.Context, in *AdjustTokenUsageRequest, opts ...grpc.CallOption) (*CheckTokenQuotaResponse, error)
}
type userAuthClient struct {
cc grpc.ClientConnInterface
}
func NewUserAuthClient(cc grpc.ClientConnInterface) UserAuthClient {
return &userAuthClient{cc}
}
func (c *userAuthClient) Register(ctx context.Context, in *RegisterRequest, opts ...grpc.CallOption) (*RegisterResponse, error) {
out := new(RegisterResponse)
err := c.cc.Invoke(ctx, UserAuth_Register_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *userAuthClient) Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*TokensResponse, error) {
out := new(TokensResponse)
err := c.cc.Invoke(ctx, UserAuth_Login_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *userAuthClient) RefreshToken(ctx context.Context, in *RefreshTokenRequest, opts ...grpc.CallOption) (*TokensResponse, error) {
out := new(TokensResponse)
err := c.cc.Invoke(ctx, UserAuth_RefreshToken_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *userAuthClient) Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*StatusResponse, error) {
out := new(StatusResponse)
err := c.cc.Invoke(ctx, UserAuth_Logout_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *userAuthClient) ValidateAccessToken(ctx context.Context, in *ValidateAccessTokenRequest, opts ...grpc.CallOption) (*ValidateAccessTokenResponse, error) {
out := new(ValidateAccessTokenResponse)
err := c.cc.Invoke(ctx, UserAuth_ValidateAccessToken_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *userAuthClient) CheckTokenQuota(ctx context.Context, in *CheckTokenQuotaRequest, opts ...grpc.CallOption) (*CheckTokenQuotaResponse, error) {
out := new(CheckTokenQuotaResponse)
err := c.cc.Invoke(ctx, UserAuth_CheckTokenQuota_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *userAuthClient) AdjustTokenUsage(ctx context.Context, in *AdjustTokenUsageRequest, opts ...grpc.CallOption) (*CheckTokenQuotaResponse, error) {
out := new(CheckTokenQuotaResponse)
err := c.cc.Invoke(ctx, UserAuth_AdjustTokenUsage_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
type UserAuthServer interface {
Register(context.Context, *RegisterRequest) (*RegisterResponse, error)
Login(context.Context, *LoginRequest) (*TokensResponse, error)
RefreshToken(context.Context, *RefreshTokenRequest) (*TokensResponse, error)
Logout(context.Context, *LogoutRequest) (*StatusResponse, error)
ValidateAccessToken(context.Context, *ValidateAccessTokenRequest) (*ValidateAccessTokenResponse, error)
CheckTokenQuota(context.Context, *CheckTokenQuotaRequest) (*CheckTokenQuotaResponse, error)
AdjustTokenUsage(context.Context, *AdjustTokenUsageRequest) (*CheckTokenQuotaResponse, error)
}
type UnimplementedUserAuthServer struct{}
func (UnimplementedUserAuthServer) Register(context.Context, *RegisterRequest) (*RegisterResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Register not implemented")
}
func (UnimplementedUserAuthServer) Login(context.Context, *LoginRequest) (*TokensResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Login not implemented")
}
func (UnimplementedUserAuthServer) RefreshToken(context.Context, *RefreshTokenRequest) (*TokensResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method RefreshToken not implemented")
}
func (UnimplementedUserAuthServer) Logout(context.Context, *LogoutRequest) (*StatusResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented")
}
func (UnimplementedUserAuthServer) ValidateAccessToken(context.Context, *ValidateAccessTokenRequest) (*ValidateAccessTokenResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ValidateAccessToken not implemented")
}
func (UnimplementedUserAuthServer) CheckTokenQuota(context.Context, *CheckTokenQuotaRequest) (*CheckTokenQuotaResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method CheckTokenQuota not implemented")
}
func (UnimplementedUserAuthServer) AdjustTokenUsage(context.Context, *AdjustTokenUsageRequest) (*CheckTokenQuotaResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method AdjustTokenUsage not implemented")
}
func RegisterUserAuthServer(s grpc.ServiceRegistrar, srv UserAuthServer) {
s.RegisterService(&UserAuth_ServiceDesc, srv)
}
func _UserAuth_Register_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RegisterRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(UserAuthServer).Register(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: UserAuth_Register_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(UserAuthServer).Register(ctx, req.(*RegisterRequest))
}
return interceptor(ctx, in, info, handler)
}
func _UserAuth_Login_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(LoginRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(UserAuthServer).Login(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: UserAuth_Login_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(UserAuthServer).Login(ctx, req.(*LoginRequest))
}
return interceptor(ctx, in, info, handler)
}
func _UserAuth_RefreshToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RefreshTokenRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(UserAuthServer).RefreshToken(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: UserAuth_RefreshToken_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(UserAuthServer).RefreshToken(ctx, req.(*RefreshTokenRequest))
}
return interceptor(ctx, in, info, handler)
}
func _UserAuth_Logout_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(LogoutRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(UserAuthServer).Logout(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: UserAuth_Logout_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(UserAuthServer).Logout(ctx, req.(*LogoutRequest))
}
return interceptor(ctx, in, info, handler)
}
func _UserAuth_ValidateAccessToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ValidateAccessTokenRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(UserAuthServer).ValidateAccessToken(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: UserAuth_ValidateAccessToken_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(UserAuthServer).ValidateAccessToken(ctx, req.(*ValidateAccessTokenRequest))
}
return interceptor(ctx, in, info, handler)
}
func _UserAuth_CheckTokenQuota_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CheckTokenQuotaRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(UserAuthServer).CheckTokenQuota(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: UserAuth_CheckTokenQuota_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(UserAuthServer).CheckTokenQuota(ctx, req.(*CheckTokenQuotaRequest))
}
return interceptor(ctx, in, info, handler)
}
func _UserAuth_AdjustTokenUsage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AdjustTokenUsageRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(UserAuthServer).AdjustTokenUsage(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: UserAuth_AdjustTokenUsage_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(UserAuthServer).AdjustTokenUsage(ctx, req.(*AdjustTokenUsageRequest))
}
return interceptor(ctx, in, info, handler)
}
var UserAuth_ServiceDesc = grpc.ServiceDesc{
ServiceName: "smartflow.userauth.UserAuth",
HandlerType: (*UserAuthServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Register",
Handler: _UserAuth_Register_Handler,
},
{
MethodName: "Login",
Handler: _UserAuth_Login_Handler,
},
{
MethodName: "RefreshToken",
Handler: _UserAuth_RefreshToken_Handler,
},
{
MethodName: "Logout",
Handler: _UserAuth_Logout_Handler,
},
{
MethodName: "ValidateAccessToken",
Handler: _UserAuth_ValidateAccessToken_Handler,
},
{
MethodName: "CheckTokenQuota",
Handler: _UserAuth_CheckTokenQuota_Handler,
},
{
MethodName: "AdjustTokenUsage",
Handler: _UserAuth_AdjustTokenUsage_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "services/userauth/rpc/userauth.proto",
}

View File

@@ -0,0 +1,72 @@
package rpc
import (
"errors"
"log"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/services/userauth/rpc/pb"
userauthsv "github.com/LoveLosita/smartflow/backend/services/userauth/sv"
"github.com/zeromicro/go-zero/core/service"
"github.com/zeromicro/go-zero/zrpc"
"google.golang.org/grpc"
)
const (
defaultListenOn = "0.0.0.0:9081"
defaultTimeout = 2 * time.Second
)
type ServerOptions struct {
ListenOn string
Timeout time.Duration
Service *userauthsv.Service
}
// Start 启动 user/auth zrpc 服务。
//
// 职责边界:
// 1. 只负责装配 gozero zrpc server 和注册 protobuf service
// 2. 不创建 DB/Redis 连接,这些依赖由 cmd/userauth 入口注入;
// 3. 阻塞直到进程收到退出信号,保持一个服务一个独立进程的迁移方向。
func Start(opts ServerOptions) {
server, listenOn, err := NewServer(opts)
if err != nil {
log.Fatalf("failed to build userauth zrpc server: %v", err)
}
defer server.Stop()
log.Printf("userauth zrpc service starting on %s", listenOn)
server.Start()
}
func NewServer(opts ServerOptions) (*zrpc.RpcServer, string, error) {
if opts.Service == nil {
return nil, "", errors.New("userauth service dependency not initialized")
}
listenOn := strings.TrimSpace(opts.ListenOn)
if listenOn == "" {
listenOn = defaultListenOn
}
timeout := opts.Timeout
if timeout <= 0 {
timeout = defaultTimeout
}
server, err := zrpc.NewServer(zrpc.RpcServerConf{
ServiceConf: service.ServiceConf{
Name: "userauth.rpc",
Mode: service.DevMode,
},
ListenOn: listenOn,
Timeout: int64(timeout / time.Millisecond),
}, func(grpcServer *grpc.Server) {
pb.RegisterUserAuthServer(grpcServer, NewHandler(opts.Service))
})
if err != nil {
return nil, "", err
}
return server, listenOn, nil
}

View File

@@ -0,0 +1,75 @@
syntax = "proto3";
package smartflow.userauth;
option go_package = "github.com/LoveLosita/smartflow/backend/services/userauth/rpc/pb";
service UserAuth {
rpc Register(RegisterRequest) returns (RegisterResponse);
rpc Login(LoginRequest) returns (TokensResponse);
rpc RefreshToken(RefreshTokenRequest) returns (TokensResponse);
rpc Logout(LogoutRequest) returns (StatusResponse);
rpc ValidateAccessToken(ValidateAccessTokenRequest) returns (ValidateAccessTokenResponse);
rpc CheckTokenQuota(CheckTokenQuotaRequest) returns (CheckTokenQuotaResponse);
rpc AdjustTokenUsage(AdjustTokenUsageRequest) returns (CheckTokenQuotaResponse);
}
message RegisterRequest {
string username = 1;
string password = 2;
string phone_number = 3;
}
message RegisterResponse {
uint64 id = 1;
}
message LoginRequest {
string username = 1;
string password = 2;
}
message TokensResponse {
string access_token = 1;
string refresh_token = 2;
}
message RefreshTokenRequest {
string refresh_token = 1;
}
message LogoutRequest {
string access_token = 1;
}
message StatusResponse {
}
message ValidateAccessTokenRequest {
string access_token = 1;
}
message ValidateAccessTokenResponse {
bool valid = 1;
int64 user_id = 2;
string token_type = 3;
string jti = 4;
int64 expires_at_unix_nano = 5;
}
message CheckTokenQuotaRequest {
int64 user_id = 1;
}
message AdjustTokenUsageRequest {
string event_id = 1;
int64 user_id = 2;
int64 token_delta = 3;
}
message CheckTokenQuotaResponse {
bool allowed = 1;
int64 token_limit = 2;
int64 token_usage = 3;
int64 last_reset_at_unix_nano = 4;
}

View File

@@ -0,0 +1,192 @@
package sv
import (
"context"
"errors"
"log"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/respond"
userauthdao "github.com/LoveLosita/smartflow/backend/services/userauth/dao"
userauthmodel "github.com/LoveLosita/smartflow/backend/services/userauth/model"
contracts "github.com/LoveLosita/smartflow/backend/shared/contracts/userauth"
)
const (
userTokenResetInterval = 7 * 24 * time.Hour
userTokenQuotaSnapshotTTL = 60 * time.Second
minUserTokenBlockTTL = 30 * time.Second
)
// CheckTokenQuota 是 user/auth 服务内的 token 额度门禁。
//
// 职责边界:
// 1. 判断用户是否还能继续发起高消耗 agent/chat 请求;
// 2. 维护额度周期懒重置、Redis 快照和封禁键;
// 3. 不负责本轮对话完成后的 token 记账,记账由 AdjustTokenUsage 处理。
func (s *Service) CheckTokenQuota(ctx context.Context, req contracts.CheckTokenQuotaRequest) (*contracts.CheckTokenQuotaResponse, error) {
if s == nil || s.userRepo == nil || s.cacheRepo == nil {
return nil, errors.New("userauth quota dependencies not initialized")
}
if req.UserID <= 0 {
return nil, respond.ErrUnauthorized
}
now := time.Now()
// 1. 先查封禁键。封禁键的 TTL 按重置窗口计算,命中时可以避免每次回源 DB。
blocked, blockedErr := s.cacheRepo.IsUserTokenBlocked(ctx, req.UserID)
if blockedErr != nil {
log.Printf("userauth quota: 查询封禁键失败 user_id=%d err=%v回源 DB 校验", req.UserID, blockedErr)
} else if blocked {
return &contracts.CheckTokenQuotaResponse{Allowed: false}, nil
}
// 2. 快照未到重置窗口时直接判断;快照损坏或过期则回源 DB。
snapshot, hit, snapshotErr := s.cacheRepo.GetUserTokenQuotaSnapshot(ctx, req.UserID)
if snapshotErr != nil {
log.Printf("userauth quota: 读取额度快照失败 user_id=%d err=%v回源 DB 校验", req.UserID, snapshotErr)
}
if hit && snapshot != nil && !isResetDue(snapshot.LastResetAt, now) {
if isQuotaExceeded(snapshot.TokenLimit, snapshot.TokenUsage) {
ttl := calcBlockTTL(snapshot.LastResetAt, now)
if err := s.cacheRepo.SetUserTokenBlocked(ctx, req.UserID, ttl); err != nil {
log.Printf("userauth quota: 写入封禁键失败 user_id=%d err=%v", req.UserID, err)
}
return quotaResponse(false, snapshot.TokenLimit, snapshot.TokenUsage, snapshot.LastResetAt), nil
}
return quotaResponse(true, snapshot.TokenLimit, snapshot.TokenUsage, snapshot.LastResetAt), nil
}
// 3. 回源 DB 做权威判断;到 7 天窗口则先懒重置,再回读最新额度。
quota, err := s.userRepo.GetUserTokenQuotaByID(ctx, req.UserID)
if err != nil {
return nil, err
}
if isResetDue(quota.LastResetAt, now) {
if _, err = s.userRepo.ResetUserTokenUsageIfDue(ctx, req.UserID, now.Add(-userTokenResetInterval), now); err != nil {
return nil, err
}
quota, err = s.userRepo.GetUserTokenQuotaByID(ctx, req.UserID)
if err != nil {
return nil, err
}
if delErr := s.cacheRepo.DeleteUserTokenBlocked(ctx, req.UserID); delErr != nil {
log.Printf("userauth quota: 清理封禁键失败 user_id=%d err=%v", req.UserID, delErr)
}
}
return s.cacheQuotaAndBuildResponse(ctx, req.UserID, quota, now, "quota")
}
// AdjustTokenUsage 在 user/auth 服务内回写用户 token 账本。
//
// 职责边界:
// 1. 只负责 users.token_usage 的增量调整与 quota 缓存刷新;
// 2. 不负责 agent 会话 token_total调用方仍需在各自领域内维护会话统计
// 3. event_id 非空时通过 MySQL 幂等表和 users 更新同事务提交,避免 outbox 重试或并发重放重复记账。
func (s *Service) AdjustTokenUsage(ctx context.Context, req contracts.AdjustTokenUsageRequest) (*contracts.CheckTokenQuotaResponse, error) {
if s == nil || s.userRepo == nil || s.cacheRepo == nil {
return nil, errors.New("userauth adjust dependencies not initialized")
}
if req.UserID <= 0 || req.TokenDelta <= 0 {
return nil, respond.MissingParam
}
now := time.Now()
eventID := strings.TrimSpace(req.EventID)
var currentQuota *userauthmodel.User
var err error
if eventID != "" {
var duplicated bool
currentQuota, duplicated, err = s.userRepo.AdjustTokenUsageOnce(ctx, eventID, req.UserID, req.TokenDelta, now.Add(-userTokenResetInterval), now)
if err != nil {
return nil, err
}
if duplicated {
return s.CheckTokenQuota(ctx, contracts.CheckTokenQuotaRequest{UserID: req.UserID})
}
} else {
currentQuota, err = s.userRepo.GetUserTokenQuotaByID(ctx, req.UserID)
if err != nil {
return nil, err
}
if isResetDue(currentQuota.LastResetAt, now) {
if _, err = s.userRepo.ResetUserTokenUsageIfDue(ctx, req.UserID, now.Add(-userTokenResetInterval), now); err != nil {
return nil, err
}
}
if _, err = s.userRepo.AddTokenUsage(ctx, req.UserID, req.TokenDelta); err != nil {
return nil, err
}
currentQuota, err = s.userRepo.GetUserTokenQuotaByID(ctx, req.UserID)
if err != nil {
return nil, err
}
}
return s.cacheQuotaAndBuildResponse(ctx, req.UserID, currentQuota, now, "adjust")
}
func (s *Service) cacheQuotaAndBuildResponse(ctx context.Context, userID int, quota *userauthmodel.User, now time.Time, source string) (*contracts.CheckTokenQuotaResponse, error) {
if quota == nil {
return nil, errors.New("userauth quota is nil")
}
snapshot := userauthdao.TokenQuotaSnapshot{
TokenLimit: quota.TokenLimit,
TokenUsage: quota.TokenUsage,
LastResetAt: quota.LastResetAt,
}
if setErr := s.cacheRepo.SetUserTokenQuotaSnapshot(ctx, userID, snapshot, userTokenQuotaSnapshotTTL); setErr != nil {
log.Printf("userauth %s: 回填额度快照失败 user_id=%d err=%v", source, userID, setErr)
if delErr := s.cacheRepo.DeleteUserTokenQuotaSnapshot(ctx, userID); delErr != nil {
log.Printf("userauth %s: 清理失效额度快照失败 user_id=%d err=%v", source, userID, delErr)
}
}
if isQuotaExceeded(quota.TokenLimit, quota.TokenUsage) {
ttl := calcBlockTTL(quota.LastResetAt, now)
if err := s.cacheRepo.SetUserTokenBlocked(ctx, userID, ttl); err != nil {
log.Printf("userauth %s: 写入封禁标记失败 user_id=%d err=%v", source, userID, err)
}
return quotaResponse(false, quota.TokenLimit, quota.TokenUsage, quota.LastResetAt), nil
}
if delErr := s.cacheRepo.DeleteUserTokenBlocked(ctx, userID); delErr != nil {
log.Printf("userauth %s: 清理封禁标记失败 user_id=%d err=%v", source, userID, delErr)
}
return quotaResponse(true, quota.TokenLimit, quota.TokenUsage, quota.LastResetAt), nil
}
func quotaResponse(allowed bool, tokenLimit int, tokenUsage int, lastResetAt time.Time) *contracts.CheckTokenQuotaResponse {
return &contracts.CheckTokenQuotaResponse{
Allowed: allowed,
TokenLimit: tokenLimit,
TokenUsage: tokenUsage,
LastResetAt: lastResetAt,
}
}
func isQuotaExceeded(tokenLimit int, tokenUsage int) bool {
return tokenUsage >= tokenLimit
}
func isResetDue(lastResetAt time.Time, now time.Time) bool {
if lastResetAt.IsZero() {
return true
}
return !lastResetAt.Add(userTokenResetInterval).After(now)
}
func calcBlockTTL(lastResetAt time.Time, now time.Time) time.Duration {
if lastResetAt.IsZero() {
return minUserTokenBlockTTL
}
ttl := lastResetAt.Add(userTokenResetInterval).Sub(now)
if ttl <= 0 {
return minUserTokenBlockTTL
}
return ttl
}

View File

@@ -0,0 +1,176 @@
package sv
import (
"context"
"errors"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/respond"
userauthdao "github.com/LoveLosita/smartflow/backend/services/userauth/dao"
userauthauth "github.com/LoveLosita/smartflow/backend/services/userauth/internal/auth"
userauthmodel "github.com/LoveLosita/smartflow/backend/services/userauth/model"
contracts "github.com/LoveLosita/smartflow/backend/shared/contracts/userauth"
"github.com/LoveLosita/smartflow/backend/utils"
"gorm.io/gorm"
)
type UserRepo interface {
Create(ctx context.Context, username, phoneNumber, password string) (*userauthmodel.User, error)
IfUsernameExists(ctx context.Context, name string) (bool, error)
GetUserHashedPasswordByName(ctx context.Context, name string) (string, error)
GetUserIDByName(ctx context.Context, name string) (int, error)
GetUserTokenQuotaByID(ctx context.Context, id int) (*userauthmodel.User, error)
ResetUserTokenUsageIfDue(ctx context.Context, id int, dueBefore time.Time, resetAt time.Time) (bool, error)
AddTokenUsage(ctx context.Context, id int, delta int) (bool, error)
AdjustTokenUsageOnce(ctx context.Context, eventID string, id int, delta int, dueBefore time.Time, resetAt time.Time) (*userauthmodel.User, bool, error)
}
type CacheRepo interface {
IsBlacklisted(jti string) (bool, error)
SetBlacklist(jti string, expiration time.Duration) error
SetBlacklistIfAbsent(jti string, expiration time.Duration) (bool, error)
IsSessionBlacklisted(sessionID string) (bool, error)
SetSessionBlacklist(sessionID string, expiration time.Duration) error
IsUserTokenBlocked(ctx context.Context, userID int) (bool, error)
GetUserTokenQuotaSnapshot(ctx context.Context, userID int) (*userauthdao.TokenQuotaSnapshot, bool, error)
SetUserTokenQuotaSnapshot(ctx context.Context, userID int, snapshot userauthdao.TokenQuotaSnapshot, ttl time.Duration) error
DeleteUserTokenQuotaSnapshot(ctx context.Context, userID int) error
SetUserTokenBlocked(ctx context.Context, userID int, ttl time.Duration) error
DeleteUserTokenBlocked(ctx context.Context, userID int) error
}
// Service 承载 user/auth 服务内部业务规则。
//
// 职责边界:
// 1. 负责注册、登录、刷新、登出、JWT 签发/校验、黑名单和 token 额度门禁;
// 2. 不负责 Gin gateway 的响应适配、路由聚合和 SSE 等边缘职责;
// 3. 不负责 agent 会话 token 统计,迁移期该链路仍由 agent 持久化事件触发 userauth 账本调整。
type Service struct {
userRepo UserRepo
cacheRepo CacheRepo
}
func New(userRepo UserRepo, cacheRepo CacheRepo) *Service {
return &Service{
userRepo: userRepo,
cacheRepo: cacheRepo,
}
}
func (s *Service) Register(ctx context.Context, req contracts.RegisterRequest) (*contracts.RegisterResponse, error) {
if strings.TrimSpace(req.Username) == "" || strings.TrimSpace(req.Password) == "" || strings.TrimSpace(req.PhoneNumber) == "" {
return nil, respond.MissingParam
}
if len(req.Username) > 45 || len(req.Password) > 229 || len(req.PhoneNumber) > 18 {
return nil, respond.ParamTooLong
}
exists, err := s.userRepo.IfUsernameExists(ctx, req.Username)
if err != nil {
return nil, err
}
if exists {
return nil, respond.InvalidName
}
hashedPwd, err := utils.HashPassword(req.Password)
if err != nil {
return nil, err
}
newUser, err := s.userRepo.Create(ctx, req.Username, req.PhoneNumber, hashedPwd)
if err != nil {
return nil, err
}
return &contracts.RegisterResponse{ID: newUser.ID}, nil
}
func (s *Service) Login(ctx context.Context, req contracts.LoginRequest) (*contracts.Tokens, error) {
hashedPwd, err := s.userRepo.GetUserHashedPasswordByName(ctx, req.Username)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, respond.WrongName
}
return nil, err
}
matched, err := utils.CompareHashPwdAndPwd(hashedPwd, req.Password)
if err != nil {
return nil, err
}
if !matched {
return nil, respond.WrongPwd
}
userID, err := s.userRepo.GetUserIDByName(ctx, req.Username)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, respond.WrongName
}
return nil, err
}
return userauthauth.GenerateTokens(userID)
}
func (s *Service) RefreshToken(ctx context.Context, req contracts.RefreshTokenRequest) (*contracts.Tokens, error) {
if strings.TrimSpace(req.RefreshToken) == "" {
return nil, respond.MissingParam
}
claims, err := userauthauth.ValidateRefreshToken(req.RefreshToken, s.cacheRepo)
if err != nil {
return nil, err
}
ttl := time.Until(claims.ExpiresAt.Time)
if ttl <= 0 {
return nil, respond.InvalidRefreshToken
}
// 1. 先用 SET NX 抢占旧 refresh 的 JTI确保并发刷新时只有一个请求能继续签发新 token。
// 2. 这里只黑掉旧 refresh不黑掉整个 session避免误伤同一会话下新签发的 access token。
consumed, err := s.cacheRepo.SetBlacklistIfAbsent(claims.JTI, ttl)
if err != nil {
return nil, err
}
if !consumed {
return nil, respond.InvalidRefreshToken
}
return userauthauth.GenerateTokensWithSession(claims.UserID, claims.SessionID)
}
func (s *Service) LogoutByAccessToken(ctx context.Context, accessToken string) error {
if strings.TrimSpace(accessToken) == "" {
return respond.MissingToken
}
claims, err := userauthauth.ValidateAccessToken(accessToken, s.cacheRepo)
if err != nil {
return err
}
// 1. logout 的目标是整段会话,而不是单个 access token。
// 2. 先按会话维度拉黑,再让 access / refresh 各自的 validate 流程拒绝后续请求。
if strings.TrimSpace(claims.SessionID) == "" {
return s.cacheRepo.SetBlacklist(claims.JTI, time.Until(claims.ExpiresAt.Time))
}
sessionTTL, err := userauthauth.SessionBlacklistTTL()
if err != nil {
return err
}
return s.cacheRepo.SetSessionBlacklist(claims.SessionID, sessionTTL)
}
func (s *Service) ValidateAccessToken(ctx context.Context, req contracts.ValidateAccessTokenRequest) (*contracts.ValidateAccessTokenResponse, error) {
if strings.TrimSpace(req.AccessToken) == "" {
return nil, respond.MissingToken
}
claims, err := userauthauth.ValidateAccessToken(req.AccessToken, s.cacheRepo)
if err != nil {
return nil, err
}
return &contracts.ValidateAccessTokenResponse{
Valid: true,
UserID: claims.UserID,
TokenType: claims.TokenType,
JTI: claims.JTI,
ExpiresAt: claims.ExpiresAt.Time,
}, nil
}