diff --git a/backend/api/user.go b/backend/api/user.go index f6cc959..52f1c2f 100644 --- a/backend/api/user.go +++ b/backend/api/user.go @@ -6,10 +6,10 @@ import ( "errors" "net/http" + "github.com/LoveLosita/smartflow/backend/model" + "github.com/LoveLosita/smartflow/backend/respond" + "github.com/LoveLosita/smartflow/backend/service" "github.com/gin-gonic/gin" - "github.com/smartflow/backend/model" - "github.com/smartflow/backend/respond" - "github.com/smartflow/backend/service" ) type UserHandler struct { @@ -86,7 +86,7 @@ func (api *UserHandler) RefreshTokenHandler(c *gin.Context) { if err != nil { switch { case errors.Is(err, respond.InvalidRefreshToken), errors.Is(err, respond.InvalidClaims), - errors.Is(err, respond.InvalidTokenSingingMethod): //如果是无效刷新令牌或者无效claims或者无效签名方法 + errors.Is(err, respond.InvalidTokenSingingMethod), errors.Is(err, respond.UserLoggedOut): //如果是无效刷新令牌或者无效claims或者无效签名方法 c.JSON(http.StatusBadRequest, err) return default: @@ -95,3 +95,16 @@ func (api *UserHandler) RefreshTokenHandler(c *gin.Context) { } c.JSON(http.StatusOK, respond.OKWithData(respond.Ok, tokens)) } + +func (api *UserHandler) UserLogout(c *gin.Context) { + //1.从上下文中获取 jti 和 expireTime + claims, _ := c.Get("claims") + cl := claims.(*model.MyCustomClaims) + //2.调用 Service 层的 UserLogout 方法 + err := api.svc.UserLogout(cl.Jti, cl.ExpiresAt.Time) + if err != nil { + c.JSON(http.StatusInternalServerError, respond.InternalError(err)) + return + } + c.JSON(http.StatusOK, respond.Ok) +} diff --git a/backend/auth/jwt_generater.go b/backend/auth/jwt_generater.go deleted file mode 100644 index 369f5a0..0000000 --- a/backend/auth/jwt_generater.go +++ /dev/null @@ -1,68 +0,0 @@ -package auth - -import ( - "time" - - "github.com/golang-jwt/jwt/v4" - "github.com/smartflow/backend/respond" - "github.com/spf13/viper" -) - -var RefreshKey = []byte(viper.GetString("jwt.accessSecret")) // 用于签名和验证刷新Token的密钥 -var AccessKey = []byte(viper.GetString("jwt.refreshSecret")) // 用于签名和验证访问Token的密钥 - -// GenerateTokens 生成访问令牌和刷新令牌 -func GenerateTokens(userID int) (string, string, error) { - // 创建访问令牌 - accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "user_id": userID, // 获取用户ID - "exp": time.Now().Add(15 * time.Minute).Unix(), // 设置访问令牌过期时间为 15 分钟 - "token_type": "access_token", // 令牌类型为访问令牌 - }) - - // 使用密钥签名访问令牌 - accessTokenString, err := accessToken.SignedString(AccessKey) - if err != nil { - return "", "", err - } - - // 创建刷新令牌 - refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "user_id": userID, // 获取用户ID - "exp": time.Now().Add(7 * 24 * time.Hour).Unix(), // 设置刷新令牌过期时间为 7 天 - "token_type": "refresh_token", // 令牌类型为刷新令牌 - }) - - // 使用密钥签名刷新令牌 - refreshTokenString, err := refreshToken.SignedString(RefreshKey) - if err != nil { - return "", "", err - } - - return accessTokenString, refreshTokenString, nil -} - -func ValidateRefreshToken(tokenString string) (*jwt.Token, error) { - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - // 检查签名方法是否为 HMAC - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, respond.InvalidTokenSingingMethod - } - // 返回用于验证的密钥 - return RefreshKey, nil - }) - if err != nil { - return nil, err - } - - // 进一步检查载荷中 token_type 是否正确 - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - return nil, respond.InvalidClaims - } - // 检查 token_type 是否是 refresh_token - if claimType, ok := claims["token_type"].(string); !ok || claimType != "refresh_token" { - return nil, respond.WrongTokenType - } - return token, nil -} diff --git a/backend/auth/jwt_handler.go b/backend/auth/jwt_handler.go new file mode 100644 index 0000000..f74535f --- /dev/null +++ b/backend/auth/jwt_handler.go @@ -0,0 +1,121 @@ +package auth + +import ( + "errors" + "time" + + "github.com/LoveLosita/smartflow/backend/dao" + "github.com/LoveLosita/smartflow/backend/model" + "github.com/LoveLosita/smartflow/backend/respond" + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "github.com/spf13/viper" +) + +var RefreshKey = []byte(viper.GetString("jwt.refreshSecret")) // 用于签名和验证刷新Token的密钥 +var AccessKey = []byte(viper.GetString("jwt.accessSecret")) // 用于签名和验证访问Token的密钥 + +// generateJTI 生成唯一的 JWT ID +func generateJTI() string { + return uuid.New().String() +} + +// GenerateTokens 生成访问令牌和刷新令牌 +func GenerateTokens(userID int) (string, string, error) { + // 创建访问令牌 + sid := generateJTI() + accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "user_id": userID, // 获取用户ID + "exp": time.Now().Add(15 * time.Minute).Unix(), // 设置访问令牌过期时间为 15 分钟 + "token_type": "access_token", // 令牌类型为访问令牌 + "jti": sid, // 亲子共用的 JWT ID + }) + + // 使用密钥签名访问令牌 + accessTokenString, err := accessToken.SignedString(AccessKey) + if err != nil { + return "", "", err + } + + // 创建刷新令牌 + refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "user_id": userID, // 获取用户ID + "exp": time.Now().Add(7 * 24 * time.Hour).Unix(), // 设置刷新令牌过期时间为 7 天 + "token_type": "refresh_token", // 令牌类型为刷新令牌 + "jti": sid, // 亲子共用的 JWT ID + }) + + // 使用密钥签名刷新令牌 + refreshTokenString, err := refreshToken.SignedString(RefreshKey) + if err != nil { + return "", "", err + } + + return accessTokenString, refreshTokenString, nil +} + +// ValidateRefreshToken 验证刷新令牌的有效性 +/*func ValidateRefreshToken(tokenString string) (*jwt.Token, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // 检查签名方法是否为 HMAC + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, respond.InvalidTokenSingingMethod + } + // 返回用于验证的密钥 + return RefreshKey, nil + }) + if err != nil { + return nil, err + } + + // 进一步检查载荷中 token_type 是否正确 + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, respond.InvalidClaims + } + // 检查 token_type 是否是 refresh_token + if claimType, ok := claims["token_type"].(string); !ok || claimType != "refresh_token" { + return nil, respond.WrongTokenType + } + return token, nil +} +*/ + +// ValidateRefreshToken 验证刷新令牌的有效性,并增加 Redis 黑名单检查 +func ValidateRefreshToken(tokenString string, cache *dao.CacheDAO) (*jwt.Token, error) { + // 1. 解析 Token 并直接绑定到你的自定义结构体 + token, err := jwt.ParseWithClaims(tokenString, &model.MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, respond.InvalidTokenSingingMethod + } + return RefreshKey, nil + }) + + if err != nil || !token.Valid { + return nil, err + } + + // 2. 断言获取 Claims + claims, ok := token.Claims.(*model.MyCustomClaims) + if !ok { + return nil, respond.InvalidClaims + } + + // 3. 核心“设卡”:检查 token_type 是否是 refresh_token + if claims.TokenType != "refresh_token" { + return nil, respond.WrongTokenType + } + + // 4. --- 🛡️ 终极关卡:检查 Redis 黑名单 --- + // 即使签名没过期,如果 jti 在黑名单里(用户已登出),也视为无效 + isBlack, err := cache.IsBlacklisted(claims.Jti) + if err != nil { + // Redis 出错时的处理逻辑,建议报错以防“漏网之鱼” + return nil, respond.InternalError(errors.New("无法验证令牌状态")) + } + if isBlack { + return nil, respond.UserLoggedOut // 返回你定义的“用户已登出”错误 + } + + return token, nil +} diff --git a/backend/cmd/start.go b/backend/cmd/start.go index bbf3886..7a0d454 100644 --- a/backend/cmd/start.go +++ b/backend/cmd/start.go @@ -4,11 +4,11 @@ import ( "fmt" "log" - "github.com/smartflow/backend/api" - "github.com/smartflow/backend/dao" - "github.com/smartflow/backend/inits" - "github.com/smartflow/backend/routers" - "github.com/smartflow/backend/service" + "github.com/LoveLosita/smartflow/backend/api" + "github.com/LoveLosita/smartflow/backend/dao" + "github.com/LoveLosita/smartflow/backend/inits" + "github.com/LoveLosita/smartflow/backend/routers" + "github.com/LoveLosita/smartflow/backend/service" "github.com/spf13/viper" ) @@ -38,12 +38,15 @@ func Start() { if err != nil { log.Fatalf("Failed to connect to database: %v", err) } + rdb := inits.InitRedis() + userRepo := dao.NewUserDAO(db) - userService := service.NewUserService(userRepo) + cacheRepo := dao.NewCacheDAO(rdb) + userService := service.NewUserService(userRepo, cacheRepo) userApi := api.NewUserHandler(userService) handlers := &api.ApiHandlers{ UserHandler: userApi, } - r := routers.RegisterRouters(handlers) + r := routers.RegisterRouters(handlers, cacheRepo) routers.StartEngine(r) } diff --git a/backend/config.yaml b/backend/config.yaml index 1f64801..f1e6b7b 100644 --- a/backend/config.yaml +++ b/backend/config.yaml @@ -29,5 +29,5 @@ log: redis: host: localhost port: 6379 - password: "redis_password_789" + password: "" db: 0 diff --git a/backend/dao/cache.go b/backend/dao/cache.go new file mode 100644 index 0000000..6ce0f27 --- /dev/null +++ b/backend/dao/cache.go @@ -0,0 +1,33 @@ +package dao + +import ( + "context" + "errors" + "time" + + "github.com/go-redis/redis/v8" +) + +type CacheDAO struct { + client *redis.Client +} + +func NewCacheDAO(client *redis.Client) *CacheDAO { + return &CacheDAO{client: client} +} + +// SetBlacklist 把 Token 扔进黑名单 +func (dao *CacheDAO) SetBlacklist(jti string, expiration time.Duration) error { + return dao.client.Set(context.Background(), "blacklist:"+jti, "1", expiration).Err() +} + +// IsBlacklisted 检查 Token 是否在黑名单中 +func (dao *CacheDAO) IsBlacklisted(jti string) (bool, error) { + result, err := dao.client.Get(context.Background(), "blacklist:"+jti).Result() + if errors.Is(err, redis.Nil) { + return false, nil // 不在黑名单 + } else if err != nil { + return false, err // 其他错误 + } + return result == "1", nil // 在黑名单 +} diff --git a/backend/dao/user.go b/backend/dao/user.go index 5a7c2a6..2bbe96e 100644 --- a/backend/dao/user.go +++ b/backend/dao/user.go @@ -4,7 +4,7 @@ import ( "errors" "time" - "github.com/smartflow/backend/model" + "github.com/LoveLosita/smartflow/backend/model" "gorm.io/gorm" ) diff --git a/backend/go.mod b/backend/go.mod index 4a486c6..8d34135 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -1,10 +1,12 @@ -module github.com/smartflow/backend +module github.com/LoveLosita/smartflow/backend go 1.23.4 require ( github.com/gin-gonic/gin v1.11.0 + github.com/go-redis/redis/v8 v8.11.5 github.com/golang-jwt/jwt/v4 v4.5.2 + github.com/google/uuid v1.6.0 github.com/spf13/viper v1.21.0 golang.org/x/crypto v0.40.0 gorm.io/driver/mysql v1.6.0 @@ -15,7 +17,9 @@ require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/bytedance/sonic v1.14.0 // indirect github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/gin-contrib/sse v1.1.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index b815518..f551107 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -4,11 +4,15 @@ github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQ github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= @@ -27,6 +31,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4= github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= @@ -40,6 +46,8 @@ github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= @@ -60,6 +68,12 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= +github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -123,6 +137,10 @@ google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXn gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/backend/inits/db.go b/backend/inits/mysql.go similarity index 100% rename from backend/inits/db.go rename to backend/inits/mysql.go diff --git a/backend/inits/redis.go b/backend/inits/redis.go new file mode 100644 index 0000000..46a4c99 --- /dev/null +++ b/backend/inits/redis.go @@ -0,0 +1,22 @@ +package inits + +import ( + "context" + "log" + + "github.com/go-redis/redis/v8" + "github.com/spf13/viper" +) + +func InitRedis() *redis.Client { + rdb := redis.NewClient(&redis.Options{ + Addr: viper.GetString("redis.host") + ":" + viper.GetString("redis.port"), + Password: viper.GetString("redis.password"), + DB: 0, + }) + // 检查连接是否通畅 + if _, err := rdb.Ping(context.Background()).Result(); err != nil { + log.Fatalf("Redis 连接失败: %v", err) + } + return rdb +} diff --git a/backend/main.go b/backend/main.go index 93c8628..04ffd29 100644 --- a/backend/main.go +++ b/backend/main.go @@ -1,6 +1,6 @@ package main -import "github.com/smartflow/backend/cmd" +import "github.com/LoveLosita/smartflow/backend/cmd" func main() { cmd.Start() diff --git a/backend/middleware/middleware.go b/backend/middleware/middleware.go deleted file mode 100644 index e64f4c2..0000000 --- a/backend/middleware/middleware.go +++ /dev/null @@ -1,3 +0,0 @@ -// Package middleware 中间件层 -// 包含所有HTTP请求中间件 -package middleware diff --git a/backend/middleware/token_handler.go b/backend/middleware/token_handler.go new file mode 100644 index 0000000..be89b70 --- /dev/null +++ b/backend/middleware/token_handler.go @@ -0,0 +1,72 @@ +package middleware + +import ( + "errors" + "net/http" + + "github.com/LoveLosita/smartflow/backend/auth" + "github.com/LoveLosita/smartflow/backend/dao" + "github.com/LoveLosita/smartflow/backend/model" + "github.com/LoveLosita/smartflow/backend/respond" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v4" +) + +// JWTTokenAuth 接收 cache 实例,体现依赖注入 +func JWTTokenAuth(cache *dao.CacheDAO) gin.HandlerFunc { + return func(c *gin.Context) { + // 1. 获取 Token (Gin 的 GetHeader 直接返回 string) + tokenString := c.GetHeader("Authorization") + if tokenString == "" { + c.JSON(http.StatusUnauthorized, respond.MissingToken) + c.Abort() + return + } + + // 2. 改动:使用 ParseWithClaims 直接解析到你的结构体 + // 假设你的结构体叫 model.MyCustomClaims + token, err := jwt.ParseWithClaims(tokenString, &model.MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { + return auth.AccessKey, nil + }) + + if err != nil || !token.Valid { + c.JSON(http.StatusUnauthorized, respond.InvalidToken) + c.Abort() + return + } + + // 3. 校验 Claims + claims, ok := token.Claims.(*model.MyCustomClaims) + if !ok { + c.JSON(http.StatusUnauthorized, respond.InvalidClaims) + c.Abort() + return + } + // --- 🛡️ 核心改造:设卡检查 --- + if claims.TokenType != "access_token" { + c.JSON(http.StatusUnauthorized, respond.WrongTokenType) + c.Abort() + return + } + + // 拿着 jti 去 Redis 查一下 + isBlack, err := cache.IsBlacklisted(claims.Jti) + if err != nil { + // 如果 Redis 挂了,为了安全通常选择报错,或者降级放行(取决于你的业务) + c.JSON(http.StatusInternalServerError, respond.InternalError(errors.New("无法验证令牌状态"))) + c.Abort() + return + } + if isBlack { + c.JSON(http.StatusUnauthorized, respond.UserLoggedOut) + c.Abort() + return + } + + // 4. 存入上下文 + c.Set("user_id", claims.UserID) + c.Set("claims", claims) + c.Next() // 只有所有关卡都过了,才放行 + } +} diff --git a/backend/model/auth.go b/backend/model/auth.go index ece0de0..5ada3bf 100644 --- a/backend/model/auth.go +++ b/backend/model/auth.go @@ -1,6 +1,15 @@ package model +import "github.com/golang-jwt/jwt/v4" + type Tokens struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` } + +type MyCustomClaims struct { + UserID int `json:"user_id"` + TokenType string `json:"token_type"` + Jti string `json:"jti"` + jwt.RegisteredClaims // 包含 ExpiresAt, IssuedAt 等标准字段 +} diff --git a/backend/respond/respond.go b/backend/respond/respond.go index cd638a8..00d556c 100644 --- a/backend/respond/respond.go +++ b/backend/respond/respond.go @@ -118,4 +118,9 @@ var ( //请求相关的响应 Status: "40016", Info: "wrong token type", } + + UserLoggedOut = Response{ //用户已登出 + Status: "40017", + Info: "user logged out", + } ) diff --git a/backend/routers/routers.go b/backend/routers/routers.go index 818b1ed..99a7ae9 100644 --- a/backend/routers/routers.go +++ b/backend/routers/routers.go @@ -5,8 +5,10 @@ package routers import ( "log" + "github.com/LoveLosita/smartflow/backend/api" + "github.com/LoveLosita/smartflow/backend/dao" + "github.com/LoveLosita/smartflow/backend/middleware" "github.com/gin-gonic/gin" - "github.com/smartflow/backend/api" "github.com/spf13/viper" ) @@ -25,7 +27,7 @@ func StartEngine(r *gin.Engine) { } } -func RegisterRouters(handlers *api.ApiHandlers) *gin.Engine { +func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO) *gin.Engine { // 初始化Gin引擎 r := gin.Default() // 在这里注册所有的路由和路由组 @@ -44,6 +46,7 @@ func RegisterRouters(handlers *api.ApiHandlers) *gin.Engine { userGroup.POST("/register", handlers.UserHandler.UserRegister) userGroup.POST("/login", handlers.UserHandler.UserLogin) userGroup.POST("/refresh-token", handlers.UserHandler.RefreshTokenHandler) + userGroup.POST("/logout", middleware.JWTTokenAuth(cache), handlers.UserHandler.UserLogout) } } // 初始化Gin引擎 diff --git a/backend/service/user.go b/backend/service/user.go index 2aa3736..858e4c4 100644 --- a/backend/service/user.go +++ b/backend/service/user.go @@ -4,25 +4,25 @@ package service import ( "errors" + "time" - "github.com/golang-jwt/jwt/v4" - "github.com/smartflow/backend/auth" - "github.com/smartflow/backend/dao" - "github.com/smartflow/backend/model" - "github.com/smartflow/backend/respond" - "github.com/smartflow/backend/utils" + "github.com/LoveLosita/smartflow/backend/auth" + "github.com/LoveLosita/smartflow/backend/dao" + "github.com/LoveLosita/smartflow/backend/model" + "github.com/LoveLosita/smartflow/backend/respond" + "github.com/LoveLosita/smartflow/backend/utils" "gorm.io/gorm" ) type UserService struct { - // 伸出手:准备接住 DAO - repo *dao.UserDAO + userRepo *dao.UserDAO + cacheRepo *dao.CacheDAO } -// NewUserService:组装 Service 的“工厂” -func NewUserService(repo *dao.UserDAO) *UserService { +func NewUserService(userRepo *dao.UserDAO, cacheRepo *dao.CacheDAO) *UserService { return &UserService{ - repo: repo, // 把传进来的 DAO 揣进口袋里 + userRepo: userRepo, // 把传进来的 DAO 揣进口袋里 + cacheRepo: cacheRepo, } } @@ -37,7 +37,7 @@ func (sv *UserService) UserRegister(user model.UserRegisterRequest) (*model.User return nil, respond.ParamTooLong } //检查用户名是否已存在 - result, err := sv.repo.IfUsernameExists(user.Username) + result, err := sv.userRepo.IfUsernameExists(user.Username) if err != nil { return nil, err } @@ -49,7 +49,7 @@ func (sv *UserService) UserRegister(user model.UserRegisterRequest) (*model.User return nil, err } user.Password = hashedPwd //将user的密码字段改为加密后的密码 - newUser, err := sv.repo.Create(user.Username, user.PhoneNumber, user.Password) + newUser, err := sv.userRepo.Create(user.Username, user.PhoneNumber, user.Password) if err != nil { return nil, err } @@ -59,7 +59,7 @@ func (sv *UserService) UserRegister(user model.UserRegisterRequest) (*model.User func (sv *UserService) UserLogin(req *model.UserLoginRequest) (*model.Tokens, error) { var tokens model.Tokens - hashedPwd, err := sv.repo.GetUserHashedPasswordByName(req.Username) //调用dao层的方法 + hashedPwd, err := sv.userRepo.GetUserHashedPasswordByName(req.Username) //调用dao层的方法 if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, respond.WrongName @@ -72,7 +72,7 @@ func (sv *UserService) UserLogin(req *model.UserLoginRequest) (*model.Tokens, er } else if !result { //密码不匹配 return nil, respond.WrongPwd } - id, err := sv.repo.GetUserIDByName(req.Username) + id, err := sv.userRepo.GetUserIDByName(req.Username) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, respond.WrongName @@ -86,9 +86,9 @@ func (sv *UserService) UserLogin(req *model.UserLoginRequest) (*model.Tokens, er return &tokens, nil } -func (sv *UserService) RefreshTokenHandler(refreshToken string) (*model.Tokens, error) { +/*func (sv *UserService) RefreshTokenHandler(refreshToken string) (*model.Tokens, error) { // 验证刷新令牌 - token, err := auth.ValidateRefreshToken(refreshToken) + token, err := auth.ValidateRefreshToken(refreshToken, sv.cacheRepo) if err != nil || !token.Valid { // 刷新令牌无效 return nil, respond.InvalidRefreshToken } @@ -106,4 +106,38 @@ func (sv *UserService) RefreshTokenHandler(refreshToken string) (*model.Tokens, } else { return nil, respond.InvalidClaims } +}*/ + +func (sv *UserService) RefreshTokenHandler(refreshToken string) (*model.Tokens, error) { + // 1. 验证刷新令牌 (这里已经包含了 Redis 黑名单检查) + token, err := auth.ValidateRefreshToken(refreshToken, sv.cacheRepo) + if err != nil { + return nil, err + } + + // 2. 改动点:直接断言为你定义的结构体 model.MyCustomClaims + if claims, ok := token.Claims.(*model.MyCustomClaims); ok { + // 3. 这里的 userID 已经是 int 了,不再需要 (float64) 转换 + newAccessToken, newRefreshToken, err := auth.GenerateTokens(claims.UserID) + if err != nil { + return nil, err + } + // 返回新的双 Token + return &model.Tokens{ + AccessToken: newAccessToken, + RefreshToken: newRefreshToken, + }, nil + } + + return nil, respond.InvalidClaims +} + +func (sv *UserService) UserLogout(jti string, expireTime time.Time) error { + //1.直接把 jti 扔进黑名单 + expiration := time.Until(expireTime) + err := sv.cacheRepo.SetBlacklist(jti, expiration) + if err != nil { + return err + } + return nil }