Version: 0.7.4.dev.260323

 feat(schedulerefine): 新增 refine 子路由,优先执行复合操作,失败后降级至禁复合 ReAct 兜底

ReAct 升级
- ♻️ 将原有链路升级为真正的 ReAct 执行模式,进一步增强整体调度过程的可靠性

Refine 子路由
- 🧭 在 refine 主链路中新增 `route` 节点,整体流程调整为 `contract -> plan -> slice -> route -> react -> hard_check -> summary`
-  当 `route` 命中全局复合目标时,优先尝试一次调用 `SpreadEven` / `MinContextSwitch`,失败后最多重试 2 次
- 🔀 `route` 成功后直接跳过 `ReAct`;若执行失败,则自动切换至 `fallback` 模式
- 🛡️ 在 `fallback` 模式下增加后端硬约束:禁用 `SpreadEven` / `MinContextSwitch` / `BatchMove`,仅允许使用 `Move` / `Swap` 逐任务处理
- 🧠 在 `ReAct` 的 prompt 与上下文中新增 `COMPOSITE_TOOLS_ALLOWED`,显式告知当前是否允许使用复合工具
- 🧩 扩展状态字段以承载路由与降级状态:`CompositeRetryMax` / `DisableCompositeTools` / `CompositeRouteTried` / `CompositeRouteSucceeded`
- 👀 增加 `route` 相关阶段日志,便于排查命中、重试、收口与降级原因

修复
- 🐛 修复 JWT Token 过期时间未按 `config.yaml` 配置生效的问题

备注
- 🚧 当前 ReAct 逐步微排链路已趋于稳定,但两个复合操作函数仍未恢复可用,后续将继续排查
This commit is contained in:
Losita
2026-03-23 23:14:19 +08:00
parent 525a8b32cb
commit e6941f98f2
13 changed files with 4924 additions and 1080 deletions

View File

