Version: 0.8.3.dev.260328
后端: 1.彻底删除原agent文件夹,并将现agent2文件夹全量重命名为agent(包括全部涉及到的文件以及文档、注释),迁移工作完美结束 2.修复了重试消息的相关逻辑问题 前端: 1.改善了一些交互体验,修复了一些bug,现在只剩少的功能了,现存的bug基本都修复完毕 全仓库: 1.更新了决策记录和README文档
This commit is contained in:
504
backend/agent/node/quicknote.go
Normal file
504
backend/agent/node/quicknote.go
Normal file
@@ -0,0 +1,504 @@
|
||||
package agentnode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
agentllm "github.com/LoveLosita/smartflow/backend/agent/llm"
|
||||
agentmodel "github.com/LoveLosita/smartflow/backend/agent/model"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
const (
|
||||
// QuickNoteGraphNodeIntent 是随口记图中的“意图识别”节点名。
|
||||
QuickNoteGraphNodeIntent = "quick_note_intent"
|
||||
// QuickNoteGraphNodeRank 是随口记图中的“优先级评估”节点名。
|
||||
QuickNoteGraphNodeRank = "quick_note_priority"
|
||||
// QuickNoteGraphNodePersist 是随口记图中的“持久化写库”节点名。
|
||||
QuickNoteGraphNodePersist = "quick_note_persist"
|
||||
// QuickNoteGraphNodeExit 是随口记图中的“提前退出”节点名。
|
||||
QuickNoteGraphNodeExit = "quick_note_exit"
|
||||
)
|
||||
|
||||
// QuickNoteGraphRunInput 描述一次随口记图运行所需的请求级依赖。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责把模型、初始状态、工具依赖和阶段回调打包给 graph 层。
|
||||
// 2. 不负责做依赖校验,校验逻辑由 graph/node 构造阶段处理。
|
||||
type QuickNoteGraphRunInput struct {
|
||||
Model *ark.ChatModel
|
||||
State *agentmodel.QuickNoteState
|
||||
Deps QuickNoteToolDeps
|
||||
SkipIntentVerification bool
|
||||
EmitStage func(stage, detail string)
|
||||
}
|
||||
|
||||
// QuickNoteNodes 是随口记图的节点容器。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责承接节点运行时依赖,并向 graph 暴露可直接挂载的方法。
|
||||
// 2. 不负责 graph 编译,也不负责 service 层接口接线。
|
||||
type QuickNoteNodes struct {
|
||||
input QuickNoteGraphRunInput
|
||||
createTaskTool tool.InvokableTool
|
||||
emitStage func(stage, detail string)
|
||||
}
|
||||
|
||||
// NewQuickNoteNodes 负责构造随口记节点容器。
|
||||
//
|
||||
// 输入输出语义:
|
||||
// 1. createTaskTool 不能为空,否则 persist 节点无法落库。
|
||||
// 2. EmitStage 为空时会回退到空实现,避免节点内部到处判空。
|
||||
func NewQuickNoteNodes(input QuickNoteGraphRunInput, createTaskTool tool.InvokableTool) (*QuickNoteNodes, error) {
|
||||
if createTaskTool == nil {
|
||||
return nil, errors.New("quick note nodes: createTaskTool is nil")
|
||||
}
|
||||
|
||||
emitStage := input.EmitStage
|
||||
if emitStage == nil {
|
||||
emitStage = func(stage, detail string) {}
|
||||
}
|
||||
|
||||
return &QuickNoteNodes{
|
||||
input: input,
|
||||
createTaskTool: createTaskTool,
|
||||
emitStage: emitStage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Exit 是图中的显式退出节点。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 仅作为图收口占位,保持状态原样透传。
|
||||
// 2. 不做额外业务处理,避免退出节点再引入副作用。
|
||||
func (n *QuickNoteNodes) Exit(ctx context.Context, st *agentmodel.QuickNoteState) (*agentmodel.QuickNoteState, error) {
|
||||
_ = ctx
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// NextAfterIntent 根据意图识别结果决定 intent 节点后的分支走向。
|
||||
//
|
||||
// 步骤说明:
|
||||
// 1. 非随口记意图时直接退出,避免误把普通聊天写成任务。
|
||||
// 2. 截止时间校验失败时同样直接退出,让上层优先把错误提示给用户。
|
||||
// 3. 只有意图成立且时间合法,才进入优先级评估节点。
|
||||
func (n *QuickNoteNodes) NextAfterIntent(ctx context.Context, st *agentmodel.QuickNoteState) (string, error) {
|
||||
_ = ctx
|
||||
if st == nil || !st.IsQuickNoteIntent {
|
||||
return QuickNoteGraphNodeExit, nil
|
||||
}
|
||||
if st.DeadlineValidationError != "" {
|
||||
return QuickNoteGraphNodeExit, nil
|
||||
}
|
||||
return QuickNoteGraphNodeRank, nil
|
||||
}
|
||||
|
||||
// NextAfterPersist 根据持久化结果决定 persist 节点后的分支走向。
|
||||
//
|
||||
// 输入输出语义:
|
||||
// 1. Persisted=true 表示已经成功写库,可以直接结束。
|
||||
// 2. Persisted=false 且 CanRetryTool()=true 表示继续重试写库。
|
||||
// 3. 重试用尽后会补齐兜底回复,再结束链路,避免用户拿到空响应。
|
||||
func (n *QuickNoteNodes) NextAfterPersist(ctx context.Context, st *agentmodel.QuickNoteState) (string, error) {
|
||||
_ = ctx
|
||||
if st == nil {
|
||||
return compose.END, nil
|
||||
}
|
||||
if st.Persisted {
|
||||
return compose.END, nil
|
||||
}
|
||||
if st.CanRetryTool() {
|
||||
return QuickNoteGraphNodePersist, nil
|
||||
}
|
||||
if st.AssistantReply == "" {
|
||||
st.AssistantReply = "抱歉,我已经重试了多次,还是没能成功记录这条任务,请稍后再试。"
|
||||
}
|
||||
return compose.END, nil
|
||||
}
|
||||
|
||||
// Intent 负责“意图识别 + 聚合规划 + 时间校验”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责判断本次请求是否属于随口记;
|
||||
// 2. 负责把模型规划结果回填到 state;
|
||||
// 3. 负责做最后一层本地时间硬校验,避免非法时间被静默写成 NULL;
|
||||
// 4. 不负责真正写库。
|
||||
func (n *QuickNoteNodes) Intent(ctx context.Context, st *agentmodel.QuickNoteState) (*agentmodel.QuickNoteState, error) {
|
||||
if st == nil {
|
||||
return nil, errors.New("quick note graph: nil state in intent node")
|
||||
}
|
||||
|
||||
// 1. 若上游路由已经高置信命中 quick_note,则直接进入单次聚合规划。
|
||||
// 1.1 目的:尽量把“标题 / 时间 / 优先级 / banter”压缩到一次模型往返内;
|
||||
// 1.2 失败处理:若聚合规划失败,不中断整条链路,而是回退到本地兜底,保证可用性优先。
|
||||
if n.input.SkipIntentVerification {
|
||||
n.emitStage("quick_note.intent.analyzing", "已由上游路由判定为任务请求,跳过二次意图判断。")
|
||||
st.IsQuickNoteIntent = true
|
||||
st.IntentJudgeReason = "上游路由已命中 quick_note,跳过二次意图判定"
|
||||
st.PlannedBySingleCall = true
|
||||
|
||||
n.emitStage("quick_note.plan.generating", "正在一次性生成时间归一化、优先级与回复润色。")
|
||||
plan, planErr := planQuickNoteInSingleCall(ctx, n.input.Model, st.RequestNowText, st.RequestNow, st.UserInput)
|
||||
if planErr != nil {
|
||||
st.IntentJudgeReason += ";聚合规划失败,回退本地兜底"
|
||||
} else {
|
||||
if strings.TrimSpace(plan.Title) != "" {
|
||||
st.ExtractedTitle = strings.TrimSpace(plan.Title)
|
||||
}
|
||||
if plan.Deadline != nil {
|
||||
st.ExtractedDeadline = plan.Deadline
|
||||
}
|
||||
st.ExtractedDeadlineText = strings.TrimSpace(plan.DeadlineText)
|
||||
if plan.UrgencyThreshold != nil {
|
||||
st.ExtractedUrgencyThreshold = normalizeUrgencyThreshold(plan.UrgencyThreshold, plan.Deadline)
|
||||
}
|
||||
if agentmodel.IsValidTaskPriority(plan.PriorityGroup) {
|
||||
st.ExtractedPriority = plan.PriorityGroup
|
||||
st.ExtractedPriorityReason = strings.TrimSpace(plan.PriorityReason)
|
||||
}
|
||||
st.ExtractedBanter = strings.TrimSpace(plan.Banter)
|
||||
}
|
||||
|
||||
// 1.3 如果聚合规划没能给出标题,则回退到本地标题抽取,避免后续 persist 节点拿到空标题。
|
||||
if strings.TrimSpace(st.ExtractedTitle) == "" {
|
||||
st.ExtractedTitle = deriveQuickNoteTitleFromInput(st.UserInput)
|
||||
}
|
||||
|
||||
// 1.4 最后一定要做一轮本地时间硬校验。
|
||||
// 1.4.1 原因:模型即使给了时间,也可能和用户原句不一致,或者用户原句本身就是非法时间;
|
||||
// 1.4.2 若检测到“用户给了时间线索但格式非法”,直接退出图并给用户明确修正提示。
|
||||
n.emitStage("quick_note.deadline.validating", "正在校验并归一化任务时间。")
|
||||
userDeadline, userHasTimeHint, userDeadlineErr := parseOptionalDeadlineFromUserInput(st.UserInput, st.RequestNow)
|
||||
if userHasTimeHint && userDeadlineErr != nil {
|
||||
st.DeadlineValidationError = userDeadlineErr.Error()
|
||||
st.AssistantReply = "我识别到你给了时间信息,但这个时间格式我没法准确解析,请改成例如:2026-03-20 18:30、明天下午3点、下周一上午9点。"
|
||||
n.emitStage("quick_note.failed", "时间校验失败,未执行写入。")
|
||||
return st, nil
|
||||
}
|
||||
if userDeadline != nil {
|
||||
st.ExtractedDeadline = userDeadline
|
||||
st.ExtractedDeadlineText = strings.TrimSpace(st.UserInput)
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 2. 常规路径:先做一次意图识别,再做本地时间硬校验。
|
||||
n.emitStage("quick_note.intent.analyzing", "正在分析用户输入是否属于任务安排请求。")
|
||||
parsed, callErr := agentllm.IdentifyQuickNoteIntent(ctx, n.input.Model, st.RequestNowText, st.UserInput)
|
||||
if callErr != nil {
|
||||
// 2.1 这里不直接返回 error,而是把它视为“本次未能确认是 quick note”,交给上层回退普通聊天。
|
||||
st.IsQuickNoteIntent = false
|
||||
st.IntentJudgeReason = "意图识别失败,回退普通聊天"
|
||||
return st, nil
|
||||
}
|
||||
|
||||
st.IsQuickNoteIntent = parsed.IsQuickNote
|
||||
st.IntentJudgeReason = strings.TrimSpace(parsed.Reason)
|
||||
if !st.IsQuickNoteIntent {
|
||||
return st, nil
|
||||
}
|
||||
|
||||
title := strings.TrimSpace(parsed.Title)
|
||||
if title == "" {
|
||||
title = strings.TrimSpace(st.UserInput)
|
||||
}
|
||||
st.ExtractedTitle = title
|
||||
|
||||
n.emitStage("quick_note.deadline.validating", "正在校验并归一化任务时间。")
|
||||
|
||||
// 2.2 先尝试吃模型返回的 deadline_at,用于减少后续重复推理。
|
||||
st.ExtractedDeadlineText = strings.TrimSpace(parsed.DeadlineAt)
|
||||
if st.ExtractedDeadlineText != "" {
|
||||
if deadline, deadlineErr := parseOptionalDeadlineWithNow(st.ExtractedDeadlineText, st.RequestNow); deadlineErr == nil {
|
||||
st.ExtractedDeadline = deadline
|
||||
}
|
||||
}
|
||||
|
||||
// 2.3 再强制对用户原句做一次时间线索校验。
|
||||
userDeadline, userHasTimeHint, userDeadlineErr := parseOptionalDeadlineFromUserInput(st.UserInput, st.RequestNow)
|
||||
if userHasTimeHint && userDeadlineErr != nil {
|
||||
st.DeadlineValidationError = userDeadlineErr.Error()
|
||||
st.AssistantReply = "我识别到你给了时间信息,但这个时间格式我没法准确解析,请改成例如:2026-03-20 18:30、明天下午3点、下周一上午9点。"
|
||||
n.emitStage("quick_note.failed", "时间校验失败,未执行写入。")
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 2.4 若模型没提到 deadline,但用户原句能解析出来,则以用户原句为准补齐。
|
||||
if st.ExtractedDeadline == nil && userDeadline != nil {
|
||||
st.ExtractedDeadline = userDeadline
|
||||
if st.ExtractedDeadlineText == "" {
|
||||
st.ExtractedDeadlineText = strings.TrimSpace(st.UserInput)
|
||||
}
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// Priority 负责“优先级评估”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责在 intent 节点之后补齐 priority_group;
|
||||
// 2. 若聚合规划已经给出合法优先级,则直接复用,不再重复调用模型;
|
||||
// 3. 若模型评估失败,则使用本地兜底策略,保证链路继续可走;
|
||||
// 4. 不负责写库。
|
||||
func (n *QuickNoteNodes) Priority(ctx context.Context, st *agentmodel.QuickNoteState) (*agentmodel.QuickNoteState, error) {
|
||||
if st == nil {
|
||||
return nil, errors.New("quick note graph: nil state in priority node")
|
||||
}
|
||||
if !st.IsQuickNoteIntent || strings.TrimSpace(st.DeadlineValidationError) != "" {
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 1. 聚合规划已经给出合法优先级时,直接复用,避免重复调模型。
|
||||
if agentmodel.IsValidTaskPriority(st.ExtractedPriority) {
|
||||
if strings.TrimSpace(st.ExtractedPriorityReason) == "" {
|
||||
st.ExtractedPriorityReason = "复用聚合规划优先级"
|
||||
}
|
||||
n.emitStage("quick_note.priority.evaluating", "已复用聚合规划结果中的优先级。")
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 2. 单请求聚合路径若没有给出合法 priority,则直接走本地兜底,优先保证低时延。
|
||||
if n.input.SkipIntentVerification || st.PlannedBySingleCall {
|
||||
st.ExtractedPriority = fallbackPriority(st)
|
||||
st.ExtractedPriorityReason = "聚合规划未给出合法优先级,使用本地兜底"
|
||||
n.emitStage("quick_note.priority.evaluating", "聚合优先级缺失,已使用本地兜底。")
|
||||
return st, nil
|
||||
}
|
||||
|
||||
n.emitStage("quick_note.priority.evaluating", "正在评估任务优先级。")
|
||||
deadlineText := "无"
|
||||
if st.ExtractedDeadline != nil {
|
||||
deadlineText = formatQuickNoteTimeToMinute(*st.ExtractedDeadline)
|
||||
}
|
||||
deadlineClue := strings.TrimSpace(st.ExtractedDeadlineText)
|
||||
if deadlineClue == "" {
|
||||
deadlineClue = "无"
|
||||
}
|
||||
|
||||
parsed, callErr := agentllm.PlanQuickNotePriority(ctx, n.input.Model, st.RequestNowText, st.ExtractedTitle, st.UserInput, deadlineClue, deadlineText)
|
||||
if callErr != nil {
|
||||
st.ExtractedPriority = fallbackPriority(st)
|
||||
st.ExtractedPriorityReason = "优先级评估失败,使用兜底策略"
|
||||
return st, nil
|
||||
}
|
||||
if parsed == nil || !agentmodel.IsValidTaskPriority(parsed.PriorityGroup) {
|
||||
st.ExtractedPriority = fallbackPriority(st)
|
||||
st.ExtractedPriorityReason = "优先级结果异常,使用兜底策略"
|
||||
return st, nil
|
||||
}
|
||||
|
||||
st.ExtractedPriority = parsed.PriorityGroup
|
||||
st.ExtractedPriorityReason = strings.TrimSpace(parsed.Reason)
|
||||
if strings.TrimSpace(parsed.UrgencyThresholdAt) != "" {
|
||||
urgencyThreshold, thresholdErr := parseOptionalDeadlineWithNow(strings.TrimSpace(parsed.UrgencyThresholdAt), st.RequestNow)
|
||||
if thresholdErr == nil {
|
||||
st.ExtractedUrgencyThreshold = normalizeUrgencyThreshold(urgencyThreshold, st.ExtractedDeadline)
|
||||
}
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// Persist 负责“调工具写库 + 有限次重试状态回填”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责把 state 中已提取出的标题、时间、优先级组装成工具入参;
|
||||
// 2. 负责调用 createTaskTool 执行真正写库;
|
||||
// 3. 负责把成功/失败结果回填到 state,供后续分支与回复使用;
|
||||
// 4. 不负责最终回复润色,不负责 service 层的 Redis 与持久化收尾。
|
||||
func (n *QuickNoteNodes) Persist(ctx context.Context, st *agentmodel.QuickNoteState) (*agentmodel.QuickNoteState, error) {
|
||||
if st == nil {
|
||||
return nil, errors.New("quick note graph: nil state in persist node")
|
||||
}
|
||||
if !st.IsQuickNoteIntent || strings.TrimSpace(st.DeadlineValidationError) != "" {
|
||||
return st, nil
|
||||
}
|
||||
|
||||
n.emitStage("quick_note.persisting", "正在写入任务数据。")
|
||||
priority := st.ExtractedPriority
|
||||
if !agentmodel.IsValidTaskPriority(priority) {
|
||||
priority = fallbackPriority(st)
|
||||
st.ExtractedPriority = priority
|
||||
}
|
||||
|
||||
deadlineText := ""
|
||||
if st.ExtractedDeadline != nil {
|
||||
deadlineText = st.ExtractedDeadline.In(quickNoteLocation()).Format(time.RFC3339)
|
||||
}
|
||||
urgencyThresholdText := ""
|
||||
if st.ExtractedUrgencyThreshold != nil {
|
||||
urgencyThresholdText = st.ExtractedUrgencyThreshold.In(quickNoteLocation()).Format(time.RFC3339)
|
||||
}
|
||||
|
||||
toolInput := QuickNoteCreateTaskToolInput{
|
||||
Title: st.ExtractedTitle,
|
||||
PriorityGroup: priority,
|
||||
DeadlineAt: deadlineText,
|
||||
UrgencyThresholdAt: urgencyThresholdText,
|
||||
}
|
||||
rawInput, marshalErr := json.Marshal(toolInput)
|
||||
if marshalErr != nil {
|
||||
st.RecordToolError("构造工具参数失败: " + marshalErr.Error())
|
||||
if !st.CanRetryTool() {
|
||||
st.AssistantReply = "抱歉,记录任务时参数处理失败,请稍后重试。"
|
||||
n.emitStage("quick_note.failed", "参数构造失败,未完成写入。")
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
rawOutput, invokeErr := n.createTaskTool.InvokableRun(ctx, string(rawInput))
|
||||
if invokeErr != nil {
|
||||
st.RecordToolError(invokeErr.Error())
|
||||
if !st.CanRetryTool() {
|
||||
st.AssistantReply = "抱歉,我尝试了多次仍未能成功记录这条任务,请稍后再试。"
|
||||
n.emitStage("quick_note.failed", "多次重试后仍未完成写入。")
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
toolOutput, parseErr := agentllm.ParseJSONObject[QuickNoteCreateTaskToolOutput](rawOutput)
|
||||
if parseErr != nil {
|
||||
st.RecordToolError("解析工具返回失败: " + parseErr.Error())
|
||||
if !st.CanRetryTool() {
|
||||
st.AssistantReply = "抱歉,我拿到了异常结果,没能确认任务是否记录成功,请稍后再试。"
|
||||
n.emitStage("quick_note.failed", "结果解析异常,无法确认写入结果。")
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
if toolOutput.TaskID <= 0 {
|
||||
st.RecordToolError(fmt.Sprintf("工具返回非法 task_id=%d", toolOutput.TaskID))
|
||||
if !st.CanRetryTool() {
|
||||
st.AssistantReply = "抱歉,这次我没能确认任务写入成功,请再发一次我立刻补上。"
|
||||
n.emitStage("quick_note.failed", "写入结果缺少有效 task_id,已终止成功回包。")
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 1. 只有拿到有效 task_id,才视为真正写入成功;
|
||||
// 2. 这样可以避免出现“返回成功文案,但数据库里根本没记录”的假成功。
|
||||
st.RecordToolSuccess(toolOutput.TaskID)
|
||||
if strings.TrimSpace(toolOutput.Title) != "" {
|
||||
st.ExtractedTitle = strings.TrimSpace(toolOutput.Title)
|
||||
}
|
||||
if agentmodel.IsValidTaskPriority(toolOutput.PriorityGroup) {
|
||||
st.ExtractedPriority = toolOutput.PriorityGroup
|
||||
}
|
||||
|
||||
reply := strings.TrimSpace(toolOutput.Message)
|
||||
if reply == "" {
|
||||
reply = fmt.Sprintf("已为你记录:%s(%s)", st.ExtractedTitle, agentmodel.PriorityLabelCN(st.ExtractedPriority))
|
||||
}
|
||||
st.AssistantReply = reply
|
||||
n.emitStage("quick_note.persisted", "任务写入成功,正在组织回复内容。")
|
||||
return st, nil
|
||||
}
|
||||
|
||||
type quickNotePlannedResult struct {
|
||||
Title string
|
||||
Deadline *time.Time
|
||||
DeadlineText string
|
||||
UrgencyThreshold *time.Time
|
||||
UrgencyThresholdText string
|
||||
PriorityGroup int
|
||||
PriorityReason string
|
||||
Banter string
|
||||
}
|
||||
|
||||
// planQuickNoteInSingleCall 在一次模型调用里完成“时间 / 优先级 / banter”聚合规划。
|
||||
func planQuickNoteInSingleCall(ctx context.Context, chatModel *ark.ChatModel, nowText string, now time.Time, userInput string) (*quickNotePlannedResult, error) {
|
||||
parsed, err := agentllm.PlanQuickNoteInSingleCall(ctx, chatModel, nowText, userInput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &quickNotePlannedResult{
|
||||
Title: strings.TrimSpace(parsed.Title),
|
||||
DeadlineText: strings.TrimSpace(parsed.DeadlineAt),
|
||||
UrgencyThresholdText: strings.TrimSpace(parsed.UrgencyThresholdAt),
|
||||
PriorityGroup: parsed.PriorityGroup,
|
||||
PriorityReason: strings.TrimSpace(parsed.PriorityReason),
|
||||
Banter: strings.TrimSpace(parsed.Banter),
|
||||
}
|
||||
if result.Banter != "" {
|
||||
if idx := strings.Index(result.Banter, "\n"); idx >= 0 {
|
||||
result.Banter = strings.TrimSpace(result.Banter[:idx])
|
||||
}
|
||||
}
|
||||
if result.DeadlineText != "" {
|
||||
if deadline, deadlineErr := parseOptionalDeadlineWithNow(result.DeadlineText, now); deadlineErr == nil {
|
||||
result.Deadline = deadline
|
||||
}
|
||||
}
|
||||
if result.UrgencyThresholdText != "" {
|
||||
if urgencyThreshold, thresholdErr := parseOptionalDeadlineWithNow(result.UrgencyThresholdText, now); thresholdErr == nil {
|
||||
result.UrgencyThreshold = normalizeUrgencyThreshold(urgencyThreshold, result.Deadline)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func normalizeUrgencyThreshold(threshold *time.Time, deadline *time.Time) *time.Time {
|
||||
if threshold == nil {
|
||||
return nil
|
||||
}
|
||||
if deadline == nil {
|
||||
return threshold
|
||||
}
|
||||
if threshold.After(*deadline) {
|
||||
normalized := *deadline
|
||||
return &normalized
|
||||
}
|
||||
return threshold
|
||||
}
|
||||
|
||||
func fallbackPriority(st *agentmodel.QuickNoteState) int {
|
||||
if st == nil {
|
||||
return agentmodel.QuickNotePrioritySimpleNotImportant
|
||||
}
|
||||
if st.ExtractedDeadline != nil {
|
||||
if time.Until(*st.ExtractedDeadline) <= 48*time.Hour {
|
||||
return agentmodel.QuickNotePriorityImportantUrgent
|
||||
}
|
||||
return agentmodel.QuickNotePriorityImportantNotUrgent
|
||||
}
|
||||
return agentmodel.QuickNotePrioritySimpleNotImportant
|
||||
}
|
||||
|
||||
// deriveQuickNoteTitleFromInput 在“跳过二次意图判定”场景下,从用户原句提取任务标题。
|
||||
func deriveQuickNoteTitleFromInput(userInput string) string {
|
||||
text := strings.TrimSpace(userInput)
|
||||
if text == "" {
|
||||
return "这条任务"
|
||||
}
|
||||
|
||||
prefixes := []string{
|
||||
"请帮我", "麻烦帮我", "麻烦你", "帮我", "提醒我", "请提醒我", "记一个", "记个", "帮我记一个",
|
||||
}
|
||||
for _, prefix := range prefixes {
|
||||
if strings.HasPrefix(text, prefix) {
|
||||
text = strings.TrimSpace(strings.TrimPrefix(text, prefix))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
suffixSeparators := []string{
|
||||
",记得", ",记得", ",到时候", ",到时候", " 到时候", ",别忘了", ",别忘了", "。记得",
|
||||
}
|
||||
for _, sep := range suffixSeparators {
|
||||
if idx := strings.Index(text, sep); idx > 0 {
|
||||
text = strings.TrimSpace(text[:idx])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
text = strings.Trim(text, ",。?!!? ")
|
||||
if text == "" {
|
||||
return strings.TrimSpace(userInput)
|
||||
}
|
||||
return text
|
||||
}
|
||||
585
backend/agent/node/quicknote_tool.go
Normal file
585
backend/agent/node/quicknote_tool.go
Normal file
@@ -0,0 +1,585 @@
|
||||
package agentnode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
agentmodel "github.com/LoveLosita/smartflow/backend/agent/model"
|
||||
agentshared "github.com/LoveLosita/smartflow/backend/agent/shared"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
toolutils "github.com/cloudwego/eino/components/tool/utils"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
const (
|
||||
// ToolNameQuickNoteCreateTask 是“AI随口记”写库工具的稳定名称。
|
||||
ToolNameQuickNoteCreateTask = "quick_note_create_task"
|
||||
// ToolDescQuickNoteCreateTask 是给大模型看的工具职责说明。
|
||||
ToolDescQuickNoteCreateTask = "把用户随口提到的事项落库为任务,支持可选截止时间与优先级"
|
||||
)
|
||||
|
||||
var (
|
||||
quickNoteDeadlineLayouts = []string{
|
||||
time.RFC3339,
|
||||
"2006-01-02T15:04",
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02 15:04",
|
||||
"2006/01/02 15:04:05",
|
||||
"2006/01/02 15:04",
|
||||
"2006.01.02 15:04:05",
|
||||
"2006.01.02 15:04",
|
||||
"2006-01-02",
|
||||
"2006/01/02",
|
||||
"2006.01.02",
|
||||
}
|
||||
quickNoteDateOnlyLayouts = map[string]struct{}{
|
||||
"2006-01-02": {},
|
||||
"2006/01/02": {},
|
||||
"2006.01.02": {},
|
||||
}
|
||||
|
||||
quickNoteClockHMRegex = regexp.MustCompile(`(\d{1,2})\s*[::]\s*(\d{1,2})`)
|
||||
quickNoteClockCNRegex = regexp.MustCompile(`(\d{1,2})\s*点\s*(半|(\d{1,2})\s*分?)?`)
|
||||
quickNoteYMDRegex = regexp.MustCompile(`(\d{4})\s*年\s*(\d{1,2})\s*月\s*(\d{1,2})\s*[日号]?`)
|
||||
quickNoteMDRegex = regexp.MustCompile(`(\d{1,2})\s*月\s*(\d{1,2})\s*[日号]?`)
|
||||
quickNoteDateSepRegex = regexp.MustCompile(`\d{1,4}\s*[-/.]\s*\d{1,2}(\s*[-/.]\s*\d{1,2})?`)
|
||||
quickNoteWeekdayRegex = regexp.MustCompile(`(下周|下星期|下礼拜|本周|这周|本星期|这星期|周|星期|礼拜)([一二三四五六日天])`)
|
||||
quickNoteRelativeTokens = []string{
|
||||
"今天", "今日", "今晚", "今早", "今晨", "明天", "明日", "后天", "大后天", "昨天", "昨日",
|
||||
"早上", "早晨", "上午", "中午", "下午", "晚上", "傍晚", "夜里", "凌晨",
|
||||
}
|
||||
)
|
||||
|
||||
// QuickNoteToolDeps 描述随口记工具所需的外部依赖。
|
||||
type QuickNoteToolDeps struct {
|
||||
ResolveUserID func(ctx context.Context) (int, error)
|
||||
CreateTask func(ctx context.Context, req QuickNoteCreateTaskRequest) (*QuickNoteCreateTaskResult, error)
|
||||
}
|
||||
|
||||
func (d QuickNoteToolDeps) Validate() error {
|
||||
if d.ResolveUserID == nil {
|
||||
return errors.New("quick note tool deps: ResolveUserID is nil")
|
||||
}
|
||||
if d.CreateTask == nil {
|
||||
return errors.New("quick note tool deps: CreateTask is nil")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// QuickNoteToolBundle 是随口记工具集合。
|
||||
type QuickNoteToolBundle struct {
|
||||
Tools []tool.BaseTool
|
||||
ToolInfos []*schema.ToolInfo
|
||||
}
|
||||
|
||||
// QuickNoteCreateTaskRequest 是工具层传给业务层的内部请求。
|
||||
type QuickNoteCreateTaskRequest struct {
|
||||
UserID int
|
||||
Title string
|
||||
PriorityGroup int
|
||||
DeadlineAt *time.Time
|
||||
UrgencyThresholdAt *time.Time
|
||||
}
|
||||
|
||||
// QuickNoteCreateTaskResult 是业务层回给工具层的结构化结果。
|
||||
type QuickNoteCreateTaskResult struct {
|
||||
TaskID int
|
||||
Title string
|
||||
PriorityGroup int
|
||||
DeadlineAt *time.Time
|
||||
UrgencyThresholdAt *time.Time
|
||||
}
|
||||
|
||||
// QuickNoteCreateTaskToolInput 是暴露给模型的工具入参。
|
||||
type QuickNoteCreateTaskToolInput struct {
|
||||
Title string `json:"title" jsonschema:"required,description=任务标题,简洁明确"`
|
||||
// PriorityGroup 与 tasks.priority 保持一致,取值 1~4。
|
||||
PriorityGroup int `json:"priority_group" jsonschema:"required,enum=1,enum=2,enum=3,enum=4,description=优先级分组(1重要且紧急,2重要不紧急,3简单不重要,4复杂不重要)"`
|
||||
// DeadlineAt 支持绝对时间与常见中文相对时间。
|
||||
DeadlineAt string `json:"deadline_at,omitempty" jsonschema:"description=可选截止时间,支持RFC3339、yyyy-MM-dd HH:mm:ss、yyyy-MM-dd HH:mm 以及常见中文相对时间"`
|
||||
// UrgencyThresholdAt 表示何时自动进入紧急象限。
|
||||
UrgencyThresholdAt string `json:"urgency_threshold_at,omitempty" jsonschema:"description=可选紧急分界时间,支持与deadline_at相同格式"`
|
||||
}
|
||||
|
||||
// QuickNoteCreateTaskToolOutput 是返回给模型的结构化结果。
|
||||
type QuickNoteCreateTaskToolOutput struct {
|
||||
TaskID int `json:"task_id"`
|
||||
Title string `json:"title"`
|
||||
PriorityGroup int `json:"priority_group"`
|
||||
PriorityLabel string `json:"priority_label"`
|
||||
DeadlineAt string `json:"deadline_at,omitempty"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// BuildQuickNoteToolBundle 构建随口记工具包。
|
||||
func BuildQuickNoteToolBundle(ctx context.Context, deps QuickNoteToolDeps) (*QuickNoteToolBundle, error) {
|
||||
if err := deps.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
createTaskTool, err := toolutils.InferTool(
|
||||
ToolNameQuickNoteCreateTask,
|
||||
ToolDescQuickNoteCreateTask,
|
||||
func(ctx context.Context, input *QuickNoteCreateTaskToolInput) (*QuickNoteCreateTaskToolOutput, error) {
|
||||
if input == nil {
|
||||
return nil, errors.New("工具参数不能为空")
|
||||
}
|
||||
|
||||
title := strings.TrimSpace(input.Title)
|
||||
if title == "" {
|
||||
return nil, errors.New("title 不能为空")
|
||||
}
|
||||
if !agentmodel.IsValidTaskPriority(input.PriorityGroup) {
|
||||
return nil, fmt.Errorf("priority_group=%d 非法,必须在 1~4", input.PriorityGroup)
|
||||
}
|
||||
|
||||
deadline, err := parseOptionalDeadline(input.DeadlineAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
urgencyThresholdAt, err := parseOptionalDeadline(input.UrgencyThresholdAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userID, err := deps.ResolveUserID(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解析用户身份失败: %w", err)
|
||||
}
|
||||
if userID <= 0 {
|
||||
return nil, fmt.Errorf("非法 user_id=%d", userID)
|
||||
}
|
||||
|
||||
result, err := deps.CreateTask(ctx, QuickNoteCreateTaskRequest{
|
||||
UserID: userID,
|
||||
Title: title,
|
||||
PriorityGroup: input.PriorityGroup,
|
||||
DeadlineAt: deadline,
|
||||
UrgencyThresholdAt: urgencyThresholdAt,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result == nil || result.TaskID <= 0 {
|
||||
return nil, errors.New("写入任务后返回结果异常")
|
||||
}
|
||||
|
||||
finalTitle := title
|
||||
if strings.TrimSpace(result.Title) != "" {
|
||||
finalTitle = strings.TrimSpace(result.Title)
|
||||
}
|
||||
finalPriority := input.PriorityGroup
|
||||
if agentmodel.IsValidTaskPriority(result.PriorityGroup) {
|
||||
finalPriority = result.PriorityGroup
|
||||
}
|
||||
|
||||
deadlineStr := ""
|
||||
if result.DeadlineAt != nil {
|
||||
deadlineStr = result.DeadlineAt.In(quickNoteLocation()).Format(time.RFC3339)
|
||||
} else if deadline != nil {
|
||||
deadlineStr = deadline.In(quickNoteLocation()).Format(time.RFC3339)
|
||||
}
|
||||
|
||||
return &QuickNoteCreateTaskToolOutput{
|
||||
TaskID: result.TaskID,
|
||||
Title: finalTitle,
|
||||
PriorityGroup: finalPriority,
|
||||
PriorityLabel: agentmodel.PriorityLabelCN(finalPriority),
|
||||
DeadlineAt: deadlineStr,
|
||||
Message: fmt.Sprintf("已记录:%s(%s)", finalTitle, agentmodel.PriorityLabelCN(finalPriority)),
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("构建随口记工具失败: %w", err)
|
||||
}
|
||||
|
||||
tools := []tool.BaseTool{createTaskTool}
|
||||
infos, err := collectToolInfos(ctx, tools)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &QuickNoteToolBundle{
|
||||
Tools: tools,
|
||||
ToolInfos: infos,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetInvokableToolByName 通过工具名提取可执行工具实例。
|
||||
func GetInvokableToolByName(bundle *QuickNoteToolBundle, name string) (tool.InvokableTool, error) {
|
||||
if bundle == nil {
|
||||
return nil, errors.New("tool bundle is nil")
|
||||
}
|
||||
return getInvokableToolByName(bundle.Tools, bundle.ToolInfos, name)
|
||||
}
|
||||
|
||||
// parseOptionalDeadline 解析工具输入中的可选截止时间。
|
||||
func parseOptionalDeadline(raw string) (*time.Time, error) {
|
||||
value := normalizeDeadlineInput(raw)
|
||||
if value == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
deadline, hasHint, err := parseOptionalDeadlineFromText(value, quickNoteNowToMinute())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if deadline == nil {
|
||||
if !hasHint {
|
||||
return nil, fmt.Errorf("deadline_at 格式不支持: %s", value)
|
||||
}
|
||||
return nil, fmt.Errorf("deadline_at 无法解析: %s", value)
|
||||
}
|
||||
return deadline, nil
|
||||
}
|
||||
|
||||
// parseOptionalDeadlineWithNow 在给定时间基准下解析 deadline。
|
||||
func parseOptionalDeadlineWithNow(raw string, now time.Time) (*time.Time, error) {
|
||||
value := normalizeDeadlineInput(raw)
|
||||
if value == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
deadline, _, err := parseOptionalDeadlineFromText(value, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if deadline == nil {
|
||||
return nil, fmt.Errorf("deadline_at 格式不支持: %s", value)
|
||||
}
|
||||
return deadline, nil
|
||||
}
|
||||
|
||||
// parseOptionalDeadlineFromUserInput 是“用户原句解析”的宽松入口。
|
||||
func parseOptionalDeadlineFromUserInput(raw string, now time.Time) (*time.Time, bool, error) {
|
||||
value := normalizeDeadlineInput(raw)
|
||||
if value == "" {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
deadline, hasHint, err := parseOptionalDeadlineFromText(value, now)
|
||||
if err != nil {
|
||||
if hasHint {
|
||||
return nil, true, err
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
if deadline == nil {
|
||||
if hasHint {
|
||||
return nil, true, fmt.Errorf("deadline_at 无法解析: %s", value)
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
return deadline, true, nil
|
||||
}
|
||||
|
||||
// parseOptionalDeadlineFromText 是内部通用时间解析器。
|
||||
func parseOptionalDeadlineFromText(value string, now time.Time) (*time.Time, bool, error) {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
loc := quickNoteLocation()
|
||||
now = now.In(loc)
|
||||
hasHint := hasDeadlineHint(value)
|
||||
|
||||
if abs, ok := tryParseAbsoluteDeadline(value, loc); ok {
|
||||
return abs, true, nil
|
||||
}
|
||||
if rel, recognized, err := tryParseRelativeDeadline(value, now, loc); recognized {
|
||||
if err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
return rel, true, nil
|
||||
}
|
||||
if hasHint {
|
||||
return nil, true, fmt.Errorf("deadline_at 格式不支持: %s", value)
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func normalizeDeadlineInput(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
replacer := strings.NewReplacer(
|
||||
":", ":",
|
||||
",", ",",
|
||||
"。", ".",
|
||||
" ", " ",
|
||||
)
|
||||
return strings.TrimSpace(replacer.Replace(trimmed))
|
||||
}
|
||||
|
||||
func hasDeadlineHint(value string) bool {
|
||||
if quickNoteClockHMRegex.MatchString(value) ||
|
||||
quickNoteClockCNRegex.MatchString(value) ||
|
||||
quickNoteYMDRegex.MatchString(value) ||
|
||||
quickNoteMDRegex.MatchString(value) ||
|
||||
quickNoteDateSepRegex.MatchString(value) ||
|
||||
quickNoteWeekdayRegex.MatchString(value) {
|
||||
return true
|
||||
}
|
||||
for _, token := range quickNoteRelativeTokens {
|
||||
if strings.Contains(value, token) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func tryParseAbsoluteDeadline(value string, loc *time.Location) (*time.Time, bool) {
|
||||
for _, layout := range quickNoteDeadlineLayouts {
|
||||
var (
|
||||
parsed time.Time
|
||||
err error
|
||||
)
|
||||
if layout == time.RFC3339 {
|
||||
parsed, err = time.Parse(layout, value)
|
||||
if err == nil {
|
||||
parsed = parsed.In(loc)
|
||||
}
|
||||
} else {
|
||||
parsed, err = time.ParseInLocation(layout, value, loc)
|
||||
}
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, dateOnly := quickNoteDateOnlyLayouts[layout]; dateOnly {
|
||||
parsed = time.Date(parsed.Year(), parsed.Month(), parsed.Day(), 23, 59, 0, 0, loc)
|
||||
} else {
|
||||
parsed = time.Date(parsed.Year(), parsed.Month(), parsed.Day(), parsed.Hour(), parsed.Minute(), 0, 0, loc)
|
||||
}
|
||||
return &parsed, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func tryParseRelativeDeadline(value string, now time.Time, loc *time.Location) (*time.Time, bool, error) {
|
||||
baseDate, recognized := inferBaseDate(value, now, loc)
|
||||
if !recognized {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
hour, minute, hasExplicitClock, err := extractClock(value)
|
||||
if err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
if !hasExplicitClock {
|
||||
hour, minute = defaultClockByHint(value)
|
||||
}
|
||||
|
||||
deadline := time.Date(baseDate.Year(), baseDate.Month(), baseDate.Day(), hour, minute, 0, 0, loc)
|
||||
return &deadline, true, nil
|
||||
}
|
||||
|
||||
func inferBaseDate(value string, now time.Time, loc *time.Location) (time.Time, bool) {
|
||||
if matched := quickNoteYMDRegex.FindStringSubmatch(value); len(matched) == 4 {
|
||||
year, _ := strconv.Atoi(matched[1])
|
||||
month, _ := strconv.Atoi(matched[2])
|
||||
day, _ := strconv.Atoi(matched[3])
|
||||
if isValidDate(year, month, day) {
|
||||
return time.Date(year, time.Month(month), day, 0, 0, 0, 0, loc), true
|
||||
}
|
||||
}
|
||||
|
||||
if matched := quickNoteMDRegex.FindStringSubmatch(value); len(matched) == 3 {
|
||||
month, _ := strconv.Atoi(matched[1])
|
||||
day, _ := strconv.Atoi(matched[2])
|
||||
year := now.Year()
|
||||
if !isValidDate(year, month, day) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
candidate := time.Date(year, time.Month(month), day, 0, 0, 0, 0, loc)
|
||||
if candidate.Before(startOfDay(now)) {
|
||||
year++
|
||||
if !isValidDate(year, month, day) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
candidate = time.Date(year, time.Month(month), day, 0, 0, 0, 0, loc)
|
||||
}
|
||||
return candidate, true
|
||||
}
|
||||
|
||||
if matched := quickNoteWeekdayRegex.FindStringSubmatch(value); len(matched) == 3 {
|
||||
prefix := matched[1]
|
||||
target, ok := toWeekday(matched[2])
|
||||
if ok {
|
||||
return resolveWeekdayDate(now, prefix, target), true
|
||||
}
|
||||
}
|
||||
|
||||
today := startOfDay(now)
|
||||
switch {
|
||||
case strings.Contains(value, "大后天"):
|
||||
return today.AddDate(0, 0, 3), true
|
||||
case strings.Contains(value, "后天"):
|
||||
return today.AddDate(0, 0, 2), true
|
||||
case strings.Contains(value, "明天") || strings.Contains(value, "明日"):
|
||||
return today.AddDate(0, 0, 1), true
|
||||
case strings.Contains(value, "今天") || strings.Contains(value, "今日") || strings.Contains(value, "今晚") || strings.Contains(value, "今早") || strings.Contains(value, "今晨"):
|
||||
return today, true
|
||||
case strings.Contains(value, "昨天") || strings.Contains(value, "昨日"):
|
||||
return today.AddDate(0, 0, -1), true
|
||||
default:
|
||||
return time.Time{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func extractClock(value string) (int, int, bool, error) {
|
||||
hour := 0
|
||||
minute := 0
|
||||
hasClock := false
|
||||
|
||||
if matched := quickNoteClockHMRegex.FindStringSubmatch(value); len(matched) == 3 {
|
||||
h, errH := strconv.Atoi(matched[1])
|
||||
m, errM := strconv.Atoi(matched[2])
|
||||
if errH != nil || errM != nil {
|
||||
return 0, 0, true, fmt.Errorf("deadline_at 时间解析失败: %s", value)
|
||||
}
|
||||
hour = h
|
||||
minute = m
|
||||
hasClock = true
|
||||
} else if matched := quickNoteClockCNRegex.FindStringSubmatch(value); len(matched) >= 2 {
|
||||
h, errH := strconv.Atoi(matched[1])
|
||||
if errH != nil {
|
||||
return 0, 0, true, fmt.Errorf("deadline_at 时间解析失败: %s", value)
|
||||
}
|
||||
hour = h
|
||||
minute = 0
|
||||
hasClock = true
|
||||
if len(matched) >= 3 {
|
||||
if matched[2] == "半" {
|
||||
minute = 30
|
||||
} else if len(matched) >= 4 && strings.TrimSpace(matched[3]) != "" {
|
||||
m, errM := strconv.Atoi(strings.TrimSpace(matched[3]))
|
||||
if errM != nil {
|
||||
return 0, 0, true, fmt.Errorf("deadline_at 时间解析失败: %s", value)
|
||||
}
|
||||
minute = m
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasClock {
|
||||
return 0, 0, false, nil
|
||||
}
|
||||
|
||||
if isPMHint(value) && hour < 12 {
|
||||
hour += 12
|
||||
}
|
||||
if isNoonHint(value) && hour >= 1 && hour <= 10 {
|
||||
hour += 12
|
||||
}
|
||||
if strings.Contains(value, "凌晨") && hour == 12 {
|
||||
hour = 0
|
||||
}
|
||||
|
||||
if hour < 0 || hour > 23 || minute < 0 || minute > 59 {
|
||||
return 0, 0, true, fmt.Errorf("deadline_at 时间超出范围: %s", value)
|
||||
}
|
||||
return hour, minute, true, nil
|
||||
}
|
||||
|
||||
func defaultClockByHint(value string) (int, int) {
|
||||
switch {
|
||||
case strings.Contains(value, "凌晨"):
|
||||
return 1, 0
|
||||
case strings.Contains(value, "早上") || strings.Contains(value, "早晨") || strings.Contains(value, "上午") || strings.Contains(value, "今早") || strings.Contains(value, "明早"):
|
||||
return 9, 0
|
||||
case strings.Contains(value, "中午"):
|
||||
return 12, 0
|
||||
case strings.Contains(value, "下午"):
|
||||
return 15, 0
|
||||
case strings.Contains(value, "晚上") || strings.Contains(value, "今晚") || strings.Contains(value, "傍晚") || strings.Contains(value, "夜里"):
|
||||
return 20, 0
|
||||
default:
|
||||
return 23, 59
|
||||
}
|
||||
}
|
||||
|
||||
func isPMHint(value string) bool {
|
||||
return strings.Contains(value, "下午") || strings.Contains(value, "晚上") || strings.Contains(value, "今晚") || strings.Contains(value, "傍晚")
|
||||
}
|
||||
|
||||
func isNoonHint(value string) bool {
|
||||
return strings.Contains(value, "中午")
|
||||
}
|
||||
|
||||
func startOfDay(t time.Time) time.Time {
|
||||
loc := t.Location()
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
func isValidDate(year, month, day int) bool {
|
||||
if month < 1 || month > 12 || day < 1 || day > 31 {
|
||||
return false
|
||||
}
|
||||
candidate := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC)
|
||||
return candidate.Year() == year && int(candidate.Month()) == month && candidate.Day() == day
|
||||
}
|
||||
|
||||
func toWeekday(chinese string) (time.Weekday, bool) {
|
||||
switch chinese {
|
||||
case "一":
|
||||
return time.Monday, true
|
||||
case "二":
|
||||
return time.Tuesday, true
|
||||
case "三":
|
||||
return time.Wednesday, true
|
||||
case "四":
|
||||
return time.Thursday, true
|
||||
case "五":
|
||||
return time.Friday, true
|
||||
case "六":
|
||||
return time.Saturday, true
|
||||
case "日", "天":
|
||||
return time.Sunday, true
|
||||
default:
|
||||
return time.Sunday, false
|
||||
}
|
||||
}
|
||||
|
||||
func resolveWeekdayDate(now time.Time, prefix string, target time.Weekday) time.Time {
|
||||
today := startOfDay(now)
|
||||
weekdayOffset := (int(today.Weekday()) + 6) % 7
|
||||
weekStart := today.AddDate(0, 0, -weekdayOffset)
|
||||
targetOffset := (int(target) + 6) % 7
|
||||
candidateThisWeek := weekStart.AddDate(0, 0, targetOffset)
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(prefix, "下"):
|
||||
return candidateThisWeek.AddDate(0, 0, 7)
|
||||
case strings.HasPrefix(prefix, "本"), strings.HasPrefix(prefix, "这"):
|
||||
return candidateThisWeek
|
||||
default:
|
||||
if candidateThisWeek.Before(today) {
|
||||
return candidateThisWeek.AddDate(0, 0, 7)
|
||||
}
|
||||
return candidateThisWeek
|
||||
}
|
||||
}
|
||||
|
||||
func quickNoteLocation() *time.Location {
|
||||
loc, err := time.LoadLocation(agentmodel.QuickNoteTimezoneName)
|
||||
if err != nil {
|
||||
return time.Local
|
||||
}
|
||||
return loc
|
||||
}
|
||||
|
||||
func quickNoteNowToMinute() time.Time {
|
||||
return agentshared.NowToMinute()
|
||||
}
|
||||
|
||||
func formatQuickNoteTimeToMinute(t time.Time) string {
|
||||
return agentshared.FormatMinute(t.In(quickNoteLocation()))
|
||||
}
|
||||
2336
backend/agent/node/schedule_plan.go
Normal file
2336
backend/agent/node/schedule_plan.go
Normal file
File diff suppressed because it is too large
Load Diff
571
backend/agent/node/schedule_plan_tool.go
Normal file
571
backend/agent/node/schedule_plan_tool.go
Normal file
@@ -0,0 +1,571 @@
|
||||
package agentnode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
agentllm "github.com/LoveLosita/smartflow/backend/agent/llm"
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
)
|
||||
|
||||
// SchedulePlanToolDeps 描述“智能排程 graph”运行所需的外部业务依赖。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只负责声明“需要哪些能力”,不负责具体实现(实现由 service 层注入)。
|
||||
// 2. 只收口函数签名,不承载业务状态,避免跨请求共享可变数据。
|
||||
// 3. 当前统一采用 task_class_ids 语义,不再依赖单 task_class_id 主路径。
|
||||
type SchedulePlanToolDeps struct {
|
||||
// SmartPlanningMultiRaw 是可选依赖:
|
||||
// 1) 用于需要单独输出“粗排预览”时复用;
|
||||
// 2) 当前主链路已由 HybridScheduleWithPlanMulti 覆盖,可不注入。
|
||||
SmartPlanningMultiRaw func(ctx context.Context, userID int, taskClassIDs []int) ([]model.UserWeekSchedule, []model.TaskClassItem, error)
|
||||
|
||||
// HybridScheduleWithPlanMulti 把“既有日程 + 粗排结果”合并成统一的 HybridScheduleEntry 切片,
|
||||
// 供 daily/weekly ReAct 节点在内存中继续优化。
|
||||
HybridScheduleWithPlanMulti func(ctx context.Context, userID int, taskClassIDs []int) ([]model.HybridScheduleEntry, []model.TaskClassItem, error)
|
||||
|
||||
// ResolvePlanningWindow 根据 task_class_ids 解析“全局排程窗口”的相对周/天边界。
|
||||
//
|
||||
// 返回语义:
|
||||
// 1. startWeek/startDay:窗口起点(含);
|
||||
// 2. endWeek/endDay:窗口终点(含);
|
||||
// 3. error:解析失败(如任务类不存在、日期非法)。
|
||||
//
|
||||
// 用途:
|
||||
// 1. 给周级 Move 工具加硬边界,避免把任务移动到窗口外的天数;
|
||||
// 2. 解决“首尾不足一周”场景下的周内越界问题。
|
||||
ResolvePlanningWindow func(ctx context.Context, userID int, taskClassIDs []int) (startWeek, startDay, endWeek, endDay int, err error)
|
||||
}
|
||||
|
||||
// Validate 校验依赖完整性。
|
||||
//
|
||||
// 失败处理:
|
||||
// 1. 任意依赖缺失都直接返回错误,避免 graph 运行到中途才 panic。
|
||||
// 2. 调用方(runSchedulePlanFlow)收到错误后会走回退链路,不影响普通聊天可用性。
|
||||
func (d SchedulePlanToolDeps) Validate() error {
|
||||
if d.HybridScheduleWithPlanMulti == nil {
|
||||
return errors.New("schedule plan tool deps: HybridScheduleWithPlanMulti is nil")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtraInt 从 extra map 中安全提取整数值。
|
||||
//
|
||||
// 兼容策略:
|
||||
// 1) JSON 数字默认解析为 float64,做 int 转换;
|
||||
// 2) 兼容字符串形式(如 "42"),用 Atoi 解析;
|
||||
// 3) 其余类型返回 false,由调用方决定后续处理。
|
||||
func ExtraInt(extra map[string]any, key string) (int, bool) {
|
||||
v, ok := extra[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n), true
|
||||
case int:
|
||||
return n, true
|
||||
case string:
|
||||
i, err := strconv.Atoi(n)
|
||||
return i, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ExtraIntSlice 从 extra map 中安全提取整数切片。
|
||||
//
|
||||
// 兼容输入:
|
||||
// 1) []any(JSON 数组反序列化后的常见类型);
|
||||
// 2) []int;
|
||||
// 3) []float64;
|
||||
// 4) 逗号分隔字符串(例如 "1,2,3")。
|
||||
//
|
||||
// 返回语义:
|
||||
// 1) ok=true:至少成功解析出一个整数;
|
||||
// 2) ok=false:字段不存在或全部解析失败。
|
||||
func ExtraIntSlice(extra map[string]any, key string) ([]int, bool) {
|
||||
v, exists := extra[key]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
parseOne := func(raw any) (int, error) {
|
||||
switch n := raw.(type) {
|
||||
case int:
|
||||
return n, nil
|
||||
case float64:
|
||||
return int(n), nil
|
||||
case string:
|
||||
i, err := strconv.Atoi(n)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return i, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported type: %T", raw)
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]int, 0)
|
||||
switch arr := v.(type) {
|
||||
case []int:
|
||||
for _, item := range arr {
|
||||
out = append(out, item)
|
||||
}
|
||||
case []float64:
|
||||
for _, item := range arr {
|
||||
out = append(out, int(item))
|
||||
}
|
||||
case []any:
|
||||
for _, item := range arr {
|
||||
if parsed, err := parseOne(item); err == nil {
|
||||
out = append(out, parsed)
|
||||
}
|
||||
}
|
||||
case string:
|
||||
parts := strings.Split(arr, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if parsed, err := strconv.Atoi(part); err == nil {
|
||||
out = append(out, parsed)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if len(out) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
// ── ReAct Tool 调用/结果结构 ──
|
||||
|
||||
// reactToolCall 是 LLM 输出的单个工具调用。
|
||||
type reactToolCall = agentllm.ReactToolCall
|
||||
|
||||
// reactToolResult 是单个工具调用的执行结果。
|
||||
type reactToolResult struct {
|
||||
Tool string `json:"tool"`
|
||||
Success bool `json:"success"`
|
||||
Result string `json:"result"`
|
||||
}
|
||||
|
||||
// reactLLMOutput 是 LLM 输出的完整 JSON 结构。
|
||||
type reactLLMOutput = agentllm.ReactLLMOutput
|
||||
|
||||
// weeklyPlanningWindow 表示周级优化可用的全局周/天窗口。
|
||||
//
|
||||
// 语义:
|
||||
// 1. Enabled=false:不启用窗口硬边界,仅做基础合法性校验;
|
||||
// 2. Enabled=true:Move 必须落在 [StartWeek/StartDay, EndWeek/EndDay] 内;
|
||||
// 3. 该窗口用于处理“首尾不足一周”场景下的越界移动问题。
|
||||
type weeklyPlanningWindow struct {
|
||||
Enabled bool
|
||||
StartWeek int
|
||||
StartDay int
|
||||
EndWeek int
|
||||
EndDay int
|
||||
}
|
||||
|
||||
// ── 工具分发器 ──
|
||||
|
||||
// dispatchReactTool 根据工具名分发调用,返回(可能修改后的)entries 和执行结果。
|
||||
func dispatchReactTool(entries []model.HybridScheduleEntry, call reactToolCall) ([]model.HybridScheduleEntry, reactToolResult) {
|
||||
switch call.Tool {
|
||||
case "Swap":
|
||||
return reactToolSwap(entries, call.Params)
|
||||
case "Move":
|
||||
return reactToolMove(entries, call.Params)
|
||||
case "TimeAvailable":
|
||||
return entries, reactToolTimeAvailable(entries, call.Params)
|
||||
case "GetAvailableSlots":
|
||||
return entries, reactToolGetAvailableSlots(entries, call.Params)
|
||||
default:
|
||||
return entries, reactToolResult{Tool: call.Tool, Success: false, Result: fmt.Sprintf("未知工具: %s", call.Tool)}
|
||||
}
|
||||
}
|
||||
|
||||
// dispatchWeeklySingleActionTool 是“周级单步动作模式”的专用分发器。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 仅允许 Move / Swap 两个工具,禁止 TimeAvailable / GetAvailableSlots;
|
||||
// 2. 强制 Move 的目标周必须等于 currentWeek,避免并发周优化时发生跨周写穿;
|
||||
// 3. 统一返回工具执行结果,供上层决定预算扣减与下一轮上下文拼接。
|
||||
func dispatchWeeklySingleActionTool(entries []model.HybridScheduleEntry, call reactToolCall, currentWeek int, window weeklyPlanningWindow) ([]model.HybridScheduleEntry, reactToolResult) {
|
||||
tool := strings.TrimSpace(call.Tool)
|
||||
switch tool {
|
||||
case "Swap":
|
||||
return reactToolSwap(entries, call.Params)
|
||||
case "Move":
|
||||
// 1. 周级并发模式下,每个 worker 只负责单周数据。
|
||||
// 2. 为避免“一个 worker 改到别的周”导致并发写冲突,这里做硬约束。
|
||||
// 3. 失败时不抛异常,返回工具失败结果,让上层继续下一轮决策。
|
||||
toWeek, ok := paramInt(call.Params, "to_week")
|
||||
if !ok {
|
||||
return entries, reactToolResult{Tool: "Move", Success: false, Result: "参数缺失:需要 to_week"}
|
||||
}
|
||||
if toWeek != currentWeek {
|
||||
return entries, reactToolResult{
|
||||
Tool: "Move",
|
||||
Success: false,
|
||||
Result: fmt.Sprintf("当前仅允许优化本周:worker_week=%d,目标周=%d", currentWeek, toWeek),
|
||||
}
|
||||
}
|
||||
// 4. 若已配置全局窗口边界,再做“首尾不足一周”硬校验。
|
||||
// 4.1 这样可避免把任务移动到窗口外的天数(例如起始周的起始日前、结束周的结束日后)。
|
||||
// 4.2 窗口未启用时不阻断,保持兼容旧链路。
|
||||
if window.Enabled {
|
||||
toDay, ok := paramInt(call.Params, "to_day")
|
||||
if !ok {
|
||||
return entries, reactToolResult{Tool: "Move", Success: false, Result: "参数缺失:需要 to_day"}
|
||||
}
|
||||
allowed, dayFrom, dayTo := isDayWithinPlanningWindow(window, toWeek, toDay)
|
||||
if !allowed {
|
||||
return entries, reactToolResult{
|
||||
Tool: "Move",
|
||||
Success: false,
|
||||
Result: fmt.Sprintf("目标日期超出排程窗口:W%d 仅允许 D%d-D%d,当前目标为 D%d", toWeek, dayFrom, dayTo, toDay),
|
||||
}
|
||||
}
|
||||
}
|
||||
return reactToolMove(entries, call.Params)
|
||||
default:
|
||||
return entries, reactToolResult{
|
||||
Tool: tool,
|
||||
Success: false,
|
||||
Result: fmt.Sprintf("周级单步模式不支持工具: %s,仅允许 Move/Swap", tool),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isDayWithinPlanningWindow 判断目标 week/day 是否落在窗口范围内。
|
||||
//
|
||||
// 返回值:
|
||||
// 1. allowed:是否允许;
|
||||
// 2. dayFrom/dayTo:该周允许的 day 区间(用于错误提示)。
|
||||
func isDayWithinPlanningWindow(window weeklyPlanningWindow, week int, day int) (allowed bool, dayFrom int, dayTo int) {
|
||||
// 1. 窗口未启用时默认允许(调用方会跳过此分支,这里是兜底)。
|
||||
if !window.Enabled {
|
||||
return true, 1, 7
|
||||
}
|
||||
// 2. 先做周范围校验。
|
||||
if week < window.StartWeek || week > window.EndWeek {
|
||||
return false, 1, 7
|
||||
}
|
||||
// 3. 计算当前周允许的 day 边界。
|
||||
from := 1
|
||||
to := 7
|
||||
if week == window.StartWeek {
|
||||
from = window.StartDay
|
||||
}
|
||||
if week == window.EndWeek {
|
||||
to = window.EndDay
|
||||
}
|
||||
if day < from || day > to {
|
||||
return false, from, to
|
||||
}
|
||||
return true, from, to
|
||||
}
|
||||
|
||||
// ── 参数提取辅助 ──
|
||||
|
||||
func paramInt(params map[string]any, key string) (int, bool) {
|
||||
v, ok := params[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n), true
|
||||
case int:
|
||||
return n, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// findSuggestedByID 在 entries 中查找指定 TaskItemID 的 suggested 条目索引。
|
||||
func findSuggestedByID(entries []model.HybridScheduleEntry, taskItemID int) int {
|
||||
for i, e := range entries {
|
||||
if e.TaskItemID == taskItemID && e.Status == "suggested" {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// sectionsOverlap 判断两个节次区间是否有交集。
|
||||
func sectionsOverlap(aFrom, aTo, bFrom, bTo int) bool {
|
||||
return aFrom <= bTo && bFrom <= aTo
|
||||
}
|
||||
|
||||
// entryBlocksSuggested 判断某条目是否应阻塞 suggested 任务占位。
|
||||
//
|
||||
// 规则:
|
||||
// 1. suggested 任务永远阻塞(任务之间不能重叠);
|
||||
// 2. existing 条目按 BlockForSuggested 字段决定;
|
||||
// 3. 其余场景默认阻塞(保守策略,避免放出脏可用槽)。
|
||||
func entryBlocksSuggested(entry model.HybridScheduleEntry) bool {
|
||||
if entry.Status == "suggested" {
|
||||
return true
|
||||
}
|
||||
// existing 走显式字段语义。
|
||||
if entry.Status == "existing" {
|
||||
return entry.BlockForSuggested
|
||||
}
|
||||
// 未知状态兜底:按阻塞处理。
|
||||
return true
|
||||
}
|
||||
|
||||
// hasConflict 检查目标时间段是否与 entries 中任何条目冲突(排除 excludeIdx)。
|
||||
func hasConflict(entries []model.HybridScheduleEntry, week, day, sf, st, excludeIdx int) (bool, string) {
|
||||
for i, e := range entries {
|
||||
if i == excludeIdx {
|
||||
continue
|
||||
}
|
||||
// 1. 可嵌入且未占用的课程槽(BlockForSuggested=false)不参与冲突判断。
|
||||
// 2. 这样可以避免把“水课可嵌入位”误判为硬冲突。
|
||||
if !entryBlocksSuggested(e) {
|
||||
continue
|
||||
}
|
||||
if e.Week == week && e.DayOfWeek == day && sectionsOverlap(e.SectionFrom, e.SectionTo, sf, st) {
|
||||
return true, fmt.Sprintf("%s(%s)", e.Name, e.Type)
|
||||
}
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════
|
||||
// Tool 1: Swap — 交换两个 suggested 任务的时间
|
||||
// ══════════════════════════════════════════════════════════════
|
||||
|
||||
func reactToolSwap(entries []model.HybridScheduleEntry, params map[string]any) ([]model.HybridScheduleEntry, reactToolResult) {
|
||||
idA, okA := paramInt(params, "task_a")
|
||||
idB, okB := paramInt(params, "task_b")
|
||||
if !okA || !okB {
|
||||
return entries, reactToolResult{Tool: "Swap", Success: false, Result: "参数缺失:需要 task_a 和 task_b(task_item_id)"}
|
||||
}
|
||||
if idA == idB {
|
||||
return entries, reactToolResult{Tool: "Swap", Success: false, Result: "task_a 和 task_b 不能相同"}
|
||||
}
|
||||
|
||||
idxA := findSuggestedByID(entries, idA)
|
||||
idxB := findSuggestedByID(entries, idB)
|
||||
if idxA == -1 {
|
||||
return entries, reactToolResult{Tool: "Swap", Success: false, Result: fmt.Sprintf("找不到 task_item_id=%d 的 suggested 任务", idA)}
|
||||
}
|
||||
if idxB == -1 {
|
||||
return entries, reactToolResult{Tool: "Swap", Success: false, Result: fmt.Sprintf("找不到 task_item_id=%d 的 suggested 任务", idB)}
|
||||
}
|
||||
|
||||
// 交换时间坐标
|
||||
a, b := &entries[idxA], &entries[idxB]
|
||||
a.Week, b.Week = b.Week, a.Week
|
||||
a.DayOfWeek, b.DayOfWeek = b.DayOfWeek, a.DayOfWeek
|
||||
a.SectionFrom, b.SectionFrom = b.SectionFrom, a.SectionFrom
|
||||
a.SectionTo, b.SectionTo = b.SectionTo, a.SectionTo
|
||||
|
||||
return entries, reactToolResult{
|
||||
Tool: "Swap", Success: true,
|
||||
Result: fmt.Sprintf("已交换 [%s](id=%d) 和 [%s](id=%d) 的时间", a.Name, idA, b.Name, idB),
|
||||
}
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════
|
||||
// Tool 2: Move — 将一个 suggested 任务移动到新时间
|
||||
// ══════════════════════════════════════════════════════════════
|
||||
|
||||
func reactToolMove(entries []model.HybridScheduleEntry, params map[string]any) ([]model.HybridScheduleEntry, reactToolResult) {
|
||||
taskID, ok := paramInt(params, "task_item_id")
|
||||
if !ok {
|
||||
return entries, reactToolResult{Tool: "Move", Success: false, Result: "参数缺失:需要 task_item_id"}
|
||||
}
|
||||
toWeek, ok1 := paramInt(params, "to_week")
|
||||
toDay, ok2 := paramInt(params, "to_day")
|
||||
toSF, ok3 := paramInt(params, "to_section_from")
|
||||
toST, ok4 := paramInt(params, "to_section_to")
|
||||
if !ok1 || !ok2 || !ok3 || !ok4 {
|
||||
return entries, reactToolResult{Tool: "Move", Success: false, Result: "参数缺失:需要 to_week, to_day, to_section_from, to_section_to"}
|
||||
}
|
||||
|
||||
// 基础校验
|
||||
if toDay < 1 || toDay > 7 {
|
||||
return entries, reactToolResult{Tool: "Move", Success: false, Result: fmt.Sprintf("day_of_week=%d 不合法,应为 1-7", toDay)}
|
||||
}
|
||||
if toSF < 1 || toST > 12 || toSF > toST {
|
||||
return entries, reactToolResult{Tool: "Move", Success: false, Result: fmt.Sprintf("节次范围 %d-%d 不合法,应为 1-12 且 from<=to", toSF, toST)}
|
||||
}
|
||||
|
||||
idx := findSuggestedByID(entries, taskID)
|
||||
if idx == -1 {
|
||||
return entries, reactToolResult{Tool: "Move", Success: false, Result: fmt.Sprintf("找不到 task_item_id=%d 的 suggested 任务", taskID)}
|
||||
}
|
||||
|
||||
// 节次跨度必须一致
|
||||
origSpan := entries[idx].SectionTo - entries[idx].SectionFrom
|
||||
newSpan := toST - toSF
|
||||
if origSpan != newSpan {
|
||||
return entries, reactToolResult{Tool: "Move", Success: false,
|
||||
Result: fmt.Sprintf("节次跨度不一致:原任务占 %d 节,目标占 %d 节", origSpan+1, newSpan+1)}
|
||||
}
|
||||
|
||||
// 冲突检测(排除自身)
|
||||
if conflict, name := hasConflict(entries, toWeek, toDay, toSF, toST, idx); conflict {
|
||||
return entries, reactToolResult{Tool: "Move", Success: false,
|
||||
Result: fmt.Sprintf("目标时间 W%dD%d 第%d-%d节 已被 %s 占用", toWeek, toDay, toSF, toST, name)}
|
||||
}
|
||||
|
||||
// 执行移动
|
||||
e := &entries[idx]
|
||||
oldDesc := fmt.Sprintf("W%dD%d 第%d-%d节", e.Week, e.DayOfWeek, e.SectionFrom, e.SectionTo)
|
||||
e.Week, e.DayOfWeek, e.SectionFrom, e.SectionTo = toWeek, toDay, toSF, toST
|
||||
newDesc := fmt.Sprintf("W%dD%d 第%d-%d节", toWeek, toDay, toSF, toST)
|
||||
|
||||
return entries, reactToolResult{
|
||||
Tool: "Move", Success: true,
|
||||
Result: fmt.Sprintf("已将 [%s](id=%d) 从 %s 移动到 %s", e.Name, taskID, oldDesc, newDesc),
|
||||
}
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════
|
||||
// Tool 3: TimeAvailable — 检查目标时间段是否可用
|
||||
// ══════════════════════════════════════════════════════════════
|
||||
|
||||
func reactToolTimeAvailable(entries []model.HybridScheduleEntry, params map[string]any) reactToolResult {
|
||||
week, ok1 := paramInt(params, "week")
|
||||
day, ok2 := paramInt(params, "day_of_week")
|
||||
sf, ok3 := paramInt(params, "section_from")
|
||||
st, ok4 := paramInt(params, "section_to")
|
||||
if !ok1 || !ok2 || !ok3 || !ok4 {
|
||||
return reactToolResult{Tool: "TimeAvailable", Success: false, Result: "参数缺失:需要 week, day_of_week, section_from, section_to"}
|
||||
}
|
||||
|
||||
if conflict, name := hasConflict(entries, week, day, sf, st, -1); conflict {
|
||||
return reactToolResult{Tool: "TimeAvailable", Success: true,
|
||||
Result: fmt.Sprintf(`{"available":false,"conflict_with":"%s"}`, name)}
|
||||
}
|
||||
return reactToolResult{Tool: "TimeAvailable", Success: true, Result: `{"available":true}`}
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════
|
||||
// Tool 4: GetAvailableSlots — 返回可用时间段列表
|
||||
// ══════════════════════════════════════════════════════════════
|
||||
|
||||
func reactToolGetAvailableSlots(entries []model.HybridScheduleEntry, params map[string]any) reactToolResult {
|
||||
filterWeek, _ := paramInt(params, "week") // 0 表示不过滤
|
||||
|
||||
// 1. 收集所有周次范围
|
||||
minW, maxW := 999, 0
|
||||
for _, e := range entries {
|
||||
if e.Week < minW {
|
||||
minW = e.Week
|
||||
}
|
||||
if e.Week > maxW {
|
||||
maxW = e.Week
|
||||
}
|
||||
}
|
||||
if minW > maxW {
|
||||
return reactToolResult{Tool: "GetAvailableSlots", Success: true, Result: "[]"}
|
||||
}
|
||||
|
||||
// 2. 构建占用集合
|
||||
type slotKey struct{ W, D, S int }
|
||||
occupied := make(map[slotKey]bool)
|
||||
for _, e := range entries {
|
||||
if !entryBlocksSuggested(e) {
|
||||
continue
|
||||
}
|
||||
for s := e.SectionFrom; s <= e.SectionTo; s++ {
|
||||
occupied[slotKey{e.Week, e.DayOfWeek, s}] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 遍历所有时间格,找出空闲并合并连续节次
|
||||
type availSlot struct {
|
||||
Week, Day, From, To int
|
||||
}
|
||||
var slots []availSlot
|
||||
|
||||
startW, endW := minW, maxW
|
||||
if filterWeek > 0 {
|
||||
startW, endW = filterWeek, filterWeek
|
||||
}
|
||||
|
||||
for w := startW; w <= endW; w++ {
|
||||
for d := 1; d <= 7; d++ {
|
||||
runStart := 0
|
||||
for s := 1; s <= 12; s++ {
|
||||
if !occupied[slotKey{w, d, s}] {
|
||||
if runStart == 0 {
|
||||
runStart = s
|
||||
}
|
||||
} else {
|
||||
if runStart > 0 {
|
||||
slots = append(slots, availSlot{w, d, runStart, s - 1})
|
||||
runStart = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
if runStart > 0 {
|
||||
slots = append(slots, availSlot{w, d, runStart, 12})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 按自然顺序排序(已经是了,但确保)
|
||||
sort.Slice(slots, func(i, j int) bool {
|
||||
if slots[i].Week != slots[j].Week {
|
||||
return slots[i].Week < slots[j].Week
|
||||
}
|
||||
if slots[i].Day != slots[j].Day {
|
||||
return slots[i].Day < slots[j].Day
|
||||
}
|
||||
return slots[i].From < slots[j].From
|
||||
})
|
||||
|
||||
// 5. 序列化
|
||||
type slotJSON struct {
|
||||
Week int `json:"week"`
|
||||
DayOfWeek int `json:"day_of_week"`
|
||||
SectionFrom int `json:"section_from"`
|
||||
SectionTo int `json:"section_to"`
|
||||
}
|
||||
out := make([]slotJSON, 0, len(slots))
|
||||
for _, s := range slots {
|
||||
out = append(out, slotJSON{s.Week, s.Day, s.From, s.To})
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(out)
|
||||
return reactToolResult{Tool: "GetAvailableSlots", Success: true, Result: string(data)}
|
||||
}
|
||||
|
||||
// ── 辅助:解析 LLM 输出 ──
|
||||
|
||||
// parseReactLLMOutput 解析 LLM 的 JSON 输出。
|
||||
// 兼容 ```json ... ``` 包裹。
|
||||
func parseReactLLMOutput(raw string) (*reactLLMOutput, error) {
|
||||
return agentllm.ParseScheduleReactOutput(raw)
|
||||
}
|
||||
|
||||
// truncate 截断字符串到指定长度。
|
||||
func truncate(s string, maxLen int) string {
|
||||
if maxLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
runes := []rune(s)
|
||||
if len(runes) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return string(runes[:maxLen]) + "..."
|
||||
}
|
||||
3526
backend/agent/node/schedule_refine.go
Normal file
3526
backend/agent/node/schedule_refine.go
Normal file
File diff suppressed because it is too large
Load Diff
2027
backend/agent/node/schedule_refine_tool.go
Normal file
2027
backend/agent/node/schedule_refine_tool.go
Normal file
File diff suppressed because it is too large
Load Diff
729
backend/agent/node/taskquery.go
Normal file
729
backend/agent/node/taskquery.go
Normal file
@@ -0,0 +1,729 @@
|
||||
package agentnode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
agentllm "github.com/LoveLosita/smartflow/backend/agent/llm"
|
||||
agentmodel "github.com/LoveLosita/smartflow/backend/agent/model"
|
||||
agentprompt "github.com/LoveLosita/smartflow/backend/agent/prompt"
|
||||
agentstream "github.com/LoveLosita/smartflow/backend/agent/stream"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
const (
|
||||
TaskQueryGraphNodePlan = "task_query.plan"
|
||||
TaskQueryGraphNodeQuadrant = "task_query.quadrant"
|
||||
TaskQueryGraphNodeTimeAnchor = "task_query.time_anchor"
|
||||
TaskQueryGraphNodeQuery = "task_query.query"
|
||||
TaskQueryGraphNodeReflect = "task_query.reflect"
|
||||
)
|
||||
|
||||
var (
|
||||
explicitLimitPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)\btop\s*(\d{1,2})\b`),
|
||||
regexp.MustCompile(`前\s*(\d{1,2})\s*(个|条|项)?`),
|
||||
regexp.MustCompile(`(\d{1,2})\s*(个|条|项)?\s*任务`),
|
||||
regexp.MustCompile(`给我\s*(\d{1,2})\s*(个|条|项)?`),
|
||||
}
|
||||
chineseDigitMap = map[rune]int{
|
||||
'一': 1,
|
||||
'二': 2,
|
||||
'两': 2,
|
||||
'三': 3,
|
||||
'四': 4,
|
||||
'五': 5,
|
||||
'六': 6,
|
||||
'七': 7,
|
||||
'八': 8,
|
||||
'九': 9,
|
||||
'十': 10,
|
||||
}
|
||||
)
|
||||
|
||||
// TaskQueryGraphRunInput 描述一次任务查询图运行需要的依赖。
|
||||
type TaskQueryGraphRunInput struct {
|
||||
Model *ark.ChatModel
|
||||
State *agentmodel.TaskQueryState
|
||||
Deps TaskQueryToolDeps
|
||||
EmitStage func(stage, detail string)
|
||||
}
|
||||
|
||||
// TaskQueryNodes 是任务查询图的节点容器。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责承接请求级依赖,并向 graph 暴露可直接挂载的方法。
|
||||
// 2. 不负责 graph 编译、service 接线和持久化。
|
||||
type TaskQueryNodes struct {
|
||||
input TaskQueryGraphRunInput
|
||||
queryTool tool.InvokableTool
|
||||
emitStage agentstream.StageEmitter
|
||||
}
|
||||
|
||||
func NewTaskQueryNodes(input TaskQueryGraphRunInput, queryTool tool.InvokableTool) (*TaskQueryNodes, error) {
|
||||
if input.Model == nil {
|
||||
return nil, fmt.Errorf("task query nodes: model is nil")
|
||||
}
|
||||
if input.State == nil {
|
||||
return nil, fmt.Errorf("task query nodes: state is nil")
|
||||
}
|
||||
if err := input.Deps.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if queryTool == nil {
|
||||
return nil, fmt.Errorf("task query nodes: queryTool is nil")
|
||||
}
|
||||
return &TaskQueryNodes{
|
||||
input: input,
|
||||
queryTool: queryTool,
|
||||
emitStage: agentstream.WrapStageEmitter(input.EmitStage),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Plan 负责把用户原话规划成结构化查询计划。
|
||||
func (n *TaskQueryNodes) Plan(ctx context.Context, st *agentmodel.TaskQueryState) (*agentmodel.TaskQueryState, error) {
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("task query graph: nil state in plan node")
|
||||
}
|
||||
|
||||
n.emitStage("task_query.plan.generating", "正在一次性规划查询范围、排序和时间条件。")
|
||||
planned, err := agentllm.PlanTaskQuery(ctx, n.input.Model, st.RequestNowText, st.UserMessage)
|
||||
if err != nil || planned == nil {
|
||||
st.UserGoal = "查询任务"
|
||||
st.Plan = defaultTaskQueryPlan()
|
||||
return st, nil
|
||||
}
|
||||
|
||||
st.UserGoal = strings.TrimSpace(planned.UserGoal)
|
||||
if st.UserGoal == "" {
|
||||
st.UserGoal = "查询任务"
|
||||
}
|
||||
st.Plan = normalizeTaskQueryPlan(*planned)
|
||||
|
||||
// 1. 若用户原话里明确指定了返回条数,则以后端识别结果为准。
|
||||
// 2. 这样可以避免规划模型漏掉数量要求,或后续反思 patch 意外改写 limit。
|
||||
if explicitLimit, found := extractExplicitLimitFromUser(st.UserMessage); found {
|
||||
st.ExplicitLimit = explicitLimit
|
||||
st.Plan.Limit = explicitLimit
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// NormalizeQuadrant 负责把象限参数去重并统一成稳定顺序。
|
||||
func (n *TaskQueryNodes) NormalizeQuadrant(ctx context.Context, st *agentmodel.TaskQueryState) (*agentmodel.TaskQueryState, error) {
|
||||
_ = ctx
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("task query graph: nil state in quadrant node")
|
||||
}
|
||||
|
||||
n.emitStage("task_query.quadrant.routing", "正在归一化象限筛选范围。")
|
||||
st.Plan.Quadrants = normalizeQuadrants(st.Plan.Quadrants)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// AnchorTime 负责把时间文本边界解析成可执行时间对象。
|
||||
func (n *TaskQueryNodes) AnchorTime(ctx context.Context, st *agentmodel.TaskQueryState) (*agentmodel.TaskQueryState, error) {
|
||||
_ = ctx
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("task query graph: nil state in time anchor node")
|
||||
}
|
||||
|
||||
n.emitStage("task_query.time.anchoring", "正在锁定时间过滤边界。")
|
||||
applyTimeAnchorOnPlan(&st.Plan)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// Query 负责真正调用工具查询任务。
|
||||
func (n *TaskQueryNodes) Query(ctx context.Context, st *agentmodel.TaskQueryState) (*agentmodel.TaskQueryState, error) {
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("task query graph: nil state in query node")
|
||||
}
|
||||
|
||||
n.emitStage("task_query.tool.querying", "正在查询任务数据。")
|
||||
items, err := n.executePlanByTool(ctx, st.Plan)
|
||||
if err != nil {
|
||||
st.LastQueryItems = make([]agentmodel.TaskQueryItem, 0)
|
||||
st.LastQueryTotal = 0
|
||||
st.ReflectReason = "查询工具执行失败"
|
||||
return st, nil
|
||||
}
|
||||
|
||||
st.LastQueryItems = items
|
||||
st.LastQueryTotal = len(items)
|
||||
|
||||
// 1. 如果首轮为空且还没自动放宽过,则做一次可解释的自动放宽。
|
||||
// 2. 放宽范围仅限关键词、完成状态、时间边界,不主动改象限与 limit,避免语义漂移。
|
||||
if st.LastQueryTotal == 0 && !st.AutoBroadenApplied {
|
||||
broadenedPlan, changed := autoBroadenPlan(st.Plan)
|
||||
if changed {
|
||||
st.AutoBroadenApplied = true
|
||||
st.Plan = broadenedPlan
|
||||
n.emitStage("task_query.tool.broadened", "首次查询为空,已自动放宽条件再试一次。")
|
||||
retryItems, retryErr := n.executePlanByTool(ctx, st.Plan)
|
||||
if retryErr == nil {
|
||||
st.LastQueryItems = retryItems
|
||||
st.LastQueryTotal = len(retryItems)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// Reflect 负责判断当前结果是否满足用户诉求,并决定是否重试。
|
||||
func (n *TaskQueryNodes) Reflect(ctx context.Context, st *agentmodel.TaskQueryState) (*agentmodel.TaskQueryState, error) {
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("task query graph: nil state in reflect node")
|
||||
}
|
||||
|
||||
n.emitStage("task_query.reflecting", "正在判断结果是否贴合你的需求。")
|
||||
reflectPrompt := agentprompt.BuildTaskQueryReflectUserPrompt(
|
||||
st.RequestNowText,
|
||||
st.UserMessage,
|
||||
st.UserGoal,
|
||||
summarizeTaskQueryPlan(st.Plan),
|
||||
st.RetryCount,
|
||||
st.MaxReflectRetry,
|
||||
summarizeTaskQueryItems(st.LastQueryItems, 6),
|
||||
)
|
||||
reflectResult, err := agentllm.ReflectTaskQuery(ctx, n.input.Model, reflectPrompt)
|
||||
if err != nil || reflectResult == nil {
|
||||
st.NeedRetry = false
|
||||
st.FinalReply = buildTaskQueryFallbackReply(st.LastQueryItems)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
st.ReflectReason = strings.TrimSpace(reflectResult.Reason)
|
||||
|
||||
if reflectResult.Satisfied {
|
||||
st.NeedRetry = false
|
||||
st.FinalReply = buildTaskQueryFinalReply(st.LastQueryItems, st.Plan, strings.TrimSpace(reflectResult.Reply))
|
||||
return st, nil
|
||||
}
|
||||
|
||||
if reflectResult.NeedRetry && st.RetryCount < st.MaxReflectRetry {
|
||||
st.Plan = applyRetryPatch(st.Plan, reflectResult.RetryPatch, st.ExplicitLimit)
|
||||
st.RetryCount++
|
||||
st.NeedRetry = true
|
||||
if reply := strings.TrimSpace(reflectResult.Reply); reply != "" {
|
||||
st.FinalReply = reply
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
st.NeedRetry = false
|
||||
st.FinalReply = buildTaskQueryFinalReply(st.LastQueryItems, st.Plan, strings.TrimSpace(reflectResult.Reply))
|
||||
return st, nil
|
||||
}
|
||||
|
||||
func (n *TaskQueryNodes) NextAfterReflect(ctx context.Context, st *agentmodel.TaskQueryState) (string, error) {
|
||||
_ = ctx
|
||||
if st != nil && st.NeedRetry {
|
||||
return TaskQueryGraphNodeQuery, nil
|
||||
}
|
||||
return compose.END, nil
|
||||
}
|
||||
|
||||
func (n *TaskQueryNodes) executePlanByTool(ctx context.Context, plan agentmodel.TaskQueryPlan) ([]agentmodel.TaskQueryItem, error) {
|
||||
if n.queryTool == nil {
|
||||
return nil, fmt.Errorf("task query tool is nil")
|
||||
}
|
||||
|
||||
merged := make([]agentmodel.TaskQueryItem, 0, plan.Limit)
|
||||
seen := make(map[int]struct{}, plan.Limit*2)
|
||||
|
||||
runOne := func(quadrant *int) error {
|
||||
input := TaskQueryToolInput{
|
||||
Quadrant: quadrant,
|
||||
SortBy: plan.SortBy,
|
||||
Order: plan.Order,
|
||||
Limit: plan.Limit,
|
||||
Keyword: plan.Keyword,
|
||||
DeadlineBefore: plan.DeadlineBeforeText,
|
||||
DeadlineAfter: plan.DeadlineAfterText,
|
||||
}
|
||||
includeCompleted := plan.IncludeCompleted
|
||||
input.IncludeCompleted = &includeCompleted
|
||||
|
||||
rawInput, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rawOutput, err := n.queryTool.InvokableRun(ctx, string(rawInput))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
parsed, err := agentllm.ParseJSONObject[TaskQueryToolOutput](rawOutput)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, item := range parsed.Items {
|
||||
if _, exists := seen[item.ID]; exists {
|
||||
continue
|
||||
}
|
||||
seen[item.ID] = struct{}{}
|
||||
merged = append(merged, item)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(plan.Quadrants) == 0 {
|
||||
if err := runOne(nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
for _, quadrant := range plan.Quadrants {
|
||||
q := quadrant
|
||||
if err := runOne(&q); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sortTaskQueryItems(merged, plan)
|
||||
if len(merged) > plan.Limit {
|
||||
merged = merged[:plan.Limit]
|
||||
}
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
func defaultTaskQueryPlan() agentmodel.TaskQueryPlan {
|
||||
return agentmodel.TaskQueryPlan{
|
||||
SortBy: "deadline",
|
||||
Order: "asc",
|
||||
Limit: agentmodel.DefaultTaskQueryLimit,
|
||||
IncludeCompleted: false,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeTaskQueryPlan(raw agentllm.TaskQueryPlanOutput) agentmodel.TaskQueryPlan {
|
||||
plan := defaultTaskQueryPlan()
|
||||
plan.Quadrants = normalizeQuadrants(raw.Quadrants)
|
||||
|
||||
if sortBy := strings.ToLower(strings.TrimSpace(raw.SortBy)); sortBy == "deadline" || sortBy == "priority" || sortBy == "id" {
|
||||
plan.SortBy = sortBy
|
||||
}
|
||||
if order := strings.ToLower(strings.TrimSpace(raw.Order)); order == "asc" || order == "desc" {
|
||||
plan.Order = order
|
||||
}
|
||||
if raw.Limit > 0 {
|
||||
plan.Limit = raw.Limit
|
||||
}
|
||||
if plan.Limit > agentmodel.MaxTaskQueryLimit {
|
||||
plan.Limit = agentmodel.MaxTaskQueryLimit
|
||||
}
|
||||
if plan.Limit <= 0 {
|
||||
plan.Limit = agentmodel.DefaultTaskQueryLimit
|
||||
}
|
||||
if raw.IncludeCompleted != nil {
|
||||
plan.IncludeCompleted = *raw.IncludeCompleted
|
||||
}
|
||||
plan.Keyword = strings.TrimSpace(raw.Keyword)
|
||||
plan.DeadlineBeforeText = strings.TrimSpace(raw.DeadlineBefore)
|
||||
plan.DeadlineAfterText = strings.TrimSpace(raw.DeadlineAfter)
|
||||
applyTimeAnchorOnPlan(&plan)
|
||||
return plan
|
||||
}
|
||||
|
||||
func normalizeQuadrants(quadrants []int) []int {
|
||||
if len(quadrants) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
seen := make(map[int]struct{}, len(quadrants))
|
||||
result := make([]int, 0, len(quadrants))
|
||||
for _, quadrant := range quadrants {
|
||||
if quadrant < 1 || quadrant > 4 {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[quadrant]; exists {
|
||||
continue
|
||||
}
|
||||
seen[quadrant] = struct{}{}
|
||||
result = append(result, quadrant)
|
||||
}
|
||||
|
||||
sort.Ints(result)
|
||||
if len(result) == 0 || len(result) == 4 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func applyTimeAnchorOnPlan(plan *agentmodel.TaskQueryPlan) {
|
||||
if plan == nil {
|
||||
return
|
||||
}
|
||||
|
||||
before, errBefore := parseTaskQueryBoundaryTime(plan.DeadlineBeforeText, true)
|
||||
after, errAfter := parseTaskQueryBoundaryTime(plan.DeadlineAfterText, false)
|
||||
|
||||
if errBefore != nil {
|
||||
plan.DeadlineBefore = nil
|
||||
plan.DeadlineBeforeText = ""
|
||||
} else {
|
||||
plan.DeadlineBefore = before
|
||||
}
|
||||
if errAfter != nil {
|
||||
plan.DeadlineAfter = nil
|
||||
plan.DeadlineAfterText = ""
|
||||
} else {
|
||||
plan.DeadlineAfter = after
|
||||
}
|
||||
|
||||
if plan.DeadlineBefore != nil && plan.DeadlineAfter != nil && plan.DeadlineAfter.After(*plan.DeadlineBefore) {
|
||||
plan.DeadlineBefore = nil
|
||||
plan.DeadlineAfter = nil
|
||||
plan.DeadlineBeforeText = ""
|
||||
plan.DeadlineAfterText = ""
|
||||
}
|
||||
}
|
||||
|
||||
func autoBroadenPlan(plan agentmodel.TaskQueryPlan) (agentmodel.TaskQueryPlan, bool) {
|
||||
broadened := plan
|
||||
changed := false
|
||||
|
||||
if strings.TrimSpace(broadened.Keyword) != "" {
|
||||
broadened.Keyword = ""
|
||||
changed = true
|
||||
}
|
||||
if !broadened.IncludeCompleted {
|
||||
broadened.IncludeCompleted = true
|
||||
changed = true
|
||||
}
|
||||
if broadened.DeadlineBefore != nil || broadened.DeadlineAfter != nil || broadened.DeadlineBeforeText != "" || broadened.DeadlineAfterText != "" {
|
||||
broadened.DeadlineBefore = nil
|
||||
broadened.DeadlineAfter = nil
|
||||
broadened.DeadlineBeforeText = ""
|
||||
broadened.DeadlineAfterText = ""
|
||||
changed = true
|
||||
}
|
||||
return broadened, changed
|
||||
}
|
||||
|
||||
func applyRetryPatch(plan agentmodel.TaskQueryPlan, patch agentllm.TaskQueryRetryPatch, explicitLimit int) agentmodel.TaskQueryPlan {
|
||||
next := plan
|
||||
changed := false
|
||||
|
||||
if patch.Quadrants != nil {
|
||||
next.Quadrants = normalizeQuadrants(*patch.Quadrants)
|
||||
changed = true
|
||||
}
|
||||
if patch.SortBy != nil {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(*patch.SortBy))
|
||||
if sortBy == "deadline" || sortBy == "priority" || sortBy == "id" {
|
||||
next.SortBy = sortBy
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if patch.Order != nil {
|
||||
order := strings.ToLower(strings.TrimSpace(*patch.Order))
|
||||
if order == "asc" || order == "desc" {
|
||||
next.Order = order
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if patch.Limit != nil && explicitLimit <= 0 {
|
||||
limit := *patch.Limit
|
||||
if limit <= 0 {
|
||||
limit = agentmodel.DefaultTaskQueryLimit
|
||||
}
|
||||
if limit > agentmodel.MaxTaskQueryLimit {
|
||||
limit = agentmodel.MaxTaskQueryLimit
|
||||
}
|
||||
next.Limit = limit
|
||||
changed = true
|
||||
}
|
||||
if patch.IncludeCompleted != nil {
|
||||
next.IncludeCompleted = *patch.IncludeCompleted
|
||||
changed = true
|
||||
}
|
||||
if patch.Keyword != nil {
|
||||
next.Keyword = strings.TrimSpace(*patch.Keyword)
|
||||
changed = true
|
||||
}
|
||||
if patch.DeadlineBefore != nil {
|
||||
next.DeadlineBeforeText = strings.TrimSpace(*patch.DeadlineBefore)
|
||||
changed = true
|
||||
}
|
||||
if patch.DeadlineAfter != nil {
|
||||
next.DeadlineAfterText = strings.TrimSpace(*patch.DeadlineAfter)
|
||||
changed = true
|
||||
}
|
||||
|
||||
if changed {
|
||||
applyTimeAnchorOnPlan(&next)
|
||||
}
|
||||
if explicitLimit > 0 {
|
||||
next.Limit = explicitLimit
|
||||
}
|
||||
return next
|
||||
}
|
||||
|
||||
func summarizeTaskQueryPlan(plan agentmodel.TaskQueryPlan) string {
|
||||
quadrants := "全部象限"
|
||||
if len(plan.Quadrants) > 0 {
|
||||
parts := make([]string, 0, len(plan.Quadrants))
|
||||
for _, quadrant := range plan.Quadrants {
|
||||
parts = append(parts, strconv.Itoa(quadrant))
|
||||
}
|
||||
quadrants = strings.Join(parts, ",")
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"quadrants=%s sort=%s/%s limit=%d include_completed=%t keyword=%s before=%s after=%s",
|
||||
quadrants,
|
||||
plan.SortBy,
|
||||
plan.Order,
|
||||
plan.Limit,
|
||||
plan.IncludeCompleted,
|
||||
emptyToDash(plan.Keyword),
|
||||
emptyToDash(plan.DeadlineBeforeText),
|
||||
emptyToDash(plan.DeadlineAfterText),
|
||||
)
|
||||
}
|
||||
|
||||
func summarizeTaskQueryItems(items []agentmodel.TaskQueryItem, max int) string {
|
||||
if len(items) == 0 {
|
||||
return "无结果"
|
||||
}
|
||||
if max <= 0 {
|
||||
max = 5
|
||||
}
|
||||
if len(items) > max {
|
||||
items = items[:max]
|
||||
}
|
||||
|
||||
lines := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
lines = append(lines, fmt.Sprintf(
|
||||
"- #%d %s | 象限=%d | 完成=%t | 截止=%s",
|
||||
item.ID,
|
||||
item.Title,
|
||||
item.PriorityGroup,
|
||||
item.IsCompleted,
|
||||
emptyToDash(item.DeadlineAt),
|
||||
))
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func buildTaskQueryFallbackReply(items []agentmodel.TaskQueryItem) string {
|
||||
if len(items) == 0 {
|
||||
return "我这边暂时没找到匹配的任务。你可以再补一句,比如“按截止时间最早的前 3 个”或“只看简单不重要”。"
|
||||
}
|
||||
|
||||
preview := items
|
||||
if len(preview) > 3 {
|
||||
preview = preview[:3]
|
||||
}
|
||||
lines := make([]string, 0, len(preview))
|
||||
for _, item := range preview {
|
||||
lines = append(lines, fmt.Sprintf("%s(%s)", item.Title, item.PriorityLabel))
|
||||
}
|
||||
return fmt.Sprintf("我先给你筛到这些:%s。要不要我再按“更紧急”或“更简单”继续细化?", strings.Join(lines, "、"))
|
||||
}
|
||||
|
||||
func buildTaskQueryFinalReply(items []agentmodel.TaskQueryItem, plan agentmodel.TaskQueryPlan, llmReply string) string {
|
||||
if len(items) == 0 {
|
||||
base := buildTaskQueryFallbackReply(items)
|
||||
if strings.TrimSpace(llmReply) == "" {
|
||||
return base
|
||||
}
|
||||
return strings.TrimSpace(llmReply) + "\n" + base
|
||||
}
|
||||
|
||||
desired := plan.Limit
|
||||
if desired <= 0 {
|
||||
desired = agentmodel.DefaultTaskQueryLimit
|
||||
}
|
||||
if desired > agentmodel.MaxTaskQueryLimit {
|
||||
desired = agentmodel.MaxTaskQueryLimit
|
||||
}
|
||||
|
||||
showCount := desired
|
||||
if len(items) < showCount {
|
||||
showCount = len(items)
|
||||
}
|
||||
|
||||
preview := items[:showCount]
|
||||
lines := make([]string, 0, len(preview))
|
||||
for idx, item := range preview {
|
||||
deadline := strings.TrimSpace(item.DeadlineAt)
|
||||
if deadline == "" {
|
||||
deadline = "无明确截止时间"
|
||||
}
|
||||
status := "未完成"
|
||||
if item.IsCompleted {
|
||||
status = "已完成"
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf(
|
||||
"%d. %s(%s,%s,截止:%s)",
|
||||
idx+1,
|
||||
item.Title,
|
||||
item.PriorityLabel,
|
||||
status,
|
||||
deadline,
|
||||
))
|
||||
}
|
||||
|
||||
header := fmt.Sprintf("给你整理了 %d 条任务:", showCount)
|
||||
if lead := extractSafeReplyLead(llmReply); lead != "" {
|
||||
header = lead + "\n" + header
|
||||
}
|
||||
|
||||
reply := header + "\n" + strings.Join(lines, "\n")
|
||||
if len(items) > showCount {
|
||||
reply += fmt.Sprintf("\n另外还有 %d 条匹配任务,要不要我继续往下列?", len(items)-showCount)
|
||||
}
|
||||
return reply
|
||||
}
|
||||
|
||||
func extractSafeReplyLead(llmReply string) string {
|
||||
text := strings.TrimSpace(llmReply)
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
lower := strings.ToLower(text)
|
||||
if strings.Contains(text, "\n") ||
|
||||
strings.Contains(text, "#") ||
|
||||
strings.Contains(lower, "1.") ||
|
||||
strings.Contains(text, "1、") ||
|
||||
strings.Contains(text, "以下是") {
|
||||
return ""
|
||||
}
|
||||
if len([]rune(text)) > 30 {
|
||||
return ""
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func sortTaskQueryItems(items []agentmodel.TaskQueryItem, plan agentmodel.TaskQueryPlan) {
|
||||
if len(items) <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
sortBy := strings.ToLower(strings.TrimSpace(plan.SortBy))
|
||||
order := strings.ToLower(strings.TrimSpace(plan.Order))
|
||||
if order != "desc" {
|
||||
order = "asc"
|
||||
}
|
||||
|
||||
sort.SliceStable(items, func(i, j int) bool {
|
||||
left := items[i]
|
||||
right := items[j]
|
||||
|
||||
switch sortBy {
|
||||
case "priority":
|
||||
if left.PriorityGroup != right.PriorityGroup {
|
||||
if order == "desc" {
|
||||
return left.PriorityGroup > right.PriorityGroup
|
||||
}
|
||||
return left.PriorityGroup < right.PriorityGroup
|
||||
}
|
||||
return left.ID > right.ID
|
||||
case "id":
|
||||
if order == "desc" {
|
||||
return left.ID > right.ID
|
||||
}
|
||||
return left.ID < right.ID
|
||||
default:
|
||||
leftTime, leftOK := parseTaskQueryItemDeadline(left.DeadlineAt)
|
||||
rightTime, rightOK := parseTaskQueryItemDeadline(right.DeadlineAt)
|
||||
if leftOK && rightOK {
|
||||
if !leftTime.Equal(rightTime) {
|
||||
if order == "desc" {
|
||||
return leftTime.After(rightTime)
|
||||
}
|
||||
return leftTime.Before(rightTime)
|
||||
}
|
||||
return left.ID > right.ID
|
||||
}
|
||||
if leftOK && !rightOK {
|
||||
return true
|
||||
}
|
||||
if !leftOK && rightOK {
|
||||
return false
|
||||
}
|
||||
return left.ID > right.ID
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func parseTaskQueryItemDeadline(raw string) (time.Time, bool) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
parsed, err := time.ParseInLocation("2006-01-02 15:04", text, time.Local)
|
||||
if err != nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return parsed, true
|
||||
}
|
||||
|
||||
func emptyToDash(text string) string {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return "-"
|
||||
}
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
// extractExplicitLimitFromUser 从用户原话里提取显式条数要求。
|
||||
//
|
||||
// 步骤说明:
|
||||
// 1. 先识别阿拉伯数字表达,例如“前3个”“给我5条”“top 10”。
|
||||
// 2. 再识别中文数字表达,例如“前五个”“来三个”。
|
||||
// 3. 最终统一约束到 1~20 范围内。
|
||||
func extractExplicitLimitFromUser(userMessage string) (int, bool) {
|
||||
text := strings.TrimSpace(userMessage)
|
||||
if text == "" {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
for _, pattern := range explicitLimitPatterns {
|
||||
matched := pattern.FindStringSubmatch(text)
|
||||
if len(matched) < 2 {
|
||||
continue
|
||||
}
|
||||
number, err := strconv.Atoi(strings.TrimSpace(matched[1]))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
return normalizeExplicitLimit(number)
|
||||
}
|
||||
|
||||
for _, prefix := range []string{"前", "来", "给我"} {
|
||||
for digit, number := range chineseDigitMap {
|
||||
token := prefix + string(digit)
|
||||
if strings.Contains(text, token) {
|
||||
return normalizeExplicitLimit(number)
|
||||
}
|
||||
for _, suffix := range []string{"个", "条", "项"} {
|
||||
if strings.Contains(text, token+suffix) {
|
||||
return normalizeExplicitLimit(number)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func normalizeExplicitLimit(number int) (int, bool) {
|
||||
if number <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
if number > agentmodel.MaxTaskQueryLimit {
|
||||
number = agentmodel.MaxTaskQueryLimit
|
||||
}
|
||||
return number, true
|
||||
}
|
||||
286
backend/agent/node/taskquery_tool.go
Normal file
286
backend/agent/node/taskquery_tool.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package agentnode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
agentmodel "github.com/LoveLosita/smartflow/backend/agent/model"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
toolutils "github.com/cloudwego/eino/components/tool/utils"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
const (
|
||||
ToolNameTaskQueryTasks = "query_tasks"
|
||||
ToolDescTaskQueryTasks = "按象限、关键词、截止时间筛选并排序任务,返回结构化任务列表"
|
||||
)
|
||||
|
||||
var taskQueryTimeLayouts = []string{
|
||||
time.RFC3339,
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02 15:04",
|
||||
"2006-01-02",
|
||||
}
|
||||
|
||||
// TaskQueryToolDeps 描述任务查询工具依赖的外部查询能力。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. QueryTasks 负责读取真实任务数据。
|
||||
// 2. 工具层只负责参数校验、归一化和结构化输出,不直接耦合 DAO。
|
||||
type TaskQueryToolDeps struct {
|
||||
QueryTasks func(ctx context.Context, req TaskQueryRequest) ([]TaskQueryTaskRecord, error)
|
||||
}
|
||||
|
||||
// Validate 负责校验任务查询工具依赖是否齐全。
|
||||
func (d TaskQueryToolDeps) Validate() error {
|
||||
if d.QueryTasks == nil {
|
||||
return errors.New("task query tool deps: QueryTasks is nil")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TaskQueryToolBundle 同时返回工具实例和工具元信息。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. Tools 给执行节点使用。
|
||||
// 2. ToolInfos 给模型注册 schema 使用。
|
||||
type TaskQueryToolBundle struct {
|
||||
Tools []tool.BaseTool
|
||||
ToolInfos []*schema.ToolInfo
|
||||
}
|
||||
|
||||
// TaskQueryRequest 是工具层传给业务层的内部查询请求。
|
||||
type TaskQueryRequest struct {
|
||||
UserID int
|
||||
Quadrant *int
|
||||
SortBy string
|
||||
Order string
|
||||
Limit int
|
||||
IncludeCompleted bool
|
||||
Keyword string
|
||||
DeadlineBefore *time.Time
|
||||
DeadlineAfter *time.Time
|
||||
}
|
||||
|
||||
// TaskQueryTaskRecord 是业务层返回给工具层的任务记录。
|
||||
type TaskQueryTaskRecord struct {
|
||||
ID int
|
||||
Title string
|
||||
PriorityGroup int
|
||||
IsCompleted bool
|
||||
DeadlineAt *time.Time
|
||||
UrgencyThresholdAt *time.Time
|
||||
}
|
||||
|
||||
// TaskQueryToolInput 是暴露给大模型的工具入参。
|
||||
type TaskQueryToolInput struct {
|
||||
Quadrant *int `json:"quadrant,omitempty" jsonschema:"description=可选象限(1~4)"`
|
||||
SortBy string `json:"sort_by,omitempty" jsonschema:"description=排序字段(deadline|priority|id)"`
|
||||
Order string `json:"order,omitempty" jsonschema:"description=排序方向(asc|desc)"`
|
||||
Limit int `json:"limit,omitempty" jsonschema:"description=返回条数,默认5,上限20"`
|
||||
IncludeCompleted *bool `json:"include_completed,omitempty" jsonschema:"description=是否包含已完成任务,默认false"`
|
||||
Keyword string `json:"keyword,omitempty" jsonschema:"description=可选标题关键词,模糊匹配"`
|
||||
DeadlineBefore string `json:"deadline_before,omitempty" jsonschema:"description=可选截止时间上界,支持RFC3339或yyyy-MM-dd HH:mm"`
|
||||
DeadlineAfter string `json:"deadline_after,omitempty" jsonschema:"description=可选截止时间下界,支持RFC3339或yyyy-MM-dd HH:mm"`
|
||||
}
|
||||
|
||||
// TaskQueryToolOutput 是返回给模型的结构化结果。
|
||||
type TaskQueryToolOutput struct {
|
||||
Total int `json:"total"`
|
||||
Items []agentmodel.TaskQueryItem `json:"items"`
|
||||
}
|
||||
|
||||
// BuildTaskQueryToolBundle 负责构建任务查询工具包。
|
||||
//
|
||||
// 步骤说明:
|
||||
// 1. 先校验依赖是否完整,避免生成一个运行时必定失败的工具。
|
||||
// 2. 再把输入归一化成内部请求,调用业务查询函数拿到真实数据。
|
||||
// 3. 最后把业务记录转换成统一的轻量任务视图,供模型和反思节点复用。
|
||||
func BuildTaskQueryToolBundle(ctx context.Context, deps TaskQueryToolDeps) (*TaskQueryToolBundle, error) {
|
||||
if err := deps.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
queryTool, err := toolutils.InferTool(
|
||||
ToolNameTaskQueryTasks,
|
||||
ToolDescTaskQueryTasks,
|
||||
func(ctx context.Context, input *TaskQueryToolInput) (*TaskQueryToolOutput, error) {
|
||||
req, err := normalizeTaskQueryToolInput(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
records, err := deps.QueryTasks(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
items := make([]agentmodel.TaskQueryItem, 0, len(records))
|
||||
for _, record := range records {
|
||||
items = append(items, agentmodel.TaskQueryItem{
|
||||
ID: record.ID,
|
||||
Title: record.Title,
|
||||
PriorityGroup: record.PriorityGroup,
|
||||
PriorityLabel: agentmodel.PriorityLabelCN(record.PriorityGroup),
|
||||
IsCompleted: record.IsCompleted,
|
||||
DeadlineAt: formatTaskQueryTime(record.DeadlineAt),
|
||||
UrgencyThresholdAt: formatTaskQueryTime(record.UrgencyThresholdAt),
|
||||
})
|
||||
}
|
||||
|
||||
return &TaskQueryToolOutput{
|
||||
Total: len(items),
|
||||
Items: items,
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("构建任务查询工具失败: %w", err)
|
||||
}
|
||||
|
||||
tools := []tool.BaseTool{queryTool}
|
||||
infos, err := collectToolInfos(ctx, tools)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TaskQueryToolBundle{
|
||||
Tools: tools,
|
||||
ToolInfos: infos,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetTaskQueryInvokableToolByName 按工具名提取可执行工具。
|
||||
func GetTaskQueryInvokableToolByName(bundle *TaskQueryToolBundle, name string) (tool.InvokableTool, error) {
|
||||
if bundle == nil {
|
||||
return nil, errors.New("task query tool bundle is nil")
|
||||
}
|
||||
return getInvokableToolByName(bundle.Tools, bundle.ToolInfos, name)
|
||||
}
|
||||
|
||||
// normalizeTaskQueryToolInput 负责参数默认值回填与合法性校验。
|
||||
//
|
||||
// 步骤说明:
|
||||
// 1. 先准备默认值,保证空参数也能执行一次合理查询。
|
||||
// 2. 再校验象限、排序、条数和时间区间,阻止非法参数下沉到业务层。
|
||||
// 3. 若上下界冲突,则直接返回错误,避免查出必为空的结果。
|
||||
func normalizeTaskQueryToolInput(input *TaskQueryToolInput) (TaskQueryRequest, error) {
|
||||
req := TaskQueryRequest{
|
||||
SortBy: "deadline",
|
||||
Order: "asc",
|
||||
Limit: agentmodel.DefaultTaskQueryLimit,
|
||||
IncludeCompleted: false,
|
||||
}
|
||||
if input == nil {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
if input.Quadrant != nil {
|
||||
if *input.Quadrant < 1 || *input.Quadrant > 4 {
|
||||
return TaskQueryRequest{}, fmt.Errorf("quadrant=%d 非法,必须在 1~4", *input.Quadrant)
|
||||
}
|
||||
quadrant := *input.Quadrant
|
||||
req.Quadrant = &quadrant
|
||||
}
|
||||
|
||||
if sortBy := strings.ToLower(strings.TrimSpace(input.SortBy)); sortBy != "" {
|
||||
req.SortBy = sortBy
|
||||
}
|
||||
switch req.SortBy {
|
||||
case "deadline", "priority", "id":
|
||||
default:
|
||||
return TaskQueryRequest{}, fmt.Errorf("sort_by=%s 非法,仅支持 deadline|priority|id", req.SortBy)
|
||||
}
|
||||
|
||||
if order := strings.ToLower(strings.TrimSpace(input.Order)); order != "" {
|
||||
req.Order = order
|
||||
}
|
||||
switch req.Order {
|
||||
case "asc", "desc":
|
||||
default:
|
||||
return TaskQueryRequest{}, fmt.Errorf("order=%s 非法,仅支持 asc|desc", req.Order)
|
||||
}
|
||||
|
||||
if input.Limit > 0 {
|
||||
req.Limit = input.Limit
|
||||
}
|
||||
if req.Limit > agentmodel.MaxTaskQueryLimit {
|
||||
req.Limit = agentmodel.MaxTaskQueryLimit
|
||||
}
|
||||
if req.Limit <= 0 {
|
||||
req.Limit = agentmodel.DefaultTaskQueryLimit
|
||||
}
|
||||
|
||||
if input.IncludeCompleted != nil {
|
||||
req.IncludeCompleted = *input.IncludeCompleted
|
||||
}
|
||||
req.Keyword = strings.TrimSpace(input.Keyword)
|
||||
|
||||
before, err := parseTaskQueryBoundaryTime(input.DeadlineBefore, true)
|
||||
if err != nil {
|
||||
return TaskQueryRequest{}, err
|
||||
}
|
||||
after, err := parseTaskQueryBoundaryTime(input.DeadlineAfter, false)
|
||||
if err != nil {
|
||||
return TaskQueryRequest{}, err
|
||||
}
|
||||
req.DeadlineBefore = before
|
||||
req.DeadlineAfter = after
|
||||
if req.DeadlineBefore != nil && req.DeadlineAfter != nil && req.DeadlineAfter.After(*req.DeadlineBefore) {
|
||||
return TaskQueryRequest{}, errors.New("deadline_after 不能晚于 deadline_before")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// parseTaskQueryBoundaryTime 解析截止时间上下界。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. isUpper=true 时,纯日期补到当天 23:59:59。
|
||||
// 2. isUpper=false 时,纯日期补到当天 00:00:00。
|
||||
// 3. 不支持的格式直接返回错误,由调用方决定是否回退。
|
||||
func parseTaskQueryBoundaryTime(raw string, isUpper bool) (*time.Time, error) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
loc := time.Local
|
||||
for _, layout := range taskQueryTimeLayouts {
|
||||
var (
|
||||
parsed time.Time
|
||||
err error
|
||||
)
|
||||
if layout == time.RFC3339 {
|
||||
parsed, err = time.Parse(layout, text)
|
||||
if err == nil {
|
||||
parsed = parsed.In(loc)
|
||||
}
|
||||
} else {
|
||||
parsed, err = time.ParseInLocation(layout, text, loc)
|
||||
}
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if layout == "2006-01-02" {
|
||||
if isUpper {
|
||||
parsed = time.Date(parsed.Year(), parsed.Month(), parsed.Day(), 23, 59, 59, 0, loc)
|
||||
} else {
|
||||
parsed = time.Date(parsed.Year(), parsed.Month(), parsed.Day(), 0, 0, 0, 0, loc)
|
||||
}
|
||||
}
|
||||
return &parsed, nil
|
||||
}
|
||||
return nil, fmt.Errorf("时间格式不支持: %s", text)
|
||||
}
|
||||
|
||||
// formatTaskQueryTime 负责把内部时间格式化为给模型展示的分钟级文本。
|
||||
func formatTaskQueryTime(value *time.Time) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
return value.In(time.Local).Format("2006-01-02 15:04")
|
||||
}
|
||||
74
backend/agent/node/tool_common.go
Normal file
74
backend/agent/node/tool_common.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package agentnode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// collectToolInfos 负责批量提取工具元信息,供模型注册与工具索引复用。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只负责调用 tool.Info 并聚合返回结果。
|
||||
// 2. 不负责校验工具是否可执行,也不负责按名称检索工具。
|
||||
func collectToolInfos(ctx context.Context, tools []tool.BaseTool) ([]*schema.ToolInfo, error) {
|
||||
infos := make([]*schema.ToolInfo, 0, len(tools))
|
||||
for _, currentTool := range tools {
|
||||
info, err := currentTool.Info(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取工具信息失败: %w", err)
|
||||
}
|
||||
infos = append(infos, info)
|
||||
}
|
||||
return infos, nil
|
||||
}
|
||||
|
||||
// buildInvokableToolMap 负责把工具列表转换成“工具名 -> 可执行工具”的索引表。
|
||||
//
|
||||
// 步骤说明:
|
||||
// 1. 先校验 tools 与 infos 是否一一对应,避免后续按下标取值时出现错配。
|
||||
// 2. 再校验每个工具都带有合法名字,并且确实实现了 InvokableTool 接口。
|
||||
// 3. 任一步失败都立即返回错误,避免 graph 在运行期拿到半残缺的工具集合。
|
||||
func buildInvokableToolMap(tools []tool.BaseTool, infos []*schema.ToolInfo) (map[string]tool.InvokableTool, error) {
|
||||
if len(tools) == 0 || len(infos) == 0 {
|
||||
return nil, errors.New("tool bundle is empty")
|
||||
}
|
||||
if len(tools) != len(infos) {
|
||||
return nil, errors.New("tool bundle mismatch")
|
||||
}
|
||||
|
||||
result := make(map[string]tool.InvokableTool, len(tools))
|
||||
for idx, currentTool := range tools {
|
||||
info := infos[idx]
|
||||
if info == nil || strings.TrimSpace(info.Name) == "" {
|
||||
return nil, errors.New("tool info is invalid")
|
||||
}
|
||||
invokable, ok := currentTool.(tool.InvokableTool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tool %s is not invokable", info.Name)
|
||||
}
|
||||
result[info.Name] = invokable
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// getInvokableToolByName 负责从工具集合中提取指定名称的可执行工具。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责复用统一索引逻辑,避免各业务链路重复写名称查找代码。
|
||||
// 2. 不负责兜底选择其他工具;未命中时直接返回错误,由上层决定如何处理。
|
||||
func getInvokableToolByName(tools []tool.BaseTool, infos []*schema.ToolInfo, name string) (tool.InvokableTool, error) {
|
||||
invokableMap, err := buildInvokableToolMap(tools, infos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
invokable, ok := invokableMap[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tool %s not found", name)
|
||||
}
|
||||
return invokable, nil
|
||||
}
|
||||
Reference in New Issue
Block a user