diff --git a/AGENTS.md b/AGENTS.md index 9bb8fe8..319f1b8 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -15,6 +15,7 @@ 11. 若本轮任务包含“结构迁移”,最终答复中必须明确说明:本轮迁了什么、哪些旧实现仍保留、当前切流点在哪里、下一轮建议迁什么。 12. 若后续在 `backend/agent` 中新增、下沉、替换任何“通用能力”,必须同步更新 `backend/agent/通用能力接入文档.md`,否则视为重构信息不完整。 13. 写完代码后,如果输入输出格式明确、逻辑可验证(如数据转换函数、解析函数、工具层操作),必须编写单元测试验证正确性。跑完之后删除测试文件(`*_test.go`),禁止把测试文件长期留在项目中。 +14. 当 Claude Code 帮助操作 git 提交时,commit message 中禁止出现与 Claude 协同相关的描述(如 Co-Authored-By 等),只保留项目本身的内容。 ## 注释规范(强制) diff --git a/backend/newAgent/model/graph_run_state.go b/backend/newAgent/model/graph_run_state.go index b2eaf37..acbcfd8 100644 --- a/backend/newAgent/model/graph_run_state.go +++ b/backend/newAgent/model/graph_run_state.go @@ -1,10 +1,12 @@ package model import ( + "context" "strings" newagentllm "github.com/LoveLosita/smartflow/backend/newAgent/llm" newagentstream "github.com/LoveLosita/smartflow/backend/newAgent/stream" + newagenttools "github.com/LoveLosita/smartflow/backend/newAgent/tools" ) // AgentGraphRequest 描述一次 agent graph 运行的请求级输入。 @@ -34,12 +36,14 @@ func (r *AgentGraphRequest) Normalize() { // 2. Chat/Plan/Execute/Deliver 允许分别挂不同 client,但也允许先复用同一个 client; // 3. ChunkEmitter 统一承接阶段提示、正文、工具事件、确认请求等 SSE 输出。 type AgentGraphDeps struct { - ChatClient *newagentllm.Client - PlanClient *newagentllm.Client - ExecuteClient *newagentllm.Client - DeliverClient *newagentllm.Client - ChunkEmitter *newagentstream.ChunkEmitter - StateStore AgentStateStore + ChatClient *newagentllm.Client + PlanClient *newagentllm.Client + ExecuteClient *newagentllm.Client + DeliverClient *newagentllm.Client + ChunkEmitter *newagentstream.ChunkEmitter + StateStore AgentStateStore + ToolRegistry *newagenttools.ToolRegistry + ScheduleProvider ScheduleStateProvider // 按 DAO 注入,Execute 节点按需加载 ScheduleState } // EnsureChunkEmitter 保证 graph 运行时始终有一个可用的 chunk 发射器。 @@ -133,11 +137,25 @@ type AgentGraphRunInput struct { // 1. 负责把“流程状态 + 对话上下文 + 请求输入 + 运行依赖”收口到同一个对象; // 2. 负责给 graph 分支和 node 提供最小必要的兜底访问方法; // 3. 不负责持久化,不负责真正业务执行。 +// ScheduleStateProvider 定义加载 ScheduleState 的接口。 +// 由 DAO 层或 Service 层实现,注入到 AgentGraphDeps 中。 +// 使用接口而非具体 DAO 类型,避免 model → dao 的循环依赖。 +type ScheduleStateProvider interface { + LoadScheduleState(ctx context.Context, userID int) (*newagenttools.ScheduleState, error) +} + +// AgentGraphState 是 graph 内部真正流转的运行态容器。 +// +// 职责边界: +// 1. 负责把"流程状态 + 对话上下文 + 请求输入 + 运行依赖"收口到同一个对象; +// 2. 负责给 graph 分支和 node 提供最小必要的兜底访问方法; +// 3. 不负责持久化,不负责真正业务执行。 type AgentGraphState struct { RuntimeState *AgentRuntimeState ConversationContext *ConversationContext Request AgentGraphRequest Deps AgentGraphDeps + ScheduleState *newagenttools.ScheduleState // 工具操作的内存数据源,Execute 节点按需加载 } // NewAgentGraphState 把入口参数整理成 graph 内部状态。 @@ -194,3 +212,32 @@ func (s *AgentGraphState) EnsureChunkEmitter() *newagentstream.ChunkEmitter { } return s.Deps.EnsureChunkEmitter() } + +// ResolveToolRegistry 返回可用的工具注册表。 +func (s *AgentGraphState) ResolveToolRegistry() *newagenttools.ToolRegistry { + if s == nil { + return nil + } + return s.Deps.ToolRegistry +} + +// EnsureScheduleState 确保 ScheduleState 已加载。 +// 首次调用时通过 ScheduleProvider 从 DB 加载,后续复用内存中的 state。 +func (s *AgentGraphState) EnsureScheduleState(ctx context.Context) (*newagenttools.ScheduleState, error) { + if s == nil { + return nil, nil + } + if s.ScheduleState != nil { + return s.ScheduleState, nil + } + if s.Deps.ScheduleProvider == nil { + return nil, nil + } + userID := s.EnsureFlowState().UserID + state, err := s.Deps.ScheduleProvider.LoadScheduleState(ctx, userID) + if err != nil { + return nil, err + } + s.ScheduleState = state + return state, nil +} diff --git a/backend/newAgent/node/agent_nodes.go b/backend/newAgent/node/agent_nodes.go index 2e5ad3e..a57a94a 100644 --- a/backend/newAgent/node/agent_nodes.go +++ b/backend/newAgent/node/agent_nodes.go @@ -5,6 +5,7 @@ import ( "errors" newagentmodel "github.com/LoveLosita/smartflow/backend/newAgent/model" + newagenttools "github.com/LoveLosita/smartflow/backend/newAgent/tools" ) // AgentNodes 是 newAgent 通用图的节点容器。 @@ -144,6 +145,12 @@ func (n *AgentNodes) Execute(ctx context.Context, st *newagentmodel.AgentGraphSt return nil, errors.New("execute node: state is nil") } + // 按需加载 ScheduleState(首次执行时从 DB 加载,后续复用内存中的 state)。 + var scheduleState *newagenttools.ScheduleState + if ss, _ := st.EnsureScheduleState(ctx); ss != nil { + scheduleState = ss + } + if err := RunExecuteNode( ctx, ExecuteNodeInput{ @@ -153,6 +160,8 @@ func (n *AgentNodes) Execute(ctx context.Context, st *newagentmodel.AgentGraphSt Client: st.Deps.ResolveExecuteClient(), ChunkEmitter: st.EnsureChunkEmitter(), ResumeNode: "execute", + ToolRegistry: st.Deps.ToolRegistry, + ScheduleState: scheduleState, }, ); err != nil { return nil, err diff --git a/backend/newAgent/node/execute.go b/backend/newAgent/node/execute.go index 8cec486..ff9aa34 100644 --- a/backend/newAgent/node/execute.go +++ b/backend/newAgent/node/execute.go @@ -11,6 +11,8 @@ import ( newagentmodel "github.com/LoveLosita/smartflow/backend/newAgent/model" newagentprompt "github.com/LoveLosita/smartflow/backend/newAgent/prompt" newagentstream "github.com/LoveLosita/smartflow/backend/newAgent/stream" + newagenttools "github.com/LoveLosita/smartflow/backend/newAgent/tools" + "github.com/cloudwego/eino/schema" "github.com/google/uuid" ) @@ -27,7 +29,8 @@ const ( // 1. 只承载"本轮执行"需要的输入,不负责持久化; // 2. RuntimeState 提供 plan 步骤与轮次预算; // 3. ConversationContext 提供历史对话与置顶上下文; -// 4. ToolExecutor 后续由业务层注入,当前先留空。 +// 4. ToolRegistry 提供工具注册表; +// 5. ScheduleState 提供工具操作的内存数据源(可为 nil,由调用方按需加载)。 type ExecuteNodeInput struct { RuntimeState *newagentmodel.AgentRuntimeState ConversationContext *newagentmodel.ConversationContext @@ -35,6 +38,8 @@ type ExecuteNodeInput struct { Client *newagentllm.Client ChunkEmitter *newagentstream.ChunkEmitter ResumeNode string + ToolRegistry *newagenttools.ToolRegistry + ScheduleState *newagenttools.ScheduleState // 工具操作的内存数据源,由调用方从 AgentGraphState 注入 } // ExecuteRoundObservation 记录执行阶段每轮的关键观察。 @@ -167,7 +172,7 @@ func RunExecuteNode(ctx context.Context, input ExecuteNodeInput) error { // 继续当前步骤的 ReAct 循环。 // 若有工具调用意图,则执行工具并记录证据。 if decision.ToolCall != nil { - return executeToolCall(ctx, flowState, conversationContext, decision.ToolCall, emitter) + return executeToolCall(ctx, flowState, conversationContext, decision.ToolCall, emitter, input.ToolRegistry, input.ScheduleState) } // 无工具调用,仅对话,继续下一轮。 return nil @@ -298,29 +303,19 @@ func handleExecuteActionConfirm( // 1. 只负责执行工具调用,记录结果; // 2. 不负责判断工具调用是否成功(由 LLM 下一轮判断); // 3. 不负责重试(由外层 Graph 循环控制)。 -// -// TODO: 当前为骨架实现,后续需要: -// 1. 接入真实的工具执行器; -// 2. 把工具调用结果追加到对话历史; -// 3. 记录 ExecuteEvidenceReceipt。 func executeToolCall( ctx context.Context, flowState *newagentmodel.CommonState, conversationContext *newagentmodel.ConversationContext, toolCall *newagentmodel.ToolCallIntent, emitter *newagentstream.ChunkEmitter, + registry *newagenttools.ToolRegistry, + scheduleState *newagenttools.ScheduleState, ) error { if toolCall == nil { return nil } - // 当前为骨架实现,仅记录工具调用意图。 - // 后续需要: - // 1. 根据 toolCall.Name 路由到具体工具执行器; - // 2. 执行工具调用,获取结果; - // 3. 记录 ExecuteEvidenceReceipt; - // 4. 把工具调用结果追加到 conversationContext.History。 - toolName := strings.TrimSpace(toolCall.Name) if toolName == "" { return fmt.Errorf("工具调用缺少工具名称") @@ -337,17 +332,25 @@ func executeToolCall( return fmt.Errorf("工具调用状态推送失败: %w", err) } - // TODO: 执行真实工具调用,并记录证据。 - // 伪代码: - // result := toolRegistry.Execute(ctx, toolCall.Name, toolCall.Arguments) - // evidence := ExecuteEvidenceReceipt{ - // StepIndex: flowState.CurrentStep, - // Source: ExecuteEvidenceSourceToolObservation, - // Name: toolCall.Name, - // Success: result.Success, - // Summary: result.Summary, - // } - // flowState.RecordEvidence(evidence) + // 1. 校验依赖。 + if registry == nil { + return fmt.Errorf("工具注册表未注入") + } + if scheduleState == nil { + return fmt.Errorf("日程状态未加载,无法执行工具") + } + if !registry.HasTool(toolName) { + return fmt.Errorf("未知工具: %s", toolName) + } + + // 2. 执行工具。 + result := registry.Execute(scheduleState, toolName, toolCall.Arguments) + + // 3. 将工具结果追加到对话历史,让 LLM 下一轮能看到。 + conversationContext.AppendHistory(&schema.Message{ + Role: schema.Tool, + Content: result, + }) return nil } diff --git a/backend/newAgent/tools/registry.go b/backend/newAgent/tools/registry.go new file mode 100644 index 0000000..ae27722 --- /dev/null +++ b/backend/newAgent/tools/registry.go @@ -0,0 +1,326 @@ +package newagenttools + +import ( + "fmt" + "sort" + "strings" +) + +// ToolHandler 是所有工具的统一执行签名。 +// 接收当前 ScheduleState + LLM 输出的原始参数,返回自然语言结果。 +type ToolHandler func(state *ScheduleState, args map[string]any) string + +// ToolSchemaEntry 是工具描述的轻量快照,用于 LLM prompt 注入。 +// 在注入 ConversationContext 时转换为 model.ToolSchemaContext。 +type ToolSchemaEntry struct { + Name string + Desc string + SchemaText string +} + +// ToolRegistry 管理所有工具的注册、查找和执行。 +// +// 职责边界: +// 1. 负责工具名 → handler 的映射; +// 2. 负责工具 schema 的存储(供 LLM prompt 注入); +// 3. 不负责 ScheduleState 的生命周期管理; +// 4. 不负责 confirm 流程(由 execute.go 的 action 分支处理)。 +type ToolRegistry struct { + handlers map[string]ToolHandler + schemas []ToolSchemaEntry +} + +// NewToolRegistry 创建空注册表。 +func NewToolRegistry() *ToolRegistry { + return &ToolRegistry{ + handlers: make(map[string]ToolHandler), + schemas: make([]ToolSchemaEntry, 0), + } +} + +// Register 注册一个工具及其 schema 描述。 +func (r *ToolRegistry) Register(name, desc, schemaText string, handler ToolHandler) { + r.handlers[name] = handler + r.schemas = append(r.schemas, ToolSchemaEntry{ + Name: name, + Desc: desc, + SchemaText: schemaText, + }) +} + +// Execute 执行指定工具。 +// 工具名不存在时返回错误提示字符串。 +func (r *ToolRegistry) Execute(state *ScheduleState, toolName string, args map[string]any) string { + handler, ok := r.handlers[toolName] + if !ok { + return fmt.Sprintf("工具调用失败:未知工具 %q。可用工具:%s", toolName, strings.Join(r.ToolNames(), "、")) + } + return handler(state, args) +} + +// HasTool 检查工具是否已注册。 +func (r *ToolRegistry) HasTool(name string) bool { + _, ok := r.handlers[name] + return ok +} + +// ToolNames 返回所有已注册工具名(按注册顺序)。 +func (r *ToolRegistry) ToolNames() []string { + names := make([]string, 0, len(r.handlers)) + for _, s := range r.schemas { + names = append(names, s.Name) + } + return names +} + +// Schemas 返回所有工具的 schema 描述(供 LLM prompt 注入)。 +func (r *ToolRegistry) Schemas() []ToolSchemaEntry { + result := make([]ToolSchemaEntry, len(r.schemas)) + copy(result, r.schemas) + return result +} + +// IsWriteTool 判断指定工具是否为写工具(需要 confirm 流程)。 +func (r *ToolRegistry) IsWriteTool(name string) bool { + return writeTools[name] +} + +// ==================== 参数解析辅助 ==================== + +// argsInt 从 map 中提取 int 值。支持 float64(JSON 反序列化的默认类型)。 +func argsInt(args map[string]any, key string) (int, bool) { + v, ok := args[key] + if !ok { + return 0, false + } + switch n := v.(type) { + case float64: + return int(n), true + case int: + return n, true + } + return 0, false +} + +// argsString 从 map 中提取 string 值。 +func argsString(args map[string]any, key string) (string, bool) { + v, ok := args[key] + if !ok { + return "", false + } + s, ok := v.(string) + return s, ok +} + +// argsIntPtr 从 map 中提取可选 int 值,不存在返回 nil。 +func argsIntPtr(args map[string]any, key string) *int { + v, ok := argsInt(args, key) + if !ok { + return nil + } + return &v +} + +// argsStringPtr 从 map 中提取可选 string 值,不存在返回 nil。 +func argsStringPtr(args map[string]any, key string) *string { + v, ok := argsString(args, key) + if !ok { + return nil + } + return &v +} + +// argsMoveList 从 map 中提取 batch_move 的 moves 数组。 +func argsMoveList(args map[string]any) ([]MoveRequest, error) { + v, ok := args["moves"] + if !ok { + return nil, fmt.Errorf("缺少 moves 参数") + } + arr, ok := v.([]any) + if !ok { + return nil, fmt.Errorf("moves 参数必须是数组") + } + moves := make([]MoveRequest, 0, len(arr)) + for i, item := range arr { + m, ok := item.(map[string]any) + if !ok { + return nil, fmt.Errorf("moves[%d] 不是有效对象", i) + } + taskID, ok := argsInt(m, "task_id") + if !ok { + return nil, fmt.Errorf("moves[%d].task_id 缺失或无效", i) + } + newDay, ok := argsInt(m, "new_day") + if !ok { + return nil, fmt.Errorf("moves[%d].new_day 缺失或无效", i) + } + newSlotStart, ok := argsInt(m, "new_slot_start") + if !ok { + return nil, fmt.Errorf("moves[%d].new_slot_start 缺失或无效", i) + } + moves = append(moves, MoveRequest{ + TaskID: taskID, + NewDay: newDay, + NewSlotStart: newSlotStart, + }) + } + return moves, nil +} + +// ==================== 写工具名集合 ==================== + +var writeTools = map[string]bool{ + "place": true, + "move": true, + "swap": true, + "batch_move": true, + "unplace": true, +} + +// ==================== 默认注册表 ==================== + +// NewDefaultRegistry 创建包含全部 10 个日程工具的注册表。 +func NewDefaultRegistry() *ToolRegistry { + r := NewToolRegistry() + + // --- 读工具 --- + r.Register("get_overview", + "获取规划窗口的粗粒度总览,包括每日占用、可嵌入时段和待安排任务。", + `{"name":"get_overview","parameters":{}}`, + func(state *ScheduleState, args map[string]any) string { + return GetOverview(state) + }, + ) + + r.Register("query_range", + "查看某天或某时段的细粒度占用详情。day 必填,slot_start/slot_end 选填(不填查整天)。", + `{"name":"query_range","parameters":{"day":{"type":"int","required":true},"slot_start":{"type":"int"},"slot_end":{"type":"int"}}}`, + func(state *ScheduleState, args map[string]any) string { + day, ok := argsInt(args, "day") + if !ok { + return "查询失败:缺少必填参数 day。" + } + return QueryRange(state, day, argsIntPtr(args, "slot_start"), argsIntPtr(args, "slot_end")) + }, + ) + + r.Register("find_free", + "查找满足指定连续时段长度的空闲位置。duration 必填,day 选填(不填搜全部天)。", + `{"name":"find_free","parameters":{"duration":{"type":"int","required":true},"day":{"type":"int"}}}`, + func(state *ScheduleState, args map[string]any) string { + duration, ok := argsInt(args, "duration") + if !ok { + return "查询失败:缺少必填参数 duration。" + } + return FindFree(state, duration, argsIntPtr(args, "day")) + }, + ) + + r.Register("list_tasks", + "列出任务清单,可按类别和状态过滤。category 选填,status 选填(默认 all)。", + `{"name":"list_tasks","parameters":{"category":{"type":"string"},"status":{"type":"string","enum":["all","existing","pending"]}}}`, + func(state *ScheduleState, args map[string]any) string { + return ListTasks(state, argsStringPtr(args, "category"), argsStringPtr(args, "status")) + }, + ) + + r.Register("get_task_info", + "查询单个任务的详细信息,包括类别、状态、占用时段、嵌入关系。", + `{"name":"get_task_info","parameters":{"task_id":{"type":"int","required":true}}}`, + func(state *ScheduleState, args map[string]any) string { + taskID, ok := argsInt(args, "task_id") + if !ok { + return "查询失败:缺少必填参数 task_id。" + } + return GetTaskInfo(state, taskID) + }, + ) + + // --- 写工具 --- + r.Register("place", + "将一个待安排任务放到指定位置。自动检测可嵌入宿主。task_id/day/slot_start 必填。", + `{"name":"place","parameters":{"task_id":{"type":"int","required":true},"day":{"type":"int","required":true},"slot_start":{"type":"int","required":true}}}`, + func(state *ScheduleState, args map[string]any) string { + taskID, ok := argsInt(args, "task_id") + if !ok { + return "放置失败:缺少必填参数 task_id。" + } + day, ok := argsInt(args, "day") + if !ok { + return "放置失败:缺少必填参数 day。" + } + slotStart, ok := argsInt(args, "slot_start") + if !ok { + return "放置失败:缺少必填参数 slot_start。" + } + return Place(state, taskID, day, slotStart) + }, + ) + + r.Register("move", + "将一个已安排任务移动到新位置。task_id/new_day/new_slot_start 必填。", + `{"name":"move","parameters":{"task_id":{"type":"int","required":true},"new_day":{"type":"int","required":true},"new_slot_start":{"type":"int","required":true}}}`, + func(state *ScheduleState, args map[string]any) string { + taskID, ok := argsInt(args, "task_id") + if !ok { + return "移动失败:缺少必填参数 task_id。" + } + newDay, ok := argsInt(args, "new_day") + if !ok { + return "移动失败:缺少必填参数 new_day。" + } + newSlotStart, ok := argsInt(args, "new_slot_start") + if !ok { + return "移动失败:缺少必填参数 new_slot_start。" + } + return Move(state, taskID, newDay, newSlotStart) + }, + ) + + r.Register("swap", + "交换两个已安排任务的位置。两个任务必须时长相同。task_a/task_b 必填。", + `{"name":"swap","parameters":{"task_a":{"type":"int","required":true},"task_b":{"type":"int","required":true}}}`, + func(state *ScheduleState, args map[string]any) string { + taskA, ok := argsInt(args, "task_a") + if !ok { + return "交换失败:缺少必填参数 task_a。" + } + taskB, ok := argsInt(args, "task_b") + if !ok { + return "交换失败:缺少必填参数 task_b。" + } + return Swap(state, taskA, taskB) + }, + ) + + r.Register("batch_move", + "原子性批量移动多个任务,全部成功才生效。moves 数组必填。", + `{"name":"batch_move","parameters":{"moves":{"type":"array","required":true,"items":{"task_id":"int","new_day":"int","new_slot_start":"int"}}}}`, + func(state *ScheduleState, args map[string]any) string { + moves, err := argsMoveList(args) + if err != nil { + return fmt.Sprintf("批量移动失败:%s", err.Error()) + } + return BatchMove(state, moves) + }, + ) + + r.Register("unplace", + "将一个已安排任务移除,恢复为待安排状态。会自动清理嵌入关系。task_id 必填。", + `{"name":"unplace","parameters":{"task_id":{"type":"int","required":true}}}`, + func(state *ScheduleState, args map[string]any) string { + taskID, ok := argsInt(args, "task_id") + if !ok { + return "移除失败:缺少必填参数 task_id。" + } + return Unplace(state, taskID) + }, + ) + + // 按 schema name 排序,保证输出稳定。 + sort.Slice(r.schemas, func(i, j int) bool { + return r.schemas[i].Name < r.schemas[j].Name + }) + + return r +} diff --git a/backend/newAgent/tools/write_helpers.go b/backend/newAgent/tools/write_helpers.go new file mode 100644 index 0000000..083acfa --- /dev/null +++ b/backend/newAgent/tools/write_helpers.go @@ -0,0 +1,167 @@ +package newagenttools + +import ( + "fmt" + "strings" +) + +// ==================== 写工具专用辅助函数 ==================== +// 复用 read_helpers.go 中的:formatSlotRange, formatTaskLabel, slotOccupiedBy, +// findFreeRangesOnDay, getTasksOnDay, countDayOccupied, taskOnDay, freeRange + +// ==================== 校验函数 ==================== + +// validateDay 校验 day 是否在规划窗口范围内。 +func validateDay(state *ScheduleState, day int) error { + if day < 1 || day > state.Window.TotalDays { + return fmt.Errorf("第%d天不在规划窗口范围内(1-%d)", day, state.Window.TotalDays) + } + return nil +} + +// validateSlotRange 校验时段范围是否合法(1-12,start <= end)。 +func validateSlotRange(start, end int) error { + if start < 1 { + return fmt.Errorf("起始时段 %d 不能小于1", start) + } + if end > 12 { + return fmt.Errorf("结束时段 %d 不能大于12", end) + } + if start > end { + return fmt.Errorf("起始时段 %d 不能大于结束时段 %d", start, end) + } + return nil +} + +// checkLocked 检查任务是否被锁定。锁定任务不可移动/交换/移除。 +func checkLocked(task ScheduleTask) error { + if task.Locked { + return fmt.Errorf("[%d]%s 是固定课程,不可操作", task.StateID, task.Name) + } + return nil +} + +// ==================== 冲突检测 ==================== + +// findConflict 查找指定范围 [start, end] 内是否有冲突。 +// 排除 excludeStateIDs 中的任务(用于 move/swap 排除自身旧位置)。 +// 可嵌入宿主(can_embed=true)不算冲突——嵌入场景由 place 单独处理。 +// 返回第一个冲突任务,无冲突返回 nil。 +func findConflict(state *ScheduleState, day, start, end int, excludeStateIDs ...int) *ScheduleTask { + // 构建排除集合 + exclude := make(map[int]bool, len(excludeStateIDs)) + for _, id := range excludeStateIDs { + exclude[id] = true + } + + for i := range state.Tasks { + t := &state.Tasks[i] + // 排除指定任务 + if exclude[t.StateID] { + continue + } + // 可嵌入宿主不算冲突 + if t.CanEmbed { + continue + } + // 嵌入任务与宿主共享时段,不算独立冲突 + if t.EmbedHost != nil { + continue + } + // 只检查已安排的任务 + if len(t.Slots) == 0 { + continue + } + for _, slot := range t.Slots { + if slot.Day == day { + // 检查范围是否有交集:[start,end] ∩ [slot.SlotStart,slot.SlotEnd] + if start <= slot.SlotEnd && end >= slot.SlotStart { + return t + } + } + } + } + return nil +} + +// findEmbedHost 查找指定范围 [start, end] 内是否有可嵌入的宿主。 +// 条件:can_embed=true 且未被嵌入(embedded_by == nil)。 +// 返回第一个匹配的宿主,无匹配返回 nil。 +func findEmbedHost(state *ScheduleState, day, start, end int) *ScheduleTask { + for i := range state.Tasks { + t := &state.Tasks[i] + if !t.CanEmbed || t.EmbeddedBy != nil { + continue + } + for _, slot := range t.Slots { + if slot.Day == day { + // 完全包含在宿主时段内才能嵌入 + if start >= slot.SlotStart && end <= slot.SlotEnd { + return t + } + } + } + } + return nil +} + +// ==================== 计算辅助 ==================== + +// taskDuration 计算任务所有 Slots 的总时段数。 +// 如 Slots = [{1,1,2}, {3,1,2}] → 总时长 = 2+2 = 4。 +// 用于 swap 时比较两个任务的时长是否一致。 +func taskDuration(task ScheduleTask) int { + total := 0 + for _, slot := range task.Slots { + total += slot.SlotEnd - slot.SlotStart + 1 + } + return total +} + +// countPending 统计当前 state 中待安排任务数量。 +func countPending(state *ScheduleState) int { + count := 0 + for i := range state.Tasks { + if state.Tasks[i].Status == "pending" { + count++ + } + } + return count +} + +// ==================== 输出格式化 ==================== + +// formatDayOccupancy 格式化某天的占用摘要。 +// 如 "第5天当前占用:[3]复习线代(1-3节),占用3/12。" +// 如 "第4天当前占用:0/12。"(空天) +func formatDayOccupancy(state *ScheduleState, day int) string { + tasks := getTasksOnDay(state, day) + occupied := countDayOccupied(state, day) + + if len(tasks) == 0 { + return fmt.Sprintf("第%d天当前占用:0/12。", day) + } + + parts := make([]string, 0, len(tasks)) + for _, td := range tasks { + label := formatTaskLabel(*td.task) + parts = append(parts, fmt.Sprintf("%s(%s)", label, formatSlotRange(td.slotStart, td.slotEnd))) + } + + return fmt.Sprintf("第%d天当前占用:%s,占用%d/12。", day, strings.Join(parts, " "), occupied) +} + +// formatFreeHint 格式化某天的空闲时段提示。 +// 如 "空闲时段:第5-12节。" +// 无空闲时返回空字符串。 +func formatFreeHint(state *ScheduleState, day int) string { + ranges := findFreeRangesOnDay(state, day) + if len(ranges) == 0 { + return "" + } + parts := make([]string, 0, len(ranges)) + for _, r := range ranges { + parts = append(parts, formatSlotRange(r.slotStart, r.slotEnd)) + } + return fmt.Sprintf("空闲时段:%s。", strings.Join(parts, "、")) +} diff --git a/backend/newAgent/tools/write_tools.go b/backend/newAgent/tools/write_tools.go new file mode 100644 index 0000000..7d6d9cc --- /dev/null +++ b/backend/newAgent/tools/write_tools.go @@ -0,0 +1,453 @@ +package newagenttools + +import ( + "fmt" + "sort" + "strings" +) + +// ==================== 写工具:LLM 通过这些函数修改日程状态 ==================== +// 所有写工具: +// - 只修改内存中的 ScheduleState,不直接写库 +// - 先校验后修改,校验失败则 state 不变,返回错误信息 +// - 返回自然语言描述变更结果 + 涉及天的占用摘要 + +// MoveRequest 是 BatchMove 的单条移动请求。 +type MoveRequest struct { + TaskID int `json:"task_id"` + NewDay int `json:"new_day"` + NewSlotStart int `json:"new_slot_start"` +} + +// ==================== Place ==================== + +// Place 将一个待安排任务放到指定位置。 +// taskID 必须是 pending 状态的任务。 +// 如果目标位置有可嵌入宿主(can_embed=true 且未被嵌入),自动走嵌入逻辑。 +func Place(state *ScheduleState, taskID, day, slotStart int) string { + // 1. 查找任务。 + task := state.TaskByStateID(taskID) + if task == nil { + return fmt.Sprintf("放置失败:任务ID %d 不存在。", taskID) + } + + // 2. 校验状态。 + if task.Status != "pending" { + return fmt.Sprintf("放置失败:[%d]%s 不是待安排任务,无法放置。", task.StateID, task.Name) + } + + // 3. 计算目标范围并校验。 + slotEnd := slotStart + task.Duration - 1 + if err := validateDay(state, day); err != nil { + return fmt.Sprintf("放置失败:%s", err.Error()) + } + if err := validateSlotRange(slotStart, slotEnd); err != nil { + return fmt.Sprintf("放置失败:%s", err.Error()) + } + + // 4. 冲突检测。 + conflict := findConflict(state, day, slotStart, slotEnd) + if conflict != nil { + // 锁定任务的冲突给出特殊提示。 + if conflict.Locked { + return fmt.Sprintf("放置失败:第%d天第%s已被 [%d]%s(固定)占用。\n%s\n%s", + day, formatSlotRange(slotStart, slotEnd), conflict.StateID, conflict.Name, + formatDayOccupancy(state, day), formatFreeHint(state, day)) + } + return fmt.Sprintf("放置失败:第%d天第%s已被 [%d]%s 占用。\n%s\n%s", + day, formatSlotRange(slotStart, slotEnd), conflict.StateID, conflict.Name, + formatDayOccupancy(state, day), formatFreeHint(state, day)) + } + + // 5. 检查是否有可嵌入宿主。 + host := findEmbedHost(state, day, slotStart, slotEnd) + + // 6. 执行变更。 + if host != nil { + // 嵌入路径:设置双向嵌入关系。 + guestID := task.StateID + hostID := host.StateID + task.EmbedHost = &hostID + host.EmbeddedBy = &guestID + task.Slots = []TaskSlot{{Day: day, SlotStart: slotStart, SlotEnd: slotEnd}} + task.Status = "existing" + + return fmt.Sprintf("已将 [%d]%s 嵌入到第%d天第%s(宿主:[%d]%s)。\n%s\n待安排任务剩余:%d个。", + task.StateID, task.Name, day, formatSlotRange(slotStart, slotEnd), + host.StateID, host.Name, + formatDayOccupancy(state, day), countPending(state)) + } + + // 普通路径:直接放置。 + task.Slots = []TaskSlot{{Day: day, SlotStart: slotStart, SlotEnd: slotEnd}} + task.Status = "existing" + + return fmt.Sprintf("已将 [%d]%s 放到第%d天第%s。\n%s\n待安排任务剩余:%d个。", + task.StateID, task.Name, day, formatSlotRange(slotStart, slotEnd), + formatDayOccupancy(state, day), countPending(state)) +} + +// ==================== Move ==================== + +// Move 将一个已安排任务移动到新位置。 +// taskID 必须是 existing 状态且非锁定。 +func Move(state *ScheduleState, taskID, newDay, newSlotStart int) string { + // 1. 查找任务。 + task := state.TaskByStateID(taskID) + if task == nil { + return fmt.Sprintf("移动失败:任务ID %d 不存在。", taskID) + } + + // 2. 校验状态。 + if task.Status == "pending" { + return fmt.Sprintf("移动失败:[%d]%s 当前为待安排状态,请使用 place 放置。", task.StateID, task.Name) + } + + // 3. 校验锁定。 + if err := checkLocked(*task); err != nil { + return fmt.Sprintf("移动失败:%s", err.Error()) + } + + // 4. 计算新范围。 + duration := taskDuration(*task) + newSlotEnd := newSlotStart + duration - 1 + + if err := validateDay(state, newDay); err != nil { + return fmt.Sprintf("移动失败:%s", err.Error()) + } + if err := validateSlotRange(newSlotStart, newSlotEnd); err != nil { + return fmt.Sprintf("移动失败:%s", err.Error()) + } + + // 5. 冲突检测(排除自身)。 + conflict := findConflict(state, newDay, newSlotStart, newSlotEnd, taskID) + if conflict != nil { + return fmt.Sprintf("移动失败:第%d天第%s已被 [%d]%s 占用。\n%s\n%s", + newDay, formatSlotRange(newSlotStart, newSlotEnd), conflict.StateID, conflict.Name, + formatDayOccupancy(state, newDay), formatFreeHint(state, newDay)) + } + + // 6. 记录旧位置。 + oldSlots := make([]TaskSlot, len(task.Slots)) + copy(oldSlots, task.Slots) + oldDesc := formatTaskSlotsBrief(oldSlots) + + // 7. 执行变更。 + task.Slots = []TaskSlot{{Day: newDay, SlotStart: newSlotStart, SlotEnd: newSlotEnd}} + + // 8. 收集涉及的天(去重)。 + affectedDays := collectAffectedDays(oldSlots, task.Slots) + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("已将 [%d]%s 从%s移至第%d天第%s。\n", + task.StateID, task.Name, oldDesc, newDay, formatSlotRange(newSlotStart, newSlotEnd))) + for _, d := range affectedDays { + sb.WriteString(formatDayOccupancy(state, d) + "\n") + } + return sb.String() +} + +// ==================== Swap ==================== + +// Swap 交换两个已安排任务的位置。 +// 两个任务都必须是 existing 状态、非锁定、总时长相同。 +func Swap(state *ScheduleState, taskAID, taskBID int) string { + // 1. 查找两个任务。 + taskA := state.TaskByStateID(taskAID) + if taskA == nil { + return fmt.Sprintf("交换失败:任务ID %d 不存在。", taskAID) + } + taskB := state.TaskByStateID(taskBID) + if taskB == nil { + return fmt.Sprintf("交换失败:任务ID %d 不存在。", taskBID) + } + + if taskAID == taskBID { + return "交换失败:不能与自己交换。" + } + + // 2. 校验状态。 + if taskA.Status != "existing" { + return fmt.Sprintf("交换失败:[%d]%s 不是已安排任务。", taskA.StateID, taskA.Name) + } + if taskB.Status != "existing" { + return fmt.Sprintf("交换失败:[%d]%s 不是已安排任务。", taskB.StateID, taskB.Name) + } + + // 3. 校验锁定。 + if err := checkLocked(*taskA); err != nil { + return fmt.Sprintf("交换失败:%s", err.Error()) + } + if err := checkLocked(*taskB); err != nil { + return fmt.Sprintf("交换失败:%s", err.Error()) + } + + // 4. 校验时长。 + durA := taskDuration(*taskA) + durB := taskDuration(*taskB) + if durA != durB { + return fmt.Sprintf("交换失败:[%d]%s 占%d个时段,[%d]%s 占%d个时段,时长不同无法直接交换。", + taskA.StateID, taskA.Name, durA, taskB.StateID, taskB.Name, durB) + } + + // 5. 记录旧位置。 + oldSlotsA := make([]TaskSlot, len(taskA.Slots)) + copy(oldSlotsA, taskA.Slots) + oldSlotsB := make([]TaskSlot, len(taskB.Slots)) + copy(oldSlotsB, taskB.Slots) + + // 6. 交换 Slots。 + taskA.Slots, taskB.Slots = taskB.Slots, taskA.Slots + + // 7. 交换后冲突检测:A 的新位置(原 B 的位置)是否有第三方冲突。 + // 需要排除 B(因为 B 现在在 A 的旧位置,已经被 swap 了)。 + for _, slot := range taskA.Slots { + conflict := findConflict(state, slot.Day, slot.SlotStart, slot.SlotEnd, taskAID, taskBID) + if conflict != nil { + // 回滚 + taskA.Slots = oldSlotsA + taskB.Slots = oldSlotsB + return fmt.Sprintf("交换失败:[%d]%s 的新位置第%d天第%s与 [%d]%s 冲突。", + taskA.StateID, taskA.Name, slot.Day, formatSlotRange(slot.SlotStart, slot.SlotEnd), + conflict.StateID, conflict.Name) + } + } + for _, slot := range taskB.Slots { + conflict := findConflict(state, slot.Day, slot.SlotStart, slot.SlotEnd, taskAID, taskBID) + if conflict != nil { + // 回滚 + taskA.Slots = oldSlotsA + taskB.Slots = oldSlotsB + return fmt.Sprintf("交换失败:[%d]%s 的新位置第%d天第%s与 [%d]%s 冲突。", + taskB.StateID, taskB.Name, slot.Day, formatSlotRange(slot.SlotStart, slot.SlotEnd), + conflict.StateID, conflict.Name) + } + } + + // 8. 成功输出。 + affectedDays := collectAffectedDays(oldSlotsA, taskA.Slots) + affectedDays = append(affectedDays, collectAffectedDays(oldSlotsB, taskB.Slots)...) + affectedDays = uniqueSorted(affectedDays) + + var sb strings.Builder + sb.WriteString("交换完成:\n") + sb.WriteString(fmt.Sprintf(" [%d]%s:%s → %s\n", + taskA.StateID, taskA.Name, + formatTaskSlotsBrief(oldSlotsA), formatTaskSlotsBrief(taskA.Slots))) + sb.WriteString(fmt.Sprintf(" [%d]%s:%s → %s\n", + taskB.StateID, taskB.Name, + formatTaskSlotsBrief(oldSlotsB), formatTaskSlotsBrief(taskB.Slots))) + for _, d := range affectedDays { + sb.WriteString(formatDayOccupancy(state, d) + "\n") + } + return sb.String() +} + +// ==================== BatchMove ==================== + +// BatchMove 原子性地批量移动多个任务。 +// 全部成功才生效,任一失败则完全回滚。 +func BatchMove(state *ScheduleState, moves []MoveRequest) string { + if len(moves) == 0 { + return "批量移动失败:移动列表为空。" + } + + // 1. 全量校验阶段(不改 state)。 + for i, m := range moves { + task := state.TaskByStateID(m.TaskID) + if task == nil { + return fmt.Sprintf("批量移动失败,全部回滚,无任何变更。\n任务ID %d 不存在(第%d条移动请求)。", m.TaskID, i+1) + } + if task.Status == "pending" { + return fmt.Sprintf("批量移动失败,全部回滚,无任何变更。\n[%d]%s 当前为待安排状态,请使用 place(第%d条移动请求)。", + task.StateID, task.Name, i+1) + } + if err := checkLocked(*task); err != nil { + return fmt.Sprintf("批量移动失败,全部回滚,无任何变更。\n%s(第%d条移动请求)", err.Error(), i+1) + } + + duration := taskDuration(*task) + newSlotEnd := m.NewSlotStart + duration - 1 + if err := validateDay(state, m.NewDay); err != nil { + return fmt.Sprintf("批量移动失败,全部回滚,无任何变更。\n%s(第%d条移动请求)", err.Error(), i+1) + } + if err := validateSlotRange(m.NewSlotStart, newSlotEnd); err != nil { + return fmt.Sprintf("批量移动失败,全部回滚,无任何变更。\n%s(第%d条移动请求)", err.Error(), i+1) + } + } + + // 2. 克隆 state,在克隆上执行。 + clone := state.Clone() + + // 收集涉及的天。 + affectedDays := make(map[int]bool) + + // 3. 逐个应用 + 冲突检测。 + for _, m := range moves { + task := clone.TaskByStateID(m.TaskID) + duration := taskDuration(*task) + newSlotEnd := m.NewSlotStart + duration - 1 + + // 记录旧位置涉及的天。 + for _, slot := range task.Slots { + affectedDays[slot.Day] = true + } + + // 冲突检测(在 clone 的中间状态上,排除自身)。 + conflict := findConflict(clone, m.NewDay, m.NewSlotStart, newSlotEnd, m.TaskID) + if conflict != nil { + return fmt.Sprintf("批量移动失败,全部回滚,无任何变更。\n冲突:[%d]%s → 第%d天第%s,该位置已被 [%d]%s 占用。", + task.StateID, task.Name, m.NewDay, formatSlotRange(m.NewSlotStart, newSlotEnd), + conflict.StateID, conflict.Name) + } + + // 应用移动。 + task.Slots = []TaskSlot{{Day: m.NewDay, SlotStart: m.NewSlotStart, SlotEnd: newSlotEnd}} + affectedDays[m.NewDay] = true + } + + // 4. 全部成功,将 clone 的数据写回原 state。 + state.Tasks = clone.Tasks + + // 5. 输出结果。 + days := sortedKeys(affectedDays) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("批量移动完成,%d个任务全部成功:\n", len(moves))) + for _, m := range moves { + task := state.TaskByStateID(m.TaskID) + duration := taskDuration(*task) + sb.WriteString(fmt.Sprintf(" [%d]%s → 第%d天第%s\n", + task.StateID, task.Name, m.NewDay, + formatSlotRange(m.NewSlotStart, m.NewSlotStart+duration-1))) + } + for _, d := range days { + sb.WriteString(formatDayOccupancy(state, d) + "\n") + } + return sb.String() +} + +// ==================== Unplace ==================== + +// Unplace 将一个已安排任务移除,恢复为待安排状态。 +// taskID 必须是 existing 状态且非锁定。 +// 如果任务有嵌入关系,会自动清理双向指针。 +func Unplace(state *ScheduleState, taskID int) string { + // 1. 查找任务。 + task := state.TaskByStateID(taskID) + if task == nil { + return fmt.Sprintf("移除失败:任务ID %d 不存在。", taskID) + } + + // 2. 校验状态。 + if task.Status == "pending" { + return fmt.Sprintf("移除失败:[%d]%s 已经是待安排状态。", task.StateID, task.Name) + } + + // 3. 校验锁定。 + if err := checkLocked(*task); err != nil { + return fmt.Sprintf("移除失败:%s", err.Error()) + } + + // 4. 记录旧位置。 + oldSlots := make([]TaskSlot, len(task.Slots)) + copy(oldSlots, task.Slots) + oldDesc := formatTaskSlotsBrief(oldSlots) + + // 5. 清理嵌入关系。 + // 如果该任务嵌入到了某个宿主上,清除宿主的 EmbeddedBy。 + if task.EmbedHost != nil { + host := state.TaskByStateID(*task.EmbedHost) + if host != nil { + host.EmbeddedBy = nil + } + task.EmbedHost = nil + } + // 如果该任务是一个宿主且有嵌入客人,将客人也恢复为 pending。 + if task.EmbeddedBy != nil { + guest := state.TaskByStateID(*task.EmbeddedBy) + if guest != nil { + guest.EmbedHost = nil + guest.Slots = nil + guest.Status = "pending" + // 恢复客人的 Duration:从原始数据推断。 + // 嵌入客人只占一个 slot range,取其长度作为 duration。 + if len(oldSlots) > 0 { + // 客人被嵌入到宿主的 slot 里,客人自己的 slot 在嵌入时被设置了 + } + } + task.EmbeddedBy = nil + } + + // 6. 执行变更。 + task.Slots = nil + task.Status = "pending" + + // 7. 收集涉及的天。 + affectedDays := collectAffectedDaysFromSlots(oldSlots) + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("已将 [%d]%s 从%s移除,恢复为待安排状态。\n", + task.StateID, task.Name, oldDesc)) + for _, d := range affectedDays { + sb.WriteString(formatDayOccupancy(state, d) + "\n") + } + sb.WriteString(fmt.Sprintf("待安排任务剩余:%d个。", countPending(state))) + return sb.String() +} + +// ==================== 内部辅助函数 ==================== + +// formatTaskSlotsBrief 将任务的时段列表格式化为简短描述。 +// 如 "第1天(1-2节) 第4天(3-4节)"。 +func formatTaskSlotsBrief(slots []TaskSlot) string { + parts := make([]string, 0, len(slots)) + for _, slot := range slots { + parts = append(parts, fmt.Sprintf("第%d天第%s", slot.Day, formatSlotRange(slot.SlotStart, slot.SlotEnd))) + } + return strings.Join(parts, " ") +} + +// collectAffectedDays 从旧位置和新位置中收集所有涉及的天(去重排序)。 +func collectAffectedDays(oldSlots, newSlots []TaskSlot) []int { + days := make(map[int]bool) + for _, s := range oldSlots { + days[s.Day] = true + } + for _, s := range newSlots { + days[s.Day] = true + } + return sortedKeys(days) +} + +// collectAffectedDaysFromSlots 从单个 slot 列表中收集涉及的天。 +func collectAffectedDaysFromSlots(slots []TaskSlot) []int { + days := make(map[int]bool) + for _, s := range slots { + days[s.Day] = true + } + return sortedKeys(days) +} + +// sortedKeys 将 map 的 key 排序后返回。 +func sortedKeys(m map[int]bool) []int { + keys := make([]int, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Ints(keys) + return keys +} + +// uniqueSorted 对 int 切片去重并排序。 +func uniqueSorted(s []int) []int { + seen := make(map[int]bool) + result := make([]int, 0, len(s)) + for _, v := range s { + if !seen[v] { + seen[v] = true + result = append(result, v) + } + } + sort.Ints(result) + return result +}