Version: 0.7.4.dev.260323
✨ feat(schedulerefine): 新增 refine 子路由,优先执行复合操作,失败后降级至禁复合 ReAct 兜底 ReAct 升级 - ♻️ 将原有链路升级为真正的 ReAct 执行模式,进一步增强整体调度过程的可靠性 Refine 子路由 - 🧭 在 refine 主链路中新增 `route` 节点,整体流程调整为 `contract -> plan -> slice -> route -> react -> hard_check -> summary` - ⚡ 当 `route` 命中全局复合目标时,优先尝试一次调用 `SpreadEven` / `MinContextSwitch`,失败后最多重试 2 次 - 🔀 `route` 成功后直接跳过 `ReAct`;若执行失败,则自动切换至 `fallback` 模式 - 🛡️ 在 `fallback` 模式下增加后端硬约束:禁用 `SpreadEven` / `MinContextSwitch` / `BatchMove`,仅允许使用 `Move` / `Swap` 逐任务处理 - 🧠 在 `ReAct` 的 prompt 与上下文中新增 `COMPOSITE_TOOLS_ALLOWED`,显式告知当前是否允许使用复合工具 - 🧩 扩展状态字段以承载路由与降级状态:`CompositeRetryMax` / `DisableCompositeTools` / `CompositeRouteTried` / `CompositeRouteSucceeded` - 👀 增加 `route` 相关阶段日志,便于排查命中、重试、收口与降级原因 修复 - 🐛 修复 JWT Token 过期时间未按 `config.yaml` 配置生效的问题 备注 - 🚧 当前 ReAct 逐步微排链路已趋于稳定,但两个复合操作函数仍未恢复可用,后续将继续排查
This commit is contained in:
117
backend/agent/schedulerefine/composite_tools_test.go
Normal file
117
backend/agent/schedulerefine/composite_tools_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package schedulerefine
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
)
|
||||
|
||||
func TestRefineToolSpreadEvenSuccess(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 1, Name: "任务1", Type: "task", Status: "suggested", Week: 16, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2, ContextTag: "A"},
|
||||
{TaskItemID: 2, Name: "任务2", Type: "task", Status: "suggested", Week: 16, DayOfWeek: 1, SectionFrom: 3, SectionTo: 4, ContextTag: "B"},
|
||||
{TaskItemID: 99, Name: "课程", Type: "course", Status: "existing", Week: 12, DayOfWeek: 1, SectionFrom: 5, SectionTo: 6, BlockForSuggested: true},
|
||||
}
|
||||
params := map[string]any{
|
||||
"task_item_ids": []any{1.0, 2.0},
|
||||
"week": 12,
|
||||
"day_of_week": []any{1.0, 2.0, 3.0},
|
||||
"allow_embed": false,
|
||||
}
|
||||
policy := refineToolPolicy{OriginOrderMap: map[int]int{1: 1, 2: 2}}
|
||||
|
||||
nextEntries, result := refineToolSpreadEven(entries, params, planningWindow{Enabled: false}, policy)
|
||||
if !result.Success {
|
||||
t.Fatalf("SpreadEven 执行失败: %s", result.Result)
|
||||
}
|
||||
if result.Tool != "SpreadEven" {
|
||||
t.Fatalf("工具名错误,期望 SpreadEven,实际=%s", result.Tool)
|
||||
}
|
||||
|
||||
idx1 := findSuggestedByID(nextEntries, 1)
|
||||
idx2 := findSuggestedByID(nextEntries, 2)
|
||||
if idx1 < 0 || idx2 < 0 {
|
||||
t.Fatalf("移动后未找到目标任务: idx1=%d idx2=%d", idx1, idx2)
|
||||
}
|
||||
task1 := nextEntries[idx1]
|
||||
task2 := nextEntries[idx2]
|
||||
if task1.Week != 12 || task2.Week != 12 {
|
||||
t.Fatalf("期望任务被移动到 W12,实际 task1=%d task2=%d", task1.Week, task2.Week)
|
||||
}
|
||||
if task1.DayOfWeek < 1 || task1.DayOfWeek > 3 || task2.DayOfWeek < 1 || task2.DayOfWeek > 3 {
|
||||
t.Fatalf("期望任务被移动到周一到周三,实际 task1=%d task2=%d", task1.DayOfWeek, task2.DayOfWeek)
|
||||
}
|
||||
if task1.DayOfWeek == task2.DayOfWeek && sectionsOverlap(task1.SectionFrom, task1.SectionTo, task2.SectionFrom, task2.SectionTo) {
|
||||
t.Fatalf("复合工具不应产出重叠坑位: task1=%+v task2=%+v", task1, task2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefineToolMinContextSwitchGroupsContext(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 11, Name: "任务11", Type: "task", Status: "suggested", Week: 16, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2, ContextTag: "数学"},
|
||||
{TaskItemID: 12, Name: "任务12", Type: "task", Status: "suggested", Week: 16, DayOfWeek: 1, SectionFrom: 3, SectionTo: 4, ContextTag: "算法"},
|
||||
{TaskItemID: 13, Name: "任务13", Type: "task", Status: "suggested", Week: 16, DayOfWeek: 1, SectionFrom: 5, SectionTo: 6, ContextTag: "数学"},
|
||||
{TaskItemID: 99, Name: "课程", Type: "course", Status: "existing", Week: 12, DayOfWeek: 1, SectionFrom: 11, SectionTo: 12, BlockForSuggested: true},
|
||||
}
|
||||
params := map[string]any{
|
||||
"task_item_ids": []any{11.0, 12.0, 13.0},
|
||||
"week": 12,
|
||||
"day_of_week": []any{1.0},
|
||||
}
|
||||
policy := refineToolPolicy{OriginOrderMap: map[int]int{11: 1, 12: 2, 13: 3}}
|
||||
|
||||
nextEntries, result := refineToolMinContextSwitch(entries, params, planningWindow{Enabled: false}, policy)
|
||||
if !result.Success {
|
||||
t.Fatalf("MinContextSwitch 执行失败: %s", result.Result)
|
||||
}
|
||||
if result.Tool != "MinContextSwitch" {
|
||||
t.Fatalf("工具名错误,期望 MinContextSwitch,实际=%s", result.Tool)
|
||||
}
|
||||
|
||||
selected := make([]model.HybridScheduleEntry, 0, 3)
|
||||
for _, id := range []int{11, 12, 13} {
|
||||
idx := findSuggestedByID(nextEntries, id)
|
||||
if idx < 0 {
|
||||
t.Fatalf("未找到任务 id=%d", id)
|
||||
}
|
||||
selected = append(selected, nextEntries[idx])
|
||||
}
|
||||
sort.SliceStable(selected, func(i, j int) bool {
|
||||
if selected[i].Week != selected[j].Week {
|
||||
return selected[i].Week < selected[j].Week
|
||||
}
|
||||
if selected[i].DayOfWeek != selected[j].DayOfWeek {
|
||||
return selected[i].DayOfWeek < selected[j].DayOfWeek
|
||||
}
|
||||
return selected[i].SectionFrom < selected[j].SectionFrom
|
||||
})
|
||||
|
||||
switches := 0
|
||||
for i := 1; i < len(selected); i++ {
|
||||
if selected[i].ContextTag != selected[i-1].ContextTag {
|
||||
switches++
|
||||
}
|
||||
}
|
||||
if switches > 1 {
|
||||
t.Fatalf("期望最少上下文切换(<=1),实际 switches=%d, tasks=%+v", switches, selected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListTaskIDsFromToolCallComposite(t *testing.T) {
|
||||
call := reactToolCall{
|
||||
Tool: "SpreadEven",
|
||||
Params: map[string]any{
|
||||
"task_item_ids": []any{1.0, 2.0, 2.0},
|
||||
"task_item_id": 3,
|
||||
},
|
||||
}
|
||||
ids := listTaskIDsFromToolCall(call)
|
||||
if len(ids) != 3 {
|
||||
t.Fatalf("期望提取 3 个去重 ID,实际=%v", ids)
|
||||
}
|
||||
sort.Ints(ids)
|
||||
if ids[0] != 1 || ids[1] != 2 || ids[2] != 3 {
|
||||
t.Fatalf("提取结果错误,实际=%v", ids)
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,9 @@ import (
|
||||
|
||||
const (
|
||||
graphNodeContract = "schedule_refine_contract"
|
||||
graphNodePlan = "schedule_refine_plan"
|
||||
graphNodeSlice = "schedule_refine_slice"
|
||||
graphNodeRoute = "schedule_refine_route"
|
||||
graphNodeReact = "schedule_refine_react"
|
||||
graphNodeHardCheck = "schedule_refine_hard_check"
|
||||
graphNodeSummary = "schedule_refine_summary"
|
||||
@@ -30,7 +33,7 @@ type ScheduleRefineGraphRunInput struct {
|
||||
// RunScheduleRefineGraph 执行“连续微调”独立图链路。
|
||||
//
|
||||
// 链路顺序:
|
||||
// START -> contract -> react -> hard_check -> summary -> END
|
||||
// START -> contract -> plan -> slice -> route -> react -> hard_check -> summary -> END
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 当前链路采用线性图,确保可读性优先;
|
||||
@@ -55,6 +58,15 @@ func RunScheduleRefineGraph(ctx context.Context, input ScheduleRefineGraphRunInp
|
||||
if err := graph.AddLambdaNode(graphNodeContract, compose.InvokableLambda(runner.contractNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(graphNodePlan, compose.InvokableLambda(runner.planNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(graphNodeSlice, compose.InvokableLambda(runner.sliceNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(graphNodeRoute, compose.InvokableLambda(runner.routeNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(graphNodeReact, compose.InvokableLambda(runner.reactNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -68,7 +80,16 @@ func RunScheduleRefineGraph(ctx context.Context, input ScheduleRefineGraphRunInp
|
||||
if err := graph.AddEdge(compose.START, graphNodeContract); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(graphNodeContract, graphNodeReact); err != nil {
|
||||
if err := graph.AddEdge(graphNodeContract, graphNodePlan); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(graphNodePlan, graphNodeSlice); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(graphNodeSlice, graphNodeRoute); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(graphNodeRoute, graphNodeReact); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(graphNodeReact, graphNodeHardCheck); err != nil {
|
||||
@@ -83,7 +104,7 @@ func RunScheduleRefineGraph(ctx context.Context, input ScheduleRefineGraphRunInp
|
||||
|
||||
runnable, err := graph.Compile(ctx,
|
||||
compose.WithGraphName("ScheduleRefineGraph"),
|
||||
compose.WithMaxRunSteps(12),
|
||||
compose.WithMaxRunSteps(20),
|
||||
compose.WithNodeTriggerMode(compose.AnyPredecessor),
|
||||
)
|
||||
if err != nil {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,173 +1,164 @@
|
||||
package schedulerefine
|
||||
|
||||
const (
|
||||
// contractPrompt 用于“微调契约抽取”节点。
|
||||
//
|
||||
// 目标:
|
||||
// 1. 把用户自然语言微调请求收敛成结构化契约;
|
||||
// 2. 明确是否需要“保持相对顺序不变”;
|
||||
// 3. 严格输出 JSON,降低解析抖动。
|
||||
// contractPrompt 负责把用户自然语言微调请求抽取为结构化契约。
|
||||
contractPrompt = `你是 SmartFlow 的排程微调契约分析器。
|
||||
你会收到:当前时间、用户本轮微调请求、已有排程摘要。
|
||||
你的任务是把“用户真正想改什么”转成结构化契约。
|
||||
|
||||
请只输出 JSON,不要 markdown,不要解释,字段如下:
|
||||
你会收到:当前时间、用户请求、已有排程摘要。
|
||||
请只输出 JSON,不要 Markdown,不要解释,不要代码块:
|
||||
{
|
||||
"intent": "一句话概括用户本轮微调目标",
|
||||
"intent": "一句话概括本轮微调目标",
|
||||
"strategy": "local_adjust|keep",
|
||||
"hard_requirements": ["必须满足的硬性要求1","硬性要求2"],
|
||||
"hard_requirements": ["必须满足的硬性要求1","必须满足的硬性要求2"],
|
||||
"hard_assertions": [
|
||||
{
|
||||
"metric": "source_move_ratio_percent|all_source_tasks_in_target_scope|source_remaining_count",
|
||||
"operator": "==|<=|>=|between",
|
||||
"value": 50,
|
||||
"min": 50,
|
||||
"max": 50,
|
||||
"week": 17,
|
||||
"target_week": 16
|
||||
}
|
||||
],
|
||||
"keep_relative_order": true,
|
||||
"order_scope": "global|week",
|
||||
"reason": "简短中文原因,<=40字"
|
||||
"order_scope": "global|week"
|
||||
}
|
||||
|
||||
规则:
|
||||
1) 当用户表达“保持原顺序/不打乱顺序/按原顺序推进”时,keep_relative_order=true。
|
||||
2) 若用户没有提顺序要求,keep_relative_order=false,order_scope 固定输出 "global"。
|
||||
3) strategy=keep 仅用于“无需改动”的情况;只要要移动任务,就输出 local_adjust。
|
||||
4) hard_requirements 要可验证,避免空话。`
|
||||
1. 除非用户明确表达“允许打乱顺序/顺序无所谓”,keep_relative_order 默认 true。
|
||||
2. 仅当用户明确放宽顺序时,keep_relative_order 才允许为 false;order_scope 默认 "global"。
|
||||
3. 只要涉及移动任务,strategy 必须是 local_adjust;仅在无需改动时才用 keep。
|
||||
4. hard_requirements 必须可验证,避免空泛描述。
|
||||
5. hard_assertions 必须尽量结构化,避免只给自然语言目标。`
|
||||
|
||||
// plannerPrompt 用于“Plan-and-Execute”的规划阶段。
|
||||
//
|
||||
// 目标:
|
||||
// 1. 让模型按当前请求自动规划“先取证再动作”的执行路径;
|
||||
// 2. 规划结果要求结构化,便于执行阶段直接引用;
|
||||
// 3. 不在 Planner 阶段执行工具,只负责产出计划。
|
||||
plannerPrompt = `你是 SmartFlow 的排程微调规划器(Planner)。
|
||||
你会收到:用户请求、契约、最近动作日志与观察。
|
||||
你的职责是生成“下一阶段的执行计划”,而不是直接执行工具。
|
||||
|
||||
只输出 JSON:
|
||||
// plannerPrompt 只负责生成“执行路径”,不直接执行动作。
|
||||
plannerPrompt = `你是 SmartFlow 的排程微调 Planner。
|
||||
你会收到:用户请求、契约、最近动作观察。
|
||||
请只输出 JSON,不要 Markdown,不要解释,不要代码块:
|
||||
{
|
||||
"summary": "本轮计划一句话",
|
||||
"steps": ["步骤1","步骤2","步骤3"],
|
||||
"success_signals": ["满足什么算成功1","成功2"],
|
||||
"fallback": "若连续失败,准备怎么改道"
|
||||
"summary": "本阶段执行策略一句话",
|
||||
"steps": ["步骤1","步骤2","步骤3"]
|
||||
}
|
||||
|
||||
规则:
|
||||
1. steps 请优先采用“先取证后动作”的路径:例如 QueryTargetTasks / QueryAvailableSlots / BatchMove / Move / Swap / Verify。
|
||||
2. steps 保持 3~4 条,单条不超过 26 字。
|
||||
3. summary 不超过 36 字;fallback 不超过 30 字;success_signals 最多 3 条。
|
||||
4. 严禁输出半截 JSON;若信息过多,请精简而不是展开解释。
|
||||
5. 不要输出 markdown,不要输出额外文本。`
|
||||
1. steps 保持 3~4 条,优先“先取证再动作”。
|
||||
2. summary <= 36 字,单步 <= 28 字。
|
||||
3. 若目标是“均匀分散”,steps 必须体现 SpreadEven 且包含“成功后才收口”的硬条件。
|
||||
4. 若目标是“上下文切换最少/同科目连续”,steps 必须体现 MinContextSwitch 且包含“成功后才收口”的硬条件。
|
||||
5. 不要输出半截 JSON。`
|
||||
|
||||
// reactPrompt 用于“强 ReAct 微调循环”节点。
|
||||
//
|
||||
// 目标:
|
||||
// 1. 每轮先输出“计划 -> 缺口 -> 工具动作”(不承担执行后反思);
|
||||
// 2. 每轮最多一个 tool_call,但支持 BatchMove 在一个调用里原子执行多步;
|
||||
// 3. 明确遵守顺序硬约束与 existing 不可改约束。
|
||||
reactPrompt = `你是 SmartFlow 的排程微调执行器,采用“走一步看一步”的 ReAct 风格。
|
||||
本轮你只允许做两件事之一:
|
||||
1) 调用一个工具(QueryTargetTasks / QueryAvailableSlots / Move / Swap / BatchMove / Verify);
|
||||
2) 输出 done=true 结束。
|
||||
// reactPrompt 用于“单任务微步 ReAct”执行器。
|
||||
reactPrompt = `你是 SmartFlow 的单任务微步 ReAct 执行器。
|
||||
当前只处理一个任务(CURRENT_TASK),不能发散到其它任务的主动改动。
|
||||
你每轮只能做两件事之一:
|
||||
1) 调用一个工具(基础工具或复合工具)
|
||||
2) 输出 done=true 结束当前任务
|
||||
|
||||
你将收到 3 个关键输入:
|
||||
1) LAST_TOOL_RESULT:上一轮工具结果(结构化 JSON);
|
||||
2) LAST_TOOL_OBSERVATION:上一轮完整观察(包含 tool_name/tool_params/tool_success/tool_error_code/tool_result);
|
||||
3) LAST_FAILED_CALL_SIGNATURE:上一轮失败动作签名(tool+params)。
|
||||
工具分组:
|
||||
- 基础工具:QueryTargetTasks / QueryAvailableSlots / Move / Swap / BatchMove / Verify
|
||||
- 复合工具:SpreadEven / MinContextSwitch
|
||||
|
||||
硬约束:
|
||||
1. 每轮最多 1 个 tool_call。
|
||||
2. 只能修改 status="suggested" 的任务,禁止修改 existing。
|
||||
3. 如果合同中 keep_relative_order=true,任何动作都不能打乱任务原始相对顺序。
|
||||
4. 如果当前方案已满足目标,直接 done=true,不要多余动作。
|
||||
5. day_of_week 数值映射必须严格按:1周一,2周二,3周三,4周四,5周五,6周六,7周日。
|
||||
6. 若上一轮 tool_success=false,你必须先根据 tool_error_code 调整策略,再给新动作。
|
||||
7. 禁止重复上一轮失败动作(tool 与 params 完全一致);若重复会被后端拒绝执行并记为失败轮次。
|
||||
工具说明(按职责):
|
||||
1. QueryTargetTasks:查询候选任务集合(只读)。
|
||||
常用参数:week/week_filter/day_of_week/task_item_ids/status。
|
||||
适用:先摸清“有哪些任务可动、当前在哪”。
|
||||
2. QueryAvailableSlots:查询可放置坑位(只读,默认先纯空位,必要时补可嵌入位)。
|
||||
常用参数:week/week_filter/day_of_week/span/limit/allow_embed/exclude_sections。
|
||||
适用:Move 前先拿可落点清单。
|
||||
3. Move:移动单个任务到目标坑位(写操作)。
|
||||
必要参数:task_item_id,to_week,to_day,to_section_from,to_section_to。
|
||||
适用:单任务精确挪动。
|
||||
4. Swap:交换两个任务坑位(写操作)。
|
||||
必要参数:task_a,task_b。
|
||||
适用:两个任务互换位置比单独 Move 更稳时。
|
||||
5. BatchMove:批量原子移动(写操作)。
|
||||
必要参数:{"moves":[{Move参数...},{Move参数...}]}。
|
||||
适用:一轮要改多个任务且要求“要么全成要么全回滚”。
|
||||
6. Verify:执行确定性校验(只读)。
|
||||
常用参数:可空;也可传 task_item_id + 目标坐标做定点核验。
|
||||
适用:收尾前快速自检是否符合确定性约束。
|
||||
7. SpreadEven(复合):按“均匀铺开”目标一次规划并执行多任务移动(写操作)。
|
||||
必要参数:task_item_ids(必须包含 CURRENT_TASK.task_item_id)。
|
||||
可选参数:week/week_filter/day_of_week/allow_embed/limit。
|
||||
适用:目标是“把任务在时间上分散开,避免扎堆”。
|
||||
8. MinContextSwitch(复合):按“最少上下文切换”一次规划并执行多任务移动(写操作)。
|
||||
必要参数:task_item_ids(必须包含 CURRENT_TASK.task_item_id)。
|
||||
可选参数:week/week_filter/day_of_week/allow_embed/limit。
|
||||
适用:目标是“同科目/同认知标签尽量连续,减少切换成本”。
|
||||
|
||||
你必须只输出 JSON,字段如下:
|
||||
请严格输出 JSON,不要 Markdown,不要解释:
|
||||
{
|
||||
"done": false,
|
||||
"summary": "",
|
||||
"goal_check": "本轮先检查什么",
|
||||
"decision": "本轮为什么这样决策",
|
||||
"missing_info": ["如果缺信息就在这里写;不缺则返回空数组"],
|
||||
"reflect": "本轮计划备注(动作前,不是执行后复盘)",
|
||||
"decision": "本轮为何这么做",
|
||||
"missing_info": ["缺口信息1","缺口信息2"],
|
||||
"tool_calls": [
|
||||
{
|
||||
"tool": "QueryTargetTasks|QueryAvailableSlots|Move|Swap|BatchMove|Verify",
|
||||
"tool": "QueryTargetTasks|QueryAvailableSlots|Move|Swap|BatchMove|SpreadEven|MinContextSwitch|Verify",
|
||||
"params": {}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
补充规则:
|
||||
1. 若 done=true,则 tool_calls 必须是空数组。
|
||||
2. 若 done=false 且有动作,tool_calls 必须只有一个元素。
|
||||
3. QueryTargetTasks 用于“先定位要改哪些任务”,禁止直接猜。
|
||||
4. QueryAvailableSlots 用于“先看可用空位”,禁止凭直觉盲移。
|
||||
5. Move 参数优先使用标准字段:task_item_id,to_week,to_day,to_section_from,to_section_to。
|
||||
6. BatchMove 参数格式必须是:{"moves":[{Move参数1},{Move参数2},...]},后端会按顺序原子执行;任一步失败则整批回滚。
|
||||
7. Verify 是终止前自检工具:done=true 前建议先执行一次 Verify。
|
||||
8. reflect 只描述“本轮计划备注”,不要把未执行的动作写成已完成事实。
|
||||
9. 为保证 JSON 稳定可解析,请控制长度:goal_check<=50字、decision<=90字、reflect<=80字、summary<=60字、missing_info 最多3条。
|
||||
10. 你必须显式说明“上一轮失败原因如何影响本轮决策”(写在 decision 里)。
|
||||
11. 不要输出代码块,不要输出额外文本。`
|
||||
硬规则:
|
||||
1. 每轮最多 1 个 tool_call。
|
||||
2. done=true 时,tool_calls 必须为空数组。
|
||||
3. done=false 时,tool_calls 必须恰好 1 条。
|
||||
4. 只能修改 status="suggested" 的任务,禁止修改 existing。
|
||||
5. 不要把“顺序约束”当作执行期阻塞条件;你只需把坑位分布排好,顺序由后端统一收口。
|
||||
6. 若上轮失败,必须依据 LAST_TOOL_OBSERVATION.error_code 调整策略,不能重复上轮失败动作。
|
||||
7. Move 参数优先使用:task_item_id,to_week,to_day,to_section_from,to_section_to。
|
||||
8. BatchMove 参数格式必须是:{"moves":[{...},{...}]};任一步失败会整批回滚。
|
||||
9. day_of_week 映射固定:1周一,2周二,3周三,4周四,5周五,6周六,7周日。
|
||||
10. 优先使用“纯空位”;仅在空位不足时再考虑可嵌入课程位(第二优先级)。
|
||||
11. 如果 SOURCE_WEEK_FILTER 非空,只允许改写这些来源周里的任务,禁止主动改写其它周任务。
|
||||
12. CURRENT_TASK 是本轮唯一可改写任务;如果它已满足目标,立刻 done=true,不要提前处理下一个任务。
|
||||
13. 禁止发明工具名(如 GetCurrentTask、AdjustTaskTime),只能用白名单工具。
|
||||
14. 优先使用后端注入的 ENV_SLOT_HINT 进行落点决策,非必要不要重复 QueryAvailableSlots。
|
||||
15. 若 REQUIRED_COMPOSITE_TOOL 非空且 COMPOSITE_REQUIRED_SUCCESS=false,本轮必须优先调用 REQUIRED_COMPOSITE_TOOL,禁止先调用 Move/Swap/BatchMove。
|
||||
16. 若使用 SpreadEven/MinContextSwitch,必须在参数中提供 task_item_ids(且包含 CURRENT_TASK.task_item_id)。
|
||||
17. 若 COMPOSITE_TOOLS_ALLOWED=false,禁止调用 SpreadEven/MinContextSwitch,只能使用基础工具逐步处理。
|
||||
18. 为保证解析稳定:goal_check<=50字,decision<=90字,summary<=60字。`
|
||||
|
||||
// postReflectPrompt 用于“动作执行后真反思”节点。
|
||||
//
|
||||
// 目标:
|
||||
// 1. 基于后端返回的真实工具结果做复盘,而不是动作前预期;
|
||||
// 2. 输出下一轮可执行的改进策略,驱动真正的 Observe -> Think;
|
||||
// 3. 严格输出 JSON,供后端稳定解析并透传 stage。
|
||||
// postReflectPrompt 要求模型基于真实工具结果做复盘,不允许“脑补成功”。
|
||||
postReflectPrompt = `你是 SmartFlow 的 ReAct 复盘器。
|
||||
你会收到:本轮工具调用参数、后端真实执行结果、上一轮上下文。
|
||||
请基于“真实结果”复盘,不要把失败说成成功。
|
||||
|
||||
只输出 JSON:
|
||||
你会收到:本轮工具参数、后端真实执行结果、上一轮上下文。
|
||||
请只输出 JSON,不要 Markdown,不要解释:
|
||||
{
|
||||
"reflection": "本轮发生了什么(基于真实结果)",
|
||||
"next_strategy": "下一轮建议如何改(具体到换时段/换工具/保持)",
|
||||
"should_stop": false,
|
||||
"stop_reason": "若应结束,给简短原因"
|
||||
"reflection": "基于真实结果的复盘",
|
||||
"next_strategy": "下一轮建议动作",
|
||||
"should_stop": false
|
||||
}
|
||||
|
||||
规则:
|
||||
1. tool_success=false 时,reflection 必须明确失败原因(优先引用 error_code)。
|
||||
2. 若 error_code=ORDER_VIOLATION/SLOT_CONFLICT/REPEAT_FAILED_ACTION,next_strategy 必须给出“如何避开同类失败”。
|
||||
3. should_stop=true 仅在“目标已满足”或“继续动作收益很低”时使用。
|
||||
4. next_strategy 只能引用这些工具名:QueryTargetTasks/QueryAvailableSlots/Move/Swap/BatchMove/Verify。
|
||||
5. 不要输出 markdown,不要输出额外文本。`
|
||||
1. 若 tool_success=false,reflection 必须明确失败原因(优先引用 error_code)。
|
||||
2. 若 error_code 属于 ORDER_VIOLATION/SLOT_CONFLICT/REPEAT_FAILED_ACTION,next_strategy 必须给出规避方法。
|
||||
3. should_stop=true 仅用于“目标已满足”或“继续收益很低”。`
|
||||
|
||||
// reviewPrompt 用于“终审语义校验”节点。
|
||||
//
|
||||
// 目标:
|
||||
// 1. 检查方案是否满足用户本轮请求;
|
||||
// 2. 给出未满足项列表,供一次修复动作使用;
|
||||
// 3. 输出结构化 JSON,避免校验结果歧义。
|
||||
// reviewPrompt 用于终审语义校验。
|
||||
reviewPrompt = `你是 SmartFlow 的终审校验器。
|
||||
请判断“当前排程”是否满足“本轮用户微调请求 + 契约硬要求”。
|
||||
只输出 JSON:
|
||||
{
|
||||
"pass": true,
|
||||
"reason": "中文简短结论",
|
||||
"unmet": ["若不满足,这里列未满足点"]
|
||||
"unmet": []
|
||||
}
|
||||
|
||||
要求:
|
||||
1. pass=true 时,unmet 必须为空数组。
|
||||
2. pass=false 时,reason 必须给出核心差距。`
|
||||
规则:
|
||||
1. pass=true 时 unmet 必须为空数组。
|
||||
2. pass=false 时 reason 必须给出核心差距。`
|
||||
|
||||
// summaryPrompt 用于“最终回复润色”节点。
|
||||
//
|
||||
// 目标:
|
||||
// 1. 给用户返回自然语言总结;
|
||||
// 2. 体现“做了什么调整 + 为什么这样改”;
|
||||
// 3. 若终审仍有缺口,也要诚实说明。
|
||||
// summaryPrompt 用于最终面向用户的自然语言总结。
|
||||
summaryPrompt = `你是 SmartFlow 的排程结果解读助手。
|
||||
请基于输入输出 2~4 句自然中文总结:
|
||||
1) 先说本轮改了什么;
|
||||
2) 再说这样改的收益;
|
||||
3) 如果终审未完全通过,要明确说明还差什么。
|
||||
请基于输入输出 2~4 句中文总结:
|
||||
1) 先说明本轮改了什么;
|
||||
2) 再说明改动收益;
|
||||
3) 若终审未完全通过,明确还差什么。
|
||||
不要输出 JSON。`
|
||||
|
||||
// repairPrompt 用于“终审失败后的单次修复”节点。
|
||||
//
|
||||
// 目标:
|
||||
// 1. 在不重跑全链路的前提下做一次局部补救;
|
||||
// 2. 强制只输出一个工具调用,避免再次拉长思考。
|
||||
// repairPrompt 用于终审失败后的单次修复动作。
|
||||
repairPrompt = `你是 SmartFlow 的修复执行器。
|
||||
当前方案未通过终审,请根据“未满足点”只做一次修复动作。
|
||||
只允许输出一个 tool_call(Move 或 Swap),不允许 done。
|
||||
@@ -179,12 +170,19 @@ const (
|
||||
"goal_check": "本轮修复目标",
|
||||
"decision": "修复决策依据",
|
||||
"missing_info": [],
|
||||
"reflect": "修复动作后的预期",
|
||||
"tool_calls": [
|
||||
{
|
||||
"tool": "Move|Swap",
|
||||
"params": {}
|
||||
}
|
||||
]
|
||||
}`
|
||||
}
|
||||
|
||||
Move 参数必须使用标准键:
|
||||
- task_item_id
|
||||
- to_week
|
||||
- to_day
|
||||
- to_section_from
|
||||
- to_section_to
|
||||
禁止使用 new_week/new_day/section_from 等别名。`
|
||||
)
|
||||
|
||||
573
backend/agent/schedulerefine/refine_filters_test.go
Normal file
573
backend/agent/schedulerefine/refine_filters_test.go
Normal file
@@ -0,0 +1,573 @@
|
||||
package schedulerefine
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
)
|
||||
|
||||
func TestQueryTargetTasksWeekFilterAndTaskID(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 1, Name: "task-w12", Week: 12, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2, Status: "suggested", Type: "task"},
|
||||
{TaskItemID: 2, Name: "task-w13", Week: 13, DayOfWeek: 1, SectionFrom: 3, SectionTo: 4, Status: "suggested", Type: "task"},
|
||||
{TaskItemID: 3, Name: "task-w14", Week: 14, DayOfWeek: 1, SectionFrom: 5, SectionTo: 6, Status: "suggested", Type: "task"},
|
||||
}
|
||||
policy := refineToolPolicy{OriginOrderMap: map[int]int{1: 1, 2: 2, 3: 3}}
|
||||
|
||||
paramsWeek := map[string]any{
|
||||
"week_filter": []any{13.0, 14.0},
|
||||
}
|
||||
_, resultWeek := refineToolQueryTargetTasks(entries, paramsWeek, policy)
|
||||
if !resultWeek.Success {
|
||||
t.Fatalf("week_filter 查询失败: %s", resultWeek.Result)
|
||||
}
|
||||
var payloadWeek struct {
|
||||
Count int `json:"count"`
|
||||
Items []struct {
|
||||
TaskItemID int `json:"task_item_id"`
|
||||
Week int `json:"week"`
|
||||
} `json:"items"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(resultWeek.Result), &payloadWeek); err != nil {
|
||||
t.Fatalf("解析 week_filter 结果失败: %v", err)
|
||||
}
|
||||
if payloadWeek.Count != 2 {
|
||||
t.Fatalf("week_filter 期望返回 2 条,实际=%d", payloadWeek.Count)
|
||||
}
|
||||
for _, item := range payloadWeek.Items {
|
||||
if item.Week != 13 && item.Week != 14 {
|
||||
t.Fatalf("week_filter 过滤失败,出现非法周次=%d", item.Week)
|
||||
}
|
||||
}
|
||||
|
||||
paramsTaskID := map[string]any{
|
||||
"week_filter": []any{13.0, 14.0},
|
||||
"task_item_id": 2,
|
||||
}
|
||||
_, resultTaskID := refineToolQueryTargetTasks(entries, paramsTaskID, policy)
|
||||
if !resultTaskID.Success {
|
||||
t.Fatalf("task_item_id 查询失败: %s", resultTaskID.Result)
|
||||
}
|
||||
var payloadTaskID struct {
|
||||
Count int `json:"count"`
|
||||
Items []struct {
|
||||
TaskItemID int `json:"task_item_id"`
|
||||
Week int `json:"week"`
|
||||
} `json:"items"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(resultTaskID.Result), &payloadTaskID); err != nil {
|
||||
t.Fatalf("解析 task_item_id 结果失败: %v", err)
|
||||
}
|
||||
if payloadTaskID.Count != 1 {
|
||||
t.Fatalf("task_item_id 期望返回 1 条,实际=%d", payloadTaskID.Count)
|
||||
}
|
||||
if payloadTaskID.Items[0].TaskItemID != 2 || payloadTaskID.Items[0].Week != 13 {
|
||||
t.Fatalf("task_item_id 过滤错误: %+v", payloadTaskID.Items[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryAvailableSlotsExactSectionAlias(t *testing.T) {
|
||||
params := map[string]any{
|
||||
"week": 13,
|
||||
"section_duration": 2,
|
||||
"section_from": 1,
|
||||
"section_to": 2,
|
||||
"limit": 5,
|
||||
}
|
||||
_, result := refineToolQueryAvailableSlots(nil, params, planningWindow{Enabled: false})
|
||||
if !result.Success {
|
||||
t.Fatalf("QueryAvailableSlots 失败: %s", result.Result)
|
||||
}
|
||||
var payload struct {
|
||||
Count int `json:"count"`
|
||||
Slots []struct {
|
||||
Week int `json:"week"`
|
||||
SectionFrom int `json:"section_from"`
|
||||
SectionTo int `json:"section_to"`
|
||||
} `json:"slots"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(result.Result), &payload); err != nil {
|
||||
t.Fatalf("解析 QueryAvailableSlots 结果失败: %v", err)
|
||||
}
|
||||
if payload.Count == 0 {
|
||||
t.Fatalf("期望至少返回一个可用时段,实际=0")
|
||||
}
|
||||
for _, slot := range payload.Slots {
|
||||
if slot.Week != 13 {
|
||||
t.Fatalf("返回了错误周次: %+v", slot)
|
||||
}
|
||||
if slot.SectionFrom != 1 || slot.SectionTo != 2 {
|
||||
t.Fatalf("精确节次过滤失败: %+v", slot)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryAvailableSlotsWeekFilterDayFilterAlias(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 1, Name: "task-w12", Week: 12, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2, Status: "suggested", Type: "task"},
|
||||
{TaskItemID: 2, Name: "task-w17", Week: 17, DayOfWeek: 4, SectionFrom: 3, SectionTo: 4, Status: "suggested", Type: "task"},
|
||||
}
|
||||
params := map[string]any{
|
||||
"week_filter": []any{17.0},
|
||||
"day_filter": []any{1.0, 2.0, 3.0},
|
||||
"limit": 20,
|
||||
}
|
||||
|
||||
_, result := refineToolQueryAvailableSlots(entries, params, planningWindow{Enabled: false})
|
||||
if !result.Success {
|
||||
t.Fatalf("QueryAvailableSlots 别名查询失败: %s", result.Result)
|
||||
}
|
||||
var payload struct {
|
||||
Count int `json:"count"`
|
||||
Slots []struct {
|
||||
Week int `json:"week"`
|
||||
DayOfWeek int `json:"day_of_week"`
|
||||
} `json:"slots"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(result.Result), &payload); err != nil {
|
||||
t.Fatalf("解析 week/day 过滤结果失败: %v", err)
|
||||
}
|
||||
if payload.Count == 0 {
|
||||
t.Fatalf("week_filter/day_filter 查询应返回 W17 周一到周三空位,实际为空")
|
||||
}
|
||||
for _, slot := range payload.Slots {
|
||||
if slot.Week != 17 {
|
||||
t.Fatalf("week_filter 失效,出现 week=%d", slot.Week)
|
||||
}
|
||||
if slot.DayOfWeek < 1 || slot.DayOfWeek > 3 {
|
||||
t.Fatalf("day_filter 失效,出现 day_of_week=%d", slot.DayOfWeek)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectWorksetTaskIDsSourceWeekOnly(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 1, Week: 12, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2, Status: "suggested", Type: "task"},
|
||||
{TaskItemID: 2, Week: 14, DayOfWeek: 1, SectionFrom: 3, SectionTo: 4, Status: "suggested", Type: "task"},
|
||||
{TaskItemID: 3, Week: 13, DayOfWeek: 1, SectionFrom: 5, SectionTo: 6, Status: "suggested", Type: "task"},
|
||||
{TaskItemID: 4, Week: 14, DayOfWeek: 2, SectionFrom: 7, SectionTo: 8, Status: "suggested", Type: "task"},
|
||||
}
|
||||
slice := RefineSlicePlan{WeekFilter: []int{14, 13}}
|
||||
originOrder := map[int]int{1: 1, 2: 2, 3: 3, 4: 4}
|
||||
|
||||
got := collectWorksetTaskIDs(entries, slice, originOrder)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("来源周收敛失败,期望 2 条,实际=%d, got=%v", len(got), got)
|
||||
}
|
||||
if got[0] != 2 || got[1] != 4 {
|
||||
t.Fatalf("来源周结果错误,期望 [2 4],实际=%v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSlicePlanDirectionalSourceTarget(t *testing.T) {
|
||||
st := &ScheduleRefineState{
|
||||
UserMessage: "帮我把第17周周四到周五的任务都收敛到17周的周一到周三,优先放空位,空位不够了再嵌入",
|
||||
}
|
||||
plan := buildSlicePlan(st)
|
||||
if len(plan.WeekFilter) == 0 || plan.WeekFilter[0] != 17 {
|
||||
t.Fatalf("week_filter 解析错误: %+v", plan.WeekFilter)
|
||||
}
|
||||
expectSource := []int{4, 5}
|
||||
expectTarget := []int{1, 2, 3}
|
||||
if len(plan.SourceDays) != len(expectSource) {
|
||||
t.Fatalf("source_days 长度错误: got=%v", plan.SourceDays)
|
||||
}
|
||||
for i := range expectSource {
|
||||
if plan.SourceDays[i] != expectSource[i] {
|
||||
t.Fatalf("source_days 错误: got=%v", plan.SourceDays)
|
||||
}
|
||||
}
|
||||
if len(plan.TargetDays) != len(expectTarget) {
|
||||
t.Fatalf("target_days 长度错误: got=%v", plan.TargetDays)
|
||||
}
|
||||
for i := range expectTarget {
|
||||
if plan.TargetDays[i] != expectTarget[i] {
|
||||
t.Fatalf("target_days 错误: got=%v", plan.TargetDays)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyTaskCoordinateMismatch(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 28, Name: "task-w17-d4", Week: 17, DayOfWeek: 4, SectionFrom: 5, SectionTo: 6, Status: "suggested", Type: "task"},
|
||||
}
|
||||
policy := refineToolPolicy{OriginOrderMap: map[int]int{28: 1}}
|
||||
params := map[string]any{
|
||||
"task_item_id": 28,
|
||||
"week": 17,
|
||||
"day_of_week": 1,
|
||||
"section_from": 1,
|
||||
"section_to": 2,
|
||||
}
|
||||
|
||||
_, result := refineToolVerify(entries, params, policy)
|
||||
if result.Success {
|
||||
t.Fatalf("期望 Verify 在任务坐标不匹配时失败,实际 success=true, result=%s", result.Result)
|
||||
}
|
||||
if result.ErrorCode != "VERIFY_FAILED" {
|
||||
t.Fatalf("期望错误码 VERIFY_FAILED,实际=%s", result.ErrorCode)
|
||||
}
|
||||
if !strings.Contains(result.Result, "不匹配") {
|
||||
t.Fatalf("期望结果包含“不匹配”提示,实际=%s", result.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveRejectsSuggestedCourseEntry(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{
|
||||
TaskItemID: 39,
|
||||
Name: "面向对象程序设计-C++",
|
||||
Type: "course",
|
||||
Status: "suggested",
|
||||
Week: 17,
|
||||
DayOfWeek: 4,
|
||||
SectionFrom: 7,
|
||||
SectionTo: 8,
|
||||
},
|
||||
}
|
||||
params := map[string]any{
|
||||
"task_item_id": 39,
|
||||
"to_week": 17,
|
||||
"to_day": 1,
|
||||
"to_section_from": 7,
|
||||
"to_section_to": 8,
|
||||
}
|
||||
_, result := refineToolMove(entries, params, planningWindow{Enabled: false}, refineToolPolicy{OriginOrderMap: map[int]int{39: 1}})
|
||||
if result.Success {
|
||||
t.Fatalf("期望 course 类型的 suggested 条目不可移动,实际 success=true, result=%s", result.Result)
|
||||
}
|
||||
if !strings.Contains(result.Result, "可移动 suggested 任务") {
|
||||
t.Fatalf("期望返回不可移动提示,实际=%s", result.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryAvailableSlotsSlotTypePureDisablesEmbed(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{
|
||||
Name: "可嵌入课程",
|
||||
Type: "course",
|
||||
Status: "existing",
|
||||
Week: 17,
|
||||
DayOfWeek: 1,
|
||||
SectionFrom: 1,
|
||||
SectionTo: 2,
|
||||
BlockForSuggested: false,
|
||||
},
|
||||
}
|
||||
|
||||
pureParams := map[string]any{
|
||||
"week": 17,
|
||||
"day_of_week": 1,
|
||||
"section_from": 1,
|
||||
"section_to": 2,
|
||||
"slot_type": "pure",
|
||||
}
|
||||
_, pureResult := refineToolQueryAvailableSlots(entries, pureParams, planningWindow{Enabled: false})
|
||||
if !pureResult.Success {
|
||||
t.Fatalf("pure 查询失败: %s", pureResult.Result)
|
||||
}
|
||||
var purePayload struct {
|
||||
Count int `json:"count"`
|
||||
EmbeddedCount int `json:"embedded_count"`
|
||||
FallbackUsed bool `json:"fallback_used"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(pureResult.Result), &purePayload); err != nil {
|
||||
t.Fatalf("解析 pure 查询结果失败: %v", err)
|
||||
}
|
||||
if purePayload.Count != 0 || purePayload.EmbeddedCount != 0 || purePayload.FallbackUsed {
|
||||
t.Fatalf("slot_type=pure 应禁用嵌入兜底,实际 payload=%+v", purePayload)
|
||||
}
|
||||
|
||||
defaultParams := map[string]any{
|
||||
"week": 17,
|
||||
"day_of_week": 1,
|
||||
"section_from": 1,
|
||||
"section_to": 2,
|
||||
}
|
||||
_, defaultResult := refineToolQueryAvailableSlots(entries, defaultParams, planningWindow{Enabled: false})
|
||||
if !defaultResult.Success {
|
||||
t.Fatalf("default 查询失败: %s", defaultResult.Result)
|
||||
}
|
||||
var defaultPayload struct {
|
||||
Count int `json:"count"`
|
||||
EmbeddedCount int `json:"embedded_count"`
|
||||
FallbackUsed bool `json:"fallback_used"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(defaultResult.Result), &defaultPayload); err != nil {
|
||||
t.Fatalf("解析 default 查询结果失败: %v", err)
|
||||
}
|
||||
if defaultPayload.Count == 0 || defaultPayload.EmbeddedCount == 0 || !defaultPayload.FallbackUsed {
|
||||
t.Fatalf("默认查询应允许嵌入候选,实际 payload=%+v", defaultPayload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileObjectiveAndEvaluateMoveAllPass(t *testing.T) {
|
||||
initial := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 39, Name: "任务39", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 4, SectionFrom: 7, SectionTo: 8},
|
||||
{TaskItemID: 51, Name: "任务51", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 5, SectionFrom: 9, SectionTo: 10},
|
||||
}
|
||||
final := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 39, Name: "任务39", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 1, SectionFrom: 7, SectionTo: 8},
|
||||
{TaskItemID: 51, Name: "任务51", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 2, SectionFrom: 9, SectionTo: 10},
|
||||
}
|
||||
st := &ScheduleRefineState{
|
||||
UserMessage: "把17周周四到周五任务收敛到周一到周三",
|
||||
InitialHybridEntries: initial,
|
||||
HybridEntries: final,
|
||||
SlicePlan: RefineSlicePlan{
|
||||
WeekFilter: []int{17},
|
||||
SourceDays: []int{4, 5},
|
||||
TargetDays: []int{1, 2, 3},
|
||||
},
|
||||
}
|
||||
st.Objective = compileRefineObjective(st, st.SlicePlan)
|
||||
if st.Objective.Mode != "move_all" {
|
||||
t.Fatalf("期望目标模式 move_all,实际=%s", st.Objective.Mode)
|
||||
}
|
||||
|
||||
pass, _, unmet, applied := evaluateObjectiveDeterministic(st)
|
||||
if !applied {
|
||||
t.Fatalf("期望命中确定性终审")
|
||||
}
|
||||
if !pass {
|
||||
t.Fatalf("期望确定性终审通过,unmet=%v", unmet)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileObjectiveAndEvaluateMoveAllFail(t *testing.T) {
|
||||
initial := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 26, Name: "任务26", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 5, SectionFrom: 7, SectionTo: 8},
|
||||
}
|
||||
final := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 26, Name: "任务26", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 5, SectionFrom: 7, SectionTo: 8},
|
||||
}
|
||||
st := &ScheduleRefineState{
|
||||
UserMessage: "把17周周四到周五任务收敛到周一到周三",
|
||||
InitialHybridEntries: initial,
|
||||
HybridEntries: final,
|
||||
SlicePlan: RefineSlicePlan{
|
||||
WeekFilter: []int{17},
|
||||
SourceDays: []int{4, 5},
|
||||
TargetDays: []int{1, 2, 3},
|
||||
},
|
||||
}
|
||||
st.Objective = compileRefineObjective(st, st.SlicePlan)
|
||||
|
||||
pass, _, unmet, applied := evaluateObjectiveDeterministic(st)
|
||||
if !applied {
|
||||
t.Fatalf("期望命中确定性终审")
|
||||
}
|
||||
if pass {
|
||||
t.Fatalf("期望确定性终审失败")
|
||||
}
|
||||
if len(unmet) == 0 {
|
||||
t.Fatalf("期望返回未满足项")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileObjectiveMoveRatioFromContractAndEvaluatePass(t *testing.T) {
|
||||
initial, final := buildHalfTransferEntries(10, 5)
|
||||
st := &ScheduleRefineState{
|
||||
UserMessage: "17周任务太多,帮我调整到16周",
|
||||
InitialHybridEntries: initial,
|
||||
HybridEntries: final,
|
||||
SlicePlan: RefineSlicePlan{
|
||||
WeekFilter: []int{17, 16},
|
||||
},
|
||||
Contract: RefineContract{
|
||||
Intent: "将第17周任务匀一半到第16周",
|
||||
HardRequirements: []string{"原第17周任务数调整为原来的一半", "调整到第16周的任务数为原第17周任务数的一半"},
|
||||
},
|
||||
}
|
||||
st.Objective = compileRefineObjective(st, st.SlicePlan)
|
||||
if st.Objective.Mode != "move_ratio" {
|
||||
t.Fatalf("期望目标模式 move_ratio,实际=%s", st.Objective.Mode)
|
||||
}
|
||||
if st.Objective.RequiredMoveMin != 5 || st.Objective.RequiredMoveMax != 5 {
|
||||
t.Fatalf("半数迁移阈值错误: min=%d max=%d", st.Objective.RequiredMoveMin, st.Objective.RequiredMoveMax)
|
||||
}
|
||||
|
||||
pass, _, unmet, applied := evaluateObjectiveDeterministic(st)
|
||||
if !applied {
|
||||
t.Fatalf("期望命中确定性终审")
|
||||
}
|
||||
if !pass {
|
||||
t.Fatalf("期望半数迁移通过,unmet=%v", unmet)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileObjectiveMoveRatioFromContractAndEvaluateFail(t *testing.T) {
|
||||
initial, final := buildHalfTransferEntries(10, 4)
|
||||
st := &ScheduleRefineState{
|
||||
UserMessage: "17周任务太多,帮我调整到16周",
|
||||
InitialHybridEntries: initial,
|
||||
HybridEntries: final,
|
||||
SlicePlan: RefineSlicePlan{
|
||||
WeekFilter: []int{17, 16},
|
||||
},
|
||||
Contract: RefineContract{
|
||||
Intent: "将第17周任务匀一半到第16周",
|
||||
HardRequirements: []string{"原第17周任务数调整为原来的一半", "调整到第16周的任务数为原第17周任务数的一半"},
|
||||
},
|
||||
}
|
||||
st.Objective = compileRefineObjective(st, st.SlicePlan)
|
||||
|
||||
pass, _, unmet, applied := evaluateObjectiveDeterministic(st)
|
||||
if !applied {
|
||||
t.Fatalf("期望命中确定性终审")
|
||||
}
|
||||
if pass {
|
||||
t.Fatalf("期望半数迁移失败")
|
||||
}
|
||||
if len(unmet) == 0 {
|
||||
t.Fatalf("期望返回未满足项")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileObjectiveMoveRatioFromStructuredAssertion(t *testing.T) {
|
||||
initial, final := buildHalfTransferEntries(10, 5)
|
||||
st := &ScheduleRefineState{
|
||||
UserMessage: "请把任务重新分配",
|
||||
InitialHybridEntries: initial,
|
||||
HybridEntries: final,
|
||||
SlicePlan: RefineSlicePlan{
|
||||
WeekFilter: []int{17, 16},
|
||||
},
|
||||
Contract: RefineContract{
|
||||
Intent: "任务重新分配",
|
||||
HardAssertions: []RefineAssertion{
|
||||
{
|
||||
Metric: "source_move_ratio_percent",
|
||||
Operator: "==",
|
||||
Value: 50,
|
||||
Week: 17,
|
||||
TargetWeek: 16,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
st.Objective = compileRefineObjective(st, st.SlicePlan)
|
||||
if st.Objective.Mode != "move_ratio" {
|
||||
t.Fatalf("结构化断言未生效,期望 move_ratio,实际=%s", st.Objective.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
func buildHalfTransferEntries(total int, moved int) ([]model.HybridScheduleEntry, []model.HybridScheduleEntry) {
|
||||
initial := make([]model.HybridScheduleEntry, 0, total)
|
||||
final := make([]model.HybridScheduleEntry, 0, total)
|
||||
for i := 1; i <= total; i++ {
|
||||
initial = append(initial, model.HybridScheduleEntry{
|
||||
TaskItemID: i,
|
||||
Name: "task",
|
||||
Type: "task",
|
||||
Status: "suggested",
|
||||
Week: 17,
|
||||
DayOfWeek: 1,
|
||||
SectionFrom: 1,
|
||||
SectionTo: 2,
|
||||
})
|
||||
week := 17
|
||||
if i <= moved {
|
||||
week = 16
|
||||
}
|
||||
final = append(final, model.HybridScheduleEntry{
|
||||
TaskItemID: i,
|
||||
Name: "task",
|
||||
Type: "task",
|
||||
Status: "suggested",
|
||||
Week: week,
|
||||
DayOfWeek: 1,
|
||||
SectionFrom: 1,
|
||||
SectionTo: 2,
|
||||
})
|
||||
}
|
||||
return initial, final
|
||||
}
|
||||
|
||||
func TestNormalizeMovableTaskOrderByOrigin(t *testing.T) {
|
||||
st := &ScheduleRefineState{
|
||||
OriginOrderMap: map[int]int{
|
||||
101: 1,
|
||||
202: 2,
|
||||
},
|
||||
HybridEntries: []model.HybridScheduleEntry{
|
||||
{TaskItemID: 202, Name: "task-202", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2},
|
||||
{TaskItemID: 101, Name: "task-101", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 3, SectionFrom: 1, SectionTo: 2},
|
||||
},
|
||||
}
|
||||
changed := normalizeMovableTaskOrderByOrigin(st)
|
||||
if !changed {
|
||||
t.Fatalf("期望发生顺序归位")
|
||||
}
|
||||
sortHybridEntries(st.HybridEntries)
|
||||
if st.HybridEntries[0].TaskItemID != 101 || st.HybridEntries[1].TaskItemID != 202 {
|
||||
t.Fatalf("顺序归位失败: %+v", st.HybridEntries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrecheckToolCallPolicyRejectsRedundantSlotQuery(t *testing.T) {
|
||||
st := &ScheduleRefineState{
|
||||
SeenSlotQueries: make(map[string]struct{}),
|
||||
EntriesVersion: 0,
|
||||
}
|
||||
call := reactToolCall{
|
||||
Tool: "QueryAvailableSlots",
|
||||
Params: map[string]any{
|
||||
"week": 16,
|
||||
"day_of_week": 1,
|
||||
},
|
||||
}
|
||||
|
||||
if blockedResult, blocked := precheckToolCallPolicy(st, call, nil); blocked {
|
||||
t.Fatalf("首次查询不应被拒绝: %+v", blockedResult)
|
||||
}
|
||||
if blockedResult, blocked := precheckToolCallPolicy(st, call, nil); !blocked {
|
||||
t.Fatalf("重复查询应被拒绝")
|
||||
} else if blockedResult.ErrorCode != "QUERY_REDUNDANT" {
|
||||
t.Fatalf("错误码不符合预期: %+v", blockedResult)
|
||||
}
|
||||
st.EntriesVersion++
|
||||
if blockedResult, blocked := precheckToolCallPolicy(st, call, nil); blocked {
|
||||
t.Fatalf("排程版本变化后应允许再次查询: %+v", blockedResult)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanonicalizeMoveParamsFromRepairAliases(t *testing.T) {
|
||||
call := reactToolCall{
|
||||
Tool: "Move",
|
||||
Params: map[string]any{
|
||||
"task_item_id": 16,
|
||||
"new_week": 16,
|
||||
"day_of_week": 1,
|
||||
"section_from": 1,
|
||||
"section_to": 2,
|
||||
},
|
||||
}
|
||||
normalized := canonicalizeToolCall(call)
|
||||
if _, ok := paramIntAny(normalized.Params, "to_week"); !ok {
|
||||
t.Fatalf("to_week 规范化失败: %+v", normalized.Params)
|
||||
}
|
||||
if _, ok := paramIntAny(normalized.Params, "to_day"); !ok {
|
||||
t.Fatalf("to_day 规范化失败: %+v", normalized.Params)
|
||||
}
|
||||
if _, ok := paramIntAny(normalized.Params, "to_section_from"); !ok {
|
||||
t.Fatalf("to_section_from 规范化失败: %+v", normalized.Params)
|
||||
}
|
||||
if _, ok := paramIntAny(normalized.Params, "to_section_to"); !ok {
|
||||
t.Fatalf("to_section_to 规范化失败: %+v", normalized.Params)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectOrderIntentDefaultsToKeep(t *testing.T) {
|
||||
if !detectOrderIntent("16周总体任务太多了,帮我移动一半到12周") {
|
||||
t.Fatalf("未显式放宽顺序时,默认应保持顺序")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectOrderIntentExplicitAllowReorder(t *testing.T) {
|
||||
if detectOrderIntent("这次顺序无所谓,可以打乱顺序") {
|
||||
t.Fatalf("用户明确允许乱序时,应关闭顺序约束")
|
||||
}
|
||||
}
|
||||
@@ -28,6 +28,18 @@ func (r *scheduleRefineRunner) contractNode(ctx context.Context, st *ScheduleRef
|
||||
return runContractNode(ctx, r.chatModel, st, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *scheduleRefineRunner) planNode(ctx context.Context, st *ScheduleRefineState) (*ScheduleRefineState, error) {
|
||||
return runPlanNode(ctx, r.chatModel, st, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *scheduleRefineRunner) sliceNode(ctx context.Context, st *ScheduleRefineState) (*ScheduleRefineState, error) {
|
||||
return runSliceNode(ctx, st, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *scheduleRefineRunner) routeNode(ctx context.Context, st *ScheduleRefineState) (*ScheduleRefineState, error) {
|
||||
return runCompositeRouteNode(ctx, st, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *scheduleRefineRunner) reactNode(ctx context.Context, st *ScheduleRefineState) (*ScheduleRefineState, error) {
|
||||
return runReactLoopNode(ctx, r.chatModel, st, r.emitStage)
|
||||
}
|
||||
|
||||
@@ -9,43 +9,48 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// timezoneName 固定排程链路使用的业务时区,避免容器默认时区导致“明天/今晚”偏移。
|
||||
// 固定业务时区,避免“今天/明天”在容器默认时区下偏移。
|
||||
timezoneName = "Asia/Shanghai"
|
||||
// datetimeLayout 统一使用分钟级时间文本,方便模型理解与日志比对。
|
||||
// 统一分钟级时间文本格式。
|
||||
datetimeLayout = "2006-01-02 15:04"
|
||||
// defaultPlanMax 是 Planner 最大调用次数(包含首次规划 + 重规划)。
|
||||
defaultPlanMax = 2
|
||||
// defaultExecuteMax 是执行阶段最大工具动作轮次。
|
||||
defaultExecuteMax = 16
|
||||
// defaultReplanMax 是执行阶段允许触发的重规划次数上限。
|
||||
defaultReplanMax = 2
|
||||
// defaultRepairReserve 表示为“终审修复”保留的最小动作预算。
|
||||
defaultRepairReserve = 1
|
||||
|
||||
// 预算默认值。
|
||||
defaultPlanMax = 2
|
||||
defaultExecuteMax = 24
|
||||
defaultPerTaskBudget = 4
|
||||
defaultReplanMax = 2
|
||||
defaultCompositeRetry = 2
|
||||
defaultRepairReserve = 1
|
||||
)
|
||||
|
||||
// RefineContract 表示“微调意图契约”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责承载“本轮微调到底要满足什么”的结构化目标;
|
||||
// 2. 负责给后续 ReAct 动作与终审硬校验提供统一语义;
|
||||
// 3. 不负责实际排程修改动作执行(动作由工具层负责)。
|
||||
// RefineContract 表示本轮微调意图契约。
|
||||
type RefineContract struct {
|
||||
Intent string `json:"intent"`
|
||||
Strategy string `json:"strategy"`
|
||||
HardRequirements []string `json:"hard_requirements"`
|
||||
KeepRelativeOrder bool `json:"keep_relative_order"`
|
||||
OrderScope string `json:"order_scope"`
|
||||
Reason string `json:"reason"`
|
||||
Intent string `json:"intent"`
|
||||
Strategy string `json:"strategy"`
|
||||
HardRequirements []string `json:"hard_requirements"`
|
||||
HardAssertions []RefineAssertion `json:"hard_assertions,omitempty"`
|
||||
KeepRelativeOrder bool `json:"keep_relative_order"`
|
||||
OrderScope string `json:"order_scope"`
|
||||
}
|
||||
|
||||
// HardCheckReport 表示“终审硬校验报告”。
|
||||
// RefineAssertion 表示可由后端直接判定的结构化硬断言。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 记录规则层(物理冲突)是否通过;
|
||||
// 2. 记录语义层(是否满足用户要求)是否通过;
|
||||
// 3. 记录顺序层(是否保持相对顺序)是否通过;
|
||||
// 4. 记录失败原因与修复尝试信息,便于后续持续优化 prompt;
|
||||
// 5. 不负责直接决定是否落库(落库决策仍由服务层控制)。
|
||||
// 字段说明:
|
||||
// 1. Metric:断言指标名,例如 source_move_ratio_percent;
|
||||
// 2. Operator:比较操作符,支持 == / <= / >= / between;
|
||||
// 3. Value/Min/Max:阈值;
|
||||
// 4. Week/TargetWeek:可选周次上下文。
|
||||
type RefineAssertion struct {
|
||||
Metric string `json:"metric"`
|
||||
Operator string `json:"operator"`
|
||||
Value int `json:"value,omitempty"`
|
||||
Min int `json:"min,omitempty"`
|
||||
Max int `json:"max,omitempty"`
|
||||
Week int `json:"week,omitempty"`
|
||||
TargetWeek int `json:"target_week,omitempty"`
|
||||
}
|
||||
|
||||
// HardCheckReport 表示终审硬校验结果。
|
||||
type HardCheckReport struct {
|
||||
PhysicsPassed bool `json:"physics_passed"`
|
||||
PhysicsIssues []string `json:"physics_issues,omitempty"`
|
||||
@@ -60,17 +65,11 @@ type HardCheckReport struct {
|
||||
RepairTried bool `json:"repair_tried"`
|
||||
}
|
||||
|
||||
// ReactRoundObservation 用于沉淀“每轮 ReAct 的可见观测信息”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责记录每轮“计划 -> 动作 -> 观察 -> 反思”的关键信息;
|
||||
// 2. 既用于 SSE 透传,也用于下一轮 prompt 的上下文回灌;
|
||||
// 3. 不承担排程真实数据存储职责(真实排程仍在 HybridEntries)。
|
||||
// ReactRoundObservation 记录每轮 ReAct 的关键观察。
|
||||
type ReactRoundObservation struct {
|
||||
Round int `json:"round"`
|
||||
GoalCheck string `json:"goal_check,omitempty"`
|
||||
Decision string `json:"decision,omitempty"`
|
||||
MissingInfo []string `json:"missing_info,omitempty"`
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
ToolParams map[string]any `json:"tool_params,omitempty"`
|
||||
ToolSuccess bool `json:"tool_success"`
|
||||
@@ -79,27 +78,47 @@ type ReactRoundObservation struct {
|
||||
Reflect string `json:"reflect,omitempty"`
|
||||
}
|
||||
|
||||
// PlannerPlan 表示“本轮执行前的结构化计划”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责记录模型当前建议的执行路径(先查什么、再做什么);
|
||||
// 2. 负责在失败重规划后替换为新版本,供执行器下一轮参考;
|
||||
// 3. 不直接约束工具执行结果(执行合法性仍由工具层硬校验负责)。
|
||||
// PlannerPlan 表示 Planner 生成的阶段执行计划。
|
||||
type PlannerPlan struct {
|
||||
Summary string `json:"summary"`
|
||||
Steps []string `json:"steps,omitempty"`
|
||||
SuccessSignals []string `json:"success_signals,omitempty"`
|
||||
Fallback string `json:"fallback,omitempty"`
|
||||
Summary string `json:"summary"`
|
||||
Steps []string `json:"steps,omitempty"`
|
||||
}
|
||||
|
||||
// ScheduleRefineState 是“连续微调图”的统一状态容器。
|
||||
// RefineSlicePlan 表示切片节点输出。
|
||||
type RefineSlicePlan struct {
|
||||
WeekFilter []int `json:"week_filter,omitempty"`
|
||||
SourceDays []int `json:"source_days,omitempty"`
|
||||
TargetDays []int `json:"target_days,omitempty"`
|
||||
ExcludeSections []int `json:"exclude_sections,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// RefineObjective 表示“可执行且可校验”的目标约束。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责在图节点间传递“上一版排程快照 + 本轮用户微调请求 + 动作日志 + 终审报告”;
|
||||
// 2. 负责承载最终对用户可见的 summary 与结构化 candidate_plans;
|
||||
// 3. 不负责 Redis/MySQL 读写(持久化由 service 层负责)。
|
||||
// 设计说明:
|
||||
// 1. 由 contract/slice 从自然语言编译得到;
|
||||
// 2. 执行阶段(done 收口)与终审阶段(hard_check)共用同一份约束;
|
||||
// 3. 避免“执行逻辑与终审逻辑各说各话”。
|
||||
type RefineObjective struct {
|
||||
Mode string `json:"mode,omitempty"` // none | move_all | move_ratio
|
||||
|
||||
SourceWeeks []int `json:"source_weeks,omitempty"`
|
||||
TargetWeeks []int `json:"target_weeks,omitempty"`
|
||||
SourceDays []int `json:"source_days,omitempty"`
|
||||
TargetDays []int `json:"target_days,omitempty"`
|
||||
|
||||
ExcludeSections []int `json:"exclude_sections,omitempty"`
|
||||
|
||||
BaselineSourceTaskCount int `json:"baseline_source_task_count,omitempty"`
|
||||
RequiredMoveMin int `json:"required_move_min,omitempty"`
|
||||
RequiredMoveMax int `json:"required_move_max,omitempty"`
|
||||
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// ScheduleRefineState 是连续微调图的统一状态。
|
||||
type ScheduleRefineState struct {
|
||||
// 1. 基础请求上下文。
|
||||
// 1) 请求上下文
|
||||
TraceID string
|
||||
UserID int
|
||||
ConversationID string
|
||||
@@ -107,59 +126,85 @@ type ScheduleRefineState struct {
|
||||
RequestNow time.Time
|
||||
RequestNowText string
|
||||
|
||||
// 2. 继承自上一版预览快照的可调度数据。
|
||||
TaskClassIDs []int
|
||||
Constraints []string
|
||||
HybridEntries []model.HybridScheduleEntry
|
||||
AllocatedItems []model.TaskClassItem
|
||||
CandidatePlans []model.UserWeekSchedule
|
||||
// 2) 继承自预览快照的数据
|
||||
TaskClassIDs []int
|
||||
Constraints []string
|
||||
// InitialHybridEntries 保存本轮微调开始前的基线,用于终审做“前后对比”。
|
||||
// 说明:
|
||||
// 1. 只读语义,不参与执行期改写;
|
||||
// 2. 终审可基于它判断“来源任务是否真正迁移到目标区域”。
|
||||
InitialHybridEntries []model.HybridScheduleEntry
|
||||
HybridEntries []model.HybridScheduleEntry
|
||||
AllocatedItems []model.TaskClassItem
|
||||
CandidatePlans []model.UserWeekSchedule
|
||||
|
||||
// 3. 本轮微调过程状态。
|
||||
// 3) 本轮执行状态
|
||||
UserIntent string
|
||||
Contract RefineContract
|
||||
|
||||
PlanMax int
|
||||
ExecuteMax int
|
||||
ReplanMax int
|
||||
PlanMax int
|
||||
PerTaskBudget int
|
||||
ExecuteMax int
|
||||
ReplanMax int
|
||||
// CompositeRetryMax 表示复合路由失败后的最大重试次数(不含首次尝试)。
|
||||
CompositeRetryMax int
|
||||
|
||||
PlanUsed int
|
||||
ReplanUsed int
|
||||
|
||||
// MaxRounds 保留“总预算”语义,供终审修复节点继续复用:
|
||||
// MaxRounds = ExecuteMax + RepairReserve
|
||||
MaxRounds int
|
||||
RepairReserve int
|
||||
RoundUsed int
|
||||
ActionLogs []string
|
||||
|
||||
// ConsecutiveFailures 记录执行阶段连续失败次数,用于触发“失败兜底 thinking”。
|
||||
ConsecutiveFailures int
|
||||
// ThinkingBoostArmed 表示“当前失败串已触发过一次 thinking 兜底”。
|
||||
ThinkingBoostArmed bool
|
||||
ThinkingBoostArmed bool
|
||||
ObservationHistory []ReactRoundObservation
|
||||
|
||||
CurrentPlan PlannerPlan
|
||||
BatchMoveAllowed bool
|
||||
// DisableCompositeTools=true 表示已进入 ReAct 兜底,禁止再调用复合工具。
|
||||
DisableCompositeTools bool
|
||||
// CompositeRouteTried 标记是否尝试过“复合批处理路由”。
|
||||
CompositeRouteTried bool
|
||||
// CompositeRouteSucceeded 标记复合批处理路由是否成功收口。
|
||||
CompositeRouteSucceeded bool
|
||||
TaskActionUsed map[int]int
|
||||
EntriesVersion int
|
||||
SeenSlotQueries map[string]struct{}
|
||||
|
||||
// RequiredCompositeTool 表示本轮策略要求“必须至少成功一次”的复合工具。
|
||||
// 取值约定:"" | "SpreadEven" | "MinContextSwitch"。
|
||||
RequiredCompositeTool string
|
||||
// CompositeToolCalled 记录复合工具是否至少调用过一次(不区分成功失败)。
|
||||
CompositeToolCalled map[string]bool
|
||||
// CompositeToolSuccess 记录复合工具是否至少成功过一次。
|
||||
CompositeToolSuccess map[string]bool
|
||||
|
||||
SlicePlan RefineSlicePlan
|
||||
Objective RefineObjective
|
||||
WorksetTaskIDs []int
|
||||
WorksetCursor int
|
||||
CurrentTaskID int
|
||||
CurrentTaskAttempt int
|
||||
|
||||
LastToolResult string
|
||||
ObservationHistory []ReactRoundObservation
|
||||
CurrentPlan PlannerPlan
|
||||
LastPostStrategy string
|
||||
// LastFailedCallSignature 记录“上一轮失败动作签名(tool+params)”。用于后端硬拦截重复失败动作。
|
||||
LastFailedCallSignature string
|
||||
OriginOrderMap map[int]int
|
||||
|
||||
// 4. 终审校验状态。
|
||||
// 4) 终审状态
|
||||
HardCheck HardCheckReport
|
||||
|
||||
// 5. 最终输出。
|
||||
// 5) 最终输出
|
||||
FinalSummary string
|
||||
Completed bool
|
||||
}
|
||||
|
||||
// NewScheduleRefineState 基于“上一版排程预览快照”初始化连续微调状态。
|
||||
// NewScheduleRefineState 基于上一版预览快照初始化状态。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先初始化请求基础字段与默认预算,保证图内每个节点都能读取到稳定上下文。
|
||||
// 2. 再把 preview 的核心排程数据做深拷贝注入,避免跨请求引用污染。
|
||||
// 3. 最后构建 origin_order_map,作为“保持相对顺序”硬约束的判定基线。
|
||||
// 4. 若 preview 为空,仍返回可用 state,由上层决定是报错还是降级。
|
||||
// 职责边界:
|
||||
// 1. 负责初始化预算、上下文字段与可变状态容器;
|
||||
// 2. 负责拷贝 preview 数据,避免跨请求引用污染;
|
||||
// 3. 不负责做任何调度动作。
|
||||
func NewScheduleRefineState(traceID string, userID int, conversationID string, userMessage string, preview *model.SchedulePlanPreviewCache) *ScheduleRefineState {
|
||||
now := nowToMinute()
|
||||
st := &ScheduleRefineState{
|
||||
@@ -170,22 +215,38 @@ func NewScheduleRefineState(traceID string, userID int, conversationID string, u
|
||||
RequestNow: now,
|
||||
RequestNowText: now.In(loadLocation()).Format(datetimeLayout),
|
||||
PlanMax: defaultPlanMax,
|
||||
PerTaskBudget: defaultPerTaskBudget,
|
||||
ExecuteMax: defaultExecuteMax,
|
||||
ReplanMax: defaultReplanMax,
|
||||
CompositeRetryMax: defaultCompositeRetry,
|
||||
RepairReserve: defaultRepairReserve,
|
||||
MaxRounds: defaultExecuteMax + defaultRepairReserve,
|
||||
ActionLogs: make([]string, 0, 24),
|
||||
ObservationHistory: make([]ReactRoundObservation, 0, 16),
|
||||
ActionLogs: make([]string, 0, 32),
|
||||
ObservationHistory: make([]ReactRoundObservation, 0, 24),
|
||||
TaskActionUsed: make(map[int]int),
|
||||
SeenSlotQueries: make(map[string]struct{}),
|
||||
OriginOrderMap: make(map[int]int),
|
||||
CompositeToolCalled: map[string]bool{
|
||||
"SpreadEven": false,
|
||||
"MinContextSwitch": false,
|
||||
},
|
||||
CompositeToolSuccess: map[string]bool{
|
||||
"SpreadEven": false,
|
||||
"MinContextSwitch": false,
|
||||
},
|
||||
CurrentPlan: PlannerPlan{
|
||||
Summary: "初始化完成,等待 Planner 生成执行计划。",
|
||||
},
|
||||
SlicePlan: RefineSlicePlan{
|
||||
Reason: "尚未切片",
|
||||
},
|
||||
}
|
||||
if preview == nil {
|
||||
return st
|
||||
}
|
||||
|
||||
st.TaskClassIDs = append([]int(nil), preview.TaskClassIDs...)
|
||||
st.InitialHybridEntries = cloneHybridEntries(preview.HybridEntries)
|
||||
st.HybridEntries = cloneHybridEntries(preview.HybridEntries)
|
||||
st.AllocatedItems = cloneTaskClassItems(preview.AllocatedItems)
|
||||
st.CandidatePlans = cloneWeekSchedules(preview.CandidatePlans)
|
||||
@@ -193,7 +254,6 @@ func NewScheduleRefineState(traceID string, userID int, conversationID string, u
|
||||
return st
|
||||
}
|
||||
|
||||
// loadLocation 返回排程链路使用的业务时区。
|
||||
func loadLocation() *time.Location {
|
||||
loc, err := time.LoadLocation(timezoneName)
|
||||
if err != nil {
|
||||
@@ -202,12 +262,10 @@ func loadLocation() *time.Location {
|
||||
return loc
|
||||
}
|
||||
|
||||
// nowToMinute 返回当前时刻并截断到分钟级,降低 prompt 中秒级噪声。
|
||||
func nowToMinute() time.Time {
|
||||
return time.Now().In(loadLocation()).Truncate(time.Minute)
|
||||
}
|
||||
|
||||
// cloneHybridEntries 深拷贝混合日程切片。
|
||||
func cloneHybridEntries(src []model.HybridScheduleEntry) []model.HybridScheduleEntry {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
@@ -217,7 +275,6 @@ func cloneHybridEntries(src []model.HybridScheduleEntry) []model.HybridScheduleE
|
||||
return dst
|
||||
}
|
||||
|
||||
// cloneTaskClassItems 深拷贝任务块切片(包含指针字段)。
|
||||
func cloneTaskClassItems(src []model.TaskClassItem) []model.TaskClassItem {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
@@ -250,7 +307,6 @@ func cloneTaskClassItems(src []model.TaskClassItem) []model.TaskClassItem {
|
||||
return dst
|
||||
}
|
||||
|
||||
// cloneWeekSchedules 深拷贝周视图切片。
|
||||
func cloneWeekSchedules(src []model.UserWeekSchedule) []model.UserWeekSchedule {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
@@ -267,12 +323,7 @@ func cloneWeekSchedules(src []model.UserWeekSchedule) []model.UserWeekSchedule {
|
||||
return dst
|
||||
}
|
||||
|
||||
// buildOriginOrderMap 从当前 suggested 排程位置构建“初始相对顺序映射”。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先筛出所有可调的 suggested 任务;
|
||||
// 2. 按 week/day/section/task_item_id 稳定排序,得到“时间先后基线”;
|
||||
// 3. 把 task_item_id -> rank 写入 map,后续 Move/Swap 都基于该 rank 做顺序硬校验。
|
||||
// buildOriginOrderMap 构建 suggested 任务的初始顺序基线(task_item_id -> rank)。
|
||||
func buildOriginOrderMap(entries []model.HybridScheduleEntry) map[int]int {
|
||||
orderMap := make(map[int]int)
|
||||
if len(entries) == 0 {
|
||||
@@ -280,7 +331,7 @@ func buildOriginOrderMap(entries []model.HybridScheduleEntry) map[int]int {
|
||||
}
|
||||
suggested := make([]model.HybridScheduleEntry, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry.Status == "suggested" && entry.TaskItemID > 0 {
|
||||
if isMovableSuggestedTask(entry) {
|
||||
suggested = append(suggested, entry)
|
||||
}
|
||||
}
|
||||
@@ -301,8 +352,8 @@ func buildOriginOrderMap(entries []model.HybridScheduleEntry) map[int]int {
|
||||
}
|
||||
return left.TaskItemID < right.TaskItemID
|
||||
})
|
||||
for idx, entry := range suggested {
|
||||
orderMap[entry.TaskItemID] = idx + 1
|
||||
for i, entry := range suggested {
|
||||
orderMap[entry.TaskItemID] = i + 1
|
||||
}
|
||||
return orderMap
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,9 @@ package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/dao"
|
||||
@@ -12,41 +15,204 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var RefreshKey = []byte(viper.GetString("jwt.refreshSecret")) // 用于签名和验证刷新Token的密钥
|
||||
var AccessKey = []byte(viper.GetString("jwt.accessSecret")) // 用于签名和验证访问Token的密钥
|
||||
const (
|
||||
accessSecretConfigKey = "jwt.accessSecret"
|
||||
refreshSecretConfigKey = "jwt.refreshSecret"
|
||||
accessExpireConfigKey = "jwt.accessTokenExpire"
|
||||
refreshExpireConfigKey = "jwt.refreshTokenExpire"
|
||||
|
||||
// generateJTI 生成唯一的 JWT ID
|
||||
defaultAccessTokenExpire = 15 * time.Minute
|
||||
defaultRefreshTokenExpire = 7 * 24 * time.Hour
|
||||
)
|
||||
|
||||
type jwtRuntimeConfig struct {
|
||||
AccessKey []byte
|
||||
RefreshKey []byte
|
||||
AccessExpire time.Duration
|
||||
RefreshExpire time.Duration
|
||||
}
|
||||
|
||||
// AccessSigningKey 负责提供访问令牌签名/验签密钥。
|
||||
// 职责边界:
|
||||
// 1. 负责从运行时配置读取 accessSecret 并做空值校验。
|
||||
// 2. 不负责 token 解析、业务鉴权与错误码映射。
|
||||
// 3. 返回值语义:[]byte 为签名密钥;error 非空表示配置不可用。
|
||||
func AccessSigningKey() ([]byte, error) {
|
||||
cfg, err := loadJWTConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cfg.AccessKey, nil
|
||||
}
|
||||
|
||||
// generateJTI 生成唯一的 JWT ID。
|
||||
func generateJTI() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// GenerateTokens 生成访问令牌和刷新令牌
|
||||
func GenerateTokens(userID int) (string, string, error) {
|
||||
// 创建访问令牌
|
||||
sid := generateJTI()
|
||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": userID, // 获取用户ID
|
||||
"exp": time.Now().Add(15 * time.Minute).Unix(), // 设置访问令牌过期时间为 15 分钟
|
||||
"token_type": "access_token", // 令牌类型为访问令牌
|
||||
"jti": sid, // 亲子共用的 JWT ID
|
||||
})
|
||||
// loadJWTConfig 负责聚合 JWT 运行时配置。
|
||||
// 职责边界:
|
||||
// 1. 负责读取密钥与过期时间配置,并转换为可直接使用的结构。
|
||||
// 2. 不负责持久化配置,也不负责降级到“不安全默认密钥”。
|
||||
// 3. 返回值语义:cfg 可直接用于签发/校验;error 非空表示配置不合法。
|
||||
func loadJWTConfig() (*jwtRuntimeConfig, error) {
|
||||
accessKey, err := readJWTSecret(accessSecretConfigKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
refreshKey, err := readJWTSecret(refreshSecretConfigKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 使用密钥签名访问令牌
|
||||
accessTokenString, err := accessToken.SignedString(AccessKey)
|
||||
accessExpire, err := readJWTExpireDuration(accessExpireConfigKey, defaultAccessTokenExpire)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
refreshExpire, err := readJWTExpireDuration(refreshExpireConfigKey, defaultRefreshTokenExpire)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &jwtRuntimeConfig{
|
||||
AccessKey: accessKey,
|
||||
RefreshKey: refreshKey,
|
||||
AccessExpire: accessExpire,
|
||||
RefreshExpire: refreshExpire,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// readJWTSecret 负责读取并校验 JWT 密钥配置。
|
||||
// 职责边界:
|
||||
// 1. 负责“读配置 + 去空白 + 空值校验”。
|
||||
// 2. 不负责任何默认值回退,避免静默使用弱配置。
|
||||
// 3. 返回值语义:[]byte 为密钥;error 非空表示该配置项不可用。
|
||||
func readJWTSecret(configKey string) ([]byte, error) {
|
||||
secret := strings.TrimSpace(viper.GetString(configKey))
|
||||
if secret == "" {
|
||||
return nil, fmt.Errorf("jwt 配置缺失: %s", configKey)
|
||||
}
|
||||
return []byte(secret), nil
|
||||
}
|
||||
|
||||
// readJWTExpireDuration 负责读取并解析 JWT 过期时间配置。
|
||||
// 职责边界:
|
||||
// 1. 负责把字符串配置解析成 time.Duration,并保证结果大于 0。
|
||||
// 2. 不负责签发 token;仅提供“可计算”的过期时长。
|
||||
// 3. 返回值语义:duration 为最终时长;error 非空表示格式非法。
|
||||
func readJWTExpireDuration(configKey string, fallback time.Duration) (time.Duration, error) {
|
||||
raw := strings.TrimSpace(viper.GetString(configKey))
|
||||
if raw == "" {
|
||||
return fallback, nil
|
||||
}
|
||||
d, err := parseFlexibleDuration(raw)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("jwt 配置项 %s 非法: %w", configKey, err)
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("jwt 配置项 %s 必须大于 0", configKey)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// parseFlexibleDuration 负责解析项目内常见时长格式。
|
||||
// 职责边界:
|
||||
// 1. 负责兼容 Go 标准格式(如 15m、168h)与项目常见格式(如 15min、7d)。
|
||||
// 2. 不负责读取配置键名;仅解析输入字符串。
|
||||
// 3. 输入输出语义:raw 为原始时长文本;返回解析后的正时长或错误。
|
||||
func parseFlexibleDuration(raw string) (time.Duration, error) {
|
||||
normalized := strings.ToLower(strings.TrimSpace(raw))
|
||||
if normalized == "" {
|
||||
return 0, errors.New("时长不能为空")
|
||||
}
|
||||
|
||||
// 1. 先走 Go 原生解析,优先兼容标准写法(如 15m/168h)。
|
||||
if d, err := time.ParseDuration(normalized); err == nil {
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// 2. 原生解析失败后,兼容项目常见简写(如 15min、7d)。
|
||||
type unitDef struct {
|
||||
Suffix string
|
||||
Multiplier time.Duration
|
||||
}
|
||||
unitDefs := []unitDef{
|
||||
{Suffix: "minutes", Multiplier: time.Minute},
|
||||
{Suffix: "minute", Multiplier: time.Minute},
|
||||
{Suffix: "mins", Multiplier: time.Minute},
|
||||
{Suffix: "min", Multiplier: time.Minute},
|
||||
{Suffix: "days", Multiplier: 24 * time.Hour},
|
||||
{Suffix: "day", Multiplier: 24 * time.Hour},
|
||||
{Suffix: "d", Multiplier: 24 * time.Hour},
|
||||
{Suffix: "hours", Multiplier: time.Hour},
|
||||
{Suffix: "hour", Multiplier: time.Hour},
|
||||
{Suffix: "h", Multiplier: time.Hour},
|
||||
{Suffix: "seconds", Multiplier: time.Second},
|
||||
{Suffix: "second", Multiplier: time.Second},
|
||||
{Suffix: "secs", Multiplier: time.Second},
|
||||
{Suffix: "sec", Multiplier: time.Second},
|
||||
{Suffix: "m", Multiplier: time.Minute},
|
||||
{Suffix: "s", Multiplier: time.Second},
|
||||
}
|
||||
|
||||
for _, unit := range unitDefs {
|
||||
if !strings.HasSuffix(normalized, unit.Suffix) {
|
||||
continue
|
||||
}
|
||||
numberPart := strings.TrimSpace(strings.TrimSuffix(normalized, unit.Suffix))
|
||||
value, err := strconv.Atoi(numberPart)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("时长数值非法: %q", numberPart)
|
||||
}
|
||||
if value <= 0 {
|
||||
return 0, fmt.Errorf("时长数值必须大于 0: %d", value)
|
||||
}
|
||||
return time.Duration(value) * unit.Multiplier, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("不支持的时长格式: %s", raw)
|
||||
}
|
||||
|
||||
// GenerateTokens 负责按配置签发访问令牌与刷新令牌。
|
||||
// 职责边界:
|
||||
// 1. 负责根据配置生成 exp,并签发 access/refresh 双 token。
|
||||
// 2. 不负责登录鉴权(用户名/密码验证在 service 层处理)。
|
||||
// 3. 返回值语义:第一个为 access token,第二个为 refresh token,error 非空表示签发失败。
|
||||
func GenerateTokens(userID int) (string, string, error) {
|
||||
cfg, err := loadJWTConfig()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// 创建刷新令牌
|
||||
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": userID, // 获取用户ID
|
||||
"exp": time.Now().Add(7 * 24 * time.Hour).Unix(), // 设置刷新令牌过期时间为 7 天
|
||||
"token_type": "refresh_token", // 令牌类型为刷新令牌
|
||||
"jti": sid, // 亲子共用的 JWT ID
|
||||
})
|
||||
now := time.Now()
|
||||
sid := generateJTI()
|
||||
|
||||
// 使用密钥签名刷新令牌
|
||||
refreshTokenString, err := refreshToken.SignedString(RefreshKey)
|
||||
// 1. 先签 access token:短期有效,面向接口访问。
|
||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, model.MyCustomClaims{
|
||||
UserID: userID,
|
||||
TokenType: "access_token",
|
||||
Jti: sid,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(cfg.AccessExpire)),
|
||||
},
|
||||
})
|
||||
accessTokenString, err := accessToken.SignedString(cfg.AccessKey)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// 2. 再签 refresh token:长期有效,仅用于换发新 token。
|
||||
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, model.MyCustomClaims{
|
||||
UserID: userID,
|
||||
TokenType: "refresh_token",
|
||||
Jti: sid,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(cfg.RefreshExpire)),
|
||||
},
|
||||
})
|
||||
refreshTokenString, err := refreshToken.SignedString(cfg.RefreshKey)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
@@ -54,71 +220,45 @@ func GenerateTokens(userID int) (string, string, error) {
|
||||
return accessTokenString, refreshTokenString, nil
|
||||
}
|
||||
|
||||
// ValidateRefreshToken 验证刷新令牌的有效性
|
||||
/*func ValidateRefreshToken(tokenString string) (*jwt.Token, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// 检查签名方法是否为 HMAC
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, respond.InvalidTokenSingingMethod
|
||||
}
|
||||
// 返回用于验证的密钥
|
||||
return RefreshKey, nil
|
||||
})
|
||||
// ValidateRefreshToken 验证刷新令牌的有效性,并增加 Redis 黑名单检查。
|
||||
func ValidateRefreshToken(tokenString string, cache *dao.CacheDAO) (*jwt.Token, error) {
|
||||
cfg, err := loadJWTConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 进一步检查载荷中 token_type 是否正确
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, respond.InvalidClaims
|
||||
}
|
||||
// 检查 token_type 是否是 refresh_token
|
||||
if claimType, ok := claims["token_type"].(string); !ok || claimType != "refresh_token" {
|
||||
return nil, respond.WrongTokenType
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
*/
|
||||
|
||||
// ValidateRefreshToken 验证刷新令牌的有效性,并增加 Redis 黑名单检查
|
||||
func ValidateRefreshToken(tokenString string, cache *dao.CacheDAO) (*jwt.Token, error) {
|
||||
// 1. 解析 Token 并直接绑定到你的自定义结构体
|
||||
// 1. 解析 refresh token,并强制校验签名算法与密钥来源。
|
||||
token, err := jwt.ParseWithClaims(tokenString, &model.MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, respond.InvalidTokenSingingMethod
|
||||
}
|
||||
return RefreshKey, nil
|
||||
return cfg.RefreshKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, respond.InvalidRefreshToken
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return nil, respond.InvalidRefreshToken
|
||||
}
|
||||
|
||||
// 2. 断言获取 Claims
|
||||
// 2. 断言 claims 类型,后续业务字段都从结构体读取。
|
||||
claims, ok := token.Claims.(*model.MyCustomClaims)
|
||||
if !ok {
|
||||
return nil, respond.InvalidClaims
|
||||
}
|
||||
|
||||
// 3. 核心“设卡”:检查 token_type 是否是 refresh_token
|
||||
// 3. 校验 token_type,防止把 access token 当 refresh token 用。
|
||||
if claims.TokenType != "refresh_token" {
|
||||
return nil, respond.WrongTokenType
|
||||
}
|
||||
|
||||
// 4. --- 🛡️ 终极关卡:检查 Redis 黑名单 ---
|
||||
// 即使签名没过期,如果 jti 在黑名单里(用户已登出),也视为无效
|
||||
// 4. 黑名单校验:签名合法也要确认 jti 未被主动注销。
|
||||
isBlack, err := cache.IsBlacklisted(claims.Jti)
|
||||
if err != nil {
|
||||
// Redis 出错时的处理逻辑,建议报错以防“漏网之鱼”
|
||||
return nil, errors.New("无法验证令牌状态")
|
||||
}
|
||||
if isBlack {
|
||||
return nil, respond.UserLoggedOut // 返回你定义的“用户已登出”错误
|
||||
return nil, respond.UserLoggedOut
|
||||
}
|
||||
|
||||
return token, nil
|
||||
|
||||
128
backend/auth/jwt_handler_test.go
Normal file
128
backend/auth/jwt_handler_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func TestParseFlexibleDuration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
raw string
|
||||
want time.Duration
|
||||
wantFail bool
|
||||
}{
|
||||
{name: "标准格式", raw: "15m", want: 15 * time.Minute},
|
||||
{name: "项目分钟简写", raw: "15min", want: 15 * time.Minute},
|
||||
{name: "项目天简写", raw: "7d", want: 7 * 24 * time.Hour},
|
||||
{name: "非法格式", raw: "abc", wantFail: true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := parseFlexibleDuration(tc.raw)
|
||||
if tc.wantFail {
|
||||
if err == nil {
|
||||
t.Fatalf("期望解析失败,但得到成功: %s", tc.raw)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("解析失败: %v", err)
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Fatalf("解析结果不符合预期,got=%v want=%v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateTokens_UseConfigExpire(t *testing.T) {
|
||||
const (
|
||||
accessSecret = "unit-test-access-secret"
|
||||
refreshSecret = "unit-test-refresh-secret"
|
||||
accessExpire = "2h"
|
||||
refreshExpire = "3d"
|
||||
)
|
||||
|
||||
originAccessSecret := viper.GetString(accessSecretConfigKey)
|
||||
originRefreshSecret := viper.GetString(refreshSecretConfigKey)
|
||||
originAccessExpire := viper.GetString(accessExpireConfigKey)
|
||||
originRefreshExpire := viper.GetString(refreshExpireConfigKey)
|
||||
|
||||
viper.Set(accessSecretConfigKey, accessSecret)
|
||||
viper.Set(refreshSecretConfigKey, refreshSecret)
|
||||
viper.Set(accessExpireConfigKey, accessExpire)
|
||||
viper.Set(refreshExpireConfigKey, refreshExpire)
|
||||
|
||||
t.Cleanup(func() {
|
||||
viper.Set(accessSecretConfigKey, originAccessSecret)
|
||||
viper.Set(refreshSecretConfigKey, originRefreshSecret)
|
||||
viper.Set(accessExpireConfigKey, originAccessExpire)
|
||||
viper.Set(refreshExpireConfigKey, originRefreshExpire)
|
||||
})
|
||||
|
||||
start := time.Now()
|
||||
accessTokenString, refreshTokenString, err := GenerateTokens(9527)
|
||||
if err != nil {
|
||||
t.Fatalf("签发 token 失败: %v", err)
|
||||
}
|
||||
|
||||
accessClaims := parseTokenClaimsForTest(t, accessTokenString, []byte(accessSecret))
|
||||
refreshClaims := parseTokenClaimsForTest(t, refreshTokenString, []byte(refreshSecret))
|
||||
|
||||
if accessClaims.TokenType != "access_token" {
|
||||
t.Fatalf("access token_type 不符合预期: %s", accessClaims.TokenType)
|
||||
}
|
||||
if refreshClaims.TokenType != "refresh_token" {
|
||||
t.Fatalf("refresh token_type 不符合预期: %s", refreshClaims.TokenType)
|
||||
}
|
||||
if accessClaims.Jti == "" || refreshClaims.Jti == "" {
|
||||
t.Fatalf("jti 不能为空")
|
||||
}
|
||||
if accessClaims.Jti != refreshClaims.Jti {
|
||||
t.Fatalf("access/refresh 应共享同一个 jti")
|
||||
}
|
||||
|
||||
assertExpireNear(t, accessClaims.ExpiresAt.Time, start.Add(2*time.Hour), 3*time.Second)
|
||||
assertExpireNear(t, refreshClaims.ExpiresAt.Time, start.Add(3*24*time.Hour), 3*time.Second)
|
||||
}
|
||||
|
||||
func parseTokenClaimsForTest(t *testing.T, tokenString string, key []byte) *model.MyCustomClaims {
|
||||
t.Helper()
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &model.MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return key, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("解析 token 失败: %v", err)
|
||||
}
|
||||
if !token.Valid {
|
||||
t.Fatalf("token 无效")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*model.MyCustomClaims)
|
||||
if !ok {
|
||||
t.Fatalf("claims 类型断言失败")
|
||||
}
|
||||
return claims
|
||||
}
|
||||
|
||||
func assertExpireNear(t *testing.T, got time.Time, want time.Time, tolerance time.Duration) {
|
||||
t.Helper()
|
||||
delta := got.Sub(want)
|
||||
if delta < 0 {
|
||||
delta = -delta
|
||||
}
|
||||
if delta > tolerance {
|
||||
t.Fatalf("exp 偏差超出容忍范围,got=%s want=%s delta=%s tolerance=%s", got.Format(time.RFC3339), want.Format(time.RFC3339), delta, tolerance)
|
||||
}
|
||||
}
|
||||
373
backend/logic/refine_compound_ops.go
Normal file
373
backend/logic/refine_compound_ops.go
Normal file
@@ -0,0 +1,373 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// RefineTaskCandidate 表示复合工具规划阶段可移动的任务候选。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只承载“任务当前坐标 + 规划所需标签”;
|
||||
// 2. 不承载冲突判断、窗口判断等执行期逻辑;
|
||||
// 3. 由调用方保证 task_item_id 唯一且为正数。
|
||||
type RefineTaskCandidate struct {
|
||||
TaskItemID int
|
||||
Week int
|
||||
DayOfWeek int
|
||||
SectionFrom int
|
||||
SectionTo int
|
||||
Name string
|
||||
ContextTag string
|
||||
OriginRank int
|
||||
}
|
||||
|
||||
// RefineSlotCandidate 表示复合工具可选落点(坑位)。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只描述可候选的时段坐标;
|
||||
// 2. 不描述“为什么可用”,可用性由调用方预先筛好;
|
||||
// 3. Span 由 SectionFrom/SectionTo 推导,不单独存储。
|
||||
type RefineSlotCandidate struct {
|
||||
Week int
|
||||
DayOfWeek int
|
||||
SectionFrom int
|
||||
SectionTo int
|
||||
}
|
||||
|
||||
// RefineMovePlanItem 表示“任务 -> 目标坑位”的确定性规划结果。
|
||||
type RefineMovePlanItem struct {
|
||||
TaskItemID int
|
||||
ToWeek int
|
||||
ToDay int
|
||||
ToSectionFrom int
|
||||
ToSectionTo int
|
||||
}
|
||||
|
||||
// RefineCompositePlanOptions 是复合规划器的可选辅助输入。
|
||||
//
|
||||
// 说明:
|
||||
// 1. ExistingDayLoad 用于提供“目标范围内的既有负载基线”,用于均匀铺开打分;
|
||||
// 2. key 约定为 "week-day",例如 "16-3";
|
||||
// 3. 未提供时,规划器按 0 基线处理。
|
||||
type RefineCompositePlanOptions struct {
|
||||
ExistingDayLoad map[string]int
|
||||
}
|
||||
|
||||
// PlanEvenSpreadMoves 规划“均匀铺开”的确定性移动方案。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先按稳定顺序归一化任务与坑位,保证同输入必同输出;
|
||||
// 2. 逐任务选择“投放后日负载最小”的坑位,主目标是降低日负载离散度;
|
||||
// 3. 同分时按时间更早优先,进一步保证确定性;
|
||||
// 4. 若某任务不存在同跨度坑位,直接失败并返回明确错误。
|
||||
func PlanEvenSpreadMoves(tasks []RefineTaskCandidate, slots []RefineSlotCandidate, options RefineCompositePlanOptions) ([]RefineMovePlanItem, error) {
|
||||
normalizedTasks, err := normalizeRefineTasks(tasks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalizedSlots, err := normalizeRefineSlots(slots)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(normalizedSlots) < len(normalizedTasks) {
|
||||
return nil, fmt.Errorf("可用坑位不足:tasks=%d, slots=%d", len(normalizedTasks), len(normalizedSlots))
|
||||
}
|
||||
|
||||
// 1. dayLoad 记录“当前已占 + 本次规划已分配”的日负载。
|
||||
// 2. 这里先写入调用方提供的既有基线,再在循环中动态递增。
|
||||
dayLoad := make(map[string]int, len(options.ExistingDayLoad)+len(normalizedSlots))
|
||||
for key, value := range options.ExistingDayLoad {
|
||||
if value <= 0 {
|
||||
continue
|
||||
}
|
||||
dayLoad[strings.TrimSpace(key)] = value
|
||||
}
|
||||
|
||||
used := make([]bool, len(normalizedSlots))
|
||||
moves := make([]RefineMovePlanItem, 0, len(normalizedTasks))
|
||||
selectedSlots := make([]RefineSlotCandidate, 0, len(normalizedTasks))
|
||||
|
||||
for _, task := range normalizedTasks {
|
||||
taskSpan := sectionSpan(task.SectionFrom, task.SectionTo)
|
||||
bestIdx := -1
|
||||
bestScore := int(^uint(0) >> 1) // max int
|
||||
|
||||
for idx, slot := range normalizedSlots {
|
||||
if used[idx] {
|
||||
continue
|
||||
}
|
||||
if sectionSpan(slot.SectionFrom, slot.SectionTo) != taskSpan {
|
||||
continue
|
||||
}
|
||||
if slotOverlapsAny(slot, selectedSlots) {
|
||||
continue
|
||||
}
|
||||
dayKey := composeDayKey(slot.Week, slot.DayOfWeek)
|
||||
projectedLoad := dayLoad[dayKey] + 1
|
||||
// 1. projectedLoad 是主目标(越小越均衡);
|
||||
// 2. idx 是次级目标(越早的坑位越优先,保证稳定)。
|
||||
score := projectedLoad*10000 + idx
|
||||
if score < bestScore {
|
||||
bestScore = score
|
||||
bestIdx = idx
|
||||
}
|
||||
}
|
||||
if bestIdx < 0 {
|
||||
return nil, fmt.Errorf("任务 id=%d 无可用同跨度坑位", task.TaskItemID)
|
||||
}
|
||||
|
||||
chosen := normalizedSlots[bestIdx]
|
||||
used[bestIdx] = true
|
||||
selectedSlots = append(selectedSlots, chosen)
|
||||
dayLoad[composeDayKey(chosen.Week, chosen.DayOfWeek)]++
|
||||
moves = append(moves, RefineMovePlanItem{
|
||||
TaskItemID: task.TaskItemID,
|
||||
ToWeek: chosen.Week,
|
||||
ToDay: chosen.DayOfWeek,
|
||||
ToSectionFrom: chosen.SectionFrom,
|
||||
ToSectionTo: chosen.SectionTo,
|
||||
})
|
||||
}
|
||||
return moves, nil
|
||||
}
|
||||
|
||||
// PlanMinContextSwitchMoves 规划“同科目上下文切换最少”的确定性移动方案。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先把任务按 context_tag 分组,目标是让同组任务尽量连续;
|
||||
// 2. 分组顺序按“组大小降序 + 最早 origin_rank + 标签字典序”稳定排序;
|
||||
// 3. 组内按任务稳定顺序排,再顺序填入时间上最早可用同跨度坑位;
|
||||
// 4. 若某任务不存在同跨度坑位,立即失败并返回明确错误。
|
||||
func PlanMinContextSwitchMoves(tasks []RefineTaskCandidate, slots []RefineSlotCandidate, _ RefineCompositePlanOptions) ([]RefineMovePlanItem, error) {
|
||||
normalizedTasks, err := normalizeRefineTasks(tasks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalizedSlots, err := normalizeRefineSlots(slots)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(normalizedSlots) < len(normalizedTasks) {
|
||||
return nil, fmt.Errorf("可用坑位不足:tasks=%d, slots=%d", len(normalizedTasks), len(normalizedSlots))
|
||||
}
|
||||
|
||||
type taskGroup struct {
|
||||
ContextKey string
|
||||
Tasks []RefineTaskCandidate
|
||||
MinRank int
|
||||
}
|
||||
groupMap := make(map[string]*taskGroup)
|
||||
groupOrder := make([]string, 0, len(normalizedTasks))
|
||||
|
||||
for _, task := range normalizedTasks {
|
||||
key := normalizeContextKey(task.ContextTag)
|
||||
group, exists := groupMap[key]
|
||||
if !exists {
|
||||
group = &taskGroup{
|
||||
ContextKey: key,
|
||||
MinRank: normalizedOriginRank(task),
|
||||
}
|
||||
groupMap[key] = group
|
||||
groupOrder = append(groupOrder, key)
|
||||
}
|
||||
group.Tasks = append(group.Tasks, task)
|
||||
if rank := normalizedOriginRank(task); rank < group.MinRank {
|
||||
group.MinRank = rank
|
||||
}
|
||||
}
|
||||
|
||||
groups := make([]taskGroup, 0, len(groupMap))
|
||||
for _, key := range groupOrder {
|
||||
group := groupMap[key]
|
||||
sort.SliceStable(group.Tasks, func(i, j int) bool {
|
||||
return compareTaskOrder(group.Tasks[i], group.Tasks[j]) < 0
|
||||
})
|
||||
groups = append(groups, *group)
|
||||
}
|
||||
sort.SliceStable(groups, func(i, j int) bool {
|
||||
if len(groups[i].Tasks) != len(groups[j].Tasks) {
|
||||
return len(groups[i].Tasks) > len(groups[j].Tasks)
|
||||
}
|
||||
if groups[i].MinRank != groups[j].MinRank {
|
||||
return groups[i].MinRank < groups[j].MinRank
|
||||
}
|
||||
return groups[i].ContextKey < groups[j].ContextKey
|
||||
})
|
||||
|
||||
orderedTasks := make([]RefineTaskCandidate, 0, len(normalizedTasks))
|
||||
for _, group := range groups {
|
||||
orderedTasks = append(orderedTasks, group.Tasks...)
|
||||
}
|
||||
|
||||
used := make([]bool, len(normalizedSlots))
|
||||
moves := make([]RefineMovePlanItem, 0, len(orderedTasks))
|
||||
selectedSlots := make([]RefineSlotCandidate, 0, len(orderedTasks))
|
||||
for _, task := range orderedTasks {
|
||||
taskSpan := sectionSpan(task.SectionFrom, task.SectionTo)
|
||||
chosenIdx := -1
|
||||
for idx, slot := range normalizedSlots {
|
||||
if used[idx] {
|
||||
continue
|
||||
}
|
||||
if sectionSpan(slot.SectionFrom, slot.SectionTo) != taskSpan {
|
||||
continue
|
||||
}
|
||||
if slotOverlapsAny(slot, selectedSlots) {
|
||||
continue
|
||||
}
|
||||
chosenIdx = idx
|
||||
break
|
||||
}
|
||||
if chosenIdx < 0 {
|
||||
return nil, fmt.Errorf("任务 id=%d 无可用同跨度坑位", task.TaskItemID)
|
||||
}
|
||||
chosen := normalizedSlots[chosenIdx]
|
||||
used[chosenIdx] = true
|
||||
selectedSlots = append(selectedSlots, chosen)
|
||||
moves = append(moves, RefineMovePlanItem{
|
||||
TaskItemID: task.TaskItemID,
|
||||
ToWeek: chosen.Week,
|
||||
ToDay: chosen.DayOfWeek,
|
||||
ToSectionFrom: chosen.SectionFrom,
|
||||
ToSectionTo: chosen.SectionTo,
|
||||
})
|
||||
}
|
||||
return moves, nil
|
||||
}
|
||||
|
||||
func normalizeRefineTasks(tasks []RefineTaskCandidate) ([]RefineTaskCandidate, error) {
|
||||
if len(tasks) == 0 {
|
||||
return nil, fmt.Errorf("任务列表为空")
|
||||
}
|
||||
normalized := make([]RefineTaskCandidate, 0, len(tasks))
|
||||
seen := make(map[int]struct{}, len(tasks))
|
||||
for _, task := range tasks {
|
||||
if task.TaskItemID <= 0 {
|
||||
return nil, fmt.Errorf("存在非法 task_item_id=%d", task.TaskItemID)
|
||||
}
|
||||
if _, exists := seen[task.TaskItemID]; exists {
|
||||
return nil, fmt.Errorf("任务 id=%d 重复", task.TaskItemID)
|
||||
}
|
||||
if !isValidDay(task.DayOfWeek) {
|
||||
return nil, fmt.Errorf("任务 id=%d day_of_week 非法=%d", task.TaskItemID, task.DayOfWeek)
|
||||
}
|
||||
if !isValidSection(task.SectionFrom, task.SectionTo) {
|
||||
return nil, fmt.Errorf("任务 id=%d 节次区间非法=%d-%d", task.TaskItemID, task.SectionFrom, task.SectionTo)
|
||||
}
|
||||
seen[task.TaskItemID] = struct{}{}
|
||||
normalized = append(normalized, task)
|
||||
}
|
||||
sort.SliceStable(normalized, func(i, j int) bool {
|
||||
return compareTaskOrder(normalized[i], normalized[j]) < 0
|
||||
})
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func normalizeRefineSlots(slots []RefineSlotCandidate) ([]RefineSlotCandidate, error) {
|
||||
if len(slots) == 0 {
|
||||
return nil, fmt.Errorf("可用坑位为空")
|
||||
}
|
||||
normalized := make([]RefineSlotCandidate, 0, len(slots))
|
||||
seen := make(map[string]struct{}, len(slots))
|
||||
for _, slot := range slots {
|
||||
if slot.Week <= 0 {
|
||||
return nil, fmt.Errorf("存在非法 week=%d", slot.Week)
|
||||
}
|
||||
if !isValidDay(slot.DayOfWeek) {
|
||||
return nil, fmt.Errorf("存在非法 day_of_week=%d", slot.DayOfWeek)
|
||||
}
|
||||
if !isValidSection(slot.SectionFrom, slot.SectionTo) {
|
||||
return nil, fmt.Errorf("存在非法节次区间=%d-%d", slot.SectionFrom, slot.SectionTo)
|
||||
}
|
||||
key := fmt.Sprintf("%d-%d-%d-%d", slot.Week, slot.DayOfWeek, slot.SectionFrom, slot.SectionTo)
|
||||
if _, exists := seen[key]; exists {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
normalized = append(normalized, slot)
|
||||
}
|
||||
sort.SliceStable(normalized, func(i, j int) bool {
|
||||
if normalized[i].Week != normalized[j].Week {
|
||||
return normalized[i].Week < normalized[j].Week
|
||||
}
|
||||
if normalized[i].DayOfWeek != normalized[j].DayOfWeek {
|
||||
return normalized[i].DayOfWeek < normalized[j].DayOfWeek
|
||||
}
|
||||
if normalized[i].SectionFrom != normalized[j].SectionFrom {
|
||||
return normalized[i].SectionFrom < normalized[j].SectionFrom
|
||||
}
|
||||
return normalized[i].SectionTo < normalized[j].SectionTo
|
||||
})
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func compareTaskOrder(a, b RefineTaskCandidate) int {
|
||||
rankA := normalizedOriginRank(a)
|
||||
rankB := normalizedOriginRank(b)
|
||||
if rankA != rankB {
|
||||
return rankA - rankB
|
||||
}
|
||||
if a.Week != b.Week {
|
||||
return a.Week - b.Week
|
||||
}
|
||||
if a.DayOfWeek != b.DayOfWeek {
|
||||
return a.DayOfWeek - b.DayOfWeek
|
||||
}
|
||||
if a.SectionFrom != b.SectionFrom {
|
||||
return a.SectionFrom - b.SectionFrom
|
||||
}
|
||||
if a.SectionTo != b.SectionTo {
|
||||
return a.SectionTo - b.SectionTo
|
||||
}
|
||||
return a.TaskItemID - b.TaskItemID
|
||||
}
|
||||
|
||||
func normalizedOriginRank(task RefineTaskCandidate) int {
|
||||
if task.OriginRank > 0 {
|
||||
return task.OriginRank
|
||||
}
|
||||
// 1. 无 origin_rank 时回退到较大稳定值,避免把“未知顺序”抢到前面。
|
||||
// 2. 叠加 task_id 作为细粒度稳定因子,保证排序可复现。
|
||||
return 1_000_000 + task.TaskItemID
|
||||
}
|
||||
|
||||
func normalizeContextKey(tag string) string {
|
||||
text := strings.TrimSpace(tag)
|
||||
if text == "" {
|
||||
return "General"
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func composeDayKey(week, day int) string {
|
||||
return fmt.Sprintf("%d-%d", week, day)
|
||||
}
|
||||
|
||||
func sectionSpan(from, to int) int {
|
||||
return to - from + 1
|
||||
}
|
||||
|
||||
func isValidDay(day int) bool {
|
||||
return day >= 1 && day <= 7
|
||||
}
|
||||
|
||||
func isValidSection(from, to int) bool {
|
||||
if from < 1 || to > 12 {
|
||||
return false
|
||||
}
|
||||
return from <= to
|
||||
}
|
||||
|
||||
func slotOverlapsAny(candidate RefineSlotCandidate, selected []RefineSlotCandidate) bool {
|
||||
for _, current := range selected {
|
||||
if current.Week != candidate.Week || current.DayOfWeek != candidate.DayOfWeek {
|
||||
continue
|
||||
}
|
||||
if current.SectionFrom <= candidate.SectionTo && candidate.SectionFrom <= current.SectionTo {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
95
backend/logic/refine_compound_ops_test.go
Normal file
95
backend/logic/refine_compound_ops_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPlanEvenSpreadMovesPrefersLowerLoadDay(t *testing.T) {
|
||||
tasks := []RefineTaskCandidate{
|
||||
{TaskItemID: 101, Week: 16, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2, OriginRank: 1},
|
||||
{TaskItemID: 102, Week: 16, DayOfWeek: 1, SectionFrom: 3, SectionTo: 4, OriginRank: 2},
|
||||
}
|
||||
slots := []RefineSlotCandidate{
|
||||
{Week: 12, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2},
|
||||
{Week: 12, DayOfWeek: 2, SectionFrom: 1, SectionTo: 2},
|
||||
{Week: 12, DayOfWeek: 3, SectionFrom: 1, SectionTo: 2},
|
||||
}
|
||||
moves, err := PlanEvenSpreadMoves(tasks, slots, RefineCompositePlanOptions{
|
||||
ExistingDayLoad: map[string]int{
|
||||
composeDayKey(12, 1): 5,
|
||||
composeDayKey(12, 2): 1,
|
||||
composeDayKey(12, 3): 0,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PlanEvenSpreadMoves 返回错误: %v", err)
|
||||
}
|
||||
if len(moves) != 2 {
|
||||
t.Fatalf("期望移动 2 条,实际=%d", len(moves))
|
||||
}
|
||||
|
||||
// 1. 低负载日(周三)应优先被填充;
|
||||
// 2. 第二条应落在次低负载日(周二),而不是高负载日(周一)。
|
||||
weekDayByID := make(map[int][2]int, len(moves))
|
||||
for _, move := range moves {
|
||||
weekDayByID[move.TaskItemID] = [2]int{move.ToWeek, move.ToDay}
|
||||
}
|
||||
if got := weekDayByID[101]; got != [2]int{12, 3} {
|
||||
t.Fatalf("任务101应优先落到 W12D3,实际=%v", got)
|
||||
}
|
||||
if got := weekDayByID[102]; got != [2]int{12, 2} {
|
||||
t.Fatalf("任务102应落到 W12D2,实际=%v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlanMinContextSwitchMovesGroupsSameContext(t *testing.T) {
|
||||
tasks := []RefineTaskCandidate{
|
||||
{TaskItemID: 201, Week: 16, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2, ContextTag: "数学", OriginRank: 1},
|
||||
{TaskItemID: 202, Week: 16, DayOfWeek: 1, SectionFrom: 3, SectionTo: 4, ContextTag: "算法", OriginRank: 2},
|
||||
{TaskItemID: 203, Week: 16, DayOfWeek: 1, SectionFrom: 5, SectionTo: 6, ContextTag: "数学", OriginRank: 3},
|
||||
}
|
||||
slots := []RefineSlotCandidate{
|
||||
{Week: 12, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2},
|
||||
{Week: 12, DayOfWeek: 1, SectionFrom: 3, SectionTo: 4},
|
||||
{Week: 12, DayOfWeek: 1, SectionFrom: 5, SectionTo: 6},
|
||||
}
|
||||
moves, err := PlanMinContextSwitchMoves(tasks, slots, RefineCompositePlanOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("PlanMinContextSwitchMoves 返回错误: %v", err)
|
||||
}
|
||||
if len(moves) != 3 {
|
||||
t.Fatalf("期望移动 3 条,实际=%d", len(moves))
|
||||
}
|
||||
|
||||
// 1. “数学”有 2 条,分组后应先连续落在最早两个坑位;
|
||||
// 2. 因此 201 与 203 对应的目标节次应是 1-2 与 3-4(顺序由 origin_rank 决定)。
|
||||
sort.SliceStable(moves, func(i, j int) bool {
|
||||
if moves[i].ToWeek != moves[j].ToWeek {
|
||||
return moves[i].ToWeek < moves[j].ToWeek
|
||||
}
|
||||
if moves[i].ToDay != moves[j].ToDay {
|
||||
return moves[i].ToDay < moves[j].ToDay
|
||||
}
|
||||
return moves[i].ToSectionFrom < moves[j].ToSectionFrom
|
||||
})
|
||||
if moves[0].TaskItemID != 201 || moves[1].TaskItemID != 203 {
|
||||
t.Fatalf("期望前两个坑位由同上下文任务占据,实际=%+v", moves)
|
||||
}
|
||||
if moves[2].TaskItemID != 202 {
|
||||
t.Fatalf("期望最后一个坑位为算法任务,实际=%+v", moves[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlanEvenSpreadMovesReturnsErrorWhenSpanNotMatched(t *testing.T) {
|
||||
tasks := []RefineTaskCandidate{
|
||||
{TaskItemID: 301, Week: 16, DayOfWeek: 1, SectionFrom: 1, SectionTo: 3, OriginRank: 1}, // span=3
|
||||
}
|
||||
slots := []RefineSlotCandidate{
|
||||
{Week: 12, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2}, // span=2
|
||||
}
|
||||
_, err := PlanEvenSpreadMoves(tasks, slots, RefineCompositePlanOptions{})
|
||||
if err == nil {
|
||||
t.Fatalf("期望 span 不匹配时报错,实际 err=nil")
|
||||
}
|
||||
}
|
||||
@@ -3,57 +3,89 @@ package middleware
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/auth"
|
||||
"github.com/LoveLosita/smartflow/backend/dao"
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
"github.com/LoveLosita/smartflow/backend/respond"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
// JWTTokenAuth 接收 cache 实例,体现依赖注入
|
||||
// extractTokenFromAuthorization 负责解析 Authorization 头中的 token。
|
||||
// 职责边界:
|
||||
// 1. 兼容“裸 token”和“Bearer <token>”两种传参方式。
|
||||
// 2. 不负责 token 合法性校验,只做字符串提取。
|
||||
// 3. 输入输出语义:header 为空或格式非法时返回空字符串。
|
||||
func extractTokenFromAuthorization(header string) string {
|
||||
trimmed := strings.TrimSpace(header)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := strings.Fields(trimmed)
|
||||
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
if len(parts) == 1 {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// JWTTokenAuth 负责 access token 的鉴权拦截。
|
||||
// 职责边界:
|
||||
// 1. 负责解析 token、验签、校验 token_type 与黑名单状态。
|
||||
// 2. 不负责签发 token,也不负责用户登录逻辑。
|
||||
// 3. 输出语义:校验通过时写入 user_id/claims 到上下文并放行;失败则中断请求。
|
||||
func JWTTokenAuth(cache *dao.CacheDAO) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 1. 获取 Token (Gin 的 GetHeader 直接返回 string)
|
||||
tokenString := c.GetHeader("Authorization")
|
||||
tokenString := extractTokenFromAuthorization(c.GetHeader("Authorization"))
|
||||
if tokenString == "" {
|
||||
c.JSON(http.StatusUnauthorized, respond.MissingToken)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 改动:使用 ParseWithClaims 直接解析到你的结构体
|
||||
// 假设你的结构体叫 model.MyCustomClaims
|
||||
token, err := jwt.ParseWithClaims(tokenString, &model.MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return auth.AccessKey, nil
|
||||
})
|
||||
accessKey, err := auth.AccessSigningKey()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, respond.InternalError(err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 1. 先验签并由 jwt 库统一校验 exp 等标准声明。
|
||||
token, err := jwt.ParseWithClaims(tokenString, &model.MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, respond.InvalidTokenSingingMethod
|
||||
}
|
||||
return accessKey, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
c.JSON(http.StatusUnauthorized, respond.InvalidToken)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 校验 Claims
|
||||
// 2. 再做业务声明校验,防止 refresh token 越权访问业务接口。
|
||||
claims, ok := token.Claims.(*model.MyCustomClaims)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, respond.InvalidClaims)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// --- 🛡️ 核心改造:设卡检查 ---
|
||||
if claims.TokenType != "access_token" {
|
||||
c.JSON(http.StatusUnauthorized, respond.WrongTokenType)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 拿着 jti 去 Redis 查一下
|
||||
// 3. 最后查黑名单,兜住“用户已登出但 token 仍未到期”的场景。
|
||||
isBlack, err := cache.IsBlacklisted(claims.Jti)
|
||||
if err != nil {
|
||||
// 如果 Redis 挂了,为了安全通常选择报错,或者降级放行(取决于你的业务)
|
||||
c.JSON(http.StatusInternalServerError, respond.InternalError(errors.New("无法验证令牌状态")))
|
||||
c.Abort()
|
||||
return
|
||||
@@ -64,9 +96,8 @@ func JWTTokenAuth(cache *dao.CacheDAO) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 存入上下文
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("claims", claims)
|
||||
c.Next() // 只有所有关卡都过了,才放行
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user