@@ -2,6 +2,9 @@ package auth
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/dao"
@@ -12,41 +15,204 @@ import (
"github.com/spf13/viper"
)
var RefreshKey = []byte(viper.GetString("jwt.refreshSecret")) // 用于签名和验证刷新Token的密钥
var AccessKey = []byte(viper.GetString("jwt.accessSecret")) // 用于签名和验证访问Token的密钥
const (
accessSecretConfigKey = "jwt.accessSecret"
refreshSecretConfigKey = "jwt.refreshSecret"
accessExpireConfigKey = "jwt.accessTokenExpire"
refreshExpireConfigKey = "jwt.refreshTokenExpire"
// generateJTI 生成唯一的 JWT ID
defaultAccessTokenExpire = 15 * time.Minute
defaultRefreshTokenExpire = 7 * 24 * time.Hour
)
type jwtRuntimeConfig struct {
AccessKey []byte
RefreshKey []byte
AccessExpire time.Duration
RefreshExpire time.Duration
}
// AccessSigningKey 负责提供访问令牌签名/验签密钥。
// 职责边界:
// 1. 负责从运行时配置读取 accessSecret 并做空值校验。
// 2. 不负责 token 解析、业务鉴权与错误码映射。
// 3. 返回值语义:[]byte 为签名密钥error 非空表示配置不可用。
func AccessSigningKey() ([]byte, error) {
cfg, err := loadJWTConfig()
if err != nil {
return nil, err
}
return cfg.AccessKey, nil
}
// generateJTI 生成唯一的 JWT ID。
func generateJTI() string {
return uuid.New().String()
}
// GenerateTokens 生成访问令牌和刷新令牌
func GenerateTokens(userID int) (string, string, error) {
// 创建访问令牌
sid := generateJTI()
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": userID, // 获取用户ID
"exp": time.Now().Add(15 * time.Minute).Unix(), // 设置访问令牌过期时间为 15 分钟
"token_type": "access_token", // 令牌类型为访问令牌
"jti": sid, // 亲子共用的 JWT ID
})
// loadJWTConfig 负责聚合 JWT 运行时配置。
// 职责边界:
// 1. 负责读取密钥与过期时间配置,并转换为可直接使用的结构。
// 2. 不负责持久化配置,也不负责降级到“不安全默认密钥”。
// 3. 返回值语义cfg 可直接用于签发/校验error 非空表示配置不合法。
func loadJWTConfig() (*jwtRuntimeConfig, error) {
accessKey, err := readJWTSecret(accessSecretConfigKey)
if err != nil {
return nil, err
}
refreshKey, err := readJWTSecret(refreshSecretConfigKey)
if err != nil {
return nil, err
}
// 使用密钥签名访问令牌
accessTokenString, err := accessToken.SignedString(AccessKey)
accessExpire, err := readJWTExpireDuration(accessExpireConfigKey, defaultAccessTokenExpire)
if err != nil {
return nil, err
}
refreshExpire, err := readJWTExpireDuration(refreshExpireConfigKey, defaultRefreshTokenExpire)
if err != nil {
return nil, err
}
return &jwtRuntimeConfig{
AccessKey: accessKey,
RefreshKey: refreshKey,
AccessExpire: accessExpire,
RefreshExpire: refreshExpire,
}, nil
}
// readJWTSecret 负责读取并校验 JWT 密钥配置。
// 职责边界:
// 1. 负责“读配置 + 去空白 + 空值校验”。
// 2. 不负责任何默认值回退,避免静默使用弱配置。
// 3. 返回值语义:[]byte 为密钥error 非空表示该配置项不可用。
func readJWTSecret(configKey string) ([]byte, error) {
secret := strings.TrimSpace(viper.GetString(configKey))
if secret == "" {
return nil, fmt.Errorf("jwt 配置缺失: %s", configKey)
}
return []byte(secret), nil
}
// readJWTExpireDuration 负责读取并解析 JWT 过期时间配置。
// 职责边界:
// 1. 负责把字符串配置解析成 time.Duration并保证结果大于 0。
// 2. 不负责签发 token仅提供“可计算”的过期时长。
// 3. 返回值语义duration 为最终时长error 非空表示格式非法。
func readJWTExpireDuration(configKey string, fallback time.Duration) (time.Duration, error) {
raw := strings.TrimSpace(viper.GetString(configKey))
if raw == "" {
return fallback, nil
}
d, err := parseFlexibleDuration(raw)
if err != nil {
return 0, fmt.Errorf("jwt 配置项 %s 非法: %w", configKey, err)
}
if d <= 0 {
return 0, fmt.Errorf("jwt 配置项 %s 必须大于 0", configKey)
}
return d, nil
}
// parseFlexibleDuration 负责解析项目内常见时长格式。
// 职责边界:
// 1. 负责兼容 Go 标准格式(如 15m、168h与项目常见格式如 15min、7d
// 2. 不负责读取配置键名;仅解析输入字符串。
// 3. 输入输出语义raw 为原始时长文本;返回解析后的正时长或错误。
func parseFlexibleDuration(raw string) (time.Duration, error) {
normalized := strings.ToLower(strings.TrimSpace(raw))
if normalized == "" {
return 0, errors.New("时长不能为空")
}
// 1. 先走 Go 原生解析,优先兼容标准写法(如 15m/168h
if d, err := time.ParseDuration(normalized); err == nil {
return d, nil
}
// 2. 原生解析失败后,兼容项目常见简写(如 15min、7d
type unitDef struct {
Suffix string
Multiplier time.Duration
}
unitDefs := []unitDef{
{Suffix: "minutes", Multiplier: time.Minute},
{Suffix: "minute", Multiplier: time.Minute},
{Suffix: "mins", Multiplier: time.Minute},
{Suffix: "min", Multiplier: time.Minute},
{Suffix: "days", Multiplier: 24 * time.Hour},
{Suffix: "day", Multiplier: 24 * time.Hour},
{Suffix: "d", Multiplier: 24 * time.Hour},
{Suffix: "hours", Multiplier: time.Hour},
{Suffix: "hour", Multiplier: time.Hour},
{Suffix: "h", Multiplier: time.Hour},
{Suffix: "seconds", Multiplier: time.Second},
{Suffix: "second", Multiplier: time.Second},
{Suffix: "secs", Multiplier: time.Second},
{Suffix: "sec", Multiplier: time.Second},
{Suffix: "m", Multiplier: time.Minute},
{Suffix: "s", Multiplier: time.Second},
}
for _, unit := range unitDefs {
if !strings.HasSuffix(normalized, unit.Suffix) {
continue
}
numberPart := strings.TrimSpace(strings.TrimSuffix(normalized, unit.Suffix))
value, err := strconv.Atoi(numberPart)
if err != nil {
return 0, fmt.Errorf("时长数值非法: %q", numberPart)
}
if value <= 0 {
return 0, fmt.Errorf("时长数值必须大于 0: %d", value)
}
return time.Duration(value) * unit.Multiplier, nil
}
return 0, fmt.Errorf("不支持的时长格式: %s", raw)
}
// GenerateTokens 负责按配置签发访问令牌与刷新令牌。
// 职责边界:
// 1. 负责根据配置生成 exp并签发 access/refresh 双 token。
// 2. 不负责登录鉴权(用户名/密码验证在 service 层处理)。
// 3. 返回值语义:第一个为 access token第二个为 refresh tokenerror 非空表示签发失败。
func GenerateTokens(userID int) (string, string, error) {
cfg, err := loadJWTConfig()
if err != nil {
return "", "", err
}
// 创建刷新令牌
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": userID, // 获取用户ID
"exp": time.Now().Add(7 * 24 * time.Hour).Unix(), // 设置刷新令牌过期时间为 7 天
"token_type": "refresh_token", // 令牌类型为刷新令牌
"jti": sid, // 亲子共用的 JWT ID
})
now := time.Now()
sid := generateJTI()
// 使用密钥签名刷新令牌
refreshTokenString, err := refreshToken.SignedString(RefreshKey)
// 1. 先签 access token短期有效面向接口访问。
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, model.MyCustomClaims{
UserID: userID,
TokenType: "access_token",
Jti: sid,
RegisteredClaims: jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(cfg.AccessExpire)),
},
})
accessTokenString, err := accessToken.SignedString(cfg.AccessKey)
if err != nil {
return "", "", err
}
// 2. 再签 refresh token长期有效仅用于换发新 token。
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, model.MyCustomClaims{
UserID: userID,
TokenType: "refresh_token",
Jti: sid,
RegisteredClaims: jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(cfg.RefreshExpire)),
},
})
refreshTokenString, err := refreshToken.SignedString(cfg.RefreshKey)
if err != nil {
return "", "", err
}
@@ -54,71 +220,45 @@ func GenerateTokens(userID int) (string, string, error) {
return accessTokenString, refreshTokenString, nil
}
// ValidateRefreshToken 验证刷新令牌的有效性
/*func ValidateRefreshToken(tokenString string) (*jwt.Token, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// 检查签名方法是否为 HMAC
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, respond.InvalidTokenSingingMethod
}
// 返回用于验证的密钥
return RefreshKey, nil
})
// ValidateRefreshToken 验证刷新令牌的有效性,并增加 Redis 黑名单检查。
func ValidateRefreshToken(tokenString string, cache *dao.CacheDAO) (*jwt.Token, error) {
cfg, err := loadJWTConfig()
if err != nil {
return nil, err
}
// 进一步检查载荷中 token_type 是否正确
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, respond.InvalidClaims
}
// 检查 token_type 是否是 refresh_token
if claimType, ok := claims["token_type"].(string); !ok || claimType != "refresh_token" {
return nil, respond.WrongTokenType
}
return token, nil
}
*/
// ValidateRefreshToken 验证刷新令牌的有效性,并增加 Redis 黑名单检查
func ValidateRefreshToken(tokenString string, cache *dao.CacheDAO) (*jwt.Token, error) {
// 1. 解析 Token 并直接绑定到你的自定义结构体
// 1. 解析 refresh token并强制校验签名算法与密钥来源。
token, err := jwt.ParseWithClaims(tokenString, &model.MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, respond.InvalidTokenSingingMethod
}
return RefreshKey, nil
return cfg.RefreshKey, nil
})
if err != nil {
return nil, respond.InvalidRefreshToken
}
if !token.Valid {
return nil, respond.InvalidRefreshToken
}
// 2. 断言获取 Claims
// 2. 断言 claims 类型,后续业务字段都从结构体读取。
claims, ok := token.Claims.(*model.MyCustomClaims)
if !ok {
return nil, respond.InvalidClaims
}
// 3. 核心“设卡”:检查 token_type 是否是 refresh_token
// 3. 校验 token_type,防止把 access token 当 refresh token 用。
if claims.TokenType != "refresh_token" {
return nil, respond.WrongTokenType
}
// 4. --- 🛡️ 终极关卡:检查 Redis 黑名单 ---
// 即使签名没过期,如果 jti 在黑名单里(用户已登出),也视为无效
// 4. 黑名单校验:签名合法也要确认 jti 未被主动注销。
isBlack, err := cache.IsBlacklisted(claims.Jti)
if err != nil {
// Redis 出错时的处理逻辑,建议报错以防“漏网之鱼”
return nil, errors.New("无法验证令牌状态")
}
if isBlack {
return nil, respond.UserLoggedOut // 返回你定义的“用户已登出”错误
return nil, respond.UserLoggedOut
}
return token, nil

