Version:0.0.2.dev.260203
feat: implement redis-based logout and jwt middleware 🚀 feat: 实现基于 Redis 的登出机制与 JWT 中间件 🚀 Middleware Construction: Implemented JWTTokenAuth middleware for Gin, featuring structured claims parsing and active session validation. 🛡️ 中间件构建:为 Gin 框架实现了 JWTTokenAuth 中间件,支持结构化 Claims 解析与活跃会话验证。🛡️ Redis Integration: Introduced Redis for high-performance state management. Integrated CacheDAO into the Dependency Injection (DI) chain. ⚡ Redis 引入:引入 Redis 进行高性能状态管理。将 CacheDAO 成功集成至依赖注入 (DI) 调用链中。⚡ Secure Logout Module: Developed the logout functional module using a Redis Blacklist mechanism. 🔐 安全登出模块:开发了基于 Redis 黑名单 机制的登出功能模块。🔐 Marked invalidated tokens by storing jti (JWT ID) in Redis with automatic TTL expiration. 通过在 Redis 中存储 jti(JWT 唯一标识)并设置自动 TTL 过期,实现 Token 的主动失效。 Added blacklist checkpoints in both AuthMiddleware and RefreshToken logic to prevent session resurrection. 在认证中间件与 Token 刷新逻辑中同步增设黑名单检查点,杜绝登出后的“死灰复燃”。 Architecture Refinement: Upgraded ValidateRefreshToken and Service-layer handlers to use type-safe struct assertions instead of raw MapClaims. 🏗️ 架构精进:升级了 ValidateRefreshToken 与 Service 层处理器,改用类型安全的结构体断言取代原始的 MapClaims,提升了代码健壮性。🏗️
This commit is contained in:
@@ -6,10 +6,10 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
"github.com/LoveLosita/smartflow/backend/respond"
|
||||
"github.com/LoveLosita/smartflow/backend/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/smartflow/backend/model"
|
||||
"github.com/smartflow/backend/respond"
|
||||
"github.com/smartflow/backend/service"
|
||||
)
|
||||
|
||||
type UserHandler struct {
|
||||
@@ -86,7 +86,7 @@ func (api *UserHandler) RefreshTokenHandler(c *gin.Context) {
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, respond.InvalidRefreshToken), errors.Is(err, respond.InvalidClaims),
|
||||
errors.Is(err, respond.InvalidTokenSingingMethod): //如果是无效刷新令牌或者无效claims或者无效签名方法
|
||||
errors.Is(err, respond.InvalidTokenSingingMethod), errors.Is(err, respond.UserLoggedOut): //如果是无效刷新令牌或者无效claims或者无效签名方法
|
||||
c.JSON(http.StatusBadRequest, err)
|
||||
return
|
||||
default:
|
||||
@@ -95,3 +95,16 @@ func (api *UserHandler) RefreshTokenHandler(c *gin.Context) {
|
||||
}
|
||||
c.JSON(http.StatusOK, respond.OKWithData(respond.Ok, tokens))
|
||||
}
|
||||
|
||||
func (api *UserHandler) UserLogout(c *gin.Context) {
|
||||
//1.从上下文中获取 jti 和 expireTime
|
||||
claims, _ := c.Get("claims")
|
||||
cl := claims.(*model.MyCustomClaims)
|
||||
//2.调用 Service 层的 UserLogout 方法
|
||||
err := api.svc.UserLogout(cl.Jti, cl.ExpiresAt.Time)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, respond.InternalError(err))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, respond.Ok)
|
||||
}
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/smartflow/backend/respond"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var RefreshKey = []byte(viper.GetString("jwt.accessSecret")) // 用于签名和验证刷新Token的密钥
|
||||
var AccessKey = []byte(viper.GetString("jwt.refreshSecret")) // 用于签名和验证访问Token的密钥
|
||||
|
||||
// GenerateTokens 生成访问令牌和刷新令牌
|
||||
func GenerateTokens(userID int) (string, string, error) {
|
||||
// 创建访问令牌
|
||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": userID, // 获取用户ID
|
||||
"exp": time.Now().Add(15 * time.Minute).Unix(), // 设置访问令牌过期时间为 15 分钟
|
||||
"token_type": "access_token", // 令牌类型为访问令牌
|
||||
})
|
||||
|
||||
// 使用密钥签名访问令牌
|
||||
accessTokenString, err := accessToken.SignedString(AccessKey)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// 创建刷新令牌
|
||||
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": userID, // 获取用户ID
|
||||
"exp": time.Now().Add(7 * 24 * time.Hour).Unix(), // 设置刷新令牌过期时间为 7 天
|
||||
"token_type": "refresh_token", // 令牌类型为刷新令牌
|
||||
})
|
||||
|
||||
// 使用密钥签名刷新令牌
|
||||
refreshTokenString, err := refreshToken.SignedString(RefreshKey)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return accessTokenString, refreshTokenString, nil
|
||||
}
|
||||
|
||||
func ValidateRefreshToken(tokenString string) (*jwt.Token, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// 检查签名方法是否为 HMAC
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, respond.InvalidTokenSingingMethod
|
||||
}
|
||||
// 返回用于验证的密钥
|
||||
return RefreshKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 进一步检查载荷中 token_type 是否正确
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, respond.InvalidClaims
|
||||
}
|
||||
// 检查 token_type 是否是 refresh_token
|
||||
if claimType, ok := claims["token_type"].(string); !ok || claimType != "refresh_token" {
|
||||
return nil, respond.WrongTokenType
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
121
backend/auth/jwt_handler.go
Normal file
121
backend/auth/jwt_handler.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/dao"
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
"github.com/LoveLosita/smartflow/backend/respond"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var RefreshKey = []byte(viper.GetString("jwt.refreshSecret")) // 用于签名和验证刷新Token的密钥
|
||||
var AccessKey = []byte(viper.GetString("jwt.accessSecret")) // 用于签名和验证访问Token的密钥
|
||||
|
||||
// generateJTI 生成唯一的 JWT ID
|
||||
func generateJTI() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// GenerateTokens 生成访问令牌和刷新令牌
|
||||
func GenerateTokens(userID int) (string, string, error) {
|
||||
// 创建访问令牌
|
||||
sid := generateJTI()
|
||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": userID, // 获取用户ID
|
||||
"exp": time.Now().Add(15 * time.Minute).Unix(), // 设置访问令牌过期时间为 15 分钟
|
||||
"token_type": "access_token", // 令牌类型为访问令牌
|
||||
"jti": sid, // 亲子共用的 JWT ID
|
||||
})
|
||||
|
||||
// 使用密钥签名访问令牌
|
||||
accessTokenString, err := accessToken.SignedString(AccessKey)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// 创建刷新令牌
|
||||
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": userID, // 获取用户ID
|
||||
"exp": time.Now().Add(7 * 24 * time.Hour).Unix(), // 设置刷新令牌过期时间为 7 天
|
||||
"token_type": "refresh_token", // 令牌类型为刷新令牌
|
||||
"jti": sid, // 亲子共用的 JWT ID
|
||||
})
|
||||
|
||||
// 使用密钥签名刷新令牌
|
||||
refreshTokenString, err := refreshToken.SignedString(RefreshKey)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return accessTokenString, refreshTokenString, nil
|
||||
}
|
||||
|
||||
// ValidateRefreshToken 验证刷新令牌的有效性
|
||||
/*func ValidateRefreshToken(tokenString string) (*jwt.Token, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// 检查签名方法是否为 HMAC
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, respond.InvalidTokenSingingMethod
|
||||
}
|
||||
// 返回用于验证的密钥
|
||||
return RefreshKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 进一步检查载荷中 token_type 是否正确
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, respond.InvalidClaims
|
||||
}
|
||||
// 检查 token_type 是否是 refresh_token
|
||||
if claimType, ok := claims["token_type"].(string); !ok || claimType != "refresh_token" {
|
||||
return nil, respond.WrongTokenType
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
*/
|
||||
|
||||
// ValidateRefreshToken 验证刷新令牌的有效性,并增加 Redis 黑名单检查
|
||||
func ValidateRefreshToken(tokenString string, cache *dao.CacheDAO) (*jwt.Token, error) {
|
||||
// 1. 解析 Token 并直接绑定到你的自定义结构体
|
||||
token, err := jwt.ParseWithClaims(tokenString, &model.MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, respond.InvalidTokenSingingMethod
|
||||
}
|
||||
return RefreshKey, nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 断言获取 Claims
|
||||
claims, ok := token.Claims.(*model.MyCustomClaims)
|
||||
if !ok {
|
||||
return nil, respond.InvalidClaims
|
||||
}
|
||||
|
||||
// 3. 核心“设卡”:检查 token_type 是否是 refresh_token
|
||||
if claims.TokenType != "refresh_token" {
|
||||
return nil, respond.WrongTokenType
|
||||
}
|
||||
|
||||
// 4. --- 🛡️ 终极关卡:检查 Redis 黑名单 ---
|
||||
// 即使签名没过期,如果 jti 在黑名单里(用户已登出),也视为无效
|
||||
isBlack, err := cache.IsBlacklisted(claims.Jti)
|
||||
if err != nil {
|
||||
// Redis 出错时的处理逻辑,建议报错以防“漏网之鱼”
|
||||
return nil, respond.InternalError(errors.New("无法验证令牌状态"))
|
||||
}
|
||||
if isBlack {
|
||||
return nil, respond.UserLoggedOut // 返回你定义的“用户已登出”错误
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/smartflow/backend/api"
|
||||
"github.com/smartflow/backend/dao"
|
||||
"github.com/smartflow/backend/inits"
|
||||
"github.com/smartflow/backend/routers"
|
||||
"github.com/smartflow/backend/service"
|
||||
"github.com/LoveLosita/smartflow/backend/api"
|
||||
"github.com/LoveLosita/smartflow/backend/dao"
|
||||
"github.com/LoveLosita/smartflow/backend/inits"
|
||||
"github.com/LoveLosita/smartflow/backend/routers"
|
||||
"github.com/LoveLosita/smartflow/backend/service"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
@@ -38,12 +38,15 @@ func Start() {
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect to database: %v", err)
|
||||
}
|
||||
rdb := inits.InitRedis()
|
||||
|
||||
userRepo := dao.NewUserDAO(db)
|
||||
userService := service.NewUserService(userRepo)
|
||||
cacheRepo := dao.NewCacheDAO(rdb)
|
||||
userService := service.NewUserService(userRepo, cacheRepo)
|
||||
userApi := api.NewUserHandler(userService)
|
||||
handlers := &api.ApiHandlers{
|
||||
UserHandler: userApi,
|
||||
}
|
||||
r := routers.RegisterRouters(handlers)
|
||||
r := routers.RegisterRouters(handlers, cacheRepo)
|
||||
routers.StartEngine(r)
|
||||
}
|
||||
|
||||
@@ -29,5 +29,5 @@ log:
|
||||
redis:
|
||||
host: localhost
|
||||
port: 6379
|
||||
password: "redis_password_789"
|
||||
password: ""
|
||||
db: 0
|
||||
|
||||
33
backend/dao/cache.go
Normal file
33
backend/dao/cache.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
type CacheDAO struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
func NewCacheDAO(client *redis.Client) *CacheDAO {
|
||||
return &CacheDAO{client: client}
|
||||
}
|
||||
|
||||
// SetBlacklist 把 Token 扔进黑名单
|
||||
func (dao *CacheDAO) SetBlacklist(jti string, expiration time.Duration) error {
|
||||
return dao.client.Set(context.Background(), "blacklist:"+jti, "1", expiration).Err()
|
||||
}
|
||||
|
||||
// IsBlacklisted 检查 Token 是否在黑名单中
|
||||
func (dao *CacheDAO) IsBlacklisted(jti string) (bool, error) {
|
||||
result, err := dao.client.Get(context.Background(), "blacklist:"+jti).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return false, nil // 不在黑名单
|
||||
} else if err != nil {
|
||||
return false, err // 其他错误
|
||||
}
|
||||
return result == "1", nil // 在黑名单
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/smartflow/backend/model"
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
module github.com/smartflow/backend
|
||||
module github.com/LoveLosita/smartflow/backend
|
||||
|
||||
go 1.23.4
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.11.0
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/spf13/viper v1.21.0
|
||||
golang.org/x/crypto v0.40.0
|
||||
gorm.io/driver/mysql v1.6.0
|
||||
@@ -15,7 +17,9 @@ require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/bytedance/sonic v1.14.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
|
||||
@@ -4,11 +4,15 @@ github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQ
|
||||
github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA=
|
||||
github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA=
|
||||
github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
@@ -27,6 +31,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4=
|
||||
github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||
@@ -40,6 +46,8 @@ github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
@@ -60,6 +68,12 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
||||
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
@@ -123,6 +137,10 @@ google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXn
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
22
backend/inits/redis.go
Normal file
22
backend/inits/redis.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package inits
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func InitRedis() *redis.Client {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: viper.GetString("redis.host") + ":" + viper.GetString("redis.port"),
|
||||
Password: viper.GetString("redis.password"),
|
||||
DB: 0,
|
||||
})
|
||||
// 检查连接是否通畅
|
||||
if _, err := rdb.Ping(context.Background()).Result(); err != nil {
|
||||
log.Fatalf("Redis 连接失败: %v", err)
|
||||
}
|
||||
return rdb
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
package main
|
||||
|
||||
import "github.com/smartflow/backend/cmd"
|
||||
import "github.com/LoveLosita/smartflow/backend/cmd"
|
||||
|
||||
func main() {
|
||||
cmd.Start()
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
// Package middleware 中间件层
|
||||
// 包含所有HTTP请求中间件
|
||||
package middleware
|
||||
72
backend/middleware/token_handler.go
Normal file
72
backend/middleware/token_handler.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/auth"
|
||||
"github.com/LoveLosita/smartflow/backend/dao"
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
"github.com/LoveLosita/smartflow/backend/respond"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
// JWTTokenAuth 接收 cache 实例,体现依赖注入
|
||||
func JWTTokenAuth(cache *dao.CacheDAO) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 1. 获取 Token (Gin 的 GetHeader 直接返回 string)
|
||||
tokenString := c.GetHeader("Authorization")
|
||||
if tokenString == "" {
|
||||
c.JSON(http.StatusUnauthorized, respond.MissingToken)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 改动:使用 ParseWithClaims 直接解析到你的结构体
|
||||
// 假设你的结构体叫 model.MyCustomClaims
|
||||
token, err := jwt.ParseWithClaims(tokenString, &model.MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return auth.AccessKey, nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
c.JSON(http.StatusUnauthorized, respond.InvalidToken)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 校验 Claims
|
||||
claims, ok := token.Claims.(*model.MyCustomClaims)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, respond.InvalidClaims)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// --- 🛡️ 核心改造:设卡检查 ---
|
||||
if claims.TokenType != "access_token" {
|
||||
c.JSON(http.StatusUnauthorized, respond.WrongTokenType)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 拿着 jti 去 Redis 查一下
|
||||
isBlack, err := cache.IsBlacklisted(claims.Jti)
|
||||
if err != nil {
|
||||
// 如果 Redis 挂了,为了安全通常选择报错,或者降级放行(取决于你的业务)
|
||||
c.JSON(http.StatusInternalServerError, respond.InternalError(errors.New("无法验证令牌状态")))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if isBlack {
|
||||
c.JSON(http.StatusUnauthorized, respond.UserLoggedOut)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 存入上下文
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("claims", claims)
|
||||
c.Next() // 只有所有关卡都过了,才放行
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,15 @@
|
||||
package model
|
||||
|
||||
import "github.com/golang-jwt/jwt/v4"
|
||||
|
||||
type Tokens struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
type MyCustomClaims struct {
|
||||
UserID int `json:"user_id"`
|
||||
TokenType string `json:"token_type"`
|
||||
Jti string `json:"jti"`
|
||||
jwt.RegisteredClaims // 包含 ExpiresAt, IssuedAt 等标准字段
|
||||
}
|
||||
|
||||
@@ -118,4 +118,9 @@ var ( //请求相关的响应
|
||||
Status: "40016",
|
||||
Info: "wrong token type",
|
||||
}
|
||||
|
||||
UserLoggedOut = Response{ //用户已登出
|
||||
Status: "40017",
|
||||
Info: "user logged out",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -5,8 +5,10 @@ package routers
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/api"
|
||||
"github.com/LoveLosita/smartflow/backend/dao"
|
||||
"github.com/LoveLosita/smartflow/backend/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/smartflow/backend/api"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
@@ -25,7 +27,7 @@ func StartEngine(r *gin.Engine) {
|
||||
}
|
||||
}
|
||||
|
||||
func RegisterRouters(handlers *api.ApiHandlers) *gin.Engine {
|
||||
func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO) *gin.Engine {
|
||||
// 初始化Gin引擎
|
||||
r := gin.Default()
|
||||
// 在这里注册所有的路由和路由组
|
||||
@@ -44,6 +46,7 @@ func RegisterRouters(handlers *api.ApiHandlers) *gin.Engine {
|
||||
userGroup.POST("/register", handlers.UserHandler.UserRegister)
|
||||
userGroup.POST("/login", handlers.UserHandler.UserLogin)
|
||||
userGroup.POST("/refresh-token", handlers.UserHandler.RefreshTokenHandler)
|
||||
userGroup.POST("/logout", middleware.JWTTokenAuth(cache), handlers.UserHandler.UserLogout)
|
||||
}
|
||||
}
|
||||
// 初始化Gin引擎
|
||||
|
||||
@@ -4,25 +4,25 @@ package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/smartflow/backend/auth"
|
||||
"github.com/smartflow/backend/dao"
|
||||
"github.com/smartflow/backend/model"
|
||||
"github.com/smartflow/backend/respond"
|
||||
"github.com/smartflow/backend/utils"
|
||||
"github.com/LoveLosita/smartflow/backend/auth"
|
||||
"github.com/LoveLosita/smartflow/backend/dao"
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
"github.com/LoveLosita/smartflow/backend/respond"
|
||||
"github.com/LoveLosita/smartflow/backend/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserService struct {
|
||||
// 伸出手:准备接住 DAO
|
||||
repo *dao.UserDAO
|
||||
userRepo *dao.UserDAO
|
||||
cacheRepo *dao.CacheDAO
|
||||
}
|
||||
|
||||
// NewUserService:组装 Service 的“工厂”
|
||||
func NewUserService(repo *dao.UserDAO) *UserService {
|
||||
func NewUserService(userRepo *dao.UserDAO, cacheRepo *dao.CacheDAO) *UserService {
|
||||
return &UserService{
|
||||
repo: repo, // 把传进来的 DAO 揣进口袋里
|
||||
userRepo: userRepo, // 把传进来的 DAO 揣进口袋里
|
||||
cacheRepo: cacheRepo,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ func (sv *UserService) UserRegister(user model.UserRegisterRequest) (*model.User
|
||||
return nil, respond.ParamTooLong
|
||||
}
|
||||
//检查用户名是否已存在
|
||||
result, err := sv.repo.IfUsernameExists(user.Username)
|
||||
result, err := sv.userRepo.IfUsernameExists(user.Username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func (sv *UserService) UserRegister(user model.UserRegisterRequest) (*model.User
|
||||
return nil, err
|
||||
}
|
||||
user.Password = hashedPwd //将user的密码字段改为加密后的密码
|
||||
newUser, err := sv.repo.Create(user.Username, user.PhoneNumber, user.Password)
|
||||
newUser, err := sv.userRepo.Create(user.Username, user.PhoneNumber, user.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func (sv *UserService) UserRegister(user model.UserRegisterRequest) (*model.User
|
||||
|
||||
func (sv *UserService) UserLogin(req *model.UserLoginRequest) (*model.Tokens, error) {
|
||||
var tokens model.Tokens
|
||||
hashedPwd, err := sv.repo.GetUserHashedPasswordByName(req.Username) //调用dao层的方法
|
||||
hashedPwd, err := sv.userRepo.GetUserHashedPasswordByName(req.Username) //调用dao层的方法
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, respond.WrongName
|
||||
@@ -72,7 +72,7 @@ func (sv *UserService) UserLogin(req *model.UserLoginRequest) (*model.Tokens, er
|
||||
} else if !result { //密码不匹配
|
||||
return nil, respond.WrongPwd
|
||||
}
|
||||
id, err := sv.repo.GetUserIDByName(req.Username)
|
||||
id, err := sv.userRepo.GetUserIDByName(req.Username)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, respond.WrongName
|
||||
@@ -86,9 +86,9 @@ func (sv *UserService) UserLogin(req *model.UserLoginRequest) (*model.Tokens, er
|
||||
return &tokens, nil
|
||||
}
|
||||
|
||||
func (sv *UserService) RefreshTokenHandler(refreshToken string) (*model.Tokens, error) {
|
||||
/*func (sv *UserService) RefreshTokenHandler(refreshToken string) (*model.Tokens, error) {
|
||||
// 验证刷新令牌
|
||||
token, err := auth.ValidateRefreshToken(refreshToken)
|
||||
token, err := auth.ValidateRefreshToken(refreshToken, sv.cacheRepo)
|
||||
if err != nil || !token.Valid { // 刷新令牌无效
|
||||
return nil, respond.InvalidRefreshToken
|
||||
}
|
||||
@@ -106,4 +106,38 @@ func (sv *UserService) RefreshTokenHandler(refreshToken string) (*model.Tokens,
|
||||
} else {
|
||||
return nil, respond.InvalidClaims
|
||||
}
|
||||
}*/
|
||||
|
||||
func (sv *UserService) RefreshTokenHandler(refreshToken string) (*model.Tokens, error) {
|
||||
// 1. 验证刷新令牌 (这里已经包含了 Redis 黑名单检查)
|
||||
token, err := auth.ValidateRefreshToken(refreshToken, sv.cacheRepo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 改动点:直接断言为你定义的结构体 model.MyCustomClaims
|
||||
if claims, ok := token.Claims.(*model.MyCustomClaims); ok {
|
||||
// 3. 这里的 userID 已经是 int 了,不再需要 (float64) 转换
|
||||
newAccessToken, newRefreshToken, err := auth.GenerateTokens(claims.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 返回新的双 Token
|
||||
return &model.Tokens{
|
||||
AccessToken: newAccessToken,
|
||||
RefreshToken: newRefreshToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, respond.InvalidClaims
|
||||
}
|
||||
|
||||
func (sv *UserService) UserLogout(jti string, expireTime time.Time) error {
|
||||
//1.直接把 jti 扔进黑名单
|
||||
expiration := time.Until(expireTime)
|
||||
err := sv.cacheRepo.SetBlacklist(jti, expiration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user