Files
smartmate/backend/newAgent/tools/registry.go
Losita bcee43b610 Version: 0.9.0.dev.260405
后端:
1.新建tools/write_helpers.go:写工具专用辅助函数(冲突检测、范围校验、嵌入宿主查找、锁定检查、格式化)
2.新建tools/write_tools.go:实现5个写工具(Place/Move/Swap/BatchMove/Unplace),含嵌入逻辑、原子性批量操作、双向嵌入关系清理,26个单元测试全部通过
3.新建tools/registry.go:工具注册表(ToolRegistry),统一管理10个工具的注册/查找/执行,支持读写工具区分和参数解析
4.更新model/graph_run_state.go:
新增 ScheduleStateProvider 接口和 ToolRegistry 依赖注入,AgentGraphState 支持按需加载ScheduleState
5.更新 node/execute.go:接入 ToolRegistry 实现真实工具调用,替换原骨架实现
6.更新 AGENTS.md
前端:无
仓库:无
2026-04-05 15:22:46 +08:00

327 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package newagenttools
import (
"fmt"
"sort"
"strings"
)
// ToolHandler 是所有工具的统一执行签名。
// 接收当前 ScheduleState + LLM 输出的原始参数,返回自然语言结果。
type ToolHandler func(state *ScheduleState, args map[string]any) string
// ToolSchemaEntry 是工具描述的轻量快照,用于 LLM prompt 注入。
// 在注入 ConversationContext 时转换为 model.ToolSchemaContext。
type ToolSchemaEntry struct {
Name string
Desc string
SchemaText string
}
// ToolRegistry 管理所有工具的注册、查找和执行。
//
// 职责边界:
// 1. 负责工具名 → handler 的映射;
// 2. 负责工具 schema 的存储(供 LLM prompt 注入);
// 3. 不负责 ScheduleState 的生命周期管理;
// 4. 不负责 confirm 流程(由 execute.go 的 action 分支处理)。
type ToolRegistry struct {
handlers map[string]ToolHandler
schemas []ToolSchemaEntry
}
// NewToolRegistry 创建空注册表。
func NewToolRegistry() *ToolRegistry {
return &ToolRegistry{
handlers: make(map[string]ToolHandler),
schemas: make([]ToolSchemaEntry, 0),
}
}
// Register 注册一个工具及其 schema 描述。
func (r *ToolRegistry) Register(name, desc, schemaText string, handler ToolHandler) {
r.handlers[name] = handler
r.schemas = append(r.schemas, ToolSchemaEntry{
Name: name,
Desc: desc,
SchemaText: schemaText,
})
}
// Execute 执行指定工具。
// 工具名不存在时返回错误提示字符串。
func (r *ToolRegistry) Execute(state *ScheduleState, toolName string, args map[string]any) string {
handler, ok := r.handlers[toolName]
if !ok {
return fmt.Sprintf("工具调用失败:未知工具 %q。可用工具%s", toolName, strings.Join(r.ToolNames(), "、"))
}
return handler(state, args)
}
// HasTool 检查工具是否已注册。
func (r *ToolRegistry) HasTool(name string) bool {
_, ok := r.handlers[name]
return ok
}
// ToolNames 返回所有已注册工具名(按注册顺序)。
func (r *ToolRegistry) ToolNames() []string {
names := make([]string, 0, len(r.handlers))
for _, s := range r.schemas {
names = append(names, s.Name)
}
return names
}
// Schemas 返回所有工具的 schema 描述(供 LLM prompt 注入)。
func (r *ToolRegistry) Schemas() []ToolSchemaEntry {
result := make([]ToolSchemaEntry, len(r.schemas))
copy(result, r.schemas)
return result
}
// IsWriteTool 判断指定工具是否为写工具(需要 confirm 流程)。
func (r *ToolRegistry) IsWriteTool(name string) bool {
return writeTools[name]
}
// ==================== 参数解析辅助 ====================
// argsInt 从 map 中提取 int 值。支持 float64JSON 反序列化的默认类型)。
func argsInt(args map[string]any, key string) (int, bool) {
v, ok := args[key]
if !ok {
return 0, false
}
switch n := v.(type) {
case float64:
return int(n), true
case int:
return n, true
}
return 0, false
}
// argsString 从 map 中提取 string 值。
func argsString(args map[string]any, key string) (string, bool) {
v, ok := args[key]
if !ok {
return "", false
}
s, ok := v.(string)
return s, ok
}
// argsIntPtr 从 map 中提取可选 int 值,不存在返回 nil。
func argsIntPtr(args map[string]any, key string) *int {
v, ok := argsInt(args, key)
if !ok {
return nil
}
return &v
}
// argsStringPtr 从 map 中提取可选 string 值,不存在返回 nil。
func argsStringPtr(args map[string]any, key string) *string {
v, ok := argsString(args, key)
if !ok {
return nil
}
return &v
}
// argsMoveList 从 map 中提取 batch_move 的 moves 数组。
func argsMoveList(args map[string]any) ([]MoveRequest, error) {
v, ok := args["moves"]
if !ok {
return nil, fmt.Errorf("缺少 moves 参数")
}
arr, ok := v.([]any)
if !ok {
return nil, fmt.Errorf("moves 参数必须是数组")
}
moves := make([]MoveRequest, 0, len(arr))
for i, item := range arr {
m, ok := item.(map[string]any)
if !ok {
return nil, fmt.Errorf("moves[%d] 不是有效对象", i)
}
taskID, ok := argsInt(m, "task_id")
if !ok {
return nil, fmt.Errorf("moves[%d].task_id 缺失或无效", i)
}
newDay, ok := argsInt(m, "new_day")
if !ok {
return nil, fmt.Errorf("moves[%d].new_day 缺失或无效", i)
}
newSlotStart, ok := argsInt(m, "new_slot_start")
if !ok {
return nil, fmt.Errorf("moves[%d].new_slot_start 缺失或无效", i)
}
moves = append(moves, MoveRequest{
TaskID: taskID,
NewDay: newDay,
NewSlotStart: newSlotStart,
})
}
return moves, nil
}
// ==================== 写工具名集合 ====================
var writeTools = map[string]bool{
"place": true,
"move": true,
"swap": true,
"batch_move": true,
"unplace": true,
}
// ==================== 默认注册表 ====================
// NewDefaultRegistry 创建包含全部 10 个日程工具的注册表。
func NewDefaultRegistry() *ToolRegistry {
r := NewToolRegistry()
// --- 读工具 ---
r.Register("get_overview",
"获取规划窗口的粗粒度总览,包括每日占用、可嵌入时段和待安排任务。",
`{"name":"get_overview","parameters":{}}`,
func(state *ScheduleState, args map[string]any) string {
return GetOverview(state)
},
)
r.Register("query_range",
"查看某天或某时段的细粒度占用详情。day 必填slot_start/slot_end 选填(不填查整天)。",
`{"name":"query_range","parameters":{"day":{"type":"int","required":true},"slot_start":{"type":"int"},"slot_end":{"type":"int"}}}`,
func(state *ScheduleState, args map[string]any) string {
day, ok := argsInt(args, "day")
if !ok {
return "查询失败:缺少必填参数 day。"
}
return QueryRange(state, day, argsIntPtr(args, "slot_start"), argsIntPtr(args, "slot_end"))
},
)
r.Register("find_free",
"查找满足指定连续时段长度的空闲位置。duration 必填day 选填(不填搜全部天)。",
`{"name":"find_free","parameters":{"duration":{"type":"int","required":true},"day":{"type":"int"}}}`,
func(state *ScheduleState, args map[string]any) string {
duration, ok := argsInt(args, "duration")
if !ok {
return "查询失败:缺少必填参数 duration。"
}
return FindFree(state, duration, argsIntPtr(args, "day"))
},
)
r.Register("list_tasks",
"列出任务清单可按类别和状态过滤。category 选填status 选填(默认 all。",
`{"name":"list_tasks","parameters":{"category":{"type":"string"},"status":{"type":"string","enum":["all","existing","pending"]}}}`,
func(state *ScheduleState, args map[string]any) string {
return ListTasks(state, argsStringPtr(args, "category"), argsStringPtr(args, "status"))
},
)
r.Register("get_task_info",
"查询单个任务的详细信息,包括类别、状态、占用时段、嵌入关系。",
`{"name":"get_task_info","parameters":{"task_id":{"type":"int","required":true}}}`,
func(state *ScheduleState, args map[string]any) string {
taskID, ok := argsInt(args, "task_id")
if !ok {
return "查询失败:缺少必填参数 task_id。"
}
return GetTaskInfo(state, taskID)
},
)
// --- 写工具 ---
r.Register("place",
"将一个待安排任务放到指定位置。自动检测可嵌入宿主。task_id/day/slot_start 必填。",
`{"name":"place","parameters":{"task_id":{"type":"int","required":true},"day":{"type":"int","required":true},"slot_start":{"type":"int","required":true}}}`,
func(state *ScheduleState, args map[string]any) string {
taskID, ok := argsInt(args, "task_id")
if !ok {
return "放置失败:缺少必填参数 task_id。"
}
day, ok := argsInt(args, "day")
if !ok {
return "放置失败:缺少必填参数 day。"
}
slotStart, ok := argsInt(args, "slot_start")
if !ok {
return "放置失败:缺少必填参数 slot_start。"
}
return Place(state, taskID, day, slotStart)
},
)
r.Register("move",
"将一个已安排任务移动到新位置。task_id/new_day/new_slot_start 必填。",
`{"name":"move","parameters":{"task_id":{"type":"int","required":true},"new_day":{"type":"int","required":true},"new_slot_start":{"type":"int","required":true}}}`,
func(state *ScheduleState, args map[string]any) string {
taskID, ok := argsInt(args, "task_id")
if !ok {
return "移动失败:缺少必填参数 task_id。"
}
newDay, ok := argsInt(args, "new_day")
if !ok {
return "移动失败:缺少必填参数 new_day。"
}
newSlotStart, ok := argsInt(args, "new_slot_start")
if !ok {
return "移动失败:缺少必填参数 new_slot_start。"
}
return Move(state, taskID, newDay, newSlotStart)
},
)
r.Register("swap",
"交换两个已安排任务的位置。两个任务必须时长相同。task_a/task_b 必填。",
`{"name":"swap","parameters":{"task_a":{"type":"int","required":true},"task_b":{"type":"int","required":true}}}`,
func(state *ScheduleState, args map[string]any) string {
taskA, ok := argsInt(args, "task_a")
if !ok {
return "交换失败:缺少必填参数 task_a。"
}
taskB, ok := argsInt(args, "task_b")
if !ok {
return "交换失败:缺少必填参数 task_b。"
}
return Swap(state, taskA, taskB)
},
)
r.Register("batch_move",
"原子性批量移动多个任务全部成功才生效。moves 数组必填。",
`{"name":"batch_move","parameters":{"moves":{"type":"array","required":true,"items":{"task_id":"int","new_day":"int","new_slot_start":"int"}}}}`,
func(state *ScheduleState, args map[string]any) string {
moves, err := argsMoveList(args)
if err != nil {
return fmt.Sprintf("批量移动失败:%s", err.Error())
}
return BatchMove(state, moves)
},
)
r.Register("unplace",
"将一个已安排任务移除恢复为待安排状态。会自动清理嵌入关系。task_id 必填。",
`{"name":"unplace","parameters":{"task_id":{"type":"int","required":true}}}`,
func(state *ScheduleState, args map[string]any) string {
taskID, ok := argsInt(args, "task_id")
if !ok {
return "移除失败:缺少必填参数 task_id。"
}
return Unplace(state, taskID)
},
)
// 按 schema name 排序,保证输出稳定。
sort.Slice(r.schemas, func(i, j int) bool {
return r.schemas[i].Name < r.schemas[j].Name
})
return r
}