View File

@@ -0,0 +1,128 @@
package auth
import (
"testing"
"time"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/golang-jwt/jwt/v4"
"github.com/spf13/viper"
)
func TestParseFlexibleDuration(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
raw string
want time.Duration
wantFail bool
}{
{name: "标准格式", raw: "15m", want: 15 * time.Minute},
{name: "项目分钟简写", raw: "15min", want: 15 * time.Minute},
{name: "项目天简写", raw: "7d", want: 7 * 24 * time.Hour},
{name: "非法格式", raw: "abc", wantFail: true},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := parseFlexibleDuration(tc.raw)
if tc.wantFail {
if err == nil {
t.Fatalf("期望解析失败,但得到成功: %s", tc.raw)
}
return
}
if err != nil {
t.Fatalf("解析失败: %v", err)
}
if got != tc.want {
t.Fatalf("解析结果不符合预期got=%v want=%v", got, tc.want)
}
})
}
}
func TestGenerateTokens_UseConfigExpire(t *testing.T) {
const (
accessSecret = "unit-test-access-secret"
refreshSecret = "unit-test-refresh-secret"
accessExpire = "2h"
refreshExpire = "3d"
)
originAccessSecret := viper.GetString(accessSecretConfigKey)
originRefreshSecret := viper.GetString(refreshSecretConfigKey)
originAccessExpire := viper.GetString(accessExpireConfigKey)
originRefreshExpire := viper.GetString(refreshExpireConfigKey)
viper.Set(accessSecretConfigKey, accessSecret)
viper.Set(refreshSecretConfigKey, refreshSecret)
viper.Set(accessExpireConfigKey, accessExpire)
viper.Set(refreshExpireConfigKey, refreshExpire)
t.Cleanup(func() {
viper.Set(accessSecretConfigKey, originAccessSecret)
viper.Set(refreshSecretConfigKey, originRefreshSecret)
viper.Set(accessExpireConfigKey, originAccessExpire)
viper.Set(refreshExpireConfigKey, originRefreshExpire)
})
start := time.Now()
accessTokenString, refreshTokenString, err := GenerateTokens(9527)
if err != nil {
t.Fatalf("签发 token 失败: %v", err)
}
accessClaims := parseTokenClaimsForTest(t, accessTokenString, []byte(accessSecret))
refreshClaims := parseTokenClaimsForTest(t, refreshTokenString, []byte(refreshSecret))
if accessClaims.TokenType != "access_token" {
t.Fatalf("access token_type 不符合预期: %s", accessClaims.TokenType)
}
if refreshClaims.TokenType != "refresh_token" {
t.Fatalf("refresh token_type 不符合预期: %s", refreshClaims.TokenType)
}
if accessClaims.Jti == "" || refreshClaims.Jti == "" {
t.Fatalf("jti 不能为空")
}
if accessClaims.Jti != refreshClaims.Jti {
t.Fatalf("access/refresh 应共享同一个 jti")
}
assertExpireNear(t, accessClaims.ExpiresAt.Time, start.Add(2*time.Hour), 3*time.Second)
assertExpireNear(t, refreshClaims.ExpiresAt.Time, start.Add(3*24*time.Hour), 3*time.Second)
}
func parseTokenClaimsForTest(t *testing.T, tokenString string, key []byte) *model.MyCustomClaims {
t.Helper()
token, err := jwt.ParseWithClaims(tokenString, &model.MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
return key, nil
})
if err != nil {
t.Fatalf("解析 token 失败: %v", err)
}
if !token.Valid {
t.Fatalf("token 无效")
}
claims, ok := token.Claims.(*model.MyCustomClaims)
if !ok {
t.Fatalf("claims 类型断言失败")
}
return claims
}
func assertExpireNear(t *testing.T, got time.Time, want time.Time, tolerance time.Duration) {
t.Helper()
delta := got.Sub(want)
if delta < 0 {
delta = -delta
}
if delta > tolerance {
t.Fatalf("exp 偏差超出容忍范围got=%s want=%s delta=%s tolerance=%s", got.Format(time.RFC3339), want.Format(time.RFC3339), delta, tolerance)
}
}