Files
smartmate/backend/agent2/node/taskquery.go
Losita e06284d0b0 Version: 0.7.6.dev.260325
后端:
- ♻️ 将 `taskquery` 模块迁移至 `agent2`,并完成与 `agent2` 业务链路及整体结构的正式接入

前端:
- 🧱 已完成基础框架搭建,并完成了登录、注册、主页等页面并对接了对应接口;但整体功能实现仍在完善中
2026-03-25 00:49:16 +08:00

730 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package agentnode
import (
"context"
"encoding/json"
"fmt"
"regexp"
"sort"
"strconv"
"strings"
"time"
agentllm "github.com/LoveLosita/smartflow/backend/agent2/llm"
agentmodel "github.com/LoveLosita/smartflow/backend/agent2/model"
agentprompt "github.com/LoveLosita/smartflow/backend/agent2/prompt"
agentstream "github.com/LoveLosita/smartflow/backend/agent2/stream"
"github.com/cloudwego/eino-ext/components/model/ark"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
)
const (
TaskQueryGraphNodePlan = "task_query.plan"
TaskQueryGraphNodeQuadrant = "task_query.quadrant"
TaskQueryGraphNodeTimeAnchor = "task_query.time_anchor"
TaskQueryGraphNodeQuery = "task_query.query"
TaskQueryGraphNodeReflect = "task_query.reflect"
)
var (
explicitLimitPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)\btop\s*(\d{1,2})\b`),
regexp.MustCompile(`前\s*(\d{1,2})\s*(个|条|项)?`),
regexp.MustCompile(`(\d{1,2})\s*(个|条|项)?\s*任务`),
regexp.MustCompile(`给我\s*(\d{1,2})\s*(个|条|项)?`),
}
chineseDigitMap = map[rune]int{
'一': 1,
'二': 2,
'两': 2,
'三': 3,
'四': 4,
'五': 5,
'六': 6,
'七': 7,
'八': 8,
'九': 9,
'十': 10,
}
)
// TaskQueryGraphRunInput 描述一次任务查询图运行需要的依赖。
type TaskQueryGraphRunInput struct {
Model *ark.ChatModel
State *agentmodel.TaskQueryState
Deps TaskQueryToolDeps
EmitStage func(stage, detail string)
}
// TaskQueryNodes 是任务查询图的节点容器。
//
// 职责边界:
// 1. 负责承接请求级依赖,并向 graph 暴露可直接挂载的方法。
// 2. 不负责 graph 编译、service 接线和持久化。
type TaskQueryNodes struct {
input TaskQueryGraphRunInput
queryTool tool.InvokableTool
emitStage agentstream.StageEmitter
}
func NewTaskQueryNodes(input TaskQueryGraphRunInput, queryTool tool.InvokableTool) (*TaskQueryNodes, error) {
if input.Model == nil {
return nil, fmt.Errorf("task query nodes: model is nil")
}
if input.State == nil {
return nil, fmt.Errorf("task query nodes: state is nil")
}
if err := input.Deps.Validate(); err != nil {
return nil, err
}
if queryTool == nil {
return nil, fmt.Errorf("task query nodes: queryTool is nil")
}
return &TaskQueryNodes{
input: input,
queryTool: queryTool,
emitStage: agentstream.WrapStageEmitter(input.EmitStage),
}, nil
}
// Plan 负责把用户原话规划成结构化查询计划。
func (n *TaskQueryNodes) Plan(ctx context.Context, st *agentmodel.TaskQueryState) (*agentmodel.TaskQueryState, error) {
if st == nil {
return nil, fmt.Errorf("task query graph: nil state in plan node")
}
n.emitStage("task_query.plan.generating", "正在一次性规划查询范围、排序和时间条件。")
planned, err := agentllm.PlanTaskQuery(ctx, n.input.Model, st.RequestNowText, st.UserMessage)
if err != nil || planned == nil {
st.UserGoal = "查询任务"
st.Plan = defaultTaskQueryPlan()
return st, nil
}
st.UserGoal = strings.TrimSpace(planned.UserGoal)
if st.UserGoal == "" {
st.UserGoal = "查询任务"
}
st.Plan = normalizeTaskQueryPlan(*planned)
// 1. 若用户原话里明确指定了返回条数,则以后端识别结果为准。
// 2. 这样可以避免规划模型漏掉数量要求,或后续反思 patch 意外改写 limit。
if explicitLimit, found := extractExplicitLimitFromUser(st.UserMessage); found {
st.ExplicitLimit = explicitLimit
st.Plan.Limit = explicitLimit
}
return st, nil
}
// NormalizeQuadrant 负责把象限参数去重并统一成稳定顺序。
func (n *TaskQueryNodes) NormalizeQuadrant(ctx context.Context, st *agentmodel.TaskQueryState) (*agentmodel.TaskQueryState, error) {
_ = ctx
if st == nil {
return nil, fmt.Errorf("task query graph: nil state in quadrant node")
}
n.emitStage("task_query.quadrant.routing", "正在归一化象限筛选范围。")
st.Plan.Quadrants = normalizeQuadrants(st.Plan.Quadrants)
return st, nil
}
// AnchorTime 负责把时间文本边界解析成可执行时间对象。
func (n *TaskQueryNodes) AnchorTime(ctx context.Context, st *agentmodel.TaskQueryState) (*agentmodel.TaskQueryState, error) {
_ = ctx
if st == nil {
return nil, fmt.Errorf("task query graph: nil state in time anchor node")
}
n.emitStage("task_query.time.anchoring", "正在锁定时间过滤边界。")
applyTimeAnchorOnPlan(&st.Plan)
return st, nil
}
// Query 负责真正调用工具查询任务。
func (n *TaskQueryNodes) Query(ctx context.Context, st *agentmodel.TaskQueryState) (*agentmodel.TaskQueryState, error) {
if st == nil {
return nil, fmt.Errorf("task query graph: nil state in query node")
}
n.emitStage("task_query.tool.querying", "正在查询任务数据。")
items, err := n.executePlanByTool(ctx, st.Plan)
if err != nil {
st.LastQueryItems = make([]agentmodel.TaskQueryItem, 0)
st.LastQueryTotal = 0
st.ReflectReason = "查询工具执行失败"
return st, nil
}
st.LastQueryItems = items
st.LastQueryTotal = len(items)
// 1. 如果首轮为空且还没自动放宽过,则做一次可解释的自动放宽。
// 2. 放宽范围仅限关键词、完成状态、时间边界,不主动改象限与 limit避免语义漂移。
if st.LastQueryTotal == 0 && !st.AutoBroadenApplied {
broadenedPlan, changed := autoBroadenPlan(st.Plan)
if changed {
st.AutoBroadenApplied = true
st.Plan = broadenedPlan
n.emitStage("task_query.tool.broadened", "首次查询为空,已自动放宽条件再试一次。")
retryItems, retryErr := n.executePlanByTool(ctx, st.Plan)
if retryErr == nil {
st.LastQueryItems = retryItems
st.LastQueryTotal = len(retryItems)
}
}
}
return st, nil
}
// Reflect 负责判断当前结果是否满足用户诉求,并决定是否重试。
func (n *TaskQueryNodes) Reflect(ctx context.Context, st *agentmodel.TaskQueryState) (*agentmodel.TaskQueryState, error) {
if st == nil {
return nil, fmt.Errorf("task query graph: nil state in reflect node")
}
n.emitStage("task_query.reflecting", "正在判断结果是否贴合你的需求。")
reflectPrompt := agentprompt.BuildTaskQueryReflectUserPrompt(
st.RequestNowText,
st.UserMessage,
st.UserGoal,
summarizeTaskQueryPlan(st.Plan),
st.RetryCount,
st.MaxReflectRetry,
summarizeTaskQueryItems(st.LastQueryItems, 6),
)
reflectResult, err := agentllm.ReflectTaskQuery(ctx, n.input.Model, reflectPrompt)
if err != nil || reflectResult == nil {
st.NeedRetry = false
st.FinalReply = buildTaskQueryFallbackReply(st.LastQueryItems)
return st, nil
}
st.ReflectReason = strings.TrimSpace(reflectResult.Reason)
if reflectResult.Satisfied {
st.NeedRetry = false
st.FinalReply = buildTaskQueryFinalReply(st.LastQueryItems, st.Plan, strings.TrimSpace(reflectResult.Reply))
return st, nil
}
if reflectResult.NeedRetry && st.RetryCount < st.MaxReflectRetry {
st.Plan = applyRetryPatch(st.Plan, reflectResult.RetryPatch, st.ExplicitLimit)
st.RetryCount++
st.NeedRetry = true
if reply := strings.TrimSpace(reflectResult.Reply); reply != "" {
st.FinalReply = reply
}
return st, nil
}
st.NeedRetry = false
st.FinalReply = buildTaskQueryFinalReply(st.LastQueryItems, st.Plan, strings.TrimSpace(reflectResult.Reply))
return st, nil
}
func (n *TaskQueryNodes) NextAfterReflect(ctx context.Context, st *agentmodel.TaskQueryState) (string, error) {
_ = ctx
if st != nil && st.NeedRetry {
return TaskQueryGraphNodeQuery, nil
}
return compose.END, nil
}
func (n *TaskQueryNodes) executePlanByTool(ctx context.Context, plan agentmodel.TaskQueryPlan) ([]agentmodel.TaskQueryItem, error) {
if n.queryTool == nil {
return nil, fmt.Errorf("task query tool is nil")
}
merged := make([]agentmodel.TaskQueryItem, 0, plan.Limit)
seen := make(map[int]struct{}, plan.Limit*2)
runOne := func(quadrant *int) error {
input := TaskQueryToolInput{
Quadrant: quadrant,
SortBy: plan.SortBy,
Order: plan.Order,
Limit: plan.Limit,
Keyword: plan.Keyword,
DeadlineBefore: plan.DeadlineBeforeText,
DeadlineAfter: plan.DeadlineAfterText,
}
includeCompleted := plan.IncludeCompleted
input.IncludeCompleted = &includeCompleted
rawInput, err := json.Marshal(input)
if err != nil {
return err
}
rawOutput, err := n.queryTool.InvokableRun(ctx, string(rawInput))
if err != nil {
return err
}
parsed, err := agentllm.ParseJSONObject[TaskQueryToolOutput](rawOutput)
if err != nil {
return err
}
for _, item := range parsed.Items {
if _, exists := seen[item.ID]; exists {
continue
}
seen[item.ID] = struct{}{}
merged = append(merged, item)
}
return nil
}
if len(plan.Quadrants) == 0 {
if err := runOne(nil); err != nil {
return nil, err
}
} else {
for _, quadrant := range plan.Quadrants {
q := quadrant
if err := runOne(&q); err != nil {
return nil, err
}
}
}
sortTaskQueryItems(merged, plan)
if len(merged) > plan.Limit {
merged = merged[:plan.Limit]
}
return merged, nil
}
func defaultTaskQueryPlan() agentmodel.TaskQueryPlan {
return agentmodel.TaskQueryPlan{
SortBy: "deadline",
Order: "asc",
Limit: agentmodel.DefaultTaskQueryLimit,
IncludeCompleted: false,
}
}
func normalizeTaskQueryPlan(raw agentllm.TaskQueryPlanOutput) agentmodel.TaskQueryPlan {
plan := defaultTaskQueryPlan()
plan.Quadrants = normalizeQuadrants(raw.Quadrants)
if sortBy := strings.ToLower(strings.TrimSpace(raw.SortBy)); sortBy == "deadline" || sortBy == "priority" || sortBy == "id" {
plan.SortBy = sortBy
}
if order := strings.ToLower(strings.TrimSpace(raw.Order)); order == "asc" || order == "desc" {
plan.Order = order
}
if raw.Limit > 0 {
plan.Limit = raw.Limit
}
if plan.Limit > agentmodel.MaxTaskQueryLimit {
plan.Limit = agentmodel.MaxTaskQueryLimit
}
if plan.Limit <= 0 {
plan.Limit = agentmodel.DefaultTaskQueryLimit
}
if raw.IncludeCompleted != nil {
plan.IncludeCompleted = *raw.IncludeCompleted
}
plan.Keyword = strings.TrimSpace(raw.Keyword)
plan.DeadlineBeforeText = strings.TrimSpace(raw.DeadlineBefore)
plan.DeadlineAfterText = strings.TrimSpace(raw.DeadlineAfter)
applyTimeAnchorOnPlan(&plan)
return plan
}
func normalizeQuadrants(quadrants []int) []int {
if len(quadrants) == 0 {
return nil
}
seen := make(map[int]struct{}, len(quadrants))
result := make([]int, 0, len(quadrants))
for _, quadrant := range quadrants {
if quadrant < 1 || quadrant > 4 {
continue
}
if _, exists := seen[quadrant]; exists {
continue
}
seen[quadrant] = struct{}{}
result = append(result, quadrant)
}
sort.Ints(result)
if len(result) == 0 || len(result) == 4 {
return nil
}
return result
}
func applyTimeAnchorOnPlan(plan *agentmodel.TaskQueryPlan) {
if plan == nil {
return
}
before, errBefore := parseTaskQueryBoundaryTime(plan.DeadlineBeforeText, true)
after, errAfter := parseTaskQueryBoundaryTime(plan.DeadlineAfterText, false)
if errBefore != nil {
plan.DeadlineBefore = nil
plan.DeadlineBeforeText = ""
} else {
plan.DeadlineBefore = before
}
if errAfter != nil {
plan.DeadlineAfter = nil
plan.DeadlineAfterText = ""
} else {
plan.DeadlineAfter = after
}
if plan.DeadlineBefore != nil && plan.DeadlineAfter != nil && plan.DeadlineAfter.After(*plan.DeadlineBefore) {
plan.DeadlineBefore = nil
plan.DeadlineAfter = nil
plan.DeadlineBeforeText = ""
plan.DeadlineAfterText = ""
}
}
func autoBroadenPlan(plan agentmodel.TaskQueryPlan) (agentmodel.TaskQueryPlan, bool) {
broadened := plan
changed := false
if strings.TrimSpace(broadened.Keyword) != "" {
broadened.Keyword = ""
changed = true
}
if !broadened.IncludeCompleted {
broadened.IncludeCompleted = true
changed = true
}
if broadened.DeadlineBefore != nil || broadened.DeadlineAfter != nil || broadened.DeadlineBeforeText != "" || broadened.DeadlineAfterText != "" {
broadened.DeadlineBefore = nil
broadened.DeadlineAfter = nil
broadened.DeadlineBeforeText = ""
broadened.DeadlineAfterText = ""
changed = true
}
return broadened, changed
}
func applyRetryPatch(plan agentmodel.TaskQueryPlan, patch agentllm.TaskQueryRetryPatch, explicitLimit int) agentmodel.TaskQueryPlan {
next := plan
changed := false
if patch.Quadrants != nil {
next.Quadrants = normalizeQuadrants(*patch.Quadrants)
changed = true
}
if patch.SortBy != nil {
sortBy := strings.ToLower(strings.TrimSpace(*patch.SortBy))
if sortBy == "deadline" || sortBy == "priority" || sortBy == "id" {
next.SortBy = sortBy
changed = true
}
}
if patch.Order != nil {
order := strings.ToLower(strings.TrimSpace(*patch.Order))
if order == "asc" || order == "desc" {
next.Order = order
changed = true
}
}
if patch.Limit != nil && explicitLimit <= 0 {
limit := *patch.Limit
if limit <= 0 {
limit = agentmodel.DefaultTaskQueryLimit
}
if limit > agentmodel.MaxTaskQueryLimit {
limit = agentmodel.MaxTaskQueryLimit
}
next.Limit = limit
changed = true
}
if patch.IncludeCompleted != nil {
next.IncludeCompleted = *patch.IncludeCompleted
changed = true
}
if patch.Keyword != nil {
next.Keyword = strings.TrimSpace(*patch.Keyword)
changed = true
}
if patch.DeadlineBefore != nil {
next.DeadlineBeforeText = strings.TrimSpace(*patch.DeadlineBefore)
changed = true
}
if patch.DeadlineAfter != nil {
next.DeadlineAfterText = strings.TrimSpace(*patch.DeadlineAfter)
changed = true
}
if changed {
applyTimeAnchorOnPlan(&next)
}
if explicitLimit > 0 {
next.Limit = explicitLimit
}
return next
}
func summarizeTaskQueryPlan(plan agentmodel.TaskQueryPlan) string {
quadrants := "全部象限"
if len(plan.Quadrants) > 0 {
parts := make([]string, 0, len(plan.Quadrants))
for _, quadrant := range plan.Quadrants {
parts = append(parts, strconv.Itoa(quadrant))
}
quadrants = strings.Join(parts, ",")
}
return fmt.Sprintf(
"quadrants=%s sort=%s/%s limit=%d include_completed=%t keyword=%s before=%s after=%s",
quadrants,
plan.SortBy,
plan.Order,
plan.Limit,
plan.IncludeCompleted,
emptyToDash(plan.Keyword),
emptyToDash(plan.DeadlineBeforeText),
emptyToDash(plan.DeadlineAfterText),
)
}
func summarizeTaskQueryItems(items []agentmodel.TaskQueryItem, max int) string {
if len(items) == 0 {
return "无结果"
}
if max <= 0 {
max = 5
}
if len(items) > max {
items = items[:max]
}
lines := make([]string, 0, len(items))
for _, item := range items {
lines = append(lines, fmt.Sprintf(
"- #%d %s | 象限=%d | 完成=%t | 截止=%s",
item.ID,
item.Title,
item.PriorityGroup,
item.IsCompleted,
emptyToDash(item.DeadlineAt),
))
}
return strings.Join(lines, "\n")
}
func buildTaskQueryFallbackReply(items []agentmodel.TaskQueryItem) string {
if len(items) == 0 {
return "我这边暂时没找到匹配的任务。你可以再补一句,比如“按截止时间最早的前 3 个”或“只看简单不重要”。"
}
preview := items
if len(preview) > 3 {
preview = preview[:3]
}
lines := make([]string, 0, len(preview))
for _, item := range preview {
lines = append(lines, fmt.Sprintf("%s%s", item.Title, item.PriorityLabel))
}
return fmt.Sprintf("我先给你筛到这些:%s。要不要我再按“更紧急”或“更简单”继续细化", strings.Join(lines, "、"))
}
func buildTaskQueryFinalReply(items []agentmodel.TaskQueryItem, plan agentmodel.TaskQueryPlan, llmReply string) string {
if len(items) == 0 {
base := buildTaskQueryFallbackReply(items)
if strings.TrimSpace(llmReply) == "" {
return base
}
return strings.TrimSpace(llmReply) + "\n" + base
}
desired := plan.Limit
if desired <= 0 {
desired = agentmodel.DefaultTaskQueryLimit
}
if desired > agentmodel.MaxTaskQueryLimit {
desired = agentmodel.MaxTaskQueryLimit
}
showCount := desired
if len(items) < showCount {
showCount = len(items)
}
preview := items[:showCount]
lines := make([]string, 0, len(preview))
for idx, item := range preview {
deadline := strings.TrimSpace(item.DeadlineAt)
if deadline == "" {
deadline = "无明确截止时间"
}
status := "未完成"
if item.IsCompleted {
status = "已完成"
}
lines = append(lines, fmt.Sprintf(
"%d. %s%s%s截止%s",
idx+1,
item.Title,
item.PriorityLabel,
status,
deadline,
))
}
header := fmt.Sprintf("给你整理了 %d 条任务:", showCount)
if lead := extractSafeReplyLead(llmReply); lead != "" {
header = lead + "\n" + header
}
reply := header + "\n" + strings.Join(lines, "\n")
if len(items) > showCount {
reply += fmt.Sprintf("\n另外还有 %d 条匹配任务,要不要我继续往下列?", len(items)-showCount)
}
return reply
}
func extractSafeReplyLead(llmReply string) string {
text := strings.TrimSpace(llmReply)
if text == "" {
return ""
}
lower := strings.ToLower(text)
if strings.Contains(text, "\n") ||
strings.Contains(text, "#") ||
strings.Contains(lower, "1.") ||
strings.Contains(text, "1、") ||
strings.Contains(text, "以下是") {
return ""
}
if len([]rune(text)) > 30 {
return ""
}
return text
}
func sortTaskQueryItems(items []agentmodel.TaskQueryItem, plan agentmodel.TaskQueryPlan) {
if len(items) <= 1 {
return
}
sortBy := strings.ToLower(strings.TrimSpace(plan.SortBy))
order := strings.ToLower(strings.TrimSpace(plan.Order))
if order != "desc" {
order = "asc"
}
sort.SliceStable(items, func(i, j int) bool {
left := items[i]
right := items[j]
switch sortBy {
case "priority":
if left.PriorityGroup != right.PriorityGroup {
if order == "desc" {
return left.PriorityGroup > right.PriorityGroup
}
return left.PriorityGroup < right.PriorityGroup
}
return left.ID > right.ID
case "id":
if order == "desc" {
return left.ID > right.ID
}
return left.ID < right.ID
default:
leftTime, leftOK := parseTaskQueryItemDeadline(left.DeadlineAt)
rightTime, rightOK := parseTaskQueryItemDeadline(right.DeadlineAt)
if leftOK && rightOK {
if !leftTime.Equal(rightTime) {
if order == "desc" {
return leftTime.After(rightTime)
}
return leftTime.Before(rightTime)
}
return left.ID > right.ID
}
if leftOK && !rightOK {
return true
}
if !leftOK && rightOK {
return false
}
return left.ID > right.ID
}
})
}
func parseTaskQueryItemDeadline(raw string) (time.Time, bool) {
text := strings.TrimSpace(raw)
if text == "" {
return time.Time{}, false
}
parsed, err := time.ParseInLocation("2006-01-02 15:04", text, time.Local)
if err != nil {
return time.Time{}, false
}
return parsed, true
}
func emptyToDash(text string) string {
if strings.TrimSpace(text) == "" {
return "-"
}
return strings.TrimSpace(text)
}
// extractExplicitLimitFromUser 从用户原话里提取显式条数要求。
//
// 步骤说明:
// 1. 先识别阿拉伯数字表达例如“前3个”“给我5条”“top 10”。
// 2. 再识别中文数字表达,例如“前五个”“来三个”。
// 3. 最终统一约束到 1~20 范围内。
func extractExplicitLimitFromUser(userMessage string) (int, bool) {
text := strings.TrimSpace(userMessage)
if text == "" {
return 0, false
}
for _, pattern := range explicitLimitPatterns {
matched := pattern.FindStringSubmatch(text)
if len(matched) < 2 {
continue
}
number, err := strconv.Atoi(strings.TrimSpace(matched[1]))
if err != nil {
continue
}
return normalizeExplicitLimit(number)
}
for _, prefix := range []string{"前", "来", "给我"} {
for digit, number := range chineseDigitMap {
token := prefix + string(digit)
if strings.Contains(text, token) {
return normalizeExplicitLimit(number)
}
for _, suffix := range []string{"个", "条", "项"} {
if strings.Contains(text, token+suffix) {
return normalizeExplicitLimit(number)
}
}
}
}
return 0, false
}
func normalizeExplicitLimit(number int) (int, bool) {
if number <= 0 {
return 0, false
}
if number > agentmodel.MaxTaskQueryLimit {
number = agentmodel.MaxTaskQueryLimit
}
return number, true
}