Version: 0.9.0.dev.260405
后端: 1.新建tools/write_helpers.go:写工具专用辅助函数(冲突检测、范围校验、嵌入宿主查找、锁定检查、格式化) 2.新建tools/write_tools.go:实现5个写工具(Place/Move/Swap/BatchMove/Unplace),含嵌入逻辑、原子性批量操作、双向嵌入关系清理,26个单元测试全部通过 3.新建tools/registry.go:工具注册表(ToolRegistry),统一管理10个工具的注册/查找/执行,支持读写工具区分和参数解析 4.更新model/graph_run_state.go: 新增 ScheduleStateProvider 接口和 ToolRegistry 依赖注入,AgentGraphState 支持按需加载ScheduleState 5.更新 node/execute.go:接入 ToolRegistry 实现真实工具调用,替换原骨架实现 6.更新 AGENTS.md 前端:无 仓库:无
This commit is contained in:
326
backend/newAgent/tools/registry.go
Normal file
326
backend/newAgent/tools/registry.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package newagenttools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ToolHandler 是所有工具的统一执行签名。
|
||||
// 接收当前 ScheduleState + LLM 输出的原始参数,返回自然语言结果。
|
||||
type ToolHandler func(state *ScheduleState, args map[string]any) string
|
||||
|
||||
// ToolSchemaEntry 是工具描述的轻量快照,用于 LLM prompt 注入。
|
||||
// 在注入 ConversationContext 时转换为 model.ToolSchemaContext。
|
||||
type ToolSchemaEntry struct {
|
||||
Name string
|
||||
Desc string
|
||||
SchemaText string
|
||||
}
|
||||
|
||||
// ToolRegistry 管理所有工具的注册、查找和执行。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责工具名 → handler 的映射;
|
||||
// 2. 负责工具 schema 的存储(供 LLM prompt 注入);
|
||||
// 3. 不负责 ScheduleState 的生命周期管理;
|
||||
// 4. 不负责 confirm 流程(由 execute.go 的 action 分支处理)。
|
||||
type ToolRegistry struct {
|
||||
handlers map[string]ToolHandler
|
||||
schemas []ToolSchemaEntry
|
||||
}
|
||||
|
||||
// NewToolRegistry 创建空注册表。
|
||||
func NewToolRegistry() *ToolRegistry {
|
||||
return &ToolRegistry{
|
||||
handlers: make(map[string]ToolHandler),
|
||||
schemas: make([]ToolSchemaEntry, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Register 注册一个工具及其 schema 描述。
|
||||
func (r *ToolRegistry) Register(name, desc, schemaText string, handler ToolHandler) {
|
||||
r.handlers[name] = handler
|
||||
r.schemas = append(r.schemas, ToolSchemaEntry{
|
||||
Name: name,
|
||||
Desc: desc,
|
||||
SchemaText: schemaText,
|
||||
})
|
||||
}
|
||||
|
||||
// Execute 执行指定工具。
|
||||
// 工具名不存在时返回错误提示字符串。
|
||||
func (r *ToolRegistry) Execute(state *ScheduleState, toolName string, args map[string]any) string {
|
||||
handler, ok := r.handlers[toolName]
|
||||
if !ok {
|
||||
return fmt.Sprintf("工具调用失败:未知工具 %q。可用工具:%s", toolName, strings.Join(r.ToolNames(), "、"))
|
||||
}
|
||||
return handler(state, args)
|
||||
}
|
||||
|
||||
// HasTool 检查工具是否已注册。
|
||||
func (r *ToolRegistry) HasTool(name string) bool {
|
||||
_, ok := r.handlers[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ToolNames 返回所有已注册工具名(按注册顺序)。
|
||||
func (r *ToolRegistry) ToolNames() []string {
|
||||
names := make([]string, 0, len(r.handlers))
|
||||
for _, s := range r.schemas {
|
||||
names = append(names, s.Name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// Schemas 返回所有工具的 schema 描述(供 LLM prompt 注入)。
|
||||
func (r *ToolRegistry) Schemas() []ToolSchemaEntry {
|
||||
result := make([]ToolSchemaEntry, len(r.schemas))
|
||||
copy(result, r.schemas)
|
||||
return result
|
||||
}
|
||||
|
||||
// IsWriteTool 判断指定工具是否为写工具(需要 confirm 流程)。
|
||||
func (r *ToolRegistry) IsWriteTool(name string) bool {
|
||||
return writeTools[name]
|
||||
}
|
||||
|
||||
// ==================== 参数解析辅助 ====================
|
||||
|
||||
// argsInt 从 map 中提取 int 值。支持 float64(JSON 反序列化的默认类型)。
|
||||
func argsInt(args map[string]any, key string) (int, bool) {
|
||||
v, ok := args[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n), true
|
||||
case int:
|
||||
return n, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// argsString 从 map 中提取 string 值。
|
||||
func argsString(args map[string]any, key string) (string, bool) {
|
||||
v, ok := args[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
s, ok := v.(string)
|
||||
return s, ok
|
||||
}
|
||||
|
||||
// argsIntPtr 从 map 中提取可选 int 值,不存在返回 nil。
|
||||
func argsIntPtr(args map[string]any, key string) *int {
|
||||
v, ok := argsInt(args, key)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return &v
|
||||
}
|
||||
|
||||
// argsStringPtr 从 map 中提取可选 string 值,不存在返回 nil。
|
||||
func argsStringPtr(args map[string]any, key string) *string {
|
||||
v, ok := argsString(args, key)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return &v
|
||||
}
|
||||
|
||||
// argsMoveList 从 map 中提取 batch_move 的 moves 数组。
|
||||
func argsMoveList(args map[string]any) ([]MoveRequest, error) {
|
||||
v, ok := args["moves"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("缺少 moves 参数")
|
||||
}
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("moves 参数必须是数组")
|
||||
}
|
||||
moves := make([]MoveRequest, 0, len(arr))
|
||||
for i, item := range arr {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("moves[%d] 不是有效对象", i)
|
||||
}
|
||||
taskID, ok := argsInt(m, "task_id")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("moves[%d].task_id 缺失或无效", i)
|
||||
}
|
||||
newDay, ok := argsInt(m, "new_day")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("moves[%d].new_day 缺失或无效", i)
|
||||
}
|
||||
newSlotStart, ok := argsInt(m, "new_slot_start")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("moves[%d].new_slot_start 缺失或无效", i)
|
||||
}
|
||||
moves = append(moves, MoveRequest{
|
||||
TaskID: taskID,
|
||||
NewDay: newDay,
|
||||
NewSlotStart: newSlotStart,
|
||||
})
|
||||
}
|
||||
return moves, nil
|
||||
}
|
||||
|
||||
// ==================== 写工具名集合 ====================
|
||||
|
||||
var writeTools = map[string]bool{
|
||||
"place": true,
|
||||
"move": true,
|
||||
"swap": true,
|
||||
"batch_move": true,
|
||||
"unplace": true,
|
||||
}
|
||||
|
||||
// ==================== 默认注册表 ====================
|
||||
|
||||
// NewDefaultRegistry 创建包含全部 10 个日程工具的注册表。
|
||||
func NewDefaultRegistry() *ToolRegistry {
|
||||
r := NewToolRegistry()
|
||||
|
||||
// --- 读工具 ---
|
||||
r.Register("get_overview",
|
||||
"获取规划窗口的粗粒度总览,包括每日占用、可嵌入时段和待安排任务。",
|
||||
`{"name":"get_overview","parameters":{}}`,
|
||||
func(state *ScheduleState, args map[string]any) string {
|
||||
return GetOverview(state)
|
||||
},
|
||||
)
|
||||
|
||||
r.Register("query_range",
|
||||
"查看某天或某时段的细粒度占用详情。day 必填,slot_start/slot_end 选填(不填查整天)。",
|
||||
`{"name":"query_range","parameters":{"day":{"type":"int","required":true},"slot_start":{"type":"int"},"slot_end":{"type":"int"}}}`,
|
||||
func(state *ScheduleState, args map[string]any) string {
|
||||
day, ok := argsInt(args, "day")
|
||||
if !ok {
|
||||
return "查询失败:缺少必填参数 day。"
|
||||
}
|
||||
return QueryRange(state, day, argsIntPtr(args, "slot_start"), argsIntPtr(args, "slot_end"))
|
||||
},
|
||||
)
|
||||
|
||||
r.Register("find_free",
|
||||
"查找满足指定连续时段长度的空闲位置。duration 必填,day 选填(不填搜全部天)。",
|
||||
`{"name":"find_free","parameters":{"duration":{"type":"int","required":true},"day":{"type":"int"}}}`,
|
||||
func(state *ScheduleState, args map[string]any) string {
|
||||
duration, ok := argsInt(args, "duration")
|
||||
if !ok {
|
||||
return "查询失败:缺少必填参数 duration。"
|
||||
}
|
||||
return FindFree(state, duration, argsIntPtr(args, "day"))
|
||||
},
|
||||
)
|
||||
|
||||
r.Register("list_tasks",
|
||||
"列出任务清单,可按类别和状态过滤。category 选填,status 选填(默认 all)。",
|
||||
`{"name":"list_tasks","parameters":{"category":{"type":"string"},"status":{"type":"string","enum":["all","existing","pending"]}}}`,
|
||||
func(state *ScheduleState, args map[string]any) string {
|
||||
return ListTasks(state, argsStringPtr(args, "category"), argsStringPtr(args, "status"))
|
||||
},
|
||||
)
|
||||
|
||||
r.Register("get_task_info",
|
||||
"查询单个任务的详细信息,包括类别、状态、占用时段、嵌入关系。",
|
||||
`{"name":"get_task_info","parameters":{"task_id":{"type":"int","required":true}}}`,
|
||||
func(state *ScheduleState, args map[string]any) string {
|
||||
taskID, ok := argsInt(args, "task_id")
|
||||
if !ok {
|
||||
return "查询失败:缺少必填参数 task_id。"
|
||||
}
|
||||
return GetTaskInfo(state, taskID)
|
||||
},
|
||||
)
|
||||
|
||||
// --- 写工具 ---
|
||||
r.Register("place",
|
||||
"将一个待安排任务放到指定位置。自动检测可嵌入宿主。task_id/day/slot_start 必填。",
|
||||
`{"name":"place","parameters":{"task_id":{"type":"int","required":true},"day":{"type":"int","required":true},"slot_start":{"type":"int","required":true}}}`,
|
||||
func(state *ScheduleState, args map[string]any) string {
|
||||
taskID, ok := argsInt(args, "task_id")
|
||||
if !ok {
|
||||
return "放置失败:缺少必填参数 task_id。"
|
||||
}
|
||||
day, ok := argsInt(args, "day")
|
||||
if !ok {
|
||||
return "放置失败:缺少必填参数 day。"
|
||||
}
|
||||
slotStart, ok := argsInt(args, "slot_start")
|
||||
if !ok {
|
||||
return "放置失败:缺少必填参数 slot_start。"
|
||||
}
|
||||
return Place(state, taskID, day, slotStart)
|
||||
},
|
||||
)
|
||||
|
||||
r.Register("move",
|
||||
"将一个已安排任务移动到新位置。task_id/new_day/new_slot_start 必填。",
|
||||
`{"name":"move","parameters":{"task_id":{"type":"int","required":true},"new_day":{"type":"int","required":true},"new_slot_start":{"type":"int","required":true}}}`,
|
||||
func(state *ScheduleState, args map[string]any) string {
|
||||
taskID, ok := argsInt(args, "task_id")
|
||||
if !ok {
|
||||
return "移动失败:缺少必填参数 task_id。"
|
||||
}
|
||||
newDay, ok := argsInt(args, "new_day")
|
||||
if !ok {
|
||||
return "移动失败:缺少必填参数 new_day。"
|
||||
}
|
||||
newSlotStart, ok := argsInt(args, "new_slot_start")
|
||||
if !ok {
|
||||
return "移动失败:缺少必填参数 new_slot_start。"
|
||||
}
|
||||
return Move(state, taskID, newDay, newSlotStart)
|
||||
},
|
||||
)
|
||||
|
||||
r.Register("swap",
|
||||
"交换两个已安排任务的位置。两个任务必须时长相同。task_a/task_b 必填。",
|
||||
`{"name":"swap","parameters":{"task_a":{"type":"int","required":true},"task_b":{"type":"int","required":true}}}`,
|
||||
func(state *ScheduleState, args map[string]any) string {
|
||||
taskA, ok := argsInt(args, "task_a")
|
||||
if !ok {
|
||||
return "交换失败:缺少必填参数 task_a。"
|
||||
}
|
||||
taskB, ok := argsInt(args, "task_b")
|
||||
if !ok {
|
||||
return "交换失败:缺少必填参数 task_b。"
|
||||
}
|
||||
return Swap(state, taskA, taskB)
|
||||
},
|
||||
)
|
||||
|
||||
r.Register("batch_move",
|
||||
"原子性批量移动多个任务,全部成功才生效。moves 数组必填。",
|
||||
`{"name":"batch_move","parameters":{"moves":{"type":"array","required":true,"items":{"task_id":"int","new_day":"int","new_slot_start":"int"}}}}`,
|
||||
func(state *ScheduleState, args map[string]any) string {
|
||||
moves, err := argsMoveList(args)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("批量移动失败:%s", err.Error())
|
||||
}
|
||||
return BatchMove(state, moves)
|
||||
},
|
||||
)
|
||||
|
||||
r.Register("unplace",
|
||||
"将一个已安排任务移除,恢复为待安排状态。会自动清理嵌入关系。task_id 必填。",
|
||||
`{"name":"unplace","parameters":{"task_id":{"type":"int","required":true}}}`,
|
||||
func(state *ScheduleState, args map[string]any) string {
|
||||
taskID, ok := argsInt(args, "task_id")
|
||||
if !ok {
|
||||
return "移除失败:缺少必填参数 task_id。"
|
||||
}
|
||||
return Unplace(state, taskID)
|
||||
},
|
||||
)
|
||||
|
||||
// 按 schema name 排序,保证输出稳定。
|
||||
sort.Slice(r.schemas, func(i, j int) bool {
|
||||
return r.schemas[i].Name < r.schemas[j].Name
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
Reference in New Issue
Block a user