Version: 0.8.3.dev.260328
后端: 1.彻底删除原agent文件夹,并将现agent2文件夹全量重命名为agent(包括全部涉及到的文件以及文档、注释),迁移工作完美结束 2.修复了重试消息的相关逻辑问题 前端: 1.改善了一些交互体验,修复了一些bug,现在只剩少的功能了,现存的bug基本都修复完毕 全仓库: 1.更新了决策记录和README文档
This commit is contained in:
@@ -1,23 +0,0 @@
|
||||
# backend/agent 目录说明
|
||||
|
||||
该目录已按“路由 / 聊天 / 随口记”三层拆分,便于阅读、调试与扩展:
|
||||
|
||||
1. `route/`
|
||||
- `route.go`:只负责模型控制码分流(`quick_note` / `chat`)。
|
||||
- 提供控制码解析、nonce 校验、路由兜底,不参与写库与回复拼装。
|
||||
|
||||
2. `chat/`
|
||||
- `stream.go`:普通聊天流式输出封装(SSE/OpenAI 兼容 chunk 转换)。
|
||||
- `prompt.go`:聊天主系统提示词。
|
||||
|
||||
3. `quicknote/`
|
||||
- `graph.go`:只负责图编排连线与分支,不承载节点内部实现。
|
||||
- `nodes.go`:节点实现(意图识别、优先级评估、持久化、分支选择)。
|
||||
- `tool.go`:工具定义、参数校验、deadline 解析、写库工具打包。
|
||||
- `state.go`:随口记状态容器与重试状态记录。
|
||||
- `prompt.go`:随口记提示词(控制码路由、聚合规划、优先级评估、回复润色)。
|
||||
|
||||
4. `README.md`(当前文件)
|
||||
- 记录目录职责边界,帮助后续继续按同样范式扩展 `query/update` 等技能链路。
|
||||
|
||||
> 说明:服务层仍通过 `RunQuickNoteGraph` 调用随口记图;若判定为非随口记意图,会自动回落到普通流式聊天链路。
|
||||
@@ -1,4 +1,4 @@
|
||||
package chat
|
||||
package agentchat
|
||||
|
||||
const (
|
||||
// SystemPrompt 全局系统人设:定义 SmartFlow 的基本调性
|
||||
|
||||
@@ -1,93 +1,19 @@
|
||||
package chat
|
||||
package agentchat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
agentllm "github.com/LoveLosita/smartflow/backend/agent/llm"
|
||||
agentstream "github.com/LoveLosita/smartflow/backend/agent/stream"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/google/uuid"
|
||||
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
// StreamResponse 是 OpenAI/DeepSeek 兼容的流式 chunk 结构。
|
||||
type StreamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []StreamChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type StreamChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta StreamDelta `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type StreamDelta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
}
|
||||
|
||||
// ToOpenAIStream 将单个 Eino chunk 转为 OpenAI 兼容 JSON。
|
||||
func ToOpenAIStream(chunk *schema.Message, requestID, modelName string, created int64, includeRole bool) (string, error) {
|
||||
delta := StreamDelta{}
|
||||
if includeRole {
|
||||
delta.Role = "assistant"
|
||||
}
|
||||
if chunk != nil {
|
||||
delta.Content = chunk.Content
|
||||
delta.ReasoningContent = chunk.ReasoningContent
|
||||
}
|
||||
|
||||
if delta.Role == "" && delta.Content == "" && delta.ReasoningContent == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
dto := StreamResponse{
|
||||
ID: requestID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: created,
|
||||
Model: modelName,
|
||||
Choices: []StreamChoice{{
|
||||
Index: 0,
|
||||
Delta: delta,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(dto)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
|
||||
// ToOpenAIFinishStream 生成结束 chunk(finish_reason=stop)。
|
||||
func ToOpenAIFinishStream(requestID, modelName string, created int64) (string, error) {
|
||||
stop := "stop"
|
||||
dto := StreamResponse{
|
||||
ID: requestID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: created,
|
||||
Model: modelName,
|
||||
Choices: []StreamChoice{{
|
||||
Index: 0,
|
||||
Delta: StreamDelta{},
|
||||
FinishReason: &stop,
|
||||
}},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(dto)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
|
||||
// StreamChat 负责模型流式输出,并在关键节点打点:
|
||||
// 1) 流连接建立(llm.Stream 返回)
|
||||
// 2) 首包到达(首字延迟)
|
||||
@@ -103,7 +29,8 @@ func StreamChat(
|
||||
traceID string,
|
||||
chatID string,
|
||||
requestStart time.Time,
|
||||
) (string, *schema.TokenUsage, error) {
|
||||
reasoningStartAt *time.Time,
|
||||
) (string, string, int, *schema.TokenUsage, error) {
|
||||
/*callStart := time.Now()*/
|
||||
|
||||
messages := make([]*schema.Message, 0)
|
||||
@@ -123,7 +50,7 @@ func StreamChat(
|
||||
/*connectStart := time.Now()*/
|
||||
reader, err := llm.Stream(ctx, messages, ark.WithThinking(thinking))
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", "", 0, nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
@@ -135,6 +62,12 @@ func StreamChat(
|
||||
firstChunk := true
|
||||
chunkCount := 0
|
||||
var tokenUsage *schema.TokenUsage
|
||||
var localReasoningStartAt *time.Time
|
||||
if reasoningStartAt != nil && !reasoningStartAt.IsZero() {
|
||||
startCopy := reasoningStartAt.In(time.Local)
|
||||
localReasoningStartAt = &startCopy
|
||||
}
|
||||
var reasoningEndAt *time.Time
|
||||
/*streamRecvStart := time.Now()
|
||||
|
||||
log.Printf("打点|流连接建立|trace_id=%s|chat_id=%s|request_id=%s|本步耗时_ms=%d|请求累计_ms=%d|history_len=%d",
|
||||
@@ -147,29 +80,42 @@ func StreamChat(
|
||||
)*/
|
||||
|
||||
var fullText strings.Builder
|
||||
var reasoningText strings.Builder
|
||||
for {
|
||||
chunk, err := reader.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", "", 0, nil, err
|
||||
}
|
||||
|
||||
// 优先记录模型真实 usage(通常在尾块返回,部分模型也可能中途返回)。
|
||||
if chunk != nil && chunk.ResponseMeta != nil && chunk.ResponseMeta.Usage != nil {
|
||||
tokenUsage = mergeTokenUsage(tokenUsage, chunk.ResponseMeta.Usage)
|
||||
tokenUsage = agentllm.MergeUsage(tokenUsage, chunk.ResponseMeta.Usage)
|
||||
}
|
||||
|
||||
fullText.WriteString(chunk.Content)
|
||||
if chunk != nil {
|
||||
if strings.TrimSpace(chunk.ReasoningContent) != "" && localReasoningStartAt == nil {
|
||||
now := time.Now()
|
||||
localReasoningStartAt = &now
|
||||
}
|
||||
if strings.TrimSpace(chunk.Content) != "" && localReasoningStartAt != nil && reasoningEndAt == nil {
|
||||
now := time.Now()
|
||||
reasoningEndAt = &now
|
||||
}
|
||||
fullText.WriteString(chunk.Content)
|
||||
reasoningText.WriteString(chunk.ReasoningContent)
|
||||
}
|
||||
|
||||
payload, err := ToOpenAIStream(chunk, requestID, modelName, created, firstChunk)
|
||||
payload, err := agentstream.ToOpenAIStream(chunk, requestID, modelName, created, firstChunk)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", "", 0, nil, err
|
||||
}
|
||||
if payload != "" {
|
||||
outChan <- payload
|
||||
chunkCount++
|
||||
firstChunk = false
|
||||
/*if firstChunk {
|
||||
log.Printf("打点|首包到达|trace_id=%s|chat_id=%s|request_id=%s|本步耗时_ms=%d|请求累计_ms=%d",
|
||||
traceID,
|
||||
@@ -183,9 +129,9 @@ func StreamChat(
|
||||
}
|
||||
}
|
||||
|
||||
finishChunk, err := ToOpenAIFinishStream(requestID, modelName, created)
|
||||
finishChunk, err := agentstream.ToOpenAIFinishStream(requestID, modelName, created)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", "", 0, nil, err
|
||||
}
|
||||
outChan <- finishChunk
|
||||
outChan <- "[DONE]"
|
||||
@@ -200,39 +146,16 @@ func StreamChat(
|
||||
time.Since(requestStart).Milliseconds(),
|
||||
)*/
|
||||
|
||||
return fullText.String(), tokenUsage, nil
|
||||
}
|
||||
|
||||
// mergeTokenUsage 合并流式分片中的 usage。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 不同模型的 usage 回传时机不同(中间块/尾块);
|
||||
// 2. 这里按“更大值覆盖”合并,确保最终拿到完整统计;
|
||||
// 3. 只用于统计,不影响流式正文输出。
|
||||
func mergeTokenUsage(base *schema.TokenUsage, incoming *schema.TokenUsage) *schema.TokenUsage {
|
||||
if incoming == nil {
|
||||
return base
|
||||
}
|
||||
if base == nil {
|
||||
copied := *incoming
|
||||
return &copied
|
||||
reasoningDurationSeconds := 0
|
||||
if localReasoningStartAt != nil {
|
||||
if reasoningEndAt == nil {
|
||||
now := time.Now()
|
||||
reasoningEndAt = &now
|
||||
}
|
||||
if reasoningEndAt.After(*localReasoningStartAt) {
|
||||
reasoningDurationSeconds = int(reasoningEndAt.Sub(*localReasoningStartAt) / time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
merged := *base
|
||||
if incoming.PromptTokens > merged.PromptTokens {
|
||||
merged.PromptTokens = incoming.PromptTokens
|
||||
}
|
||||
if incoming.CompletionTokens > merged.CompletionTokens {
|
||||
merged.CompletionTokens = incoming.CompletionTokens
|
||||
}
|
||||
if incoming.TotalTokens > merged.TotalTokens {
|
||||
merged.TotalTokens = incoming.TotalTokens
|
||||
}
|
||||
if incoming.PromptTokenDetails.CachedTokens > merged.PromptTokenDetails.CachedTokens {
|
||||
merged.PromptTokenDetails.CachedTokens = incoming.PromptTokenDetails.CachedTokens
|
||||
}
|
||||
if incoming.CompletionTokensDetails.ReasoningTokens > merged.CompletionTokensDetails.ReasoningTokens {
|
||||
merged.CompletionTokensDetails.ReasoningTokens = incoming.CompletionTokensDetails.ReasoningTokens
|
||||
}
|
||||
return &merged
|
||||
return fullText.String(), reasoningText.String(), reasoningDurationSeconds, tokenUsage, nil
|
||||
}
|
||||
|
||||
41
backend/agent/entrance.go
Normal file
41
backend/agent/entrance.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
agentrouter "github.com/LoveLosita/smartflow/backend/agent/router"
|
||||
)
|
||||
|
||||
// Service 是 agent 模块的总入口。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责接住一次完整的 Agent 请求,并把请求交给统一路由器分流;
|
||||
// 2. 负责维护“路由器 + 各 skill handler”的装配关系;
|
||||
// 3. 不负责具体 skill 的 graph 连线,也不负责节点内部业务实现。
|
||||
type Service struct {
|
||||
dispatcher *agentrouter.Dispatcher
|
||||
}
|
||||
|
||||
// NewService 创建 agent 总入口服务。
|
||||
func NewService(resolver agentrouter.Resolver) *Service {
|
||||
return &Service{
|
||||
dispatcher: agentrouter.NewDispatcher(resolver),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandler 注册某个 skill 的执行入口。
|
||||
func (s *Service) RegisterHandler(action agentrouter.Action, handler agentrouter.SkillHandler) error {
|
||||
if s == nil || s.dispatcher == nil {
|
||||
return errors.New("agent service is not initialized")
|
||||
}
|
||||
return s.dispatcher.Register(action, handler)
|
||||
}
|
||||
|
||||
// Handle 是 agent 的统一处理入口。
|
||||
func (s *Service) Handle(ctx context.Context, req *agentrouter.AgentRequest) (*agentrouter.AgentResponse, error) {
|
||||
if s == nil || s.dispatcher == nil {
|
||||
return nil, errors.New("agent service is not initialized")
|
||||
}
|
||||
return s.dispatcher.Dispatch(ctx, req)
|
||||
}
|
||||
122
backend/agent/graph/quicknote.go
Normal file
122
backend/agent/graph/quicknote.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package agentgraph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
agentmodel "github.com/LoveLosita/smartflow/backend/agent/model"
|
||||
agentnode "github.com/LoveLosita/smartflow/backend/agent/node"
|
||||
agentshared "github.com/LoveLosita/smartflow/backend/agent/shared"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
const (
|
||||
// QuickNoteGraphName 是随口记图编排的稳定标识。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 仅用于 graph 编译和链路标识,方便日志与排障统一定位。
|
||||
// 2. 不参与意图判断,也不承载任务写库的业务语义。
|
||||
QuickNoteGraphName = "quick_note"
|
||||
)
|
||||
|
||||
// RunQuickNoteGraph 负责执行“随口记 -> 判断 -> 提取 -> 落库 -> 收口”的整条图链路。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责输入兜底、工具装配、节点注册与 graph 运行。
|
||||
// 2. 不负责每个节点的具体业务决策,节点内部逻辑由 node 层实现。
|
||||
// 3. 返回的 state 表示整条链路的最终状态,供上层继续拼接响应或写日志。
|
||||
func RunQuickNoteGraph(ctx context.Context, input agentnode.QuickNoteGraphRunInput) (*agentmodel.QuickNoteState, error) {
|
||||
// 1. 先校验最基础依赖,避免图已经启动后才发现模型或状态为空。
|
||||
if input.Model == nil {
|
||||
return nil, errors.New("quick note graph: model is nil")
|
||||
}
|
||||
if input.State == nil {
|
||||
return nil, errors.New("quick note graph: state is nil")
|
||||
}
|
||||
if err := input.Deps.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 补齐当前请求时间,保证后续提示词、时间解析和落库字段都基于同一时刻。
|
||||
if input.State.RequestNow.IsZero() {
|
||||
input.State.RequestNow = agentshared.NowToMinute()
|
||||
}
|
||||
if strings.TrimSpace(input.State.RequestNowText) == "" {
|
||||
input.State.RequestNowText = agentshared.FormatMinute(input.State.RequestNow)
|
||||
}
|
||||
|
||||
// 3. 图运行前统一准备工具与节点容器,避免节点内部重复做依赖解析。
|
||||
toolBundle, err := agentnode.BuildQuickNoteToolBundle(ctx, input.Deps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
createTaskTool, err := agentnode.GetInvokableToolByName(toolBundle, agentnode.ToolNameQuickNoteCreateTask)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes, err := agentnode.NewQuickNoteNodes(input, createTaskTool)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. 主链路保持“意图识别 -> 优先级评估 -> 持久化 -> 退出”,中间通过 branch 决定是否提前结束或重试写库。
|
||||
graph := compose.NewGraph[*agentmodel.QuickNoteState, *agentmodel.QuickNoteState]()
|
||||
if err = graph.AddLambdaNode(agentnode.QuickNoteGraphNodeIntent, compose.InvokableLambda(nodes.Intent)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.QuickNoteGraphNodeRank, compose.InvokableLambda(nodes.Priority)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.QuickNoteGraphNodePersist, compose.InvokableLambda(nodes.Persist)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.QuickNoteGraphNodeExit, compose.InvokableLambda(nodes.Exit)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = graph.AddEdge(compose.START, agentnode.QuickNoteGraphNodeIntent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddBranch(agentnode.QuickNoteGraphNodeIntent, compose.NewGraphBranch(
|
||||
nodes.NextAfterIntent,
|
||||
map[string]bool{
|
||||
agentnode.QuickNoteGraphNodeRank: true,
|
||||
agentnode.QuickNoteGraphNodeExit: true,
|
||||
},
|
||||
)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.QuickNoteGraphNodeExit, compose.END); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.QuickNoteGraphNodeRank, agentnode.QuickNoteGraphNodePersist); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddBranch(agentnode.QuickNoteGraphNodePersist, compose.NewGraphBranch(
|
||||
nodes.NextAfterPersist,
|
||||
map[string]bool{
|
||||
agentnode.QuickNoteGraphNodePersist: true,
|
||||
compose.END: true,
|
||||
},
|
||||
)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 5. persist 节点允许有限次重试,因此最大步数要覆盖首次执行与重试回路。
|
||||
maxSteps := input.State.MaxToolRetry + 10
|
||||
if maxSteps < 12 {
|
||||
maxSteps = 12
|
||||
}
|
||||
|
||||
runnable, err := graph.Compile(ctx,
|
||||
compose.WithGraphName(QuickNoteGraphName),
|
||||
compose.WithMaxRunSteps(maxSteps),
|
||||
compose.WithNodeTriggerMode(compose.AnyPredecessor),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return runnable.Invoke(ctx, input.State)
|
||||
}
|
||||
202
backend/agent/graph/schedule.go
Normal file
202
backend/agent/graph/schedule.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package agentgraph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
agentmodel "github.com/LoveLosita/smartflow/backend/agent/model"
|
||||
agentnode "github.com/LoveLosita/smartflow/backend/agent/node"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
const (
|
||||
SchedulePlanGraphName = "schedule_plan"
|
||||
ScheduleRefineGraphName = "schedule_refine"
|
||||
)
|
||||
|
||||
func RunSchedulePlanGraph(ctx context.Context, input agentnode.SchedulePlanGraphRunInput) (*agentmodel.SchedulePlanState, error) {
|
||||
if input.Model == nil {
|
||||
return nil, errors.New("schedule plan graph: model is nil")
|
||||
}
|
||||
if input.State == nil {
|
||||
return nil, errors.New("schedule plan graph: state is nil")
|
||||
}
|
||||
if err := input.Deps.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if input.DailyRefineConcurrency > 0 {
|
||||
input.State.DailyRefineConcurrency = input.DailyRefineConcurrency
|
||||
}
|
||||
if input.WeeklyAdjustBudget > 0 {
|
||||
input.State.WeeklyAdjustBudget = input.WeeklyAdjustBudget
|
||||
}
|
||||
|
||||
nodes, err := agentnode.NewSchedulePlanNodes(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
graph := compose.NewGraph[*agentmodel.SchedulePlanState, *agentmodel.SchedulePlanState]()
|
||||
if err = graph.AddLambdaNode(agentnode.SchedulePlanGraphNodePlan, compose.InvokableLambda(nodes.Plan)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.SchedulePlanGraphNodeRoughBuild, compose.InvokableLambda(nodes.RoughBuild)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.SchedulePlanGraphNodeExit, compose.InvokableLambda(nodes.Exit)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.SchedulePlanGraphNodeDailySplit, compose.InvokableLambda(nodes.DailySplit)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.SchedulePlanGraphNodeQuickRefine, compose.InvokableLambda(nodes.QuickRefine)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.SchedulePlanGraphNodeDailyRefine, compose.InvokableLambda(nodes.DailyRefine)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.SchedulePlanGraphNodeMerge, compose.InvokableLambda(nodes.Merge)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.SchedulePlanGraphNodeWeeklyRefine, compose.InvokableLambda(nodes.WeeklyRefine)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.SchedulePlanGraphNodeFinalCheck, compose.InvokableLambda(nodes.FinalCheck)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.SchedulePlanGraphNodeReturnPreview, compose.InvokableLambda(nodes.ReturnPreview)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = graph.AddEdge(compose.START, agentnode.SchedulePlanGraphNodePlan); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddBranch(agentnode.SchedulePlanGraphNodePlan, compose.NewGraphBranch(
|
||||
nodes.NextAfterPlan,
|
||||
map[string]bool{
|
||||
agentnode.SchedulePlanGraphNodeRoughBuild: true,
|
||||
agentnode.SchedulePlanGraphNodeExit: true,
|
||||
},
|
||||
)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddBranch(agentnode.SchedulePlanGraphNodeRoughBuild, compose.NewGraphBranch(
|
||||
nodes.NextAfterRoughBuild,
|
||||
map[string]bool{
|
||||
agentnode.SchedulePlanGraphNodeDailySplit: true,
|
||||
agentnode.SchedulePlanGraphNodeQuickRefine: true,
|
||||
agentnode.SchedulePlanGraphNodeWeeklyRefine: true,
|
||||
agentnode.SchedulePlanGraphNodeExit: true,
|
||||
},
|
||||
)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = graph.AddEdge(agentnode.SchedulePlanGraphNodeQuickRefine, agentnode.SchedulePlanGraphNodeWeeklyRefine); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.SchedulePlanGraphNodeDailySplit, agentnode.SchedulePlanGraphNodeDailyRefine); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.SchedulePlanGraphNodeDailyRefine, agentnode.SchedulePlanGraphNodeMerge); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.SchedulePlanGraphNodeMerge, agentnode.SchedulePlanGraphNodeWeeklyRefine); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.SchedulePlanGraphNodeWeeklyRefine, agentnode.SchedulePlanGraphNodeFinalCheck); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.SchedulePlanGraphNodeFinalCheck, agentnode.SchedulePlanGraphNodeReturnPreview); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.SchedulePlanGraphNodeReturnPreview, compose.END); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.SchedulePlanGraphNodeExit, compose.END); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
runnable, err := graph.Compile(ctx,
|
||||
compose.WithGraphName(SchedulePlanGraphName),
|
||||
compose.WithMaxRunSteps(20),
|
||||
compose.WithNodeTriggerMode(compose.AnyPredecessor),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return runnable.Invoke(ctx, input.State)
|
||||
}
|
||||
|
||||
func RunScheduleRefineGraph(ctx context.Context, input agentnode.ScheduleRefineGraphRunInput) (*agentnode.ScheduleRefineState, error) {
|
||||
if input.Model == nil {
|
||||
return nil, errors.New("schedule refine graph: model is nil")
|
||||
}
|
||||
if input.State == nil {
|
||||
return nil, errors.New("schedule refine graph: state is nil")
|
||||
}
|
||||
|
||||
nodes, err := agentnode.NewScheduleRefineNodes(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
graph := compose.NewGraph[*agentmodel.ScheduleRefineState, *agentmodel.ScheduleRefineState]()
|
||||
if err = graph.AddLambdaNode(agentnode.ScheduleRefineGraphNodeContract, compose.InvokableLambda(nodes.Contract)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.ScheduleRefineGraphNodePlan, compose.InvokableLambda(nodes.Plan)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.ScheduleRefineGraphNodeSlice, compose.InvokableLambda(nodes.Slice)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.ScheduleRefineGraphNodeRoute, compose.InvokableLambda(nodes.Route)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.ScheduleRefineGraphNodeReact, compose.InvokableLambda(nodes.React)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.ScheduleRefineGraphNodeHardCheck, compose.InvokableLambda(nodes.HardCheck)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.ScheduleRefineGraphNodeSummary, compose.InvokableLambda(nodes.Summary)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = graph.AddEdge(compose.START, agentnode.ScheduleRefineGraphNodeContract); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.ScheduleRefineGraphNodeContract, agentnode.ScheduleRefineGraphNodePlan); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.ScheduleRefineGraphNodePlan, agentnode.ScheduleRefineGraphNodeSlice); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.ScheduleRefineGraphNodeSlice, agentnode.ScheduleRefineGraphNodeRoute); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.ScheduleRefineGraphNodeRoute, agentnode.ScheduleRefineGraphNodeReact); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.ScheduleRefineGraphNodeReact, agentnode.ScheduleRefineGraphNodeHardCheck); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.ScheduleRefineGraphNodeHardCheck, agentnode.ScheduleRefineGraphNodeSummary); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.ScheduleRefineGraphNodeSummary, compose.END); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
runnable, err := graph.Compile(ctx,
|
||||
compose.WithGraphName(ScheduleRefineGraphName),
|
||||
compose.WithMaxRunSteps(20),
|
||||
compose.WithNodeTriggerMode(compose.AnyPredecessor),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return runnable.Invoke(ctx, input.State)
|
||||
}
|
||||
126
backend/agent/graph/taskquery.go
Normal file
126
backend/agent/graph/taskquery.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package agentgraph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
agentmodel "github.com/LoveLosita/smartflow/backend/agent/model"
|
||||
agentnode "github.com/LoveLosita/smartflow/backend/agent/node"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
const (
|
||||
// TaskQueryGraphName 是任务查询图编排的稳定标识。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 仅用于 graph 编译、日志和排障时标识当前链路。
|
||||
// 2. 不承载路由判断,也不负责描述具体业务含义。
|
||||
TaskQueryGraphName = "task_query"
|
||||
)
|
||||
|
||||
// RunTaskQueryGraph 负责串起任务查询图,并返回最终给用户的回复文本。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责做图运行前的依赖校验、默认值补齐、节点装配与 graph 编译执行。
|
||||
// 2. 不负责实现单个节点的业务细节,这些逻辑由 node 层承接。
|
||||
// 3. 返回值中的 string 是最终可直接透传给上层的回复;error 仅表示链路级失败。
|
||||
func RunTaskQueryGraph(ctx context.Context, input agentnode.TaskQueryGraphRunInput) (string, error) {
|
||||
// 1. 先拦住空模型、空状态和依赖缺失,避免 graph 运行到一半才出现不可恢复错误。
|
||||
if input.Model == nil {
|
||||
return "", errors.New("task query graph: model is nil")
|
||||
}
|
||||
if input.State == nil {
|
||||
return "", errors.New("task query graph: state is nil")
|
||||
}
|
||||
if err := input.Deps.Validate(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 2. 请求时间缺失时补齐当前时间,保证后续时间锚定与提示词上下文稳定。
|
||||
if strings.TrimSpace(input.State.RequestNowText) == "" {
|
||||
input.State.RequestNowText = time.Now().In(time.Local).Format("2006-01-02 15:04")
|
||||
}
|
||||
|
||||
// 3. 先准备工具,再构造节点容器;这样 graph 中每个节点都能拿到已校验好的依赖。
|
||||
toolBundle, err := agentnode.BuildTaskQueryToolBundle(ctx, input.Deps)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
queryTool, err := agentnode.GetTaskQueryInvokableToolByName(toolBundle, agentnode.ToolNameTaskQueryTasks)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
nodes, err := agentnode.NewTaskQueryNodes(input, queryTool)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 4. 注册节点与边,保持“计划 -> 归一化 -> 时间锚定 -> 查询 -> 反思”的单向主链。
|
||||
graph := compose.NewGraph[*agentmodel.TaskQueryState, *agentmodel.TaskQueryState]()
|
||||
if err = graph.AddLambdaNode(agentnode.TaskQueryGraphNodePlan, compose.InvokableLambda(nodes.Plan)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.TaskQueryGraphNodeQuadrant, compose.InvokableLambda(nodes.NormalizeQuadrant)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.TaskQueryGraphNodeTimeAnchor, compose.InvokableLambda(nodes.AnchorTime)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.TaskQueryGraphNodeQuery, compose.InvokableLambda(nodes.Query)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddLambdaNode(agentnode.TaskQueryGraphNodeReflect, compose.InvokableLambda(nodes.Reflect)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddEdge(compose.START, agentnode.TaskQueryGraphNodePlan); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.TaskQueryGraphNodePlan, agentnode.TaskQueryGraphNodeQuadrant); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.TaskQueryGraphNodeQuadrant, agentnode.TaskQueryGraphNodeTimeAnchor); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.TaskQueryGraphNodeTimeAnchor, agentnode.TaskQueryGraphNodeQuery); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddEdge(agentnode.TaskQueryGraphNodeQuery, agentnode.TaskQueryGraphNodeReflect); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddBranch(agentnode.TaskQueryGraphNodeReflect, compose.NewGraphBranch(nodes.NextAfterReflect, map[string]bool{
|
||||
agentnode.TaskQueryGraphNodeQuery: true,
|
||||
compose.END: true,
|
||||
})); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 5. 反思节点支持按配置重试,因此最大步数需要覆盖“首次查询 + 多轮回看”的上限。
|
||||
maxSteps := 24 + input.State.MaxReflectRetry*4
|
||||
if maxSteps < 24 {
|
||||
maxSteps = 24
|
||||
}
|
||||
runnable, err := graph.Compile(ctx,
|
||||
compose.WithGraphName(TaskQueryGraphName),
|
||||
compose.WithMaxRunSteps(maxSteps),
|
||||
compose.WithNodeTriggerMode(compose.AnyPredecessor),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
finalState, err := runnable.Invoke(ctx, input.State)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if finalState == nil {
|
||||
return "", errors.New("task query graph: final state is nil")
|
||||
}
|
||||
|
||||
// 6. 最终回复为空时给一个稳定兜底,避免上层拿到空字符串后再次拼接出异常文案。
|
||||
reply := strings.TrimSpace(finalState.FinalReply)
|
||||
if reply == "" {
|
||||
reply = "我这边暂时没整理出稳定结果,你可以换一个更具体的筛选条件再试一次。"
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
83
backend/agent/llm/ark.go
Normal file
83
backend/agent/llm/ark.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package agentllm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
einoModel "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
// ArkCallOptions 是基于 ark.ChatModel 的通用调用选项。
|
||||
//
|
||||
// 设计目的:
|
||||
// 1. 当前 route / quicknote 都还直接持有 *ark.ChatModel;
|
||||
// 2. 在它们完全收敛到更抽象的 Client 前,先把重复的 ark 调用样板抽成公共层;
|
||||
// 3. 这样本轮就能先删除 route/quicknote 里那几份重复的 Generate 样板代码。
|
||||
type ArkCallOptions struct {
|
||||
Temperature float64
|
||||
MaxTokens int
|
||||
Thinking ThinkingMode
|
||||
}
|
||||
|
||||
// CallArkText 调用 ark 模型并返回纯文本。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责拼 system + user 两段消息;
|
||||
// 2. 负责统一配置 thinking / temperature / maxTokens;
|
||||
// 3. 负责拦截空响应;
|
||||
// 4. 不负责 JSON 解析,不负责业务字段校验。
|
||||
func CallArkText(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, options ArkCallOptions) (string, error) {
|
||||
if chatModel == nil {
|
||||
return "", errors.New("ark model is nil")
|
||||
}
|
||||
|
||||
messages := []*schema.Message{
|
||||
schema.SystemMessage(systemPrompt),
|
||||
schema.UserMessage(userPrompt),
|
||||
}
|
||||
resp, err := chatModel.Generate(ctx, messages, buildArkOptions(options)...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp == nil {
|
||||
return "", errors.New("模型返回为空")
|
||||
}
|
||||
|
||||
text := strings.TrimSpace(resp.Content)
|
||||
if text == "" {
|
||||
return "", errors.New("模型返回内容为空")
|
||||
}
|
||||
return text, nil
|
||||
}
|
||||
|
||||
// CallArkJSON 调用 ark 模型并直接解析 JSON。
|
||||
func CallArkJSON[T any](ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, options ArkCallOptions) (*T, string, error) {
|
||||
raw, err := CallArkText(ctx, chatModel, systemPrompt, userPrompt, options)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
parsed, err := ParseJSONObject[T](raw)
|
||||
if err != nil {
|
||||
return nil, raw, err
|
||||
}
|
||||
return parsed, raw, nil
|
||||
}
|
||||
|
||||
func buildArkOptions(options ArkCallOptions) []einoModel.Option {
|
||||
thinkingType := arkModel.ThinkingTypeDisabled
|
||||
if options.Thinking == ThinkingModeEnabled {
|
||||
thinkingType = arkModel.ThinkingTypeEnabled
|
||||
}
|
||||
opts := []einoModel.Option{
|
||||
ark.WithThinking(&arkModel.Thinking{Type: thinkingType}),
|
||||
einoModel.WithTemperature(float32(options.Temperature)),
|
||||
}
|
||||
if options.MaxTokens > 0 {
|
||||
opts = append(opts, einoModel.WithMaxTokens(options.MaxTokens))
|
||||
}
|
||||
return opts
|
||||
}
|
||||
216
backend/agent/llm/client.go
Normal file
216
backend/agent/llm/client.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package agentllm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// ThinkingMode 描述本次模型调用对 thinking 的期望。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 这里只表达“调用方希望怎样配置推理模式”;
|
||||
// 2. 不直接绑定某个具体模型厂商的参数枚举;
|
||||
// 3. 真正如何把它翻译成 ark / OpenAI / 其他 provider 的 option,由后续适配层负责。
|
||||
type ThinkingMode string
|
||||
|
||||
const (
|
||||
ThinkingModeDefault ThinkingMode = "default"
|
||||
ThinkingModeEnabled ThinkingMode = "enabled"
|
||||
ThinkingModeDisabled ThinkingMode = "disabled"
|
||||
)
|
||||
|
||||
// GenerateOptions 是 Agent 内部统一的模型调用选项。
|
||||
//
|
||||
// 设计目的:
|
||||
// 1. 先把“每个 skill 都会反复传的参数”收敛成一份结构;
|
||||
// 2. 让 node 层以后只表达“我要什么”,不再自己重复组织 option;
|
||||
// 3. 暂时不追求覆盖所有 provider 参数,先把最常用的几个公共位抽出来。
|
||||
type GenerateOptions struct {
|
||||
Temperature float64
|
||||
MaxTokens int
|
||||
Thinking ThinkingMode
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// TextResult 是统一文本生成结果。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. Text 保存模型最终返回的纯文本;
|
||||
// 2. Usage 保存本次调用的 token 使用量,供后续统一统计;
|
||||
// 3. 不负责 JSON 解析,不负责业务字段映射。
|
||||
type TextResult struct {
|
||||
Text string
|
||||
Usage *schema.TokenUsage
|
||||
}
|
||||
|
||||
// StreamReader 抽象了“可逐块 Recv 的流式返回器”。
|
||||
//
|
||||
// 之所以不直接依赖某个具体 SDK 的 reader 类型,是因为 Agent 现在还在建骨架阶段,
|
||||
// 后续接 ark、OpenAI 兼容层还是别的 provider,都可以往这个最小接口上适配。
|
||||
type StreamReader interface {
|
||||
Recv() (*schema.Message, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// TextGenerateFunc 是文本生成的统一适配函数签名。
|
||||
type TextGenerateFunc func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error)
|
||||
|
||||
// StreamGenerateFunc 是流式生成的统一适配函数签名。
|
||||
type StreamGenerateFunc func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error)
|
||||
|
||||
// Client 是 Agent 里的统一模型客户端门面。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责把 node 层的“模型调用意图”收敛到统一入口;
|
||||
// 2. 负责统一参数校验、空响应防御、GenerateJSON 复用;
|
||||
// 3. 不负责写 prompt,不负责业务 fallback,也不直接持有具体厂商 SDK 细节。
|
||||
type Client struct {
|
||||
generateText TextGenerateFunc
|
||||
streamText StreamGenerateFunc
|
||||
}
|
||||
|
||||
// NewClient 创建统一模型客户端。
|
||||
func NewClient(generateText TextGenerateFunc, streamText StreamGenerateFunc) *Client {
|
||||
return &Client{
|
||||
generateText: generateText,
|
||||
streamText: streamText,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateText 执行一次统一文本生成。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责做最小必要的入参校验;
|
||||
// 2. 负责统一拦截“模型空响应”这类公共问题;
|
||||
// 3. 不负责业务 prompt 拼接,也不负责把文本再映射成业务结构。
|
||||
func (c *Client) GenerateText(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error) {
|
||||
if c == nil || c.generateText == nil {
|
||||
return nil, errors.New("agent llm client is not ready")
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return nil, errors.New("llm messages is empty")
|
||||
}
|
||||
|
||||
result, err := c.generateText(ctx, messages, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result == nil {
|
||||
return nil, errors.New("llm result is nil")
|
||||
}
|
||||
if strings.TrimSpace(result.Text) == "" {
|
||||
return nil, errors.New("llm returned empty text")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateJSON 先走统一文本生成,再走统一 JSON 解析。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 旧 agent 里每个 skill 都各自写了一份“Generate -> 提取 JSON -> 反序列化”;
|
||||
// 2. 这里先把这一整段收敛成公共链路,后续 quicknote/taskquery/schedule 都直接复用;
|
||||
// 3. 返回 parsed + rawResult,方便上层既能拿结构化字段,也能在打点/回退时保留原文。
|
||||
// 4. 这里做成泛型函数而不是方法,是因为 Go 不支持“方法自带类型参数”。
|
||||
func GenerateJSON[T any](ctx context.Context, client *Client, messages []*schema.Message, options GenerateOptions) (*T, *TextResult, error) {
|
||||
result, err := client.GenerateText(ctx, messages, options)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
parsed, err := ParseJSONObject[T](result.Text)
|
||||
if err != nil {
|
||||
return nil, result, err
|
||||
}
|
||||
return parsed, result, nil
|
||||
}
|
||||
|
||||
// Stream 打开统一流式调用入口。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只负责把“流式生成能力”暴露给上层;
|
||||
// 2. 不负责 chunk 到 OpenAI 协议的转换,那部分应放在 stream/;
|
||||
// 3. 不负责累计全文,也不负责 token 统计落库。
|
||||
func (c *Client) Stream(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error) {
|
||||
if c == nil || c.streamText == nil {
|
||||
return nil, errors.New("agent llm stream client is not ready")
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return nil, errors.New("llm messages is empty")
|
||||
}
|
||||
return c.streamText(ctx, messages, options)
|
||||
}
|
||||
|
||||
// BuildSystemUserMessages 构造最常见的“system + history + user”消息列表。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 这是旧 agent 中高频重复片段,几乎每个 skill 都会拼一次;
|
||||
// 2. 这里先把最稳定的消息编排方式沉淀下来,减少 node 层样板代码;
|
||||
// 3. 只做消息切片装配,不做 prompt 生成。
|
||||
func BuildSystemUserMessages(systemPrompt string, history []*schema.Message, userPrompt string) []*schema.Message {
|
||||
messages := make([]*schema.Message, 0, len(history)+2)
|
||||
if strings.TrimSpace(systemPrompt) != "" {
|
||||
messages = append(messages, schema.SystemMessage(systemPrompt))
|
||||
}
|
||||
if len(history) > 0 {
|
||||
messages = append(messages, history...)
|
||||
}
|
||||
if strings.TrimSpace(userPrompt) != "" {
|
||||
messages = append(messages, schema.UserMessage(userPrompt))
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
// CloneUsage 深拷贝 token usage,避免后续多处累加时共享同一指针。
|
||||
func CloneUsage(usage *schema.TokenUsage) *schema.TokenUsage {
|
||||
if usage == nil {
|
||||
return nil
|
||||
}
|
||||
copied := *usage
|
||||
return &copied
|
||||
}
|
||||
|
||||
// MergeUsage 合并两段 usage。
|
||||
//
|
||||
// 合并策略:
|
||||
// 1. 对“同一次调用不同流分片”的场景,取更大值作为最终值;
|
||||
// 2. 对“多次独立调用累计”的场景,应由上层显式做加法,而不是用这个函数;
|
||||
// 3. 该函数只适用于“同一次调用的分块 usage 收敛”。
|
||||
func MergeUsage(base *schema.TokenUsage, incoming *schema.TokenUsage) *schema.TokenUsage {
|
||||
if incoming == nil {
|
||||
return CloneUsage(base)
|
||||
}
|
||||
if base == nil {
|
||||
return CloneUsage(incoming)
|
||||
}
|
||||
|
||||
merged := *base
|
||||
if incoming.PromptTokens > merged.PromptTokens {
|
||||
merged.PromptTokens = incoming.PromptTokens
|
||||
}
|
||||
if incoming.CompletionTokens > merged.CompletionTokens {
|
||||
merged.CompletionTokens = incoming.CompletionTokens
|
||||
}
|
||||
if incoming.TotalTokens > merged.TotalTokens {
|
||||
merged.TotalTokens = incoming.TotalTokens
|
||||
}
|
||||
if incoming.PromptTokenDetails.CachedTokens > merged.PromptTokenDetails.CachedTokens {
|
||||
merged.PromptTokenDetails.CachedTokens = incoming.PromptTokenDetails.CachedTokens
|
||||
}
|
||||
if incoming.CompletionTokensDetails.ReasoningTokens > merged.CompletionTokensDetails.ReasoningTokens {
|
||||
merged.CompletionTokensDetails.ReasoningTokens = incoming.CompletionTokensDetails.ReasoningTokens
|
||||
}
|
||||
return &merged
|
||||
}
|
||||
|
||||
// FormatEmptyResponseError 统一生成“模型返回空结果”的错误文案。
|
||||
func FormatEmptyResponseError(scene string) error {
|
||||
scene = strings.TrimSpace(scene)
|
||||
if scene == "" {
|
||||
scene = "unknown"
|
||||
}
|
||||
return fmt.Errorf("模型在 %s 场景返回空结果", scene)
|
||||
}
|
||||
112
backend/agent/llm/json.go
Normal file
112
backend/agent/llm/json.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package agentllm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseJSONObject 解析模型返回中的 JSON 对象。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责处理“模型输出前后夹杂解释文字 / markdown 代码块”的常见情况;
|
||||
// 2. 负责提取最外层 JSON object 并反序列化为目标结构;
|
||||
// 3. 不负责业务字段合法性校验,例如 priority 是否在 1~4,应由上层 node 再校验。
|
||||
func ParseJSONObject[T any](raw string) (*T, error) {
|
||||
clean := strings.TrimSpace(raw)
|
||||
if clean == "" {
|
||||
return nil, errors.New("模型返回为空,无法解析 JSON")
|
||||
}
|
||||
|
||||
objectText := ExtractJSONObject(clean)
|
||||
if objectText == "" {
|
||||
return nil, fmt.Errorf("模型返回中未找到 JSON 对象: %s", truncateForError(clean))
|
||||
}
|
||||
|
||||
var out T
|
||||
if err := json.Unmarshal([]byte(objectText), &out); err != nil {
|
||||
return nil, fmt.Errorf("JSON 解析失败: %w", err)
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
// ExtractJSONObject 从混合文本里提取第一个完整 JSON 对象。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. LLM 很容易输出“这里是结果:{...}”这种半结构化文本;
|
||||
// 2. 这里用括号计数而不是正则,避免嵌套对象一多就误截断;
|
||||
// 3. 目前只提取 object,不提取 array,因为当前 agent 的路由/规划契约基本都是对象。
|
||||
func ExtractJSONObject(text string) string {
|
||||
clean := trimMarkdownCodeFence(strings.TrimSpace(text))
|
||||
if clean == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
start := strings.Index(clean, "{")
|
||||
if start < 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
depth := 0
|
||||
inString := false
|
||||
escaped := false
|
||||
for idx := start; idx < len(clean); idx++ {
|
||||
ch := clean[idx]
|
||||
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' && inString {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
inString = !inString
|
||||
continue
|
||||
}
|
||||
if inString {
|
||||
continue
|
||||
}
|
||||
|
||||
switch ch {
|
||||
case '{':
|
||||
depth++
|
||||
case '}':
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return clean[start : idx+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func trimMarkdownCodeFence(text string) string {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if !strings.HasPrefix(trimmed, "```") {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
lines := strings.Split(trimmed, "\n")
|
||||
if len(lines) == 0 {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
// 1. 去掉首行 ```json / ```;
|
||||
// 2. 若末行是 ```,一并去掉;
|
||||
// 3. 中间正文保持原样,避免破坏 JSON 的换行结构。
|
||||
body := lines[1:]
|
||||
if len(body) > 0 && strings.TrimSpace(body[len(body)-1]) == "```" {
|
||||
body = body[:len(body)-1]
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(body, "\n"))
|
||||
}
|
||||
|
||||
func truncateForError(text string) string {
|
||||
if len(text) <= 160 {
|
||||
return text
|
||||
}
|
||||
return text[:160] + "..."
|
||||
}
|
||||
170
backend/agent/llm/quicknote.go
Normal file
170
backend/agent/llm/quicknote.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package agentllm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
agentprompt "github.com/LoveLosita/smartflow/backend/agent/prompt"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
)
|
||||
|
||||
// QuickNoteIntentOutput 是“随口记意图识别”模型契约。
|
||||
type QuickNoteIntentOutput struct {
|
||||
IsQuickNote bool `json:"is_quick_note"`
|
||||
Title string `json:"title"`
|
||||
DeadlineAt string `json:"deadline_at"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// QuickNotePriorityOutput 是“随口记优先级评估”模型契约。
|
||||
type QuickNotePriorityOutput struct {
|
||||
PriorityGroup int `json:"priority_group"`
|
||||
Reason string `json:"reason"`
|
||||
UrgencyThresholdAt string `json:"urgency_threshold_at"`
|
||||
}
|
||||
|
||||
// QuickNotePlanOutput 是“随口记单请求聚合规划”模型契约。
|
||||
type QuickNotePlanOutput struct {
|
||||
Title string `json:"title"`
|
||||
DeadlineAt string `json:"deadline_at"`
|
||||
UrgencyThresholdAt string `json:"urgency_threshold_at"`
|
||||
PriorityGroup int `json:"priority_group"`
|
||||
PriorityReason string `json:"priority_reason"`
|
||||
Banter string `json:"banter"`
|
||||
}
|
||||
|
||||
// IdentifyQuickNoteIntent 调用模型识别“是否随口记”。
|
||||
func IdentifyQuickNoteIntent(ctx context.Context, chatModel *ark.ChatModel, nowText, userInput string) (*QuickNoteIntentOutput, error) {
|
||||
prompt := fmt.Sprintf(`当前时间(北京时间,精确到分钟):%s
|
||||
用户输入:%s
|
||||
请仅输出 JSON(不要 markdown,不要解释),字段如下:
|
||||
{
|
||||
"is_quick_note": boolean,
|
||||
"title": string,
|
||||
"deadline_at": string,
|
||||
"reason": string
|
||||
}
|
||||
字段约束:
|
||||
1) deadline_at 只允许输出绝对时间,格式必须为 "yyyy-MM-dd HH:mm"。
|
||||
2) 如果用户说了“明天/后天/下周一/今晚”等相对时间,必须基于上面的当前时间换算成绝对时间。
|
||||
3) 如果用户没有提及时间,deadline_at 输出空字符串。`,
|
||||
nowText,
|
||||
userInput,
|
||||
)
|
||||
|
||||
parsed, _, err := CallArkJSON[QuickNoteIntentOutput](ctx, chatModel, agentprompt.QuickNoteIntentPrompt, prompt, ArkCallOptions{
|
||||
Temperature: 0,
|
||||
MaxTokens: 256,
|
||||
Thinking: ThinkingModeDisabled,
|
||||
})
|
||||
return parsed, err
|
||||
}
|
||||
|
||||
// PlanQuickNotePriority 调用模型评估优先级与紧急分界线。
|
||||
func PlanQuickNotePriority(ctx context.Context, chatModel *ark.ChatModel, nowText, title, userInput, deadlineClue, deadlineText string) (*QuickNotePriorityOutput, error) {
|
||||
prompt := fmt.Sprintf(`当前时间(北京时间,精确到分钟):%s
|
||||
请对以下任务评估优先级:
|
||||
- 任务标题:%s
|
||||
- 用户原始输入:%s
|
||||
- 时间线索原文:%s
|
||||
- 归一化截止时间:%s
|
||||
|
||||
请仅输出 JSON(不要 markdown,不要解释):
|
||||
{
|
||||
"priority_group": 1|2|3|4,
|
||||
"reason": "简短理由",
|
||||
"urgency_threshold_at": "yyyy-MM-dd HH:mm 或空字符串"
|
||||
}
|
||||
|
||||
额外约束:
|
||||
1) urgency_threshold_at 表示“何时从不紧急象限自动平移到紧急象限”;
|
||||
2) 若该任务不需要自动平移,可输出空字符串;
|
||||
3) 若任务已在紧急象限(priority_group=1 或 3),优先输出空字符串;
|
||||
4) 若输出非空时间,必须是绝对时间,且不晚于归一化截止时间(若有)。`,
|
||||
nowText,
|
||||
title,
|
||||
userInput,
|
||||
deadlineClue,
|
||||
deadlineText,
|
||||
)
|
||||
|
||||
parsed, _, err := CallArkJSON[QuickNotePriorityOutput](ctx, chatModel, agentprompt.QuickNotePriorityPrompt, prompt, ArkCallOptions{
|
||||
Temperature: 0,
|
||||
MaxTokens: 256,
|
||||
Thinking: ThinkingModeDisabled,
|
||||
})
|
||||
return parsed, err
|
||||
}
|
||||
|
||||
// PlanQuickNoteInSingleCall 一次性完成标题/时间/优先级/banter 聚合规划。
|
||||
func PlanQuickNoteInSingleCall(ctx context.Context, chatModel *ark.ChatModel, nowText, userInput string) (*QuickNotePlanOutput, error) {
|
||||
prompt := fmt.Sprintf(`当前时间(北京时间,精确到分钟):%s
|
||||
用户输入:%s
|
||||
|
||||
请仅输出 JSON(不要 markdown,不要解释),字段如下:
|
||||
{
|
||||
"title": string,
|
||||
"deadline_at": string,
|
||||
"urgency_threshold_at": string,
|
||||
"priority_group": 1|2|3|4,
|
||||
"priority_reason": string,
|
||||
"banter": string
|
||||
}
|
||||
|
||||
约束:
|
||||
1) deadline_at 只允许 "yyyy-MM-dd HH:mm" 或空字符串;
|
||||
2) urgency_threshold_at 只允许 "yyyy-MM-dd HH:mm" 或空字符串;
|
||||
3) 若用户给了相对时间(如明天/今晚/下周一),必须换算为绝对时间;
|
||||
4) 若任务不需要自动平移,可让 urgency_threshold_at 为空;
|
||||
5) banter 只允许一句中文,不超过30字,不得改动任务事实。`,
|
||||
nowText,
|
||||
strings.TrimSpace(userInput),
|
||||
)
|
||||
|
||||
parsed, _, err := CallArkJSON[QuickNotePlanOutput](ctx, chatModel, agentprompt.QuickNotePlanPrompt, prompt, ArkCallOptions{
|
||||
Temperature: 0,
|
||||
MaxTokens: 220,
|
||||
Thinking: ThinkingModeDisabled,
|
||||
})
|
||||
return parsed, err
|
||||
}
|
||||
|
||||
// GenerateQuickNoteBanter 生成成功写入后的轻松跟进句。
|
||||
func GenerateQuickNoteBanter(ctx context.Context, chatModel *ark.ChatModel, userMessage, title, priorityText, deadlineText string) (string, error) {
|
||||
if chatModel == nil {
|
||||
return "", fmt.Errorf("model is nil")
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`用户原话:%s
|
||||
已确认事实:
|
||||
- 任务标题:%s
|
||||
- %s
|
||||
- %s
|
||||
|
||||
请输出一句轻松自然的跟进话术(仅一句)。`,
|
||||
strings.TrimSpace(userMessage),
|
||||
strings.TrimSpace(title),
|
||||
strings.TrimSpace(priorityText),
|
||||
strings.TrimSpace(deadlineText),
|
||||
)
|
||||
|
||||
text, err := CallArkText(ctx, chatModel, agentprompt.QuickNoteReplyBanterPrompt, prompt, ArkCallOptions{
|
||||
Temperature: 0.7,
|
||||
MaxTokens: 72,
|
||||
Thinking: ThinkingModeDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
text = strings.TrimSpace(text)
|
||||
text = strings.Trim(text, "\"'“”‘’")
|
||||
if text == "" {
|
||||
return "", fmt.Errorf("empty content")
|
||||
}
|
||||
if idx := strings.Index(text, "\n"); idx >= 0 {
|
||||
text = strings.TrimSpace(text[:idx])
|
||||
}
|
||||
return text, nil
|
||||
}
|
||||
50
backend/agent/llm/route.go
Normal file
50
backend/agent/llm/route.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package agentllm
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
agentmodel "github.com/LoveLosita/smartflow/backend/agent/model"
|
||||
)
|
||||
|
||||
// RouteDecisionOutput 是一级路由模型的结构化输出契约。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 这里只定义“模型应该吐什么 JSON”;
|
||||
// 2. 真正的 prompt 归 prompt/ 管;
|
||||
// 3. 真正的业务分发归 router/ 管。
|
||||
type RouteDecisionOutput struct {
|
||||
Action string `json:"action"`
|
||||
TrustRoute bool `json:"trust_route"`
|
||||
Detail string `json:"detail"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
}
|
||||
|
||||
// ToDecision 把模型契约输出映射成 Agent 内部统一路由结果。
|
||||
func (o *RouteDecisionOutput) ToDecision() *agentmodel.RouteDecision {
|
||||
if o == nil {
|
||||
return &agentmodel.RouteDecision{Action: agentmodel.ActionChat}
|
||||
}
|
||||
|
||||
action := normalizeRouteAction(o.Action)
|
||||
return &agentmodel.RouteDecision{
|
||||
Action: action,
|
||||
TrustRoute: o.TrustRoute,
|
||||
Detail: strings.TrimSpace(o.Detail),
|
||||
Confidence: o.Confidence,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeRouteAction(raw string) agentmodel.AgentAction {
|
||||
switch strings.TrimSpace(strings.ToLower(raw)) {
|
||||
case "quick_note", "quick_note_create":
|
||||
return agentmodel.ActionQuickNoteCreate
|
||||
case "task_query":
|
||||
return agentmodel.ActionTaskQuery
|
||||
case "schedule_plan", "schedule_plan_create":
|
||||
return agentmodel.ActionSchedulePlanCreate
|
||||
case "schedule_refine", "schedule_plan_refine":
|
||||
return agentmodel.ActionSchedulePlanRefine
|
||||
default:
|
||||
return agentmodel.ActionChat
|
||||
}
|
||||
}
|
||||
175
backend/agent/llm/schedule.go
Normal file
175
backend/agent/llm/schedule.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package agentllm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
agentprompt "github.com/LoveLosita/smartflow/backend/agent/prompt"
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
einoModel "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
// ScheduleIntentOutput 是 plan 节点要求模型返回的结构化结果。
|
||||
//
|
||||
// 兼容说明:
|
||||
// 1. 新主语义是 task_class_ids(数组);
|
||||
// 2. 为兼容旧 prompt/旧缓存输出,保留 task_class_id(单值)兜底解析;
|
||||
// 3. TaskTags 的 key 兼容两种写法:
|
||||
// 3.1 推荐:task_item_id(例如 "12");
|
||||
// 3.2 兼容:任务名称(例如 "高数复习")。
|
||||
type ScheduleIntentOutput struct {
|
||||
Intent string `json:"intent"`
|
||||
Constraints []string `json:"constraints"`
|
||||
TaskClassIDs []int `json:"task_class_ids"`
|
||||
TaskClassID int `json:"task_class_id"`
|
||||
Strategy string `json:"strategy"`
|
||||
TaskTags map[string]string `json:"task_tags"`
|
||||
Restart bool `json:"restart"`
|
||||
AdjustmentScope string `json:"adjustment_scope"`
|
||||
Reason string `json:"reason"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
}
|
||||
|
||||
// ReactToolCall 是 LLM 输出的单个工具调用。
|
||||
type ReactToolCall struct {
|
||||
Tool string `json:"tool"`
|
||||
Params map[string]any `json:"params"`
|
||||
}
|
||||
|
||||
// ReactLLMOutput 是 ReAct 节点要求模型返回的统一 JSON。
|
||||
type ReactLLMOutput struct {
|
||||
Done bool `json:"done"`
|
||||
Summary string `json:"summary"`
|
||||
ToolCalls []ReactToolCall `json:"tool_calls"`
|
||||
}
|
||||
|
||||
// IdentifySchedulePlanIntent 调用模型识别“排程意图 + 约束 + 任务类集合”。
|
||||
func IdentifySchedulePlanIntent(
|
||||
ctx context.Context,
|
||||
chatModel *ark.ChatModel,
|
||||
nowText string,
|
||||
userMessage string,
|
||||
adjustmentHint string,
|
||||
) (*ScheduleIntentOutput, error) {
|
||||
prompt := fmt.Sprintf(
|
||||
"当前时间(北京时间):%s\n用户输入:%s%s\n\n请提取排程意图与约束。",
|
||||
strings.TrimSpace(nowText),
|
||||
strings.TrimSpace(userMessage),
|
||||
strings.TrimSpace(adjustmentHint),
|
||||
)
|
||||
|
||||
parsed, _, err := CallArkJSON[ScheduleIntentOutput](ctx, chatModel, agentprompt.SchedulePlanIntentPrompt, prompt, ArkCallOptions{
|
||||
Temperature: 0,
|
||||
MaxTokens: 256,
|
||||
Thinking: ThinkingModeDisabled,
|
||||
})
|
||||
return parsed, err
|
||||
}
|
||||
|
||||
// ParseScheduleReactOutput 解析 ReAct 节点的 JSON 输出。
|
||||
func ParseScheduleReactOutput(raw string) (*ReactLLMOutput, error) {
|
||||
return ParseJSONObject[ReactLLMOutput](raw)
|
||||
}
|
||||
|
||||
// GenerateScheduleDailyReactRound 调用模型生成“单天日内优化”的一轮决策。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只负责统一关闭 thinking、设置温度,并返回纯文本;
|
||||
// 2. 不负责工具执行,不负责结果回灌。
|
||||
func GenerateScheduleDailyReactRound(
|
||||
ctx context.Context,
|
||||
chatModel *ark.ChatModel,
|
||||
messages []*schema.Message,
|
||||
) (string, error) {
|
||||
resp, err := chatModel.Generate(
|
||||
ctx,
|
||||
messages,
|
||||
ark.WithThinking(&arkModel.Thinking{Type: arkModel.ThinkingTypeDisabled}),
|
||||
einoModel.WithTemperature(0),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp == nil {
|
||||
return "", fmt.Errorf("日内优化调用返回为空")
|
||||
}
|
||||
content := strings.TrimSpace(resp.Content)
|
||||
if content == "" {
|
||||
return "", fmt.Errorf("日内优化调用返回内容为空")
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// GenerateScheduleWeeklyReactRound 调用模型生成“单周单步优化”的一轮决策。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 周级仍保留 thinking,提高复杂排程准确率;
|
||||
// 2. 仅返回最终 content,是否透出思考流由上层决定。
|
||||
func GenerateScheduleWeeklyReactRound(
|
||||
ctx context.Context,
|
||||
chatModel *ark.ChatModel,
|
||||
messages []*schema.Message,
|
||||
) (string, error) {
|
||||
resp, err := chatModel.Generate(
|
||||
ctx,
|
||||
messages,
|
||||
ark.WithThinking(&arkModel.Thinking{Type: arkModel.ThinkingTypeEnabled}),
|
||||
einoModel.WithTemperature(0.2),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp == nil {
|
||||
return "", fmt.Errorf("周级单步调用返回为空")
|
||||
}
|
||||
content := strings.TrimSpace(resp.Content)
|
||||
if content == "" {
|
||||
return "", fmt.Errorf("周级单步调用返回内容为空")
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// GenerateScheduleHumanSummary 调用模型生成“用户可读”的最终总结。
|
||||
func GenerateScheduleHumanSummary(
|
||||
ctx context.Context,
|
||||
chatModel *ark.ChatModel,
|
||||
entries []model.HybridScheduleEntry,
|
||||
constraints []string,
|
||||
actionLogs []string,
|
||||
) (string, error) {
|
||||
if chatModel == nil {
|
||||
return "", fmt.Errorf("final summary model is nil")
|
||||
}
|
||||
|
||||
entriesJSON, _ := json.Marshal(entries)
|
||||
constraintText := "无"
|
||||
if len(constraints) > 0 {
|
||||
constraintText = strings.Join(constraints, "、")
|
||||
}
|
||||
actionLogText := "无"
|
||||
if len(actionLogs) > 0 {
|
||||
start := 0
|
||||
if len(actionLogs) > 30 {
|
||||
start = len(actionLogs) - 30
|
||||
}
|
||||
actionLogText = strings.Join(actionLogs[start:], "\n")
|
||||
}
|
||||
|
||||
userPrompt := fmt.Sprintf(
|
||||
"以下是最终排程方案(JSON):\n%s\n\n用户约束:%s\n\n以下是本次周级优化动作日志(按时间顺序):\n%s\n\n请基于“结果+过程”输出2-3句自然中文总结,重点说明本方案的优点和改进点。",
|
||||
string(entriesJSON),
|
||||
constraintText,
|
||||
actionLogText,
|
||||
)
|
||||
|
||||
return CallArkText(ctx, chatModel, agentprompt.SchedulePlanFinalCheckPrompt, userPrompt, ArkCallOptions{
|
||||
Temperature: 0.4,
|
||||
MaxTokens: 256,
|
||||
Thinking: ThinkingModeDisabled,
|
||||
})
|
||||
}
|
||||
132
backend/agent/llm/schedule_refine.go
Normal file
132
backend/agent/llm/schedule_refine.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package agentllm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
)
|
||||
|
||||
const scheduleRefineNodeTimeout = 120 * time.Second
|
||||
|
||||
type ScheduleRefineContractOutput struct {
|
||||
Intent string `json:"intent"`
|
||||
Strategy string `json:"strategy"`
|
||||
HardRequirements []string `json:"hard_requirements"`
|
||||
HardAssertions []ScheduleRefineAssertionLite `json:"hard_assertions"`
|
||||
KeepRelativeOrder bool `json:"keep_relative_order"`
|
||||
OrderScope string `json:"order_scope"`
|
||||
}
|
||||
|
||||
type ScheduleRefineAssertionLite struct {
|
||||
Metric string `json:"metric"`
|
||||
Operator string `json:"operator"`
|
||||
Value int `json:"value"`
|
||||
Min int `json:"min"`
|
||||
Max int `json:"max"`
|
||||
Week int `json:"week"`
|
||||
TargetWeek int `json:"target_week"`
|
||||
}
|
||||
|
||||
type ScheduleRefinePlannerOutput struct {
|
||||
Summary string `json:"summary"`
|
||||
Steps []string `json:"steps"`
|
||||
}
|
||||
|
||||
type ScheduleRefineToolCall struct {
|
||||
Tool string `json:"tool"`
|
||||
Params map[string]any `json:"params"`
|
||||
}
|
||||
|
||||
type ScheduleRefineReactOutput struct {
|
||||
Done bool `json:"done"`
|
||||
Summary string `json:"summary"`
|
||||
GoalCheck string `json:"goal_check"`
|
||||
Decision string `json:"decision"`
|
||||
MissingInfo []string `json:"missing_info,omitempty"`
|
||||
ToolCalls []ScheduleRefineToolCall `json:"tool_calls"`
|
||||
}
|
||||
|
||||
type ScheduleRefinePostReflectOutput struct {
|
||||
Reflection string `json:"reflection"`
|
||||
NextStrategy string `json:"next_strategy"`
|
||||
ShouldStop bool `json:"should_stop"`
|
||||
}
|
||||
|
||||
type ScheduleRefineReviewOutput struct {
|
||||
Pass bool `json:"pass"`
|
||||
Reason string `json:"reason"`
|
||||
Unmet []string `json:"unmet"`
|
||||
}
|
||||
|
||||
func GenerateScheduleRefineContract(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string) (*ScheduleRefineContractOutput, string, error) {
|
||||
return callScheduleRefineJSON[ScheduleRefineContractOutput](ctx, chatModel, systemPrompt, userPrompt, ArkCallOptions{
|
||||
Temperature: 0,
|
||||
MaxTokens: 260,
|
||||
Thinking: ThinkingModeDisabled,
|
||||
})
|
||||
}
|
||||
|
||||
func GenerateScheduleRefinePlanner(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, maxTokens int) (*ScheduleRefinePlannerOutput, string, error) {
|
||||
return callScheduleRefineJSON[ScheduleRefinePlannerOutput](ctx, chatModel, systemPrompt, userPrompt, ArkCallOptions{
|
||||
Temperature: 0,
|
||||
MaxTokens: maxTokens,
|
||||
Thinking: ThinkingModeDisabled,
|
||||
})
|
||||
}
|
||||
|
||||
func GenerateScheduleRefineReact(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, useThinking bool, maxTokens int) (string, error) {
|
||||
thinking := ThinkingModeDisabled
|
||||
if useThinking {
|
||||
thinking = ThinkingModeEnabled
|
||||
}
|
||||
return callScheduleRefineText(ctx, chatModel, systemPrompt, userPrompt, ArkCallOptions{
|
||||
Temperature: 0,
|
||||
MaxTokens: maxTokens,
|
||||
Thinking: thinking,
|
||||
})
|
||||
}
|
||||
|
||||
func GenerateScheduleRefinePostReflect(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string) (*ScheduleRefinePostReflectOutput, string, error) {
|
||||
return callScheduleRefineJSON[ScheduleRefinePostReflectOutput](ctx, chatModel, systemPrompt, userPrompt, ArkCallOptions{
|
||||
Temperature: 0,
|
||||
MaxTokens: 220,
|
||||
Thinking: ThinkingModeDisabled,
|
||||
})
|
||||
}
|
||||
|
||||
func GenerateScheduleRefineReview(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string) (*ScheduleRefineReviewOutput, string, error) {
|
||||
return callScheduleRefineJSON[ScheduleRefineReviewOutput](ctx, chatModel, systemPrompt, userPrompt, ArkCallOptions{
|
||||
Temperature: 0,
|
||||
MaxTokens: 240,
|
||||
Thinking: ThinkingModeDisabled,
|
||||
})
|
||||
}
|
||||
|
||||
func GenerateScheduleRefineSummary(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string) (string, error) {
|
||||
return callScheduleRefineText(ctx, chatModel, systemPrompt, userPrompt, ArkCallOptions{
|
||||
Temperature: 0.35,
|
||||
MaxTokens: 280,
|
||||
Thinking: ThinkingModeDisabled,
|
||||
})
|
||||
}
|
||||
|
||||
func GenerateScheduleRefineRepair(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string) (string, error) {
|
||||
return callScheduleRefineText(ctx, chatModel, systemPrompt, userPrompt, ArkCallOptions{
|
||||
Temperature: 0.15,
|
||||
MaxTokens: 240,
|
||||
Thinking: ThinkingModeDisabled,
|
||||
})
|
||||
}
|
||||
|
||||
func callScheduleRefineText(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, options ArkCallOptions) (string, error) {
|
||||
nodeCtx, cancel := context.WithTimeout(ctx, scheduleRefineNodeTimeout)
|
||||
defer cancel()
|
||||
return CallArkText(nodeCtx, chatModel, systemPrompt, userPrompt, options)
|
||||
}
|
||||
|
||||
func callScheduleRefineJSON[T any](ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, options ArkCallOptions) (*T, string, error) {
|
||||
nodeCtx, cancel := context.WithTimeout(ctx, scheduleRefineNodeTimeout)
|
||||
defer cancel()
|
||||
return CallArkJSON[T](nodeCtx, chatModel, systemPrompt, userPrompt, options)
|
||||
}
|
||||
83
backend/agent/llm/taskquery.go
Normal file
83
backend/agent/llm/taskquery.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package agentllm
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
agentprompt "github.com/LoveLosita/smartflow/backend/agent/prompt"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
)
|
||||
|
||||
// TaskQueryPlanOutput 描述计划节点返回的结构化查询方案。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只承接模型输出,不在这里做合法性校验。
|
||||
// 2. 字段为空或非法时,由 node 层继续归一化与兜底。
|
||||
type TaskQueryPlanOutput struct {
|
||||
UserGoal string `json:"user_goal"`
|
||||
Quadrants []int `json:"quadrants"`
|
||||
SortBy string `json:"sort_by"`
|
||||
Order string `json:"order"`
|
||||
Limit int `json:"limit"`
|
||||
IncludeCompleted *bool `json:"include_completed"`
|
||||
Keyword string `json:"keyword"`
|
||||
DeadlineBefore string `json:"deadline_before"`
|
||||
DeadlineAfter string `json:"deadline_after"`
|
||||
}
|
||||
|
||||
// TaskQueryRetryPatch 描述反思节点允许回写的计划补丁。
|
||||
//
|
||||
// 输入输出语义:
|
||||
// 1. 指针字段为 nil 表示“不改这个字段”。
|
||||
// 2. 非 nil 但值为空字符串,表示显式清空该条件。
|
||||
type TaskQueryRetryPatch struct {
|
||||
Quadrants *[]int `json:"quadrants,omitempty"`
|
||||
SortBy *string `json:"sort_by,omitempty"`
|
||||
Order *string `json:"order,omitempty"`
|
||||
Limit *int `json:"limit,omitempty"`
|
||||
IncludeCompleted *bool `json:"include_completed,omitempty"`
|
||||
Keyword *string `json:"keyword,omitempty"`
|
||||
DeadlineBefore *string `json:"deadline_before,omitempty"`
|
||||
DeadlineAfter *string `json:"deadline_after,omitempty"`
|
||||
}
|
||||
|
||||
// TaskQueryReflectOutput 描述反思节点对本轮查询结果的判定。
|
||||
//
|
||||
// 输入输出语义:
|
||||
// 1. Satisfied=true 表示当前结果可直接收口。
|
||||
// 2. NeedRetry=true 表示建议再跑一轮,但真正是否重试由 node 层结合次数上限决定。
|
||||
// 3. Reply 是可直接给用户的候选文案,允许为空。
|
||||
type TaskQueryReflectOutput struct {
|
||||
Satisfied bool `json:"satisfied"`
|
||||
NeedRetry bool `json:"need_retry"`
|
||||
Reason string `json:"reason"`
|
||||
Reply string `json:"reply"`
|
||||
RetryPatch TaskQueryRetryPatch `json:"retry_patch"`
|
||||
}
|
||||
|
||||
// PlanTaskQuery 负责调用模型,把自然语言查询规划成结构化检索参数。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只负责模型调用与 JSON 解析。
|
||||
// 2. 不负责结果兜底、限流裁剪或时间归一化。
|
||||
func PlanTaskQuery(ctx context.Context, chatModel *ark.ChatModel, nowText, userInput string) (*TaskQueryPlanOutput, error) {
|
||||
parsed, _, err := CallArkJSON[TaskQueryPlanOutput](ctx, chatModel, agentprompt.TaskQueryPlanPrompt, agentprompt.BuildTaskQueryPlanUserPrompt(nowText, userInput), ArkCallOptions{
|
||||
Temperature: 0,
|
||||
MaxTokens: 260,
|
||||
Thinking: ThinkingModeDisabled,
|
||||
})
|
||||
return parsed, err
|
||||
}
|
||||
|
||||
// ReflectTaskQuery 负责让模型判断当前查询结果是否满足用户意图。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只负责反思提示词调用与结构化解析。
|
||||
// 2. 不负责实际执行重试,也不负责拼接最终兜底回复。
|
||||
func ReflectTaskQuery(ctx context.Context, chatModel *ark.ChatModel, prompt string) (*TaskQueryReflectOutput, error) {
|
||||
parsed, _, err := CallArkJSON[TaskQueryReflectOutput](ctx, chatModel, agentprompt.TaskQueryReflectPrompt, prompt, ArkCallOptions{
|
||||
Temperature: 0,
|
||||
MaxTokens: 380,
|
||||
Thinking: ThinkingModeDisabled,
|
||||
})
|
||||
return parsed, err
|
||||
}
|
||||
17
backend/agent/model/common.go
Normal file
17
backend/agent/model/common.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package agentmodel
|
||||
|
||||
// AgentRequest 是 Agent 总入口接收的统一请求结构。
|
||||
type AgentRequest struct {
|
||||
UserID int
|
||||
ConversationID string
|
||||
UserMessage string
|
||||
ModelName string
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
// AgentResponse 是 Agent 总入口返回的统一响应结构。
|
||||
type AgentResponse struct {
|
||||
Action AgentAction
|
||||
Reply string
|
||||
Meta map[string]any
|
||||
}
|
||||
102
backend/agent/model/quicknote.go
Normal file
102
backend/agent/model/quicknote.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package agentmodel
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
agentshared "github.com/LoveLosita/smartflow/backend/agent/shared"
|
||||
)
|
||||
|
||||
const (
|
||||
// QuickNoteDatetimeMinuteLayout 是随口记链路统一使用的分钟级时间格式。
|
||||
QuickNoteDatetimeMinuteLayout = "2006-01-02 15:04"
|
||||
// QuickNoteTimezoneName 是随口记时间解析与展示优先使用的时区。
|
||||
QuickNoteTimezoneName = "Asia/Shanghai"
|
||||
|
||||
QuickNotePriorityImportantUrgent = TaskPriorityImportantUrgent
|
||||
QuickNotePriorityImportantNotUrgent = TaskPriorityImportantNotUrgent
|
||||
QuickNotePrioritySimpleNotImportant = TaskPrioritySimpleNotImportant
|
||||
QuickNotePriorityComplexNotImportant = TaskPriorityComplexNotImportant
|
||||
)
|
||||
|
||||
// QuickNoteState 是随口记图在节点间流转的完整状态。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责保存意图识别、任务提取、工具重试和最终回复所需状态。
|
||||
// 2. 不负责图编排,也不直接映射数据库任务实体。
|
||||
type QuickNoteState struct {
|
||||
TraceID string
|
||||
UserID int
|
||||
ConversationID string
|
||||
RequestNow time.Time
|
||||
RequestNowText string
|
||||
UserInput string
|
||||
|
||||
IsQuickNoteIntent bool
|
||||
IntentJudgeReason string
|
||||
|
||||
ExtractedTitle string
|
||||
ExtractedDeadline *time.Time
|
||||
ExtractedDeadlineText string
|
||||
ExtractedUrgencyThreshold *time.Time
|
||||
ExtractedPriority int
|
||||
ExtractedBanter string
|
||||
PlannedBySingleCall bool
|
||||
ExtractedPriorityReason string
|
||||
DeadlineValidationError string
|
||||
|
||||
ToolAttemptCount int
|
||||
MaxToolRetry int
|
||||
LastToolError string
|
||||
PersistedTaskID int
|
||||
Persisted bool
|
||||
AssistantReply string
|
||||
}
|
||||
|
||||
// NewQuickNoteState 负责创建随口记图的初始状态。
|
||||
//
|
||||
// 输入输出语义:
|
||||
// 1. RequestNow 与 RequestNowText 会在创建时同步写入,保证整条链路共用同一时间基准。
|
||||
// 2. MaxToolRetry 默认给 3,避免上层未配置时完全失去重试能力。
|
||||
func NewQuickNoteState(traceID string, userID int, conversationID, userInput string) *QuickNoteState {
|
||||
requestNow := agentshared.NowToMinute()
|
||||
return &QuickNoteState{
|
||||
TraceID: traceID,
|
||||
UserID: userID,
|
||||
ConversationID: conversationID,
|
||||
RequestNow: requestNow,
|
||||
RequestNowText: agentshared.FormatMinute(requestNow),
|
||||
UserInput: userInput,
|
||||
MaxToolRetry: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// CanRetryTool 返回当前是否还允许再次调用持久化工具。
|
||||
//
|
||||
// 输入输出语义:
|
||||
// 1. true 表示“尚未达到最大重试次数”,调用方仍可继续重试。
|
||||
// 2. false 表示必须收口,避免无限重试。
|
||||
func (s *QuickNoteState) CanRetryTool() bool {
|
||||
return s.ToolAttemptCount < s.MaxToolRetry
|
||||
}
|
||||
|
||||
// RecordToolError 记录一次工具失败,并推进重试计数。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只更新与工具失败相关的状态。
|
||||
// 2. 不决定是否继续重试,是否重试由节点分支逻辑判断。
|
||||
func (s *QuickNoteState) RecordToolError(errMsg string) {
|
||||
s.ToolAttemptCount++
|
||||
s.LastToolError = errMsg
|
||||
}
|
||||
|
||||
// RecordToolSuccess 记录一次工具成功结果。
|
||||
//
|
||||
// 输入输出语义:
|
||||
// 1. taskID 必须是持久化后的真实任务 ID。
|
||||
// 2. 成功后会清空 LastToolError,表示当前链路已进入稳定态。
|
||||
func (s *QuickNoteState) RecordToolSuccess(taskID int) {
|
||||
s.ToolAttemptCount++
|
||||
s.PersistedTaskID = taskID
|
||||
s.Persisted = true
|
||||
s.LastToolError = ""
|
||||
}
|
||||
20
backend/agent/model/route.go
Normal file
20
backend/agent/model/route.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package agentmodel
|
||||
|
||||
// AgentAction 表示一级路由动作。
|
||||
type AgentAction string
|
||||
|
||||
const (
|
||||
ActionChat AgentAction = "chat"
|
||||
ActionQuickNoteCreate AgentAction = "quick_note_create"
|
||||
ActionTaskQuery AgentAction = "task_query"
|
||||
ActionSchedulePlanCreate AgentAction = "schedule_plan_create"
|
||||
ActionSchedulePlanRefine AgentAction = "schedule_plan_refine"
|
||||
)
|
||||
|
||||
// RouteDecision 是统一一级分流结果。
|
||||
type RouteDecision struct {
|
||||
Action AgentAction
|
||||
TrustRoute bool
|
||||
Detail string
|
||||
Confidence float64
|
||||
}
|
||||
200
backend/agent/model/schedule.go
Normal file
200
backend/agent/model/schedule.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package agentmodel
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
)
|
||||
|
||||
const (
|
||||
// SchedulePlanTimezoneName 是排程链路默认业务时区。
|
||||
// 与随口记保持一致,固定东八区,避免容器运行在 UTC 导致“明天/今晚”偏移。
|
||||
SchedulePlanTimezoneName = "Asia/Shanghai"
|
||||
|
||||
// SchedulePlanDatetimeLayout 是排程链路内部统一的分钟级时间格式。
|
||||
SchedulePlanDatetimeLayout = "2006-01-02 15:04"
|
||||
|
||||
// SchedulePlanDefaultDailyRefineConcurrency 是日内并发优化默认并发度。
|
||||
// 这里给一个保守默认值,避免未配置时直接把模型并发打满导致限流。
|
||||
SchedulePlanDefaultDailyRefineConcurrency = 3
|
||||
|
||||
// SchedulePlanDefaultWeeklyAdjustBudget 是周级配平默认调整额度。
|
||||
// 额度存在的目的:
|
||||
// 1. 防止周级 ReAct 过度调整导致震荡;
|
||||
// 2. 控制 token 与时延成本;
|
||||
// 3. 让方案改动更可解释。
|
||||
SchedulePlanDefaultWeeklyAdjustBudget = 5
|
||||
|
||||
// SchedulePlanDefaultWeeklyTotalBudget 是周级“总尝试次数”默认预算。
|
||||
//
|
||||
// 设计意图:
|
||||
// 1. 总预算统计“动作尝试次数”(成功/失败都记一次);
|
||||
// 2. 有效预算统计“成功动作次数”(仅成功时记一次);
|
||||
// 3. 通过双预算把“探索次数”和“有效改动次数”分离,降低模型无效空转成本。
|
||||
SchedulePlanDefaultWeeklyTotalBudget = 8
|
||||
|
||||
// SchedulePlanDefaultWeeklyRefineConcurrency 是周级“按周并发”默认并发度。
|
||||
// 说明:
|
||||
// 1. 周级输入规模通常比单天更大,默认并发度不宜过高,避免触发模型侧限流;
|
||||
// 2. 可在运行时按请求状态覆盖。
|
||||
SchedulePlanDefaultWeeklyRefineConcurrency = 2
|
||||
|
||||
// SchedulePlanAdjustmentScopeSmall 表示“小改动微调”。
|
||||
// 语义:优先走快速路径,只做轻量周级调整。
|
||||
SchedulePlanAdjustmentScopeSmall = "small"
|
||||
// SchedulePlanAdjustmentScopeMedium 表示“中等改动微调”。
|
||||
// 语义:跳过日内拆分,直接进入周级配平。
|
||||
SchedulePlanAdjustmentScopeMedium = "medium"
|
||||
// SchedulePlanAdjustmentScopeLarge 表示“大改动重排”。
|
||||
// 语义:必要时重新走全量路径(日内并发 + 周级配平)。
|
||||
SchedulePlanAdjustmentScopeLarge = "large"
|
||||
)
|
||||
|
||||
const (
|
||||
schedulePlanTimezoneName = SchedulePlanTimezoneName
|
||||
schedulePlanDatetimeLayout = SchedulePlanDatetimeLayout
|
||||
schedulePlanDefaultDailyRefineConcurrency = SchedulePlanDefaultDailyRefineConcurrency
|
||||
schedulePlanDefaultWeeklyAdjustBudget = SchedulePlanDefaultWeeklyAdjustBudget
|
||||
schedulePlanDefaultWeeklyTotalBudget = SchedulePlanDefaultWeeklyTotalBudget
|
||||
schedulePlanDefaultWeeklyRefineConcurrency = SchedulePlanDefaultWeeklyRefineConcurrency
|
||||
schedulePlanAdjustmentScopeSmall = SchedulePlanAdjustmentScopeSmall
|
||||
schedulePlanAdjustmentScopeMedium = SchedulePlanAdjustmentScopeMedium
|
||||
schedulePlanAdjustmentScopeLarge = SchedulePlanAdjustmentScopeLarge
|
||||
)
|
||||
|
||||
// DayGroup 是“按天拆分后”的最小优化单元。
|
||||
//
|
||||
// 设计目的:
|
||||
// 1. 把全量周视角数据拆成“单天小包”,降低日内 ReAct 输入规模;
|
||||
// 2. 支持并发优化不同天的数据,缩短整体等待;
|
||||
// 3. 通过 SkipRefine 让低收益天数直接跳过,节省模型调用成本。
|
||||
type DayGroup struct {
|
||||
Week int
|
||||
DayOfWeek int
|
||||
Entries []model.HybridScheduleEntry
|
||||
SkipRefine bool
|
||||
}
|
||||
|
||||
// SchedulePlanState 是“智能排程”链路在 graph 节点间传递的统一状态容器。
|
||||
//
|
||||
// 设计目标:
|
||||
// 1) 收拢排程请求全生命周期的上下文,降低节点间参数散落;
|
||||
// 2) 支持“粗排 -> 日内并发优化 -> 周级配平 -> 终审校验”的完整链路追踪;
|
||||
// 3) 支持连续对话微调:保留上版方案 + 本次约束变更,便于增量重排。
|
||||
type SchedulePlanState struct {
|
||||
// ── 基础上下文 ──
|
||||
TraceID string
|
||||
UserID int
|
||||
ConversationID string
|
||||
RequestNow time.Time
|
||||
RequestNowText string
|
||||
|
||||
// ── plan 节点输出 ──
|
||||
UserIntent string
|
||||
Constraints []string
|
||||
TaskClassIDs []int
|
||||
Strategy string
|
||||
TaskTags map[int]string
|
||||
TaskTagHintsByName map[string]string
|
||||
|
||||
// ── preview 节点输出 ──
|
||||
CandidatePlans []model.UserWeekSchedule
|
||||
AllocatedItems []model.TaskClassItem
|
||||
HasPlanningWindow bool
|
||||
PlanStartWeek int
|
||||
PlanStartDay int
|
||||
PlanEndWeek int
|
||||
PlanEndDay int
|
||||
|
||||
// ── 日内并发优化阶段 ──
|
||||
DailyGroups map[int]map[int]*DayGroup
|
||||
DailyResults map[int]map[int][]model.HybridScheduleEntry
|
||||
DailyRefineConcurrency int
|
||||
|
||||
// ── 周级 ReAct 精排阶段 ──
|
||||
HybridEntries []model.HybridScheduleEntry
|
||||
MergeSnapshot []model.HybridScheduleEntry
|
||||
ReactRound int
|
||||
ReactMaxRound int
|
||||
ReactSummary string
|
||||
ReactDone bool
|
||||
WeeklyAdjustBudget int
|
||||
WeeklyAdjustUsed int
|
||||
WeeklyTotalBudget int
|
||||
WeeklyTotalUsed int
|
||||
WeeklyRefineConcurrency int
|
||||
WeeklyActionLogs []string
|
||||
|
||||
// ── 连续对话微调 ──
|
||||
PreviousPlanJSON string
|
||||
IsAdjustment bool
|
||||
RestartRequested bool
|
||||
AdjustmentScope string
|
||||
AdjustmentReason string
|
||||
AdjustmentConfidence float64
|
||||
HasPreviousPreview bool
|
||||
PreviousTaskClassIDs []int
|
||||
PreviousHybridEntries []model.HybridScheduleEntry
|
||||
PreviousAllocatedItems []model.TaskClassItem
|
||||
PreviousCandidatePlans []model.UserWeekSchedule
|
||||
|
||||
// ── 最终输出 ──
|
||||
FinalSummary string
|
||||
Completed bool
|
||||
}
|
||||
|
||||
// NewSchedulePlanState 创建排程状态对象并初始化默认值。
|
||||
func NewSchedulePlanState(traceID string, userID int, conversationID string) *SchedulePlanState {
|
||||
now := schedulePlanNowToMinute()
|
||||
return &SchedulePlanState{
|
||||
TraceID: traceID,
|
||||
UserID: userID,
|
||||
ConversationID: conversationID,
|
||||
RequestNow: now,
|
||||
RequestNowText: now.In(schedulePlanLocation()).Format(schedulePlanDatetimeLayout),
|
||||
Strategy: "steady",
|
||||
TaskTags: make(map[int]string),
|
||||
TaskTagHintsByName: make(map[string]string),
|
||||
DailyRefineConcurrency: schedulePlanDefaultDailyRefineConcurrency,
|
||||
WeeklyRefineConcurrency: schedulePlanDefaultWeeklyRefineConcurrency,
|
||||
AdjustmentScope: schedulePlanAdjustmentScopeLarge,
|
||||
ReactMaxRound: 2,
|
||||
WeeklyAdjustBudget: schedulePlanDefaultWeeklyAdjustBudget,
|
||||
WeeklyTotalBudget: schedulePlanDefaultWeeklyTotalBudget,
|
||||
}
|
||||
}
|
||||
|
||||
// NormalizeSchedulePlanAdjustmentScope 归一化排程微调力度字段。
|
||||
//
|
||||
// 兜底策略:
|
||||
// 1. 只接受 small/medium/large;
|
||||
// 2. 任何未知值都回退为 large,保证不会误走“过轻”路径。
|
||||
func NormalizeSchedulePlanAdjustmentScope(raw string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case schedulePlanAdjustmentScopeSmall:
|
||||
return schedulePlanAdjustmentScopeSmall
|
||||
case schedulePlanAdjustmentScopeMedium:
|
||||
return schedulePlanAdjustmentScopeMedium
|
||||
default:
|
||||
return schedulePlanAdjustmentScopeLarge
|
||||
}
|
||||
}
|
||||
|
||||
// schedulePlanLocation 返回排程链路使用的业务时区。
|
||||
func schedulePlanLocation() *time.Location {
|
||||
loc, err := time.LoadLocation(schedulePlanTimezoneName)
|
||||
if err != nil {
|
||||
return time.Local
|
||||
}
|
||||
return loc
|
||||
}
|
||||
|
||||
// schedulePlanNowToMinute 返回当前时间并截断到分钟级。
|
||||
func schedulePlanNowToMinute() time.Time {
|
||||
return time.Now().In(schedulePlanLocation()).Truncate(time.Minute)
|
||||
}
|
||||
|
||||
func normalizeAdjustmentScope(raw string) string {
|
||||
return NormalizeSchedulePlanAdjustmentScope(raw)
|
||||
}
|
||||
@@ -1,29 +1,35 @@
|
||||
package schedulerefine
|
||||
package agentmodel
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
agentshared "github.com/LoveLosita/smartflow/backend/agent/shared"
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
)
|
||||
|
||||
const (
|
||||
// 固定业务时区,避免“今天/明天”在容器默认时区下偏移。
|
||||
timezoneName = "Asia/Shanghai"
|
||||
// 统一分钟级时间文本格式。
|
||||
datetimeLayout = "2006-01-02 15:04"
|
||||
datetimeLayout = agentshared.MinuteLayout
|
||||
|
||||
// 预算默认值。
|
||||
defaultPlanMax = 2
|
||||
defaultExecuteMax = 24
|
||||
defaultPerTaskBudget = 4
|
||||
defaultReplanMax = 2
|
||||
defaultCompositeRetry = 2
|
||||
defaultRepairReserve = 1
|
||||
ScheduleRefineDefaultPlanMax = 2
|
||||
ScheduleRefineDefaultExecuteMax = 24
|
||||
ScheduleRefineDefaultPerTaskBudget = 4
|
||||
ScheduleRefineDefaultReplanMax = 2
|
||||
ScheduleRefineDefaultCompositeRetry = 2
|
||||
ScheduleRefineDefaultRepairReserve = 1
|
||||
)
|
||||
|
||||
// RefineContract 表示本轮微调意图契约。
|
||||
const (
|
||||
defaultPlanMax = ScheduleRefineDefaultPlanMax
|
||||
defaultExecuteMax = ScheduleRefineDefaultExecuteMax
|
||||
defaultPerTaskBudget = ScheduleRefineDefaultPerTaskBudget
|
||||
defaultReplanMax = ScheduleRefineDefaultReplanMax
|
||||
defaultCompositeRetry = ScheduleRefineDefaultCompositeRetry
|
||||
defaultRepairReserve = ScheduleRefineDefaultRepairReserve
|
||||
)
|
||||
|
||||
// RefineContract 琛ㄧず鏈疆寰皟鎰忓浘濂戠害銆?
|
||||
type RefineContract struct {
|
||||
Intent string `json:"intent"`
|
||||
Strategy string `json:"strategy"`
|
||||
@@ -33,13 +39,13 @@ type RefineContract struct {
|
||||
OrderScope string `json:"order_scope"`
|
||||
}
|
||||
|
||||
// RefineAssertion 表示可由后端直接判定的结构化硬断言。
|
||||
// RefineAssertion 琛ㄧず鍙敱鍚庣鐩存帴鍒ゅ畾鐨勭粨鏋勫寲纭柇瑷€銆?
|
||||
//
|
||||
// 字段说明:
|
||||
// 1. Metric:断言指标名,例如 source_move_ratio_percent;
|
||||
// 2. Operator:比较操作符,支持 == / <= / >= / between;
|
||||
// 3. Value/Min/Max:阈值;
|
||||
// 4. Week/TargetWeek:可选周次上下文。
|
||||
// 瀛楁璇存槑锛?
|
||||
// 1. Metric锛氭柇瑷€鎸囨爣鍚嶏紝渚嬪 source_move_ratio_percent锛?
|
||||
// 2. Operator锛氭瘮杈冩搷浣滅锛屾敮鎸?== / <= / >= / between锛?
|
||||
// 3. Value/Min/Max锛氶槇鍊硷紱
|
||||
// 4. Week/TargetWeek锛氬彲閫夊懆娆′笂涓嬫枃銆?
|
||||
type RefineAssertion struct {
|
||||
Metric string `json:"metric"`
|
||||
Operator string `json:"operator"`
|
||||
@@ -50,7 +56,7 @@ type RefineAssertion struct {
|
||||
TargetWeek int `json:"target_week,omitempty"`
|
||||
}
|
||||
|
||||
// HardCheckReport 表示终审硬校验结果。
|
||||
// HardCheckReport 琛ㄧず缁堝纭牎楠岀粨鏋溿€?
|
||||
type HardCheckReport struct {
|
||||
PhysicsPassed bool `json:"physics_passed"`
|
||||
PhysicsIssues []string `json:"physics_issues,omitempty"`
|
||||
@@ -65,7 +71,7 @@ type HardCheckReport struct {
|
||||
RepairTried bool `json:"repair_tried"`
|
||||
}
|
||||
|
||||
// ReactRoundObservation 记录每轮 ReAct 的关键观察。
|
||||
// ReactRoundObservation 璁板綍姣忚疆 ReAct 鐨勫叧閿瀵熴€?
|
||||
type ReactRoundObservation struct {
|
||||
Round int `json:"round"`
|
||||
GoalCheck string `json:"goal_check,omitempty"`
|
||||
@@ -78,13 +84,13 @@ type ReactRoundObservation struct {
|
||||
Reflect string `json:"reflect,omitempty"`
|
||||
}
|
||||
|
||||
// PlannerPlan 表示 Planner 生成的阶段执行计划。
|
||||
// PlannerPlan 琛ㄧず Planner 鐢熸垚鐨勯樁娈垫墽琛岃鍒掋€?
|
||||
type PlannerPlan struct {
|
||||
Summary string `json:"summary"`
|
||||
Steps []string `json:"steps,omitempty"`
|
||||
}
|
||||
|
||||
// RefineSlicePlan 表示切片节点输出。
|
||||
// RefineSlicePlan 琛ㄧず鍒囩墖鑺傜偣杈撳嚭銆?
|
||||
type RefineSlicePlan struct {
|
||||
WeekFilter []int `json:"week_filter,omitempty"`
|
||||
SourceDays []int `json:"source_days,omitempty"`
|
||||
@@ -93,12 +99,12 @@ type RefineSlicePlan struct {
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// RefineObjective 表示“可执行且可校验”的目标约束。
|
||||
// RefineObjective 琛ㄧず鈥滃彲鎵ц涓斿彲鏍¢獙鈥濈殑鐩爣绾︽潫銆?
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 由 contract/slice 从自然语言编译得到;
|
||||
// 2. 执行阶段(done 收口)与终审阶段(hard_check)共用同一份约束;
|
||||
// 3. 避免“执行逻辑与终审逻辑各说各话”。
|
||||
// 璁捐璇存槑锛?
|
||||
// 1. 鐢?contract/slice 浠庤嚜鐒惰瑷€缂栬瘧寰楀埌锛?
|
||||
// 2. 鎵ц闃舵锛坉one 鏀跺彛锛変笌缁堝闃舵锛坔ard_check锛夊叡鐢ㄥ悓涓€浠界害鏉燂紱
|
||||
// 3. 閬垮厤鈥滄墽琛岄€昏緫涓庣粓瀹¢€昏緫鍚勮鍚勮瘽鈥濄€?
|
||||
type RefineObjective struct {
|
||||
Mode string `json:"mode,omitempty"` // none | move_all | move_ratio
|
||||
|
||||
@@ -116,9 +122,9 @@ type RefineObjective struct {
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// ScheduleRefineState 是连续微调图的统一状态。
|
||||
// ScheduleRefineState 鏄繛缁井璋冨浘鐨勭粺涓€鐘舵€併€?
|
||||
type ScheduleRefineState struct {
|
||||
// 1) 请求上下文
|
||||
// 1) 璇锋眰涓婁笅鏂?
|
||||
TraceID string
|
||||
UserID int
|
||||
ConversationID string
|
||||
@@ -126,19 +132,19 @@ type ScheduleRefineState struct {
|
||||
RequestNow time.Time
|
||||
RequestNowText string
|
||||
|
||||
// 2) 继承自预览快照的数据
|
||||
// 2) 缁ф壙鑷瑙堝揩鐓х殑鏁版嵁
|
||||
TaskClassIDs []int
|
||||
Constraints []string
|
||||
// InitialHybridEntries 保存本轮微调开始前的基线,用于终审做“前后对比”。
|
||||
// 说明:
|
||||
// 1. 只读语义,不参与执行期改写;
|
||||
// 2. 终审可基于它判断“来源任务是否真正迁移到目标区域”。
|
||||
// InitialHybridEntries 淇濆瓨鏈疆寰皟寮€濮嬪墠鐨勫熀绾匡紝鐢ㄤ簬缁堝鍋氣€滃墠鍚庡姣斺€濄€?
|
||||
// 璇存槑锛?
|
||||
// 1. 鍙璇箟锛屼笉鍙備笌鎵ц鏈熸敼鍐欙紱
|
||||
// 2. 缁堝鍙熀浜庡畠鍒ゆ柇鈥滄潵婧愪换鍔℃槸鍚︾湡姝h縼绉诲埌鐩爣鍖哄煙鈥濄€?
|
||||
InitialHybridEntries []model.HybridScheduleEntry
|
||||
HybridEntries []model.HybridScheduleEntry
|
||||
AllocatedItems []model.TaskClassItem
|
||||
CandidatePlans []model.UserWeekSchedule
|
||||
|
||||
// 3) 本轮执行状态
|
||||
// 3) 鏈疆鎵ц鐘舵€?
|
||||
UserIntent string
|
||||
Contract RefineContract
|
||||
|
||||
@@ -146,7 +152,7 @@ type ScheduleRefineState struct {
|
||||
PerTaskBudget int
|
||||
ExecuteMax int
|
||||
ReplanMax int
|
||||
// CompositeRetryMax 表示复合路由失败后的最大重试次数(不含首次尝试)。
|
||||
// CompositeRetryMax 琛ㄧず澶嶅悎璺敱澶辫触鍚庣殑鏈€澶ч噸璇曟鏁帮紙涓嶅惈棣栨灏濊瘯锛夈€?
|
||||
CompositeRetryMax int
|
||||
|
||||
PlanUsed int
|
||||
@@ -163,27 +169,27 @@ type ScheduleRefineState struct {
|
||||
|
||||
CurrentPlan PlannerPlan
|
||||
BatchMoveAllowed bool
|
||||
// DisableCompositeTools=true 表示已进入 ReAct 兜底,禁止再调用复合工具。
|
||||
// DisableCompositeTools=true 琛ㄧず宸茶繘鍏?ReAct 鍏滃簳锛岀姝㈠啀璋冪敤澶嶅悎宸ュ叿銆?
|
||||
DisableCompositeTools bool
|
||||
// CompositeRouteTried 标记是否尝试过“复合批处理路由”。
|
||||
// CompositeRouteTried 鏍囪鏄惁灏濊瘯杩団€滃鍚堟壒澶勭悊璺敱鈥濄€?
|
||||
CompositeRouteTried bool
|
||||
// CompositeRouteSucceeded 标记复合批处理路由是否已完成“复合分支出站”。
|
||||
// CompositeRouteSucceeded 鏍囪澶嶅悎鎵瑰鐞嗚矾鐢辨槸鍚﹀凡瀹屾垚鈥滃鍚堝垎鏀嚭绔欌€濄€?
|
||||
//
|
||||
// 说明:
|
||||
// 1. true 表示当前链路可以跳过 ReAct 兜底,直接进入 hard_check;
|
||||
// 2. 它不等价于“终审已通过”,终审是否通过仍以后续 HardCheck 结果为准;
|
||||
// 3. 这样区分是为了避免“复合工具已成功执行,但业务目标要等终审裁决”时被误判为失败。
|
||||
// 璇存槑锛?
|
||||
// 1. true 琛ㄧず褰撳墠閾捐矾鍙互璺宠繃 ReAct 鍏滃簳锛岀洿鎺ヨ繘鍏?hard_check锛?
|
||||
// 2. 瀹冧笉绛変环浜庘€滅粓瀹″凡閫氳繃鈥濓紝缁堝鏄惁閫氳繃浠嶄互鍚庣画 HardCheck 缁撴灉涓哄噯锛?
|
||||
// 3. 杩欐牱鍖哄垎鏄负浜嗛伩鍏嶁€滃鍚堝伐鍏峰凡鎴愬姛鎵ц锛屼絾涓氬姟鐩爣瑕佺瓑缁堝瑁佸喅鈥濇椂琚鍒や负澶辫触銆?
|
||||
CompositeRouteSucceeded bool
|
||||
TaskActionUsed map[int]int
|
||||
EntriesVersion int
|
||||
SeenSlotQueries map[string]struct{}
|
||||
|
||||
// RequiredCompositeTool 表示本轮策略要求“必须至少成功一次”的复合工具。
|
||||
// 取值约定:"" | "SpreadEven" | "MinContextSwitch"。
|
||||
// RequiredCompositeTool 琛ㄧず鏈疆绛栫暐瑕佹眰鈥滃繀椤昏嚦灏戞垚鍔熶竴娆♀€濈殑澶嶅悎宸ュ叿銆?
|
||||
// 鍙栧€肩害瀹氾細"" | "SpreadEven" | "MinContextSwitch"銆?
|
||||
RequiredCompositeTool string
|
||||
// CompositeToolCalled 记录复合工具是否至少调用过一次(不区分成功失败)。
|
||||
// CompositeToolCalled 璁板綍澶嶅悎宸ュ叿鏄惁鑷冲皯璋冪敤杩囦竴娆★紙涓嶅尯鍒嗘垚鍔熷け璐ワ級銆?
|
||||
CompositeToolCalled map[string]bool
|
||||
// CompositeToolSuccess 记录复合工具是否至少成功过一次。
|
||||
// CompositeToolSuccess 璁板綍澶嶅悎宸ュ叿鏄惁鑷冲皯鎴愬姛杩囦竴娆°€?
|
||||
CompositeToolSuccess map[string]bool
|
||||
|
||||
SlicePlan RefineSlicePlan
|
||||
@@ -196,20 +202,20 @@ type ScheduleRefineState struct {
|
||||
LastFailedCallSignature string
|
||||
OriginOrderMap map[int]int
|
||||
|
||||
// 4) 终审状态
|
||||
// 4) 缁堝鐘舵€?
|
||||
HardCheck HardCheckReport
|
||||
|
||||
// 5) 最终输出
|
||||
// 5) 鏈€缁堣緭鍑?
|
||||
FinalSummary string
|
||||
Completed bool
|
||||
}
|
||||
|
||||
// NewScheduleRefineState 基于上一版预览快照初始化状态。
|
||||
// NewScheduleRefineState 鍩轰簬涓婁竴鐗堥瑙堝揩鐓у垵濮嬪寲鐘舵€併€?
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责初始化预算、上下文字段与可变状态容器;
|
||||
// 2. 负责拷贝 preview 数据,避免跨请求引用污染;
|
||||
// 3. 不负责做任何调度动作。
|
||||
// 鑱岃矗杈圭晫锛?
|
||||
// 1. 璐熻矗鍒濆鍖栭绠椼€佷笂涓嬫枃瀛楁涓庡彲鍙樼姸鎬佸鍣紱
|
||||
// 2. 璐熻矗鎷疯礉 preview 鏁版嵁锛岄伩鍏嶈法璇锋眰寮曠敤姹℃煋锛?
|
||||
// 3. 涓嶈礋璐e仛浠讳綍璋冨害鍔ㄤ綔銆?
|
||||
func NewScheduleRefineState(traceID string, userID int, conversationID string, userMessage string, preview *model.SchedulePlanPreviewCache) *ScheduleRefineState {
|
||||
now := nowToMinute()
|
||||
st := &ScheduleRefineState{
|
||||
@@ -240,10 +246,10 @@ func NewScheduleRefineState(traceID string, userID int, conversationID string, u
|
||||
"MinContextSwitch": false,
|
||||
},
|
||||
CurrentPlan: PlannerPlan{
|
||||
Summary: "初始化完成,等待 Planner 生成执行计划。",
|
||||
Summary: "initialized, waiting for planner output",
|
||||
},
|
||||
SlicePlan: RefineSlicePlan{
|
||||
Reason: "尚未切片",
|
||||
Reason: "灏氭湭鍒囩墖",
|
||||
},
|
||||
}
|
||||
if preview == nil {
|
||||
@@ -260,75 +266,26 @@ func NewScheduleRefineState(traceID string, userID int, conversationID string, u
|
||||
}
|
||||
|
||||
func loadLocation() *time.Location {
|
||||
loc, err := time.LoadLocation(timezoneName)
|
||||
if err != nil {
|
||||
return time.Local
|
||||
}
|
||||
return loc
|
||||
return agentshared.ShanghaiLocation()
|
||||
}
|
||||
|
||||
func nowToMinute() time.Time {
|
||||
return time.Now().In(loadLocation()).Truncate(time.Minute)
|
||||
return agentshared.NowToMinute()
|
||||
}
|
||||
|
||||
func cloneHybridEntries(src []model.HybridScheduleEntry) []model.HybridScheduleEntry {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
dst := make([]model.HybridScheduleEntry, len(src))
|
||||
copy(dst, src)
|
||||
return dst
|
||||
return agentshared.CloneHybridEntries(src)
|
||||
}
|
||||
|
||||
func cloneTaskClassItems(src []model.TaskClassItem) []model.TaskClassItem {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
dst := make([]model.TaskClassItem, 0, len(src))
|
||||
for _, item := range src {
|
||||
copied := item
|
||||
if item.CategoryID != nil {
|
||||
v := *item.CategoryID
|
||||
copied.CategoryID = &v
|
||||
}
|
||||
if item.Order != nil {
|
||||
v := *item.Order
|
||||
copied.Order = &v
|
||||
}
|
||||
if item.Content != nil {
|
||||
v := *item.Content
|
||||
copied.Content = &v
|
||||
}
|
||||
if item.Status != nil {
|
||||
v := *item.Status
|
||||
copied.Status = &v
|
||||
}
|
||||
if item.EmbeddedTime != nil {
|
||||
t := *item.EmbeddedTime
|
||||
copied.EmbeddedTime = &t
|
||||
}
|
||||
dst = append(dst, copied)
|
||||
}
|
||||
return dst
|
||||
return agentshared.CloneTaskClassItems(src)
|
||||
}
|
||||
|
||||
func cloneWeekSchedules(src []model.UserWeekSchedule) []model.UserWeekSchedule {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
dst := make([]model.UserWeekSchedule, 0, len(src))
|
||||
for _, week := range src {
|
||||
eventsCopy := make([]model.WeeklyEventBrief, len(week.Events))
|
||||
copy(eventsCopy, week.Events)
|
||||
dst = append(dst, model.UserWeekSchedule{
|
||||
Week: week.Week,
|
||||
Events: eventsCopy,
|
||||
})
|
||||
}
|
||||
return dst
|
||||
return agentshared.CloneWeekSchedules(src)
|
||||
}
|
||||
|
||||
// buildOriginOrderMap 构建 suggested 任务的初始顺序基线(task_item_id -> rank)。
|
||||
// buildOriginOrderMap 鏋勫缓 suggested 浠诲姟鐨勫垵濮嬮『搴忓熀绾匡紙task_item_id -> rank锛夈€?
|
||||
func buildOriginOrderMap(entries []model.HybridScheduleEntry) map[int]int {
|
||||
orderMap := make(map[int]int)
|
||||
if len(entries) == 0 {
|
||||
@@ -363,15 +320,25 @@ func buildOriginOrderMap(entries []model.HybridScheduleEntry) map[int]int {
|
||||
return orderMap
|
||||
}
|
||||
|
||||
// FinalHardCheckPassed 判断“最终终审”是否整体通过。
|
||||
// FinalHardCheckPassed 鍒ゆ柇鈥滄渶缁堢粓瀹♀€濇槸鍚︽暣浣撻€氳繃銆?
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责聚合 physics/order/intent 三类硬校验结果,给服务层与总结阶段统一复用;
|
||||
// 2. 不负责触发终审,也不负责推导修复动作;
|
||||
// 3. nil state 视为未通过,避免上层把缺失结果误判为成功。
|
||||
// 鑱岃矗杈圭晫锛?
|
||||
// 1. 璐熻矗鑱氬悎 physics/order/intent 涓夌被纭牎楠岀粨鏋滐紝缁欐湇鍔″眰涓庢€荤粨闃舵缁熶竴澶嶇敤锛?
|
||||
// 2. 涓嶈礋璐hЕ鍙戠粓瀹★紝涔熶笉璐熻矗鎺ㄥ淇鍔ㄤ綔锛?
|
||||
// 3. nil state 瑙嗕负鏈€氳繃锛岄伩鍏嶄笂灞傛妸缂哄け缁撴灉璇垽涓烘垚鍔熴€?
|
||||
func FinalHardCheckPassed(st *ScheduleRefineState) bool {
|
||||
if st == nil {
|
||||
return false
|
||||
}
|
||||
return st.HardCheck.PhysicsPassed && st.HardCheck.OrderPassed && st.HardCheck.IntentPassed
|
||||
}
|
||||
|
||||
func isMovableSuggestedTask(entry model.HybridScheduleEntry) bool {
|
||||
if strings.TrimSpace(entry.Status) != "suggested" || entry.TaskItemID <= 0 {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(entry.Type), "course") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
37
backend/agent/model/task_priority.go
Normal file
37
backend/agent/model/task_priority.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package agentmodel
|
||||
|
||||
const (
|
||||
TaskPriorityImportantUrgent = 1
|
||||
TaskPriorityImportantNotUrgent = 2
|
||||
TaskPrioritySimpleNotImportant = 3
|
||||
TaskPriorityComplexNotImportant = 4
|
||||
)
|
||||
|
||||
// IsValidTaskPriority 用于校验任务优先级是否合法。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只负责判断 priority 是否落在系统支持的 1~4 范围内。
|
||||
// 2. 不负责把自然语言映射成优先级,也不负责做业务兜底推断。
|
||||
func IsValidTaskPriority(priority int) bool {
|
||||
return priority >= TaskPriorityImportantUrgent && priority <= TaskPriorityComplexNotImportant
|
||||
}
|
||||
|
||||
// PriorityLabelCN 返回任务优先级对应的中文标签。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只负责“优先级枚举 -> 中文展示文案”的稳定映射。
|
||||
// 2. 不负责国际化、多语言切换或业务规则解释。
|
||||
func PriorityLabelCN(priority int) string {
|
||||
switch priority {
|
||||
case TaskPriorityImportantUrgent:
|
||||
return "重要且紧急"
|
||||
case TaskPriorityImportantNotUrgent:
|
||||
return "重要不紧急"
|
||||
case TaskPrioritySimpleNotImportant:
|
||||
return "简单不重要"
|
||||
case TaskPriorityComplexNotImportant:
|
||||
return "复杂不重要"
|
||||
default:
|
||||
return "未知优先级"
|
||||
}
|
||||
}
|
||||
87
backend/agent/model/taskquery.go
Normal file
87
backend/agent/model/taskquery.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package agentmodel
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
// DefaultTaskQueryLimit 是任务查询默认返回条数。
|
||||
DefaultTaskQueryLimit = 5
|
||||
// MaxTaskQueryLimit 是任务查询允许的最大返回条数,用于限制模型输出范围。
|
||||
MaxTaskQueryLimit = 20
|
||||
// DefaultTaskQueryReflectRetry 是任务查询反思节点的默认重试次数。
|
||||
DefaultTaskQueryReflectRetry = 2
|
||||
)
|
||||
|
||||
// TaskQueryItem 是任务查询链路最终展示给模型和用户的轻量任务视图。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只承载展示和反思所需字段,避免把底层数据库结构直接暴露给图层。
|
||||
// 2. 不负责描述完整任务实体,也不负责持久化。
|
||||
type TaskQueryItem struct {
|
||||
ID int `json:"id"`
|
||||
Title string `json:"title"`
|
||||
PriorityGroup int `json:"priority_group"`
|
||||
PriorityLabel string `json:"priority_label"`
|
||||
IsCompleted bool `json:"is_completed"`
|
||||
DeadlineAt string `json:"deadline_at,omitempty"`
|
||||
UrgencyThresholdAt string `json:"urgency_threshold_at,omitempty"`
|
||||
}
|
||||
|
||||
// TaskQueryPlan 是计划节点产出的内部查询方案。
|
||||
//
|
||||
// 输入输出语义:
|
||||
// 1. DeadlineBeforeText / DeadlineAfterText 保留原始文本,便于继续透传给工具和日志。
|
||||
// 2. DeadlineBefore / DeadlineAfter 是归一化后的时间对象,仅供执行期使用。
|
||||
// 3. IncludeCompleted=true 表示允许把已完成任务纳入候选集。
|
||||
type TaskQueryPlan struct {
|
||||
Quadrants []int
|
||||
SortBy string
|
||||
Order string
|
||||
Limit int
|
||||
|
||||
IncludeCompleted bool
|
||||
Keyword string
|
||||
DeadlineBeforeText string
|
||||
DeadlineAfterText string
|
||||
DeadlineBefore *time.Time
|
||||
DeadlineAfter *time.Time
|
||||
}
|
||||
|
||||
// TaskQueryState 是任务查询图在各节点之间流转的完整状态。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责保存用户输入、结构化计划、工具结果和反思过程状态。
|
||||
// 2. 不负责图编排本身,也不直接绑定外部数据库实体。
|
||||
type TaskQueryState struct {
|
||||
UserMessage string
|
||||
RequestNowText string
|
||||
UserGoal string
|
||||
Plan TaskQueryPlan
|
||||
ExplicitLimit int
|
||||
|
||||
LastQueryItems []TaskQueryItem
|
||||
LastQueryTotal int
|
||||
AutoBroadenApplied bool
|
||||
RetryCount int
|
||||
MaxReflectRetry int
|
||||
NeedRetry bool
|
||||
ReflectReason string
|
||||
FinalReply string
|
||||
}
|
||||
|
||||
// NewTaskQueryState 负责创建任务查询图的初始状态。
|
||||
//
|
||||
// 输入输出语义:
|
||||
// 1. maxReflectRetry <= 0 时会自动回退到默认值,避免上层遗漏配置导致无法重试。
|
||||
// 2. 返回的状态对象已初始化空切片,可直接进入 graph 执行。
|
||||
func NewTaskQueryState(userMessage, requestNowText string, maxReflectRetry int) *TaskQueryState {
|
||||
if maxReflectRetry <= 0 {
|
||||
maxReflectRetry = DefaultTaskQueryReflectRetry
|
||||
}
|
||||
return &TaskQueryState{
|
||||
UserMessage: userMessage,
|
||||
RequestNowText: requestNowText,
|
||||
MaxReflectRetry: maxReflectRetry,
|
||||
LastQueryItems: make([]TaskQueryItem, 0),
|
||||
AutoBroadenApplied: false,
|
||||
}
|
||||
}
|
||||
504
backend/agent/node/quicknote.go
Normal file
504
backend/agent/node/quicknote.go
Normal file
@@ -0,0 +1,504 @@
|
||||
package agentnode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
agentllm "github.com/LoveLosita/smartflow/backend/agent/llm"
|
||||
agentmodel "github.com/LoveLosita/smartflow/backend/agent/model"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
const (
|
||||
// QuickNoteGraphNodeIntent 是随口记图中的“意图识别”节点名。
|
||||
QuickNoteGraphNodeIntent = "quick_note_intent"
|
||||
// QuickNoteGraphNodeRank 是随口记图中的“优先级评估”节点名。
|
||||
QuickNoteGraphNodeRank = "quick_note_priority"
|
||||
// QuickNoteGraphNodePersist 是随口记图中的“持久化写库”节点名。
|
||||
QuickNoteGraphNodePersist = "quick_note_persist"
|
||||
// QuickNoteGraphNodeExit 是随口记图中的“提前退出”节点名。
|
||||
QuickNoteGraphNodeExit = "quick_note_exit"
|
||||
)
|
||||
|
||||
// QuickNoteGraphRunInput 描述一次随口记图运行所需的请求级依赖。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责把模型、初始状态、工具依赖和阶段回调打包给 graph 层。
|
||||
// 2. 不负责做依赖校验,校验逻辑由 graph/node 构造阶段处理。
|
||||
type QuickNoteGraphRunInput struct {
|
||||
Model *ark.ChatModel
|
||||
State *agentmodel.QuickNoteState
|
||||
Deps QuickNoteToolDeps
|
||||
SkipIntentVerification bool
|
||||
EmitStage func(stage, detail string)
|
||||
}
|
||||
|
||||
// QuickNoteNodes 是随口记图的节点容器。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责承接节点运行时依赖,并向 graph 暴露可直接挂载的方法。
|
||||
// 2. 不负责 graph 编译,也不负责 service 层接口接线。
|
||||
type QuickNoteNodes struct {
|
||||
input QuickNoteGraphRunInput
|
||||
createTaskTool tool.InvokableTool
|
||||
emitStage func(stage, detail string)
|
||||
}
|
||||
|
||||
// NewQuickNoteNodes 负责构造随口记节点容器。
|
||||
//
|
||||
// 输入输出语义:
|
||||
// 1. createTaskTool 不能为空,否则 persist 节点无法落库。
|
||||
// 2. EmitStage 为空时会回退到空实现,避免节点内部到处判空。
|
||||
func NewQuickNoteNodes(input QuickNoteGraphRunInput, createTaskTool tool.InvokableTool) (*QuickNoteNodes, error) {
|
||||
if createTaskTool == nil {
|
||||
return nil, errors.New("quick note nodes: createTaskTool is nil")
|
||||
}
|
||||
|
||||
emitStage := input.EmitStage
|
||||
if emitStage == nil {
|
||||
emitStage = func(stage, detail string) {}
|
||||
}
|
||||
|
||||
return &QuickNoteNodes{
|
||||
input: input,
|
||||
createTaskTool: createTaskTool,
|
||||
emitStage: emitStage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Exit 是图中的显式退出节点。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 仅作为图收口占位,保持状态原样透传。
|
||||
// 2. 不做额外业务处理,避免退出节点再引入副作用。
|
||||
func (n *QuickNoteNodes) Exit(ctx context.Context, st *agentmodel.QuickNoteState) (*agentmodel.QuickNoteState, error) {
|
||||
_ = ctx
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// NextAfterIntent 根据意图识别结果决定 intent 节点后的分支走向。
|
||||
//
|
||||
// 步骤说明:
|
||||
// 1. 非随口记意图时直接退出,避免误把普通聊天写成任务。
|
||||
// 2. 截止时间校验失败时同样直接退出,让上层优先把错误提示给用户。
|
||||
// 3. 只有意图成立且时间合法,才进入优先级评估节点。
|
||||
func (n *QuickNoteNodes) NextAfterIntent(ctx context.Context, st *agentmodel.QuickNoteState) (string, error) {
|
||||
_ = ctx
|
||||
if st == nil || !st.IsQuickNoteIntent {
|
||||
return QuickNoteGraphNodeExit, nil
|
||||
}
|
||||
if st.DeadlineValidationError != "" {
|
||||
return QuickNoteGraphNodeExit, nil
|
||||
}
|
||||
return QuickNoteGraphNodeRank, nil
|
||||
}
|
||||
|
||||
// NextAfterPersist 根据持久化结果决定 persist 节点后的分支走向。
|
||||
//
|
||||
// 输入输出语义:
|
||||
// 1. Persisted=true 表示已经成功写库,可以直接结束。
|
||||
// 2. Persisted=false 且 CanRetryTool()=true 表示继续重试写库。
|
||||
// 3. 重试用尽后会补齐兜底回复,再结束链路,避免用户拿到空响应。
|
||||
func (n *QuickNoteNodes) NextAfterPersist(ctx context.Context, st *agentmodel.QuickNoteState) (string, error) {
|
||||
_ = ctx
|
||||
if st == nil {
|
||||
return compose.END, nil
|
||||
}
|
||||
if st.Persisted {
|
||||
return compose.END, nil
|
||||
}
|
||||
if st.CanRetryTool() {
|
||||
return QuickNoteGraphNodePersist, nil
|
||||
}
|
||||
if st.AssistantReply == "" {
|
||||
st.AssistantReply = "抱歉,我已经重试了多次,还是没能成功记录这条任务,请稍后再试。"
|
||||
}
|
||||
return compose.END, nil
|
||||
}
|
||||
|
||||
// Intent 负责“意图识别 + 聚合规划 + 时间校验”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责判断本次请求是否属于随口记;
|
||||
// 2. 负责把模型规划结果回填到 state;
|
||||
// 3. 负责做最后一层本地时间硬校验,避免非法时间被静默写成 NULL;
|
||||
// 4. 不负责真正写库。
|
||||
func (n *QuickNoteNodes) Intent(ctx context.Context, st *agentmodel.QuickNoteState) (*agentmodel.QuickNoteState, error) {
|
||||
if st == nil {
|
||||
return nil, errors.New("quick note graph: nil state in intent node")
|
||||
}
|
||||
|
||||
// 1. 若上游路由已经高置信命中 quick_note,则直接进入单次聚合规划。
|
||||
// 1.1 目的:尽量把“标题 / 时间 / 优先级 / banter”压缩到一次模型往返内;
|
||||
// 1.2 失败处理:若聚合规划失败,不中断整条链路,而是回退到本地兜底,保证可用性优先。
|
||||
if n.input.SkipIntentVerification {
|
||||
n.emitStage("quick_note.intent.analyzing", "已由上游路由判定为任务请求,跳过二次意图判断。")
|
||||
st.IsQuickNoteIntent = true
|
||||
st.IntentJudgeReason = "上游路由已命中 quick_note,跳过二次意图判定"
|
||||
st.PlannedBySingleCall = true
|
||||
|
||||
n.emitStage("quick_note.plan.generating", "正在一次性生成时间归一化、优先级与回复润色。")
|
||||
plan, planErr := planQuickNoteInSingleCall(ctx, n.input.Model, st.RequestNowText, st.RequestNow, st.UserInput)
|
||||
if planErr != nil {
|
||||
st.IntentJudgeReason += ";聚合规划失败,回退本地兜底"
|
||||
} else {
|
||||
if strings.TrimSpace(plan.Title) != "" {
|
||||
st.ExtractedTitle = strings.TrimSpace(plan.Title)
|
||||
}
|
||||
if plan.Deadline != nil {
|
||||
st.ExtractedDeadline = plan.Deadline
|
||||
}
|
||||
st.ExtractedDeadlineText = strings.TrimSpace(plan.DeadlineText)
|
||||
if plan.UrgencyThreshold != nil {
|
||||
st.ExtractedUrgencyThreshold = normalizeUrgencyThreshold(plan.UrgencyThreshold, plan.Deadline)
|
||||
}
|
||||
if agentmodel.IsValidTaskPriority(plan.PriorityGroup) {
|
||||
st.ExtractedPriority = plan.PriorityGroup
|
||||
st.ExtractedPriorityReason = strings.TrimSpace(plan.PriorityReason)
|
||||
}
|
||||
st.ExtractedBanter = strings.TrimSpace(plan.Banter)
|
||||
}
|
||||
|
||||
// 1.3 如果聚合规划没能给出标题,则回退到本地标题抽取,避免后续 persist 节点拿到空标题。
|
||||
if strings.TrimSpace(st.ExtractedTitle) == "" {
|
||||
st.ExtractedTitle = deriveQuickNoteTitleFromInput(st.UserInput)
|
||||
}
|
||||
|
||||
// 1.4 最后一定要做一轮本地时间硬校验。
|
||||
// 1.4.1 原因:模型即使给了时间,也可能和用户原句不一致,或者用户原句本身就是非法时间;
|
||||
// 1.4.2 若检测到“用户给了时间线索但格式非法”,直接退出图并给用户明确修正提示。
|
||||
n.emitStage("quick_note.deadline.validating", "正在校验并归一化任务时间。")
|
||||
userDeadline, userHasTimeHint, userDeadlineErr := parseOptionalDeadlineFromUserInput(st.UserInput, st.RequestNow)
|
||||
if userHasTimeHint && userDeadlineErr != nil {
|
||||
st.DeadlineValidationError = userDeadlineErr.Error()
|
||||
st.AssistantReply = "我识别到你给了时间信息,但这个时间格式我没法准确解析,请改成例如:2026-03-20 18:30、明天下午3点、下周一上午9点。"
|
||||
n.emitStage("quick_note.failed", "时间校验失败,未执行写入。")
|
||||
return st, nil
|
||||
}
|
||||
if userDeadline != nil {
|
||||
st.ExtractedDeadline = userDeadline
|
||||
st.ExtractedDeadlineText = strings.TrimSpace(st.UserInput)
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 2. 常规路径:先做一次意图识别,再做本地时间硬校验。
|
||||
n.emitStage("quick_note.intent.analyzing", "正在分析用户输入是否属于任务安排请求。")
|
||||
parsed, callErr := agentllm.IdentifyQuickNoteIntent(ctx, n.input.Model, st.RequestNowText, st.UserInput)
|
||||
if callErr != nil {
|
||||
// 2.1 这里不直接返回 error,而是把它视为“本次未能确认是 quick note”,交给上层回退普通聊天。
|
||||
st.IsQuickNoteIntent = false
|
||||
st.IntentJudgeReason = "意图识别失败,回退普通聊天"
|
||||
return st, nil
|
||||
}
|
||||
|
||||
st.IsQuickNoteIntent = parsed.IsQuickNote
|
||||
st.IntentJudgeReason = strings.TrimSpace(parsed.Reason)
|
||||
if !st.IsQuickNoteIntent {
|
||||
return st, nil
|
||||
}
|
||||
|
||||
title := strings.TrimSpace(parsed.Title)
|
||||
if title == "" {
|
||||
title = strings.TrimSpace(st.UserInput)
|
||||
}
|
||||
st.ExtractedTitle = title
|
||||
|
||||
n.emitStage("quick_note.deadline.validating", "正在校验并归一化任务时间。")
|
||||
|
||||
// 2.2 先尝试吃模型返回的 deadline_at,用于减少后续重复推理。
|
||||
st.ExtractedDeadlineText = strings.TrimSpace(parsed.DeadlineAt)
|
||||
if st.ExtractedDeadlineText != "" {
|
||||
if deadline, deadlineErr := parseOptionalDeadlineWithNow(st.ExtractedDeadlineText, st.RequestNow); deadlineErr == nil {
|
||||
st.ExtractedDeadline = deadline
|
||||
}
|
||||
}
|
||||
|
||||
// 2.3 再强制对用户原句做一次时间线索校验。
|
||||
userDeadline, userHasTimeHint, userDeadlineErr := parseOptionalDeadlineFromUserInput(st.UserInput, st.RequestNow)
|
||||
if userHasTimeHint && userDeadlineErr != nil {
|
||||
st.DeadlineValidationError = userDeadlineErr.Error()
|
||||
st.AssistantReply = "我识别到你给了时间信息,但这个时间格式我没法准确解析,请改成例如:2026-03-20 18:30、明天下午3点、下周一上午9点。"
|
||||
n.emitStage("quick_note.failed", "时间校验失败,未执行写入。")
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 2.4 若模型没提到 deadline,但用户原句能解析出来,则以用户原句为准补齐。
|
||||
if st.ExtractedDeadline == nil && userDeadline != nil {
|
||||
st.ExtractedDeadline = userDeadline
|
||||
if st.ExtractedDeadlineText == "" {
|
||||
st.ExtractedDeadlineText = strings.TrimSpace(st.UserInput)
|
||||
}
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// Priority 负责“优先级评估”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责在 intent 节点之后补齐 priority_group;
|
||||
// 2. 若聚合规划已经给出合法优先级,则直接复用,不再重复调用模型;
|
||||
// 3. 若模型评估失败,则使用本地兜底策略,保证链路继续可走;
|
||||
// 4. 不负责写库。
|
||||
func (n *QuickNoteNodes) Priority(ctx context.Context, st *agentmodel.QuickNoteState) (*agentmodel.QuickNoteState, error) {
|
||||
if st == nil {
|
||||
return nil, errors.New("quick note graph: nil state in priority node")
|
||||
}
|
||||
if !st.IsQuickNoteIntent || strings.TrimSpace(st.DeadlineValidationError) != "" {
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 1. 聚合规划已经给出合法优先级时,直接复用,避免重复调模型。
|
||||
if agentmodel.IsValidTaskPriority(st.ExtractedPriority) {
|
||||
if strings.TrimSpace(st.ExtractedPriorityReason) == "" {
|
||||
st.ExtractedPriorityReason = "复用聚合规划优先级"
|
||||
}
|
||||
n.emitStage("quick_note.priority.evaluating", "已复用聚合规划结果中的优先级。")
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 2. 单请求聚合路径若没有给出合法 priority,则直接走本地兜底,优先保证低时延。
|
||||
if n.input.SkipIntentVerification || st.PlannedBySingleCall {
|
||||
st.ExtractedPriority = fallbackPriority(st)
|
||||
st.ExtractedPriorityReason = "聚合规划未给出合法优先级,使用本地兜底"
|
||||
n.emitStage("quick_note.priority.evaluating", "聚合优先级缺失,已使用本地兜底。")
|
||||
return st, nil
|
||||
}
|
||||
|
||||
n.emitStage("quick_note.priority.evaluating", "正在评估任务优先级。")
|
||||
deadlineText := "无"
|
||||
if st.ExtractedDeadline != nil {
|
||||
deadlineText = formatQuickNoteTimeToMinute(*st.ExtractedDeadline)
|
||||
}
|
||||
deadlineClue := strings.TrimSpace(st.ExtractedDeadlineText)
|
||||
if deadlineClue == "" {
|
||||
deadlineClue = "无"
|
||||
}
|
||||
|
||||
parsed, callErr := agentllm.PlanQuickNotePriority(ctx, n.input.Model, st.RequestNowText, st.ExtractedTitle, st.UserInput, deadlineClue, deadlineText)
|
||||
if callErr != nil {
|
||||
st.ExtractedPriority = fallbackPriority(st)
|
||||
st.ExtractedPriorityReason = "优先级评估失败,使用兜底策略"
|
||||
return st, nil
|
||||
}
|
||||
if parsed == nil || !agentmodel.IsValidTaskPriority(parsed.PriorityGroup) {
|
||||
st.ExtractedPriority = fallbackPriority(st)
|
||||
st.ExtractedPriorityReason = "优先级结果异常,使用兜底策略"
|
||||
return st, nil
|
||||
}
|
||||
|
||||
st.ExtractedPriority = parsed.PriorityGroup
|
||||
st.ExtractedPriorityReason = strings.TrimSpace(parsed.Reason)
|
||||
if strings.TrimSpace(parsed.UrgencyThresholdAt) != "" {
|
||||
urgencyThreshold, thresholdErr := parseOptionalDeadlineWithNow(strings.TrimSpace(parsed.UrgencyThresholdAt), st.RequestNow)
|
||||
if thresholdErr == nil {
|
||||
st.ExtractedUrgencyThreshold = normalizeUrgencyThreshold(urgencyThreshold, st.ExtractedDeadline)
|
||||
}
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// Persist 负责“调工具写库 + 有限次重试状态回填”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责把 state 中已提取出的标题、时间、优先级组装成工具入参;
|
||||
// 2. 负责调用 createTaskTool 执行真正写库;
|
||||
// 3. 负责把成功/失败结果回填到 state,供后续分支与回复使用;
|
||||
// 4. 不负责最终回复润色,不负责 service 层的 Redis 与持久化收尾。
|
||||
func (n *QuickNoteNodes) Persist(ctx context.Context, st *agentmodel.QuickNoteState) (*agentmodel.QuickNoteState, error) {
|
||||
if st == nil {
|
||||
return nil, errors.New("quick note graph: nil state in persist node")
|
||||
}
|
||||
if !st.IsQuickNoteIntent || strings.TrimSpace(st.DeadlineValidationError) != "" {
|
||||
return st, nil
|
||||
}
|
||||
|
||||
n.emitStage("quick_note.persisting", "正在写入任务数据。")
|
||||
priority := st.ExtractedPriority
|
||||
if !agentmodel.IsValidTaskPriority(priority) {
|
||||
priority = fallbackPriority(st)
|
||||
st.ExtractedPriority = priority
|
||||
}
|
||||
|
||||
deadlineText := ""
|
||||
if st.ExtractedDeadline != nil {
|
||||
deadlineText = st.ExtractedDeadline.In(quickNoteLocation()).Format(time.RFC3339)
|
||||
}
|
||||
urgencyThresholdText := ""
|
||||
if st.ExtractedUrgencyThreshold != nil {
|
||||
urgencyThresholdText = st.ExtractedUrgencyThreshold.In(quickNoteLocation()).Format(time.RFC3339)
|
||||
}
|
||||
|
||||
toolInput := QuickNoteCreateTaskToolInput{
|
||||
Title: st.ExtractedTitle,
|
||||
PriorityGroup: priority,
|
||||
DeadlineAt: deadlineText,
|
||||
UrgencyThresholdAt: urgencyThresholdText,
|
||||
}
|
||||
rawInput, marshalErr := json.Marshal(toolInput)
|
||||
if marshalErr != nil {
|
||||
st.RecordToolError("构造工具参数失败: " + marshalErr.Error())
|
||||
if !st.CanRetryTool() {
|
||||
st.AssistantReply = "抱歉,记录任务时参数处理失败,请稍后重试。"
|
||||
n.emitStage("quick_note.failed", "参数构造失败,未完成写入。")
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
rawOutput, invokeErr := n.createTaskTool.InvokableRun(ctx, string(rawInput))
|
||||
if invokeErr != nil {
|
||||
st.RecordToolError(invokeErr.Error())
|
||||
if !st.CanRetryTool() {
|
||||
st.AssistantReply = "抱歉,我尝试了多次仍未能成功记录这条任务,请稍后再试。"
|
||||
n.emitStage("quick_note.failed", "多次重试后仍未完成写入。")
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
toolOutput, parseErr := agentllm.ParseJSONObject[QuickNoteCreateTaskToolOutput](rawOutput)
|
||||
if parseErr != nil {
|
||||
st.RecordToolError("解析工具返回失败: " + parseErr.Error())
|
||||
if !st.CanRetryTool() {
|
||||
st.AssistantReply = "抱歉,我拿到了异常结果,没能确认任务是否记录成功,请稍后再试。"
|
||||
n.emitStage("quick_note.failed", "结果解析异常,无法确认写入结果。")
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
if toolOutput.TaskID <= 0 {
|
||||
st.RecordToolError(fmt.Sprintf("工具返回非法 task_id=%d", toolOutput.TaskID))
|
||||
if !st.CanRetryTool() {
|
||||
st.AssistantReply = "抱歉,这次我没能确认任务写入成功,请再发一次我立刻补上。"
|
||||
n.emitStage("quick_note.failed", "写入结果缺少有效 task_id,已终止成功回包。")
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 1. 只有拿到有效 task_id,才视为真正写入成功;
|
||||
// 2. 这样可以避免出现“返回成功文案,但数据库里根本没记录”的假成功。
|
||||
st.RecordToolSuccess(toolOutput.TaskID)
|
||||
if strings.TrimSpace(toolOutput.Title) != "" {
|
||||
st.ExtractedTitle = strings.TrimSpace(toolOutput.Title)
|
||||
}
|
||||
if agentmodel.IsValidTaskPriority(toolOutput.PriorityGroup) {
|
||||
st.ExtractedPriority = toolOutput.PriorityGroup
|
||||
}
|
||||
|
||||
reply := strings.TrimSpace(toolOutput.Message)
|
||||
if reply == "" {
|
||||
reply = fmt.Sprintf("已为你记录:%s(%s)", st.ExtractedTitle, agentmodel.PriorityLabelCN(st.ExtractedPriority))
|
||||
}
|
||||
st.AssistantReply = reply
|
||||
n.emitStage("quick_note.persisted", "任务写入成功,正在组织回复内容。")
|
||||
return st, nil
|
||||
}
|
||||
|
||||
type quickNotePlannedResult struct {
|
||||
Title string
|
||||
Deadline *time.Time
|
||||
DeadlineText string
|
||||
UrgencyThreshold *time.Time
|
||||
UrgencyThresholdText string
|
||||
PriorityGroup int
|
||||
PriorityReason string
|
||||
Banter string
|
||||
}
|
||||
|
||||
// planQuickNoteInSingleCall 在一次模型调用里完成“时间 / 优先级 / banter”聚合规划。
|
||||
func planQuickNoteInSingleCall(ctx context.Context, chatModel *ark.ChatModel, nowText string, now time.Time, userInput string) (*quickNotePlannedResult, error) {
|
||||
parsed, err := agentllm.PlanQuickNoteInSingleCall(ctx, chatModel, nowText, userInput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &quickNotePlannedResult{
|
||||
Title: strings.TrimSpace(parsed.Title),
|
||||
DeadlineText: strings.TrimSpace(parsed.DeadlineAt),
|
||||
UrgencyThresholdText: strings.TrimSpace(parsed.UrgencyThresholdAt),
|
||||
PriorityGroup: parsed.PriorityGroup,
|
||||
PriorityReason: strings.TrimSpace(parsed.PriorityReason),
|
||||
Banter: strings.TrimSpace(parsed.Banter),
|
||||
}
|
||||
if result.Banter != "" {
|
||||
if idx := strings.Index(result.Banter, "\n"); idx >= 0 {
|
||||
result.Banter = strings.TrimSpace(result.Banter[:idx])
|
||||
}
|
||||
}
|
||||
if result.DeadlineText != "" {
|
||||
if deadline, deadlineErr := parseOptionalDeadlineWithNow(result.DeadlineText, now); deadlineErr == nil {
|
||||
result.Deadline = deadline
|
||||
}
|
||||
}
|
||||
if result.UrgencyThresholdText != "" {
|
||||
if urgencyThreshold, thresholdErr := parseOptionalDeadlineWithNow(result.UrgencyThresholdText, now); thresholdErr == nil {
|
||||
result.UrgencyThreshold = normalizeUrgencyThreshold(urgencyThreshold, result.Deadline)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func normalizeUrgencyThreshold(threshold *time.Time, deadline *time.Time) *time.Time {
|
||||
if threshold == nil {
|
||||
return nil
|
||||
}
|
||||
if deadline == nil {
|
||||
return threshold
|
||||
}
|
||||
if threshold.After(*deadline) {
|
||||
normalized := *deadline
|
||||
return &normalized
|
||||
}
|
||||
return threshold
|
||||
}
|
||||
|
||||
func fallbackPriority(st *agentmodel.QuickNoteState) int {
|
||||
if st == nil {
|
||||
return agentmodel.QuickNotePrioritySimpleNotImportant
|
||||
}
|
||||
if st.ExtractedDeadline != nil {
|
||||
if time.Until(*st.ExtractedDeadline) <= 48*time.Hour {
|
||||
return agentmodel.QuickNotePriorityImportantUrgent
|
||||
}
|
||||
return agentmodel.QuickNotePriorityImportantNotUrgent
|
||||
}
|
||||
return agentmodel.QuickNotePrioritySimpleNotImportant
|
||||
}
|
||||
|
||||
// deriveQuickNoteTitleFromInput 在“跳过二次意图判定”场景下,从用户原句提取任务标题。
|
||||
func deriveQuickNoteTitleFromInput(userInput string) string {
|
||||
text := strings.TrimSpace(userInput)
|
||||
if text == "" {
|
||||
return "这条任务"
|
||||
}
|
||||
|
||||
prefixes := []string{
|
||||
"请帮我", "麻烦帮我", "麻烦你", "帮我", "提醒我", "请提醒我", "记一个", "记个", "帮我记一个",
|
||||
}
|
||||
for _, prefix := range prefixes {
|
||||
if strings.HasPrefix(text, prefix) {
|
||||
text = strings.TrimSpace(strings.TrimPrefix(text, prefix))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
suffixSeparators := []string{
|
||||
",记得", ",记得", ",到时候", ",到时候", " 到时候", ",别忘了", ",别忘了", "。记得",
|
||||
}
|
||||
for _, sep := range suffixSeparators {
|
||||
if idx := strings.Index(text, sep); idx > 0 {
|
||||
text = strings.TrimSpace(text[:idx])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
text = strings.Trim(text, ",。?!!? ")
|
||||
if text == "" {
|
||||
return strings.TrimSpace(userInput)
|
||||
}
|
||||
return text
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package quicknote
|
||||
package agentnode
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,22 +9,21 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
agentmodel "github.com/LoveLosita/smartflow/backend/agent/model"
|
||||
agentshared "github.com/LoveLosita/smartflow/backend/agent/shared"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
toolutils "github.com/cloudwego/eino/components/tool/utils"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
const (
|
||||
// ToolNameQuickNoteCreateTask 是“AI随口记”写库工具的标准名称。
|
||||
// 该名称会直接暴露给大模型,因此建议保持稳定,避免后续提示词和历史上下文失配。
|
||||
// ToolNameQuickNoteCreateTask 是“AI随口记”写库工具的稳定名称。
|
||||
ToolNameQuickNoteCreateTask = "quick_note_create_task"
|
||||
// ToolDescQuickNoteCreateTask 是工具的简要职责说明。
|
||||
// ToolDescQuickNoteCreateTask 是给大模型看的工具职责说明。
|
||||
ToolDescQuickNoteCreateTask = "把用户随口提到的事项落库为任务,支持可选截止时间与优先级"
|
||||
)
|
||||
|
||||
var (
|
||||
// quickNoteDeadlineLayouts 是“绝对时间”白名单格式。
|
||||
// 只要命中任意一个 layout,就会被归一化为分钟级时间并进入写库流程。
|
||||
quickNoteDeadlineLayouts = []string{
|
||||
time.RFC3339,
|
||||
"2006-01-02T15:04",
|
||||
@@ -44,9 +43,6 @@ var (
|
||||
"2006.01.02": {},
|
||||
}
|
||||
|
||||
// 正则区:
|
||||
// 1) 用于解析明确时间表达;
|
||||
// 2) 用于“是否存在时间线索”的判定(即使格式错误,也会触发校验失败而非静默忽略)。
|
||||
quickNoteClockHMRegex = regexp.MustCompile(`(\d{1,2})\s*[::]\s*(\d{1,2})`)
|
||||
quickNoteClockCNRegex = regexp.MustCompile(`(\d{1,2})\s*点\s*(半|(\d{1,2})\s*分?)?`)
|
||||
quickNoteYMDRegex = regexp.MustCompile(`(\d{4})\s*年\s*(\d{1,2})\s*月\s*(\d{1,2})\s*[日号]?`)
|
||||
@@ -59,48 +55,38 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// QuickNoteToolDeps 描述“随口记工具包”需要的外部依赖。
|
||||
// 这里采用函数注入的方式,避免 agent 包和 service/dao 强耦合,后续更容易演进为 mock 测试或多实现切换。
|
||||
// QuickNoteToolDeps 描述随口记工具所需的外部依赖。
|
||||
type QuickNoteToolDeps struct {
|
||||
// ResolveUserID 从上下文中解析当前登录用户 ID。
|
||||
ResolveUserID func(ctx context.Context) (int, error)
|
||||
// CreateTask 执行真实写库动作。
|
||||
CreateTask func(ctx context.Context, req QuickNoteCreateTaskRequest) (*QuickNoteCreateTaskResult, error)
|
||||
CreateTask func(ctx context.Context, req QuickNoteCreateTaskRequest) (*QuickNoteCreateTaskResult, error)
|
||||
}
|
||||
|
||||
func (d QuickNoteToolDeps) validate() error {
|
||||
// 1. ResolveUserID 为空会导致工具无法绑定当前用户,必须提前失败。
|
||||
func (d QuickNoteToolDeps) Validate() error {
|
||||
if d.ResolveUserID == nil {
|
||||
return errors.New("quick note tool deps: ResolveUserID is nil")
|
||||
}
|
||||
// 2. CreateTask 为空说明没有真实写库实现,工具无法完成核心职责。
|
||||
if d.CreateTask == nil {
|
||||
return errors.New("quick note tool deps: CreateTask is nil")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// QuickNoteToolBundle 是随口记工具集合的打包结果。
|
||||
// - Tools: 给 ToolsNode 使用
|
||||
// - ToolInfos: 给 ChatModel 绑定工具 schema 使用
|
||||
// 两者分开返回,可以适配你后面用 chain、graph、react 的不同挂载姿势。
|
||||
// QuickNoteToolBundle 是随口记工具集合。
|
||||
type QuickNoteToolBundle struct {
|
||||
Tools []tool.BaseTool
|
||||
ToolInfos []*schema.ToolInfo
|
||||
}
|
||||
|
||||
// QuickNoteCreateTaskRequest 是工具层到业务层的内部请求结构。
|
||||
// 与模型输入解耦,避免模型字段变化直接影响业务签名。
|
||||
// QuickNoteCreateTaskRequest 是工具层传给业务层的内部请求。
|
||||
type QuickNoteCreateTaskRequest struct {
|
||||
UserID int
|
||||
Title string
|
||||
PriorityGroup int
|
||||
DeadlineAt *time.Time
|
||||
// UrgencyThresholdAt 是“进入紧急象限”的分界时间,允许为空。
|
||||
UserID int
|
||||
Title string
|
||||
PriorityGroup int
|
||||
DeadlineAt *time.Time
|
||||
UrgencyThresholdAt *time.Time
|
||||
}
|
||||
|
||||
// QuickNoteCreateTaskResult 是业务层返回给工具层的结构化结果。
|
||||
// QuickNoteCreateTaskResult 是业务层回给工具层的结构化结果。
|
||||
type QuickNoteCreateTaskResult struct {
|
||||
TaskID int
|
||||
Title string
|
||||
@@ -109,21 +95,18 @@ type QuickNoteCreateTaskResult struct {
|
||||
UrgencyThresholdAt *time.Time
|
||||
}
|
||||
|
||||
// QuickNoteCreateTaskToolInput 是提供给大模型的工具参数定义。
|
||||
// 注意:user_id 不对模型暴露,统一从鉴权上下文提取,避免越权写入。
|
||||
// QuickNoteCreateTaskToolInput 是暴露给模型的工具入参。
|
||||
type QuickNoteCreateTaskToolInput struct {
|
||||
Title string `json:"title" jsonschema:"required,description=任务标题,简洁明确"`
|
||||
// PriorityGroup 使用 1~4,和后端 tasks.priority 保持一致。
|
||||
PriorityGroup int `json:"priority_group" jsonschema:"required,enum=1,enum=2,enum=3,enum=4,description=优先级分组(1重要且紧急,2重要不紧急,3简单不重要,4不简单不重要)"`
|
||||
// DeadlineAt 支持绝对时间与常见相对时间(如明天/后天/下周一/今晚),内部会归一化为绝对时间。
|
||||
// PriorityGroup 与 tasks.priority 保持一致,取值 1~4。
|
||||
PriorityGroup int `json:"priority_group" jsonschema:"required,enum=1,enum=2,enum=3,enum=4,description=优先级分组(1重要且紧急,2重要不紧急,3简单不重要,4复杂不重要)"`
|
||||
// DeadlineAt 支持绝对时间与常见中文相对时间。
|
||||
DeadlineAt string `json:"deadline_at,omitempty" jsonschema:"description=可选截止时间,支持RFC3339、yyyy-MM-dd HH:mm:ss、yyyy-MM-dd HH:mm 以及常见中文相对时间"`
|
||||
// UrgencyThresholdAt 表示“何时从不紧急象限自动平移到紧急象限”。
|
||||
// 允许为空;非空时会走同样的时间解析与合法性校验。
|
||||
// UrgencyThresholdAt 表示何时自动进入紧急象限。
|
||||
UrgencyThresholdAt string `json:"urgency_threshold_at,omitempty" jsonschema:"description=可选紧急分界时间,支持与deadline_at相同格式"`
|
||||
}
|
||||
|
||||
// QuickNoteCreateTaskToolOutput 是返回给大模型的工具结果。
|
||||
// 该结构可直接给模型用于“向用户解释已记录到哪个优先级”。
|
||||
// QuickNoteCreateTaskToolOutput 是返回给模型的结构化结果。
|
||||
type QuickNoteCreateTaskToolOutput struct {
|
||||
TaskID int `json:"task_id"`
|
||||
Title string `json:"title"`
|
||||
@@ -133,37 +116,28 @@ type QuickNoteCreateTaskToolOutput struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// BuildQuickNoteToolBundle 构建“AI随口记”工具包。
|
||||
// 这是 agent 目录给上层编排层(chain/graph/react)提供的统一入口。
|
||||
// BuildQuickNoteToolBundle 构建随口记工具包。
|
||||
func BuildQuickNoteToolBundle(ctx context.Context, deps QuickNoteToolDeps) (*QuickNoteToolBundle, error) {
|
||||
// 1. 启动期做依赖校验,尽早暴露 wiring 问题,避免运行时才 panic。
|
||||
if err := deps.validate(); err != nil {
|
||||
if err := deps.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 通过 InferTool 把 Go 函数声明成“模型可调用工具”。
|
||||
// 该闭包函数是工具的真实执行体,后续所有参数校验都在这里兜底。
|
||||
createTaskTool, err := toolutils.InferTool(
|
||||
ToolNameQuickNoteCreateTask,
|
||||
ToolDescQuickNoteCreateTask,
|
||||
func(ctx context.Context, input *QuickNoteCreateTaskToolInput) (*QuickNoteCreateTaskToolOutput, error) {
|
||||
// 2.1 防御式检查:工具调用参数不能为 nil。
|
||||
if input == nil {
|
||||
return nil, errors.New("工具参数不能为空")
|
||||
}
|
||||
|
||||
// 2.2 标题与优先级是写库硬条件,必须先校验。
|
||||
title := strings.TrimSpace(input.Title)
|
||||
if title == "" {
|
||||
return nil, errors.New("title 不能为空")
|
||||
}
|
||||
if !IsValidTaskPriority(input.PriorityGroup) {
|
||||
if !agentmodel.IsValidTaskPriority(input.PriorityGroup) {
|
||||
return nil, fmt.Errorf("priority_group=%d 非法,必须在 1~4", input.PriorityGroup)
|
||||
}
|
||||
|
||||
// 这里对 deadline_at 做“强校验”:
|
||||
// - 空值允许(代表没有截止时间);
|
||||
// - 非空但无法解析直接报错,避免把有问题的时间静默写成 NULL。
|
||||
deadline, err := parseOptionalDeadline(input.DeadlineAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -173,7 +147,6 @@ func BuildQuickNoteToolBundle(ctx context.Context, deps QuickNoteToolDeps) (*Qui
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2.3 user_id 一律来自鉴权上下文,不信任模型侧入参,防止越权写别人的任务。
|
||||
userID, err := deps.ResolveUserID(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解析用户身份失败: %w", err)
|
||||
@@ -182,7 +155,6 @@ func BuildQuickNoteToolBundle(ctx context.Context, deps QuickNoteToolDeps) (*Qui
|
||||
return nil, fmt.Errorf("非法 user_id=%d", userID)
|
||||
}
|
||||
|
||||
// 2.4 走业务层写库。
|
||||
result, err := deps.CreateTask(ctx, QuickNoteCreateTaskRequest{
|
||||
UserID: userID,
|
||||
Title: title,
|
||||
@@ -197,18 +169,15 @@ func BuildQuickNoteToolBundle(ctx context.Context, deps QuickNoteToolDeps) (*Qui
|
||||
return nil, errors.New("写入任务后返回结果异常")
|
||||
}
|
||||
|
||||
// 2.5 结果归一化:优先使用业务层返回值,其次回退到入参,保证输出稳定可读。
|
||||
finalTitle := title
|
||||
if strings.TrimSpace(result.Title) != "" {
|
||||
finalTitle = strings.TrimSpace(result.Title)
|
||||
}
|
||||
|
||||
finalPriority := input.PriorityGroup
|
||||
if IsValidTaskPriority(result.PriorityGroup) {
|
||||
if agentmodel.IsValidTaskPriority(result.PriorityGroup) {
|
||||
finalPriority = result.PriorityGroup
|
||||
}
|
||||
|
||||
// 2.6 截止时间输出统一为 RFC3339,便于跨系统传输与调试。
|
||||
deadlineStr := ""
|
||||
if result.DeadlineAt != nil {
|
||||
deadlineStr = result.DeadlineAt.In(quickNoteLocation()).Format(time.RFC3339)
|
||||
@@ -216,14 +185,13 @@ func BuildQuickNoteToolBundle(ctx context.Context, deps QuickNoteToolDeps) (*Qui
|
||||
deadlineStr = deadline.In(quickNoteLocation()).Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// 2.7 组装给模型的结构化结果,包含可直接面向用户的 message 草稿。
|
||||
return &QuickNoteCreateTaskToolOutput{
|
||||
TaskID: result.TaskID,
|
||||
Title: finalTitle,
|
||||
PriorityGroup: finalPriority,
|
||||
PriorityLabel: PriorityLabelCN(finalPriority),
|
||||
PriorityLabel: agentmodel.PriorityLabelCN(finalPriority),
|
||||
DeadlineAt: deadlineStr,
|
||||
Message: fmt.Sprintf("已记录:%s(%s)", finalTitle, PriorityLabelCN(finalPriority)),
|
||||
Message: fmt.Sprintf("已记录:%s(%s)", finalTitle, agentmodel.PriorityLabelCN(finalPriority)),
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
@@ -231,7 +199,6 @@ func BuildQuickNoteToolBundle(ctx context.Context, deps QuickNoteToolDeps) (*Qui
|
||||
return nil, fmt.Errorf("构建随口记工具失败: %w", err)
|
||||
}
|
||||
|
||||
// 3. Tools 给执行节点使用,ToolInfos 给模型注册 schema 使用,二者都要返回。
|
||||
tools := []tool.BaseTool{createTaskTool}
|
||||
infos, err := collectToolInfos(ctx, tools)
|
||||
if err != nil {
|
||||
@@ -244,36 +211,26 @@ func BuildQuickNoteToolBundle(ctx context.Context, deps QuickNoteToolDeps) (*Qui
|
||||
}, nil
|
||||
}
|
||||
|
||||
func collectToolInfos(ctx context.Context, tools []tool.BaseTool) ([]*schema.ToolInfo, error) {
|
||||
// 按工具列表顺序提取 ToolInfo,确保“tools[idx] <-> infos[idx]”一一对应。
|
||||
infos := make([]*schema.ToolInfo, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
info, err := t.Info(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取工具信息失败: %w", err)
|
||||
}
|
||||
infos = append(infos, info)
|
||||
// GetInvokableToolByName 通过工具名提取可执行工具实例。
|
||||
func GetInvokableToolByName(bundle *QuickNoteToolBundle, name string) (tool.InvokableTool, error) {
|
||||
if bundle == nil {
|
||||
return nil, errors.New("tool bundle is nil")
|
||||
}
|
||||
return infos, nil
|
||||
return getInvokableToolByName(bundle.Tools, bundle.ToolInfos, name)
|
||||
}
|
||||
|
||||
// parseOptionalDeadline 解析工具输入中的可选截止时间。
|
||||
// 该入口用于“工具参数强校验”:只要调用方给了非空 deadline_at,就必须能被解析。
|
||||
func parseOptionalDeadline(raw string) (*time.Time, error) {
|
||||
// 1. 先做标点与空白归一化,避免中文输入噪声影响解析。
|
||||
value := normalizeDeadlineInput(raw)
|
||||
if value == "" {
|
||||
// 2. 空字符串合法,表示任务无截止时间。
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 3. 统一按“严格模式”解析:给了时间就必须成功解析。
|
||||
deadline, hasHint, err := parseOptionalDeadlineFromText(value, quickNoteNowToMinute())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if deadline == nil {
|
||||
// 4. 区分“无时间线索”和“有线索但不支持”,返回更准确错误信息。
|
||||
if !hasHint {
|
||||
return nil, fmt.Errorf("deadline_at 格式不支持: %s", value)
|
||||
}
|
||||
@@ -283,9 +240,7 @@ func parseOptionalDeadline(raw string) (*time.Time, error) {
|
||||
}
|
||||
|
||||
// parseOptionalDeadlineWithNow 在给定时间基准下解析 deadline。
|
||||
// 该函数保持“严格模式”:非空字符串无法解析时会直接返回 error。
|
||||
func parseOptionalDeadlineWithNow(raw string, now time.Time) (*time.Time, error) {
|
||||
// 场景:模型已给出 deadline_at,需要基于同一 requestNow 再次硬校验。
|
||||
value := normalizeDeadlineInput(raw)
|
||||
if value == "" {
|
||||
return nil, nil
|
||||
@@ -302,12 +257,7 @@ func parseOptionalDeadlineWithNow(raw string, now time.Time) (*time.Time, error)
|
||||
}
|
||||
|
||||
// parseOptionalDeadlineFromUserInput 是“用户原句解析”的宽松入口。
|
||||
// 返回值说明:
|
||||
// - deadline != nil:成功解析出时间;
|
||||
// - hasHint=false 且 err=nil:文本里没有明显时间线索,应视为“用户没给时间”;
|
||||
// - hasHint=true 且 err!=nil:用户给了时间但格式非法,应提示用户修正,不应落库。
|
||||
func parseOptionalDeadlineFromUserInput(raw string, now time.Time) (*time.Time, bool, error) {
|
||||
// 场景:解析用户原始句子时,允许“没给时间”,但不允许“给了错误时间却静默通过”。
|
||||
value := normalizeDeadlineInput(raw)
|
||||
if value == "" {
|
||||
return nil, false, nil
|
||||
@@ -316,10 +266,8 @@ func parseOptionalDeadlineFromUserInput(raw string, now time.Time) (*time.Time,
|
||||
deadline, hasHint, err := parseOptionalDeadlineFromText(value, now)
|
||||
if err != nil {
|
||||
if hasHint {
|
||||
// 有时间线索 + 解析失败:上层应明确提示用户改时间格式。
|
||||
return nil, true, err
|
||||
}
|
||||
// 无明显时间线索:按“未提供时间”处理。
|
||||
return nil, false, nil
|
||||
}
|
||||
if deadline == nil {
|
||||
@@ -331,49 +279,36 @@ func parseOptionalDeadlineFromUserInput(raw string, now time.Time) (*time.Time,
|
||||
return deadline, true, nil
|
||||
}
|
||||
|
||||
// parseOptionalDeadlineFromText 是内部通用解析器。
|
||||
// 解析顺序:
|
||||
// 1) 绝对时间(明确年月日时分);
|
||||
// 2) 相对时间(明天/下周一/今晚);
|
||||
// 3) 若识别到时间线索但仍失败,返回 hasHint=true + error,交给上层决定是否拦截。
|
||||
// parseOptionalDeadlineFromText 是内部通用时间解析器。
|
||||
func parseOptionalDeadlineFromText(value string, now time.Time) (*time.Time, bool, error) {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// 1. 统一时区与时间基准,保证相对时间可重复计算。
|
||||
loc := quickNoteLocation()
|
||||
now = now.In(loc)
|
||||
hasHint := hasDeadlineHint(value)
|
||||
|
||||
// 2. 先尝试绝对时间(优先级更高,歧义更小)。
|
||||
if abs, ok := tryParseAbsoluteDeadline(value, loc); ok {
|
||||
return abs, true, nil
|
||||
}
|
||||
|
||||
// 3. 再尝试相对时间(明天/下周一/今晚)。
|
||||
if rel, recognized, err := tryParseRelativeDeadline(value, now, loc); recognized {
|
||||
if err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
return rel, true, nil
|
||||
}
|
||||
|
||||
// 4. 到这里仍失败时,根据 hasHint 决定返回“软失败”还是“硬失败”。
|
||||
if hasHint {
|
||||
return nil, true, fmt.Errorf("deadline_at 格式不支持: %s", value)
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// normalizeDeadlineInput 把中文标点和空白先归一化,降低格式解析的噪声。
|
||||
func normalizeDeadlineInput(raw string) string {
|
||||
// 先 trim,避免纯空格输入影响后续逻辑。
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
// 将中文标点统一成英文形态,降低正则和 layout 解析复杂度。
|
||||
replacer := strings.NewReplacer(
|
||||
":", ":",
|
||||
",", ",",
|
||||
@@ -383,12 +318,7 @@ func normalizeDeadlineInput(raw string) string {
|
||||
return strings.TrimSpace(replacer.Replace(trimmed))
|
||||
}
|
||||
|
||||
// hasDeadlineHint 判断文本里是否存在“时间相关线索”。
|
||||
// 该函数的意义是区分两种情况:
|
||||
// 1) 用户根本没给时间(允许 deadline 为空);
|
||||
// 2) 用户给了时间但写错(必须提示修正,不能静默写 NULL)。
|
||||
func hasDeadlineHint(value string) bool {
|
||||
// 1. 先用结构化正则快速判断(时间格式、日期格式、周几格式)。
|
||||
if quickNoteClockHMRegex.MatchString(value) ||
|
||||
quickNoteClockCNRegex.MatchString(value) ||
|
||||
quickNoteYMDRegex.MatchString(value) ||
|
||||
@@ -397,7 +327,6 @@ func hasDeadlineHint(value string) bool {
|
||||
quickNoteWeekdayRegex.MatchString(value) {
|
||||
return true
|
||||
}
|
||||
// 2. 再用词元判断“明天/今晚”等语义线索。
|
||||
for _, token := range quickNoteRelativeTokens {
|
||||
if strings.Contains(value, token) {
|
||||
return true
|
||||
@@ -406,51 +335,40 @@ func hasDeadlineHint(value string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// tryParseAbsoluteDeadline 尝试按绝对时间格式解析。
|
||||
// 若只提供日期(无时分),默认归一到当天 23:59,表示“当日截止”。
|
||||
func tryParseAbsoluteDeadline(value string, loc *time.Location) (*time.Time, bool) {
|
||||
// 逐个 layout 尝试,命中即返回。
|
||||
for _, layout := range quickNoteDeadlineLayouts {
|
||||
var (
|
||||
t time.Time
|
||||
err error
|
||||
parsed time.Time
|
||||
err error
|
||||
)
|
||||
if layout == time.RFC3339 {
|
||||
t, err = time.Parse(layout, value)
|
||||
parsed, err = time.Parse(layout, value)
|
||||
if err == nil {
|
||||
t = t.In(loc)
|
||||
parsed = parsed.In(loc)
|
||||
}
|
||||
} else {
|
||||
t, err = time.ParseInLocation(layout, value, loc)
|
||||
parsed, err = time.ParseInLocation(layout, value, loc)
|
||||
}
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Date-only 输入(例如 2026-03-20)默认补到 23:59。
|
||||
if _, dateOnly := quickNoteDateOnlyLayouts[layout]; dateOnly {
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 0, 0, loc)
|
||||
parsed = time.Date(parsed.Year(), parsed.Month(), parsed.Day(), 23, 59, 0, 0, loc)
|
||||
} else {
|
||||
// 非 date-only 则统一清零秒级,保持分钟粒度一致。
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), 0, 0, loc)
|
||||
parsed = time.Date(parsed.Year(), parsed.Month(), parsed.Day(), parsed.Hour(), parsed.Minute(), 0, 0, loc)
|
||||
}
|
||||
return &t, true
|
||||
return &parsed, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// tryParseRelativeDeadline 尝试解析“相对时间 + 可选时刻”。
|
||||
// 例子:
|
||||
// - 明天交报告(默认 23:59)
|
||||
// - 下周一上午9点开会(解析为下周一 09:00)
|
||||
func tryParseRelativeDeadline(value string, now time.Time, loc *time.Location) (*time.Time, bool, error) {
|
||||
// 1. 先确定“哪一天”。
|
||||
baseDate, recognized := inferBaseDate(value, now, loc)
|
||||
if !recognized {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// 2. 再解析“几点几分”,若缺失则按语义默认时刻兜底。
|
||||
hour, minute, hasExplicitClock, err := extractClock(value)
|
||||
if err != nil {
|
||||
return nil, true, err
|
||||
@@ -463,14 +381,7 @@ func tryParseRelativeDeadline(value string, now time.Time, loc *time.Location) (
|
||||
return &deadline, true, nil
|
||||
}
|
||||
|
||||
// inferBaseDate 负责先确定“哪一天”。
|
||||
// 解析优先级:
|
||||
// 1) 明确年月日;
|
||||
// 2) 月日(自动推断年份);
|
||||
// 3) 周几表达(本周/下周);
|
||||
// 4) 明天/后天/今晚等相对词。
|
||||
func inferBaseDate(value string, now time.Time, loc *time.Location) (time.Time, bool) {
|
||||
// 1) yyyy年MM月dd日
|
||||
if matched := quickNoteYMDRegex.FindStringSubmatch(value); len(matched) == 4 {
|
||||
year, _ := strconv.Atoi(matched[1])
|
||||
month, _ := strconv.Atoi(matched[2])
|
||||
@@ -480,7 +391,6 @@ func inferBaseDate(value string, now time.Time, loc *time.Location) (time.Time,
|
||||
}
|
||||
}
|
||||
|
||||
// 2) MM月dd日(自动推断年份:若今年已过则滚到明年)
|
||||
if matched := quickNoteMDRegex.FindStringSubmatch(value); len(matched) == 3 {
|
||||
month, _ := strconv.Atoi(matched[1])
|
||||
day, _ := strconv.Atoi(matched[2])
|
||||
@@ -499,7 +409,6 @@ func inferBaseDate(value string, now time.Time, loc *time.Location) (time.Time,
|
||||
return candidate, true
|
||||
}
|
||||
|
||||
// 3) 本周/下周 + 周几
|
||||
if matched := quickNoteWeekdayRegex.FindStringSubmatch(value); len(matched) == 3 {
|
||||
prefix := matched[1]
|
||||
target, ok := toWeekday(matched[2])
|
||||
@@ -508,7 +417,6 @@ func inferBaseDate(value string, now time.Time, loc *time.Location) (time.Time,
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 今天/明天/后天/大后天/昨天等相对词
|
||||
today := startOfDay(now)
|
||||
switch {
|
||||
case strings.Contains(value, "大后天"):
|
||||
@@ -526,17 +434,11 @@ func inferBaseDate(value string, now time.Time, loc *time.Location) (time.Time,
|
||||
}
|
||||
}
|
||||
|
||||
// extractClock 从文本提取时刻(时/分)。
|
||||
// 支持:
|
||||
// - 24h 表达:18:30
|
||||
// - 中文表达:3点、3点半、3点20分
|
||||
func extractClock(value string) (int, int, bool, error) {
|
||||
// hour/minute 最终会用于 time.Date,需要先做范围约束。
|
||||
hour := 0
|
||||
minute := 0
|
||||
hasClock := false
|
||||
|
||||
// 1) 24 小时制:18:30
|
||||
if matched := quickNoteClockHMRegex.FindStringSubmatch(value); len(matched) == 3 {
|
||||
h, errH := strconv.Atoi(matched[1])
|
||||
m, errM := strconv.Atoi(matched[2])
|
||||
@@ -547,7 +449,6 @@ func extractClock(value string) (int, int, bool, error) {
|
||||
minute = m
|
||||
hasClock = true
|
||||
} else if matched := quickNoteClockCNRegex.FindStringSubmatch(value); len(matched) >= 2 {
|
||||
// 2) 中文时刻:3点 / 3点半 / 3点20分
|
||||
h, errH := strconv.Atoi(matched[1])
|
||||
if errH != nil {
|
||||
return 0, 0, true, fmt.Errorf("deadline_at 时间解析失败: %s", value)
|
||||
@@ -569,11 +470,9 @@ func extractClock(value string) (int, int, bool, error) {
|
||||
}
|
||||
|
||||
if !hasClock {
|
||||
// 没有显式时刻并不是错误,交给默认时刻策略处理。
|
||||
return 0, 0, false, nil
|
||||
}
|
||||
|
||||
// 3) 根据“下午/晚上/中午/凌晨”等语义修正 12/24 小时制。
|
||||
if isPMHint(value) && hour < 12 {
|
||||
hour += 12
|
||||
}
|
||||
@@ -590,9 +489,7 @@ func extractClock(value string) (int, int, bool, error) {
|
||||
return hour, minute, true, nil
|
||||
}
|
||||
|
||||
// defaultClockByHint 当文本只给了“日期/相对日”但没给具体时刻时,按语义兜底。
|
||||
func defaultClockByHint(value string) (int, int) {
|
||||
// 没有明确时刻时按中文语义设置一个“可解释的默认值”。
|
||||
switch {
|
||||
case strings.Contains(value, "凌晨"):
|
||||
return 1, 0
|
||||
@@ -605,29 +502,24 @@ func defaultClockByHint(value string) (int, int) {
|
||||
case strings.Contains(value, "晚上") || strings.Contains(value, "今晚") || strings.Contains(value, "傍晚") || strings.Contains(value, "夜里"):
|
||||
return 20, 0
|
||||
default:
|
||||
// 只给了日期没有具体时刻时,默认当天结束前。
|
||||
return 23, 59
|
||||
}
|
||||
}
|
||||
|
||||
func isPMHint(value string) bool {
|
||||
// 下午/晚上/傍晚通常应映射到 12:00 之后。
|
||||
return strings.Contains(value, "下午") || strings.Contains(value, "晚上") || strings.Contains(value, "今晚") || strings.Contains(value, "傍晚")
|
||||
}
|
||||
|
||||
func isNoonHint(value string) bool {
|
||||
// “中午 1 点”这类表达通常是 13:00 而非 01:00。
|
||||
return strings.Contains(value, "中午")
|
||||
}
|
||||
|
||||
func startOfDay(t time.Time) time.Time {
|
||||
// 保留原时区,只把时分秒归零。
|
||||
loc := t.Location()
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
func isValidDate(year, month, day int) bool {
|
||||
// 先做快速范围筛,再用 time.Date 回填校验闰月闰年和越界日期。
|
||||
if month < 1 || month > 12 || day < 1 || day > 31 {
|
||||
return false
|
||||
}
|
||||
@@ -636,7 +528,6 @@ func isValidDate(year, month, day int) bool {
|
||||
}
|
||||
|
||||
func toWeekday(chinese string) (time.Weekday, bool) {
|
||||
// 把中文周几映射到 Go 的 Weekday 枚举。
|
||||
switch chinese {
|
||||
case "一":
|
||||
return time.Monday, true
|
||||
@@ -657,16 +548,13 @@ func toWeekday(chinese string) (time.Weekday, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// resolveWeekdayDate 根据“本周/下周 + 周几”换算目标日期。
|
||||
func resolveWeekdayDate(now time.Time, prefix string, target time.Weekday) time.Time {
|
||||
// 1. 先定位本周周一。
|
||||
today := startOfDay(now)
|
||||
weekdayOffset := (int(today.Weekday()) + 6) % 7
|
||||
weekStart := today.AddDate(0, 0, -weekdayOffset)
|
||||
targetOffset := (int(target) + 6) % 7
|
||||
candidateThisWeek := weekStart.AddDate(0, 0, targetOffset)
|
||||
|
||||
// 2. 再根据“本周/下周/无前缀”选择最终日期。
|
||||
switch {
|
||||
case strings.HasPrefix(prefix, "下"):
|
||||
return candidateThisWeek.AddDate(0, 0, 7)
|
||||
@@ -679,3 +567,19 @@ func resolveWeekdayDate(now time.Time, prefix string, target time.Weekday) time.
|
||||
return candidateThisWeek
|
||||
}
|
||||
}
|
||||
|
||||
func quickNoteLocation() *time.Location {
|
||||
loc, err := time.LoadLocation(agentmodel.QuickNoteTimezoneName)
|
||||
if err != nil {
|
||||
return time.Local
|
||||
}
|
||||
return loc
|
||||
}
|
||||
|
||||
func quickNoteNowToMinute() time.Time {
|
||||
return agentshared.NowToMinute()
|
||||
}
|
||||
|
||||
func formatQuickNoteTimeToMinute(t time.Time) string {
|
||||
return agentshared.FormatMinute(t.In(quickNoteLocation()))
|
||||
}
|
||||
2336
backend/agent/node/schedule_plan.go
Normal file
2336
backend/agent/node/schedule_plan.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,21 +1,158 @@
|
||||
package scheduleplan
|
||||
package agentnode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
agentllm "github.com/LoveLosita/smartflow/backend/agent/llm"
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
)
|
||||
|
||||
// SchedulePlanToolDeps 描述“智能排程 graph”运行所需的外部业务依赖。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只负责声明“需要哪些能力”,不负责具体实现(实现由 service 层注入)。
|
||||
// 2. 只收口函数签名,不承载业务状态,避免跨请求共享可变数据。
|
||||
// 3. 当前统一采用 task_class_ids 语义,不再依赖单 task_class_id 主路径。
|
||||
type SchedulePlanToolDeps struct {
|
||||
// SmartPlanningMultiRaw 是可选依赖:
|
||||
// 1) 用于需要单独输出“粗排预览”时复用;
|
||||
// 2) 当前主链路已由 HybridScheduleWithPlanMulti 覆盖,可不注入。
|
||||
SmartPlanningMultiRaw func(ctx context.Context, userID int, taskClassIDs []int) ([]model.UserWeekSchedule, []model.TaskClassItem, error)
|
||||
|
||||
// HybridScheduleWithPlanMulti 把“既有日程 + 粗排结果”合并成统一的 HybridScheduleEntry 切片,
|
||||
// 供 daily/weekly ReAct 节点在内存中继续优化。
|
||||
HybridScheduleWithPlanMulti func(ctx context.Context, userID int, taskClassIDs []int) ([]model.HybridScheduleEntry, []model.TaskClassItem, error)
|
||||
|
||||
// ResolvePlanningWindow 根据 task_class_ids 解析“全局排程窗口”的相对周/天边界。
|
||||
//
|
||||
// 返回语义:
|
||||
// 1. startWeek/startDay:窗口起点(含);
|
||||
// 2. endWeek/endDay:窗口终点(含);
|
||||
// 3. error:解析失败(如任务类不存在、日期非法)。
|
||||
//
|
||||
// 用途:
|
||||
// 1. 给周级 Move 工具加硬边界,避免把任务移动到窗口外的天数;
|
||||
// 2. 解决“首尾不足一周”场景下的周内越界问题。
|
||||
ResolvePlanningWindow func(ctx context.Context, userID int, taskClassIDs []int) (startWeek, startDay, endWeek, endDay int, err error)
|
||||
}
|
||||
|
||||
// Validate 校验依赖完整性。
|
||||
//
|
||||
// 失败处理:
|
||||
// 1. 任意依赖缺失都直接返回错误,避免 graph 运行到中途才 panic。
|
||||
// 2. 调用方(runSchedulePlanFlow)收到错误后会走回退链路,不影响普通聊天可用性。
|
||||
func (d SchedulePlanToolDeps) Validate() error {
|
||||
if d.HybridScheduleWithPlanMulti == nil {
|
||||
return errors.New("schedule plan tool deps: HybridScheduleWithPlanMulti is nil")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtraInt 从 extra map 中安全提取整数值。
|
||||
//
|
||||
// 兼容策略:
|
||||
// 1) JSON 数字默认解析为 float64,做 int 转换;
|
||||
// 2) 兼容字符串形式(如 "42"),用 Atoi 解析;
|
||||
// 3) 其余类型返回 false,由调用方决定后续处理。
|
||||
func ExtraInt(extra map[string]any, key string) (int, bool) {
|
||||
v, ok := extra[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n), true
|
||||
case int:
|
||||
return n, true
|
||||
case string:
|
||||
i, err := strconv.Atoi(n)
|
||||
return i, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ExtraIntSlice 从 extra map 中安全提取整数切片。
|
||||
//
|
||||
// 兼容输入:
|
||||
// 1) []any(JSON 数组反序列化后的常见类型);
|
||||
// 2) []int;
|
||||
// 3) []float64;
|
||||
// 4) 逗号分隔字符串(例如 "1,2,3")。
|
||||
//
|
||||
// 返回语义:
|
||||
// 1) ok=true:至少成功解析出一个整数;
|
||||
// 2) ok=false:字段不存在或全部解析失败。
|
||||
func ExtraIntSlice(extra map[string]any, key string) ([]int, bool) {
|
||||
v, exists := extra[key]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
parseOne := func(raw any) (int, error) {
|
||||
switch n := raw.(type) {
|
||||
case int:
|
||||
return n, nil
|
||||
case float64:
|
||||
return int(n), nil
|
||||
case string:
|
||||
i, err := strconv.Atoi(n)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return i, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported type: %T", raw)
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]int, 0)
|
||||
switch arr := v.(type) {
|
||||
case []int:
|
||||
for _, item := range arr {
|
||||
out = append(out, item)
|
||||
}
|
||||
case []float64:
|
||||
for _, item := range arr {
|
||||
out = append(out, int(item))
|
||||
}
|
||||
case []any:
|
||||
for _, item := range arr {
|
||||
if parsed, err := parseOne(item); err == nil {
|
||||
out = append(out, parsed)
|
||||
}
|
||||
}
|
||||
case string:
|
||||
parts := strings.Split(arr, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if parsed, err := strconv.Atoi(part); err == nil {
|
||||
out = append(out, parsed)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if len(out) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
// ── ReAct Tool 调用/结果结构 ──
|
||||
|
||||
// reactToolCall 是 LLM 输出的单个工具调用。
|
||||
type reactToolCall struct {
|
||||
Tool string `json:"tool"`
|
||||
Params map[string]any `json:"params"`
|
||||
}
|
||||
type reactToolCall = agentllm.ReactToolCall
|
||||
|
||||
// reactToolResult 是单个工具调用的执行结果。
|
||||
type reactToolResult struct {
|
||||
@@ -25,11 +162,7 @@ type reactToolResult struct {
|
||||
}
|
||||
|
||||
// reactLLMOutput 是 LLM 输出的完整 JSON 结构。
|
||||
type reactLLMOutput struct {
|
||||
Done bool `json:"done"`
|
||||
Summary string `json:"summary"`
|
||||
ToolCalls []reactToolCall `json:"tool_calls"`
|
||||
}
|
||||
type reactLLMOutput = agentllm.ReactLLMOutput
|
||||
|
||||
// weeklyPlanningWindow 表示周级优化可用的全局周/天窗口。
|
||||
//
|
||||
@@ -422,34 +555,7 @@ func reactToolGetAvailableSlots(entries []model.HybridScheduleEntry, params map[
|
||||
// parseReactLLMOutput 解析 LLM 的 JSON 输出。
|
||||
// 兼容 ```json ... ``` 包裹。
|
||||
func parseReactLLMOutput(raw string) (*reactLLMOutput, error) {
|
||||
clean := strings.TrimSpace(raw)
|
||||
if clean == "" {
|
||||
return nil, fmt.Errorf("LLM 输出为空")
|
||||
}
|
||||
// 兼容 markdown 包裹
|
||||
if strings.HasPrefix(clean, "```") {
|
||||
clean = strings.TrimPrefix(clean, "```json")
|
||||
clean = strings.TrimPrefix(clean, "```")
|
||||
clean = strings.TrimSuffix(clean, "```")
|
||||
clean = strings.TrimSpace(clean)
|
||||
}
|
||||
|
||||
var out reactLLMOutput
|
||||
if err := json.Unmarshal([]byte(clean), &out); err == nil {
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
// 提取最外层 JSON 对象
|
||||
start := strings.Index(clean, "{")
|
||||
end := strings.LastIndex(clean, "}")
|
||||
if start == -1 || end == -1 || end <= start {
|
||||
return nil, fmt.Errorf("无法从 LLM 输出中提取 JSON: %s", truncate(clean, 200))
|
||||
}
|
||||
obj := clean[start : end+1]
|
||||
if err := json.Unmarshal([]byte(obj), &out); err != nil {
|
||||
return nil, fmt.Errorf("JSON 解析失败: %w", err)
|
||||
}
|
||||
return &out, nil
|
||||
return agentllm.ParseScheduleReactOutput(raw)
|
||||
}
|
||||
|
||||
// truncate 截断字符串到指定长度。
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
729
backend/agent/node/taskquery.go
Normal file
729
backend/agent/node/taskquery.go
Normal file
@@ -0,0 +1,729 @@
|
||||
package agentnode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
agentllm "github.com/LoveLosita/smartflow/backend/agent/llm"
|
||||
agentmodel "github.com/LoveLosita/smartflow/backend/agent/model"
|
||||
agentprompt "github.com/LoveLosita/smartflow/backend/agent/prompt"
|
||||
agentstream "github.com/LoveLosita/smartflow/backend/agent/stream"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
const (
|
||||
TaskQueryGraphNodePlan = "task_query.plan"
|
||||
TaskQueryGraphNodeQuadrant = "task_query.quadrant"
|
||||
TaskQueryGraphNodeTimeAnchor = "task_query.time_anchor"
|
||||
TaskQueryGraphNodeQuery = "task_query.query"
|
||||
TaskQueryGraphNodeReflect = "task_query.reflect"
|
||||
)
|
||||
|
||||
var (
|
||||
explicitLimitPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)\btop\s*(\d{1,2})\b`),
|
||||
regexp.MustCompile(`前\s*(\d{1,2})\s*(个|条|项)?`),
|
||||
regexp.MustCompile(`(\d{1,2})\s*(个|条|项)?\s*任务`),
|
||||
regexp.MustCompile(`给我\s*(\d{1,2})\s*(个|条|项)?`),
|
||||
}
|
||||
chineseDigitMap = map[rune]int{
|
||||
'一': 1,
|
||||
'二': 2,
|
||||
'两': 2,
|
||||
'三': 3,
|
||||
'四': 4,
|
||||
'五': 5,
|
||||
'六': 6,
|
||||
'七': 7,
|
||||
'八': 8,
|
||||
'九': 9,
|
||||
'十': 10,
|
||||
}
|
||||
)
|
||||
|
||||
// TaskQueryGraphRunInput 描述一次任务查询图运行需要的依赖。
|
||||
type TaskQueryGraphRunInput struct {
|
||||
Model *ark.ChatModel
|
||||
State *agentmodel.TaskQueryState
|
||||
Deps TaskQueryToolDeps
|
||||
EmitStage func(stage, detail string)
|
||||
}
|
||||
|
||||
// TaskQueryNodes 是任务查询图的节点容器。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责承接请求级依赖,并向 graph 暴露可直接挂载的方法。
|
||||
// 2. 不负责 graph 编译、service 接线和持久化。
|
||||
type TaskQueryNodes struct {
|
||||
input TaskQueryGraphRunInput
|
||||
queryTool tool.InvokableTool
|
||||
emitStage agentstream.StageEmitter
|
||||
}
|
||||
|
||||
func NewTaskQueryNodes(input TaskQueryGraphRunInput, queryTool tool.InvokableTool) (*TaskQueryNodes, error) {
|
||||
if input.Model == nil {
|
||||
return nil, fmt.Errorf("task query nodes: model is nil")
|
||||
}
|
||||
if input.State == nil {
|
||||
return nil, fmt.Errorf("task query nodes: state is nil")
|
||||
}
|
||||
if err := input.Deps.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if queryTool == nil {
|
||||
return nil, fmt.Errorf("task query nodes: queryTool is nil")
|
||||
}
|
||||
return &TaskQueryNodes{
|
||||
input: input,
|
||||
queryTool: queryTool,
|
||||
emitStage: agentstream.WrapStageEmitter(input.EmitStage),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Plan 负责把用户原话规划成结构化查询计划。
|
||||
func (n *TaskQueryNodes) Plan(ctx context.Context, st *agentmodel.TaskQueryState) (*agentmodel.TaskQueryState, error) {
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("task query graph: nil state in plan node")
|
||||
}
|
||||
|
||||
n.emitStage("task_query.plan.generating", "正在一次性规划查询范围、排序和时间条件。")
|
||||
planned, err := agentllm.PlanTaskQuery(ctx, n.input.Model, st.RequestNowText, st.UserMessage)
|
||||
if err != nil || planned == nil {
|
||||
st.UserGoal = "查询任务"
|
||||
st.Plan = defaultTaskQueryPlan()
|
||||
return st, nil
|
||||
}
|
||||
|
||||
st.UserGoal = strings.TrimSpace(planned.UserGoal)
|
||||
if st.UserGoal == "" {
|
||||
st.UserGoal = "查询任务"
|
||||
}
|
||||
st.Plan = normalizeTaskQueryPlan(*planned)
|
||||
|
||||
// 1. 若用户原话里明确指定了返回条数,则以后端识别结果为准。
|
||||
// 2. 这样可以避免规划模型漏掉数量要求,或后续反思 patch 意外改写 limit。
|
||||
if explicitLimit, found := extractExplicitLimitFromUser(st.UserMessage); found {
|
||||
st.ExplicitLimit = explicitLimit
|
||||
st.Plan.Limit = explicitLimit
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// NormalizeQuadrant 负责把象限参数去重并统一成稳定顺序。
|
||||
func (n *TaskQueryNodes) NormalizeQuadrant(ctx context.Context, st *agentmodel.TaskQueryState) (*agentmodel.TaskQueryState, error) {
|
||||
_ = ctx
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("task query graph: nil state in quadrant node")
|
||||
}
|
||||
|
||||
n.emitStage("task_query.quadrant.routing", "正在归一化象限筛选范围。")
|
||||
st.Plan.Quadrants = normalizeQuadrants(st.Plan.Quadrants)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// AnchorTime 负责把时间文本边界解析成可执行时间对象。
|
||||
func (n *TaskQueryNodes) AnchorTime(ctx context.Context, st *agentmodel.TaskQueryState) (*agentmodel.TaskQueryState, error) {
|
||||
_ = ctx
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("task query graph: nil state in time anchor node")
|
||||
}
|
||||
|
||||
n.emitStage("task_query.time.anchoring", "正在锁定时间过滤边界。")
|
||||
applyTimeAnchorOnPlan(&st.Plan)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// Query 负责真正调用工具查询任务。
|
||||
func (n *TaskQueryNodes) Query(ctx context.Context, st *agentmodel.TaskQueryState) (*agentmodel.TaskQueryState, error) {
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("task query graph: nil state in query node")
|
||||
}
|
||||
|
||||
n.emitStage("task_query.tool.querying", "正在查询任务数据。")
|
||||
items, err := n.executePlanByTool(ctx, st.Plan)
|
||||
if err != nil {
|
||||
st.LastQueryItems = make([]agentmodel.TaskQueryItem, 0)
|
||||
st.LastQueryTotal = 0
|
||||
st.ReflectReason = "查询工具执行失败"
|
||||
return st, nil
|
||||
}
|
||||
|
||||
st.LastQueryItems = items
|
||||
st.LastQueryTotal = len(items)
|
||||
|
||||
// 1. 如果首轮为空且还没自动放宽过,则做一次可解释的自动放宽。
|
||||
// 2. 放宽范围仅限关键词、完成状态、时间边界,不主动改象限与 limit,避免语义漂移。
|
||||
if st.LastQueryTotal == 0 && !st.AutoBroadenApplied {
|
||||
broadenedPlan, changed := autoBroadenPlan(st.Plan)
|
||||
if changed {
|
||||
st.AutoBroadenApplied = true
|
||||
st.Plan = broadenedPlan
|
||||
n.emitStage("task_query.tool.broadened", "首次查询为空,已自动放宽条件再试一次。")
|
||||
retryItems, retryErr := n.executePlanByTool(ctx, st.Plan)
|
||||
if retryErr == nil {
|
||||
st.LastQueryItems = retryItems
|
||||
st.LastQueryTotal = len(retryItems)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// Reflect 负责判断当前结果是否满足用户诉求,并决定是否重试。
|
||||
func (n *TaskQueryNodes) Reflect(ctx context.Context, st *agentmodel.TaskQueryState) (*agentmodel.TaskQueryState, error) {
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("task query graph: nil state in reflect node")
|
||||
}
|
||||
|
||||
n.emitStage("task_query.reflecting", "正在判断结果是否贴合你的需求。")
|
||||
reflectPrompt := agentprompt.BuildTaskQueryReflectUserPrompt(
|
||||
st.RequestNowText,
|
||||
st.UserMessage,
|
||||
st.UserGoal,
|
||||
summarizeTaskQueryPlan(st.Plan),
|
||||
st.RetryCount,
|
||||
st.MaxReflectRetry,
|
||||
summarizeTaskQueryItems(st.LastQueryItems, 6),
|
||||
)
|
||||
reflectResult, err := agentllm.ReflectTaskQuery(ctx, n.input.Model, reflectPrompt)
|
||||
if err != nil || reflectResult == nil {
|
||||
st.NeedRetry = false
|
||||
st.FinalReply = buildTaskQueryFallbackReply(st.LastQueryItems)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
st.ReflectReason = strings.TrimSpace(reflectResult.Reason)
|
||||
|
||||
if reflectResult.Satisfied {
|
||||
st.NeedRetry = false
|
||||
st.FinalReply = buildTaskQueryFinalReply(st.LastQueryItems, st.Plan, strings.TrimSpace(reflectResult.Reply))
|
||||
return st, nil
|
||||
}
|
||||
|
||||
if reflectResult.NeedRetry && st.RetryCount < st.MaxReflectRetry {
|
||||
st.Plan = applyRetryPatch(st.Plan, reflectResult.RetryPatch, st.ExplicitLimit)
|
||||
st.RetryCount++
|
||||
st.NeedRetry = true
|
||||
if reply := strings.TrimSpace(reflectResult.Reply); reply != "" {
|
||||
st.FinalReply = reply
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
st.NeedRetry = false
|
||||
st.FinalReply = buildTaskQueryFinalReply(st.LastQueryItems, st.Plan, strings.TrimSpace(reflectResult.Reply))
|
||||
return st, nil
|
||||
}
|
||||
|
||||
func (n *TaskQueryNodes) NextAfterReflect(ctx context.Context, st *agentmodel.TaskQueryState) (string, error) {
|
||||
_ = ctx
|
||||
if st != nil && st.NeedRetry {
|
||||
return TaskQueryGraphNodeQuery, nil
|
||||
}
|
||||
return compose.END, nil
|
||||
}
|
||||
|
||||
func (n *TaskQueryNodes) executePlanByTool(ctx context.Context, plan agentmodel.TaskQueryPlan) ([]agentmodel.TaskQueryItem, error) {
|
||||
if n.queryTool == nil {
|
||||
return nil, fmt.Errorf("task query tool is nil")
|
||||
}
|
||||
|
||||
merged := make([]agentmodel.TaskQueryItem, 0, plan.Limit)
|
||||
seen := make(map[int]struct{}, plan.Limit*2)
|
||||
|
||||
runOne := func(quadrant *int) error {
|
||||
input := TaskQueryToolInput{
|
||||
Quadrant: quadrant,
|
||||
SortBy: plan.SortBy,
|
||||
Order: plan.Order,
|
||||
Limit: plan.Limit,
|
||||
Keyword: plan.Keyword,
|
||||
DeadlineBefore: plan.DeadlineBeforeText,
|
||||
DeadlineAfter: plan.DeadlineAfterText,
|
||||
}
|
||||
includeCompleted := plan.IncludeCompleted
|
||||
input.IncludeCompleted = &includeCompleted
|
||||
|
||||
rawInput, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rawOutput, err := n.queryTool.InvokableRun(ctx, string(rawInput))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
parsed, err := agentllm.ParseJSONObject[TaskQueryToolOutput](rawOutput)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, item := range parsed.Items {
|
||||
if _, exists := seen[item.ID]; exists {
|
||||
continue
|
||||
}
|
||||
seen[item.ID] = struct{}{}
|
||||
merged = append(merged, item)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(plan.Quadrants) == 0 {
|
||||
if err := runOne(nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
for _, quadrant := range plan.Quadrants {
|
||||
q := quadrant
|
||||
if err := runOne(&q); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sortTaskQueryItems(merged, plan)
|
||||
if len(merged) > plan.Limit {
|
||||
merged = merged[:plan.Limit]
|
||||
}
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
func defaultTaskQueryPlan() agentmodel.TaskQueryPlan {
|
||||
return agentmodel.TaskQueryPlan{
|
||||
SortBy: "deadline",
|
||||
Order: "asc",
|
||||
Limit: agentmodel.DefaultTaskQueryLimit,
|
||||
IncludeCompleted: false,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeTaskQueryPlan(raw agentllm.TaskQueryPlanOutput) agentmodel.TaskQueryPlan {
|
||||
plan := defaultTaskQueryPlan()
|
||||
plan.Quadrants = normalizeQuadrants(raw.Quadrants)
|
||||
|
||||
if sortBy := strings.ToLower(strings.TrimSpace(raw.SortBy)); sortBy == "deadline" || sortBy == "priority" || sortBy == "id" {
|
||||
plan.SortBy = sortBy
|
||||
}
|
||||
if order := strings.ToLower(strings.TrimSpace(raw.Order)); order == "asc" || order == "desc" {
|
||||
plan.Order = order
|
||||
}
|
||||
if raw.Limit > 0 {
|
||||
plan.Limit = raw.Limit
|
||||
}
|
||||
if plan.Limit > agentmodel.MaxTaskQueryLimit {
|
||||
plan.Limit = agentmodel.MaxTaskQueryLimit
|
||||
}
|
||||
if plan.Limit <= 0 {
|
||||
plan.Limit = agentmodel.DefaultTaskQueryLimit
|
||||
}
|
||||
if raw.IncludeCompleted != nil {
|
||||
plan.IncludeCompleted = *raw.IncludeCompleted
|
||||
}
|
||||
plan.Keyword = strings.TrimSpace(raw.Keyword)
|
||||
plan.DeadlineBeforeText = strings.TrimSpace(raw.DeadlineBefore)
|
||||
plan.DeadlineAfterText = strings.TrimSpace(raw.DeadlineAfter)
|
||||
applyTimeAnchorOnPlan(&plan)
|
||||
return plan
|
||||
}
|
||||
|
||||
func normalizeQuadrants(quadrants []int) []int {
|
||||
if len(quadrants) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
seen := make(map[int]struct{}, len(quadrants))
|
||||
result := make([]int, 0, len(quadrants))
|
||||
for _, quadrant := range quadrants {
|
||||
if quadrant < 1 || quadrant > 4 {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[quadrant]; exists {
|
||||
continue
|
||||
}
|
||||
seen[quadrant] = struct{}{}
|
||||
result = append(result, quadrant)
|
||||
}
|
||||
|
||||
sort.Ints(result)
|
||||
if len(result) == 0 || len(result) == 4 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func applyTimeAnchorOnPlan(plan *agentmodel.TaskQueryPlan) {
|
||||
if plan == nil {
|
||||
return
|
||||
}
|
||||
|
||||
before, errBefore := parseTaskQueryBoundaryTime(plan.DeadlineBeforeText, true)
|
||||
after, errAfter := parseTaskQueryBoundaryTime(plan.DeadlineAfterText, false)
|
||||
|
||||
if errBefore != nil {
|
||||
plan.DeadlineBefore = nil
|
||||
plan.DeadlineBeforeText = ""
|
||||
} else {
|
||||
plan.DeadlineBefore = before
|
||||
}
|
||||
if errAfter != nil {
|
||||
plan.DeadlineAfter = nil
|
||||
plan.DeadlineAfterText = ""
|
||||
} else {
|
||||
plan.DeadlineAfter = after
|
||||
}
|
||||
|
||||
if plan.DeadlineBefore != nil && plan.DeadlineAfter != nil && plan.DeadlineAfter.After(*plan.DeadlineBefore) {
|
||||
plan.DeadlineBefore = nil
|
||||
plan.DeadlineAfter = nil
|
||||
plan.DeadlineBeforeText = ""
|
||||
plan.DeadlineAfterText = ""
|
||||
}
|
||||
}
|
||||
|
||||
func autoBroadenPlan(plan agentmodel.TaskQueryPlan) (agentmodel.TaskQueryPlan, bool) {
|
||||
broadened := plan
|
||||
changed := false
|
||||
|
||||
if strings.TrimSpace(broadened.Keyword) != "" {
|
||||
broadened.Keyword = ""
|
||||
changed = true
|
||||
}
|
||||
if !broadened.IncludeCompleted {
|
||||
broadened.IncludeCompleted = true
|
||||
changed = true
|
||||
}
|
||||
if broadened.DeadlineBefore != nil || broadened.DeadlineAfter != nil || broadened.DeadlineBeforeText != "" || broadened.DeadlineAfterText != "" {
|
||||
broadened.DeadlineBefore = nil
|
||||
broadened.DeadlineAfter = nil
|
||||
broadened.DeadlineBeforeText = ""
|
||||
broadened.DeadlineAfterText = ""
|
||||
changed = true
|
||||
}
|
||||
return broadened, changed
|
||||
}
|
||||
|
||||
func applyRetryPatch(plan agentmodel.TaskQueryPlan, patch agentllm.TaskQueryRetryPatch, explicitLimit int) agentmodel.TaskQueryPlan {
|
||||
next := plan
|
||||
changed := false
|
||||
|
||||
if patch.Quadrants != nil {
|
||||
next.Quadrants = normalizeQuadrants(*patch.Quadrants)
|
||||
changed = true
|
||||
}
|
||||
if patch.SortBy != nil {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(*patch.SortBy))
|
||||
if sortBy == "deadline" || sortBy == "priority" || sortBy == "id" {
|
||||
next.SortBy = sortBy
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if patch.Order != nil {
|
||||
order := strings.ToLower(strings.TrimSpace(*patch.Order))
|
||||
if order == "asc" || order == "desc" {
|
||||
next.Order = order
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if patch.Limit != nil && explicitLimit <= 0 {
|
||||
limit := *patch.Limit
|
||||
if limit <= 0 {
|
||||
limit = agentmodel.DefaultTaskQueryLimit
|
||||
}
|
||||
if limit > agentmodel.MaxTaskQueryLimit {
|
||||
limit = agentmodel.MaxTaskQueryLimit
|
||||
}
|
||||
next.Limit = limit
|
||||
changed = true
|
||||
}
|
||||
if patch.IncludeCompleted != nil {
|
||||
next.IncludeCompleted = *patch.IncludeCompleted
|
||||
changed = true
|
||||
}
|
||||
if patch.Keyword != nil {
|
||||
next.Keyword = strings.TrimSpace(*patch.Keyword)
|
||||
changed = true
|
||||
}
|
||||
if patch.DeadlineBefore != nil {
|
||||
next.DeadlineBeforeText = strings.TrimSpace(*patch.DeadlineBefore)
|
||||
changed = true
|
||||
}
|
||||
if patch.DeadlineAfter != nil {
|
||||
next.DeadlineAfterText = strings.TrimSpace(*patch.DeadlineAfter)
|
||||
changed = true
|
||||
}
|
||||
|
||||
if changed {
|
||||
applyTimeAnchorOnPlan(&next)
|
||||
}
|
||||
if explicitLimit > 0 {
|
||||
next.Limit = explicitLimit
|
||||
}
|
||||
return next
|
||||
}
|
||||
|
||||
func summarizeTaskQueryPlan(plan agentmodel.TaskQueryPlan) string {
|
||||
quadrants := "全部象限"
|
||||
if len(plan.Quadrants) > 0 {
|
||||
parts := make([]string, 0, len(plan.Quadrants))
|
||||
for _, quadrant := range plan.Quadrants {
|
||||
parts = append(parts, strconv.Itoa(quadrant))
|
||||
}
|
||||
quadrants = strings.Join(parts, ",")
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"quadrants=%s sort=%s/%s limit=%d include_completed=%t keyword=%s before=%s after=%s",
|
||||
quadrants,
|
||||
plan.SortBy,
|
||||
plan.Order,
|
||||
plan.Limit,
|
||||
plan.IncludeCompleted,
|
||||
emptyToDash(plan.Keyword),
|
||||
emptyToDash(plan.DeadlineBeforeText),
|
||||
emptyToDash(plan.DeadlineAfterText),
|
||||
)
|
||||
}
|
||||
|
||||
func summarizeTaskQueryItems(items []agentmodel.TaskQueryItem, max int) string {
|
||||
if len(items) == 0 {
|
||||
return "无结果"
|
||||
}
|
||||
if max <= 0 {
|
||||
max = 5
|
||||
}
|
||||
if len(items) > max {
|
||||
items = items[:max]
|
||||
}
|
||||
|
||||
lines := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
lines = append(lines, fmt.Sprintf(
|
||||
"- #%d %s | 象限=%d | 完成=%t | 截止=%s",
|
||||
item.ID,
|
||||
item.Title,
|
||||
item.PriorityGroup,
|
||||
item.IsCompleted,
|
||||
emptyToDash(item.DeadlineAt),
|
||||
))
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func buildTaskQueryFallbackReply(items []agentmodel.TaskQueryItem) string {
|
||||
if len(items) == 0 {
|
||||
return "我这边暂时没找到匹配的任务。你可以再补一句,比如“按截止时间最早的前 3 个”或“只看简单不重要”。"
|
||||
}
|
||||
|
||||
preview := items
|
||||
if len(preview) > 3 {
|
||||
preview = preview[:3]
|
||||
}
|
||||
lines := make([]string, 0, len(preview))
|
||||
for _, item := range preview {
|
||||
lines = append(lines, fmt.Sprintf("%s(%s)", item.Title, item.PriorityLabel))
|
||||
}
|
||||
return fmt.Sprintf("我先给你筛到这些:%s。要不要我再按“更紧急”或“更简单”继续细化?", strings.Join(lines, "、"))
|
||||
}
|
||||
|
||||
func buildTaskQueryFinalReply(items []agentmodel.TaskQueryItem, plan agentmodel.TaskQueryPlan, llmReply string) string {
|
||||
if len(items) == 0 {
|
||||
base := buildTaskQueryFallbackReply(items)
|
||||
if strings.TrimSpace(llmReply) == "" {
|
||||
return base
|
||||
}
|
||||
return strings.TrimSpace(llmReply) + "\n" + base
|
||||
}
|
||||
|
||||
desired := plan.Limit
|
||||
if desired <= 0 {
|
||||
desired = agentmodel.DefaultTaskQueryLimit
|
||||
}
|
||||
if desired > agentmodel.MaxTaskQueryLimit {
|
||||
desired = agentmodel.MaxTaskQueryLimit
|
||||
}
|
||||
|
||||
showCount := desired
|
||||
if len(items) < showCount {
|
||||
showCount = len(items)
|
||||
}
|
||||
|
||||
preview := items[:showCount]
|
||||
lines := make([]string, 0, len(preview))
|
||||
for idx, item := range preview {
|
||||
deadline := strings.TrimSpace(item.DeadlineAt)
|
||||
if deadline == "" {
|
||||
deadline = "无明确截止时间"
|
||||
}
|
||||
status := "未完成"
|
||||
if item.IsCompleted {
|
||||
status = "已完成"
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf(
|
||||
"%d. %s(%s,%s,截止:%s)",
|
||||
idx+1,
|
||||
item.Title,
|
||||
item.PriorityLabel,
|
||||
status,
|
||||
deadline,
|
||||
))
|
||||
}
|
||||
|
||||
header := fmt.Sprintf("给你整理了 %d 条任务:", showCount)
|
||||
if lead := extractSafeReplyLead(llmReply); lead != "" {
|
||||
header = lead + "\n" + header
|
||||
}
|
||||
|
||||
reply := header + "\n" + strings.Join(lines, "\n")
|
||||
if len(items) > showCount {
|
||||
reply += fmt.Sprintf("\n另外还有 %d 条匹配任务,要不要我继续往下列?", len(items)-showCount)
|
||||
}
|
||||
return reply
|
||||
}
|
||||
|
||||
func extractSafeReplyLead(llmReply string) string {
|
||||
text := strings.TrimSpace(llmReply)
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
lower := strings.ToLower(text)
|
||||
if strings.Contains(text, "\n") ||
|
||||
strings.Contains(text, "#") ||
|
||||
strings.Contains(lower, "1.") ||
|
||||
strings.Contains(text, "1、") ||
|
||||
strings.Contains(text, "以下是") {
|
||||
return ""
|
||||
}
|
||||
if len([]rune(text)) > 30 {
|
||||
return ""
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func sortTaskQueryItems(items []agentmodel.TaskQueryItem, plan agentmodel.TaskQueryPlan) {
|
||||
if len(items) <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
sortBy := strings.ToLower(strings.TrimSpace(plan.SortBy))
|
||||
order := strings.ToLower(strings.TrimSpace(plan.Order))
|
||||
if order != "desc" {
|
||||
order = "asc"
|
||||
}
|
||||
|
||||
sort.SliceStable(items, func(i, j int) bool {
|
||||
left := items[i]
|
||||
right := items[j]
|
||||
|
||||
switch sortBy {
|
||||
case "priority":
|
||||
if left.PriorityGroup != right.PriorityGroup {
|
||||
if order == "desc" {
|
||||
return left.PriorityGroup > right.PriorityGroup
|
||||
}
|
||||
return left.PriorityGroup < right.PriorityGroup
|
||||
}
|
||||
return left.ID > right.ID
|
||||
case "id":
|
||||
if order == "desc" {
|
||||
return left.ID > right.ID
|
||||
}
|
||||
return left.ID < right.ID
|
||||
default:
|
||||
leftTime, leftOK := parseTaskQueryItemDeadline(left.DeadlineAt)
|
||||
rightTime, rightOK := parseTaskQueryItemDeadline(right.DeadlineAt)
|
||||
if leftOK && rightOK {
|
||||
if !leftTime.Equal(rightTime) {
|
||||
if order == "desc" {
|
||||
return leftTime.After(rightTime)
|
||||
}
|
||||
return leftTime.Before(rightTime)
|
||||
}
|
||||
return left.ID > right.ID
|
||||
}
|
||||
if leftOK && !rightOK {
|
||||
return true
|
||||
}
|
||||
if !leftOK && rightOK {
|
||||
return false
|
||||
}
|
||||
return left.ID > right.ID
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func parseTaskQueryItemDeadline(raw string) (time.Time, bool) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
parsed, err := time.ParseInLocation("2006-01-02 15:04", text, time.Local)
|
||||
if err != nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return parsed, true
|
||||
}
|
||||
|
||||
func emptyToDash(text string) string {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return "-"
|
||||
}
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
// extractExplicitLimitFromUser 从用户原话里提取显式条数要求。
|
||||
//
|
||||
// 步骤说明:
|
||||
// 1. 先识别阿拉伯数字表达,例如“前3个”“给我5条”“top 10”。
|
||||
// 2. 再识别中文数字表达,例如“前五个”“来三个”。
|
||||
// 3. 最终统一约束到 1~20 范围内。
|
||||
func extractExplicitLimitFromUser(userMessage string) (int, bool) {
|
||||
text := strings.TrimSpace(userMessage)
|
||||
if text == "" {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
for _, pattern := range explicitLimitPatterns {
|
||||
matched := pattern.FindStringSubmatch(text)
|
||||
if len(matched) < 2 {
|
||||
continue
|
||||
}
|
||||
number, err := strconv.Atoi(strings.TrimSpace(matched[1]))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
return normalizeExplicitLimit(number)
|
||||
}
|
||||
|
||||
for _, prefix := range []string{"前", "来", "给我"} {
|
||||
for digit, number := range chineseDigitMap {
|
||||
token := prefix + string(digit)
|
||||
if strings.Contains(text, token) {
|
||||
return normalizeExplicitLimit(number)
|
||||
}
|
||||
for _, suffix := range []string{"个", "条", "项"} {
|
||||
if strings.Contains(text, token+suffix) {
|
||||
return normalizeExplicitLimit(number)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func normalizeExplicitLimit(number int) (int, bool) {
|
||||
if number <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
if number > agentmodel.MaxTaskQueryLimit {
|
||||
number = agentmodel.MaxTaskQueryLimit
|
||||
}
|
||||
return number, true
|
||||
}
|
||||
286
backend/agent/node/taskquery_tool.go
Normal file
286
backend/agent/node/taskquery_tool.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package agentnode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
agentmodel "github.com/LoveLosita/smartflow/backend/agent/model"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
toolutils "github.com/cloudwego/eino/components/tool/utils"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
const (
|
||||
ToolNameTaskQueryTasks = "query_tasks"
|
||||
ToolDescTaskQueryTasks = "按象限、关键词、截止时间筛选并排序任务,返回结构化任务列表"
|
||||
)
|
||||
|
||||
var taskQueryTimeLayouts = []string{
|
||||
time.RFC3339,
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02 15:04",
|
||||
"2006-01-02",
|
||||
}
|
||||
|
||||
// TaskQueryToolDeps 描述任务查询工具依赖的外部查询能力。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. QueryTasks 负责读取真实任务数据。
|
||||
// 2. 工具层只负责参数校验、归一化和结构化输出,不直接耦合 DAO。
|
||||
type TaskQueryToolDeps struct {
|
||||
QueryTasks func(ctx context.Context, req TaskQueryRequest) ([]TaskQueryTaskRecord, error)
|
||||
}
|
||||
|
||||
// Validate 负责校验任务查询工具依赖是否齐全。
|
||||
func (d TaskQueryToolDeps) Validate() error {
|
||||
if d.QueryTasks == nil {
|
||||
return errors.New("task query tool deps: QueryTasks is nil")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TaskQueryToolBundle 同时返回工具实例和工具元信息。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. Tools 给执行节点使用。
|
||||
// 2. ToolInfos 给模型注册 schema 使用。
|
||||
type TaskQueryToolBundle struct {
|
||||
Tools []tool.BaseTool
|
||||
ToolInfos []*schema.ToolInfo
|
||||
}
|
||||
|
||||
// TaskQueryRequest 是工具层传给业务层的内部查询请求。
|
||||
type TaskQueryRequest struct {
|
||||
UserID int
|
||||
Quadrant *int
|
||||
SortBy string
|
||||
Order string
|
||||
Limit int
|
||||
IncludeCompleted bool
|
||||
Keyword string
|
||||
DeadlineBefore *time.Time
|
||||
DeadlineAfter *time.Time
|
||||
}
|
||||
|
||||
// TaskQueryTaskRecord 是业务层返回给工具层的任务记录。
|
||||
type TaskQueryTaskRecord struct {
|
||||
ID int
|
||||
Title string
|
||||
PriorityGroup int
|
||||
IsCompleted bool
|
||||
DeadlineAt *time.Time
|
||||
UrgencyThresholdAt *time.Time
|
||||
}
|
||||
|
||||
// TaskQueryToolInput 是暴露给大模型的工具入参。
|
||||
type TaskQueryToolInput struct {
|
||||
Quadrant *int `json:"quadrant,omitempty" jsonschema:"description=可选象限(1~4)"`
|
||||
SortBy string `json:"sort_by,omitempty" jsonschema:"description=排序字段(deadline|priority|id)"`
|
||||
Order string `json:"order,omitempty" jsonschema:"description=排序方向(asc|desc)"`
|
||||
Limit int `json:"limit,omitempty" jsonschema:"description=返回条数,默认5,上限20"`
|
||||
IncludeCompleted *bool `json:"include_completed,omitempty" jsonschema:"description=是否包含已完成任务,默认false"`
|
||||
Keyword string `json:"keyword,omitempty" jsonschema:"description=可选标题关键词,模糊匹配"`
|
||||
DeadlineBefore string `json:"deadline_before,omitempty" jsonschema:"description=可选截止时间上界,支持RFC3339或yyyy-MM-dd HH:mm"`
|
||||
DeadlineAfter string `json:"deadline_after,omitempty" jsonschema:"description=可选截止时间下界,支持RFC3339或yyyy-MM-dd HH:mm"`
|
||||
}
|
||||
|
||||
// TaskQueryToolOutput 是返回给模型的结构化结果。
|
||||
type TaskQueryToolOutput struct {
|
||||
Total int `json:"total"`
|
||||
Items []agentmodel.TaskQueryItem `json:"items"`
|
||||
}
|
||||
|
||||
// BuildTaskQueryToolBundle 负责构建任务查询工具包。
|
||||
//
|
||||
// 步骤说明:
|
||||
// 1. 先校验依赖是否完整,避免生成一个运行时必定失败的工具。
|
||||
// 2. 再把输入归一化成内部请求,调用业务查询函数拿到真实数据。
|
||||
// 3. 最后把业务记录转换成统一的轻量任务视图,供模型和反思节点复用。
|
||||
func BuildTaskQueryToolBundle(ctx context.Context, deps TaskQueryToolDeps) (*TaskQueryToolBundle, error) {
|
||||
if err := deps.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
queryTool, err := toolutils.InferTool(
|
||||
ToolNameTaskQueryTasks,
|
||||
ToolDescTaskQueryTasks,
|
||||
func(ctx context.Context, input *TaskQueryToolInput) (*TaskQueryToolOutput, error) {
|
||||
req, err := normalizeTaskQueryToolInput(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
records, err := deps.QueryTasks(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
items := make([]agentmodel.TaskQueryItem, 0, len(records))
|
||||
for _, record := range records {
|
||||
items = append(items, agentmodel.TaskQueryItem{
|
||||
ID: record.ID,
|
||||
Title: record.Title,
|
||||
PriorityGroup: record.PriorityGroup,
|
||||
PriorityLabel: agentmodel.PriorityLabelCN(record.PriorityGroup),
|
||||
IsCompleted: record.IsCompleted,
|
||||
DeadlineAt: formatTaskQueryTime(record.DeadlineAt),
|
||||
UrgencyThresholdAt: formatTaskQueryTime(record.UrgencyThresholdAt),
|
||||
})
|
||||
}
|
||||
|
||||
return &TaskQueryToolOutput{
|
||||
Total: len(items),
|
||||
Items: items,
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("构建任务查询工具失败: %w", err)
|
||||
}
|
||||
|
||||
tools := []tool.BaseTool{queryTool}
|
||||
infos, err := collectToolInfos(ctx, tools)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TaskQueryToolBundle{
|
||||
Tools: tools,
|
||||
ToolInfos: infos,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetTaskQueryInvokableToolByName 按工具名提取可执行工具。
|
||||
func GetTaskQueryInvokableToolByName(bundle *TaskQueryToolBundle, name string) (tool.InvokableTool, error) {
|
||||
if bundle == nil {
|
||||
return nil, errors.New("task query tool bundle is nil")
|
||||
}
|
||||
return getInvokableToolByName(bundle.Tools, bundle.ToolInfos, name)
|
||||
}
|
||||
|
||||
// normalizeTaskQueryToolInput 负责参数默认值回填与合法性校验。
|
||||
//
|
||||
// 步骤说明:
|
||||
// 1. 先准备默认值,保证空参数也能执行一次合理查询。
|
||||
// 2. 再校验象限、排序、条数和时间区间,阻止非法参数下沉到业务层。
|
||||
// 3. 若上下界冲突,则直接返回错误,避免查出必为空的结果。
|
||||
func normalizeTaskQueryToolInput(input *TaskQueryToolInput) (TaskQueryRequest, error) {
|
||||
req := TaskQueryRequest{
|
||||
SortBy: "deadline",
|
||||
Order: "asc",
|
||||
Limit: agentmodel.DefaultTaskQueryLimit,
|
||||
IncludeCompleted: false,
|
||||
}
|
||||
if input == nil {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
if input.Quadrant != nil {
|
||||
if *input.Quadrant < 1 || *input.Quadrant > 4 {
|
||||
return TaskQueryRequest{}, fmt.Errorf("quadrant=%d 非法,必须在 1~4", *input.Quadrant)
|
||||
}
|
||||
quadrant := *input.Quadrant
|
||||
req.Quadrant = &quadrant
|
||||
}
|
||||
|
||||
if sortBy := strings.ToLower(strings.TrimSpace(input.SortBy)); sortBy != "" {
|
||||
req.SortBy = sortBy
|
||||
}
|
||||
switch req.SortBy {
|
||||
case "deadline", "priority", "id":
|
||||
default:
|
||||
return TaskQueryRequest{}, fmt.Errorf("sort_by=%s 非法,仅支持 deadline|priority|id", req.SortBy)
|
||||
}
|
||||
|
||||
if order := strings.ToLower(strings.TrimSpace(input.Order)); order != "" {
|
||||
req.Order = order
|
||||
}
|
||||
switch req.Order {
|
||||
case "asc", "desc":
|
||||
default:
|
||||
return TaskQueryRequest{}, fmt.Errorf("order=%s 非法,仅支持 asc|desc", req.Order)
|
||||
}
|
||||
|
||||
if input.Limit > 0 {
|
||||
req.Limit = input.Limit
|
||||
}
|
||||
if req.Limit > agentmodel.MaxTaskQueryLimit {
|
||||
req.Limit = agentmodel.MaxTaskQueryLimit
|
||||
}
|
||||
if req.Limit <= 0 {
|
||||
req.Limit = agentmodel.DefaultTaskQueryLimit
|
||||
}
|
||||
|
||||
if input.IncludeCompleted != nil {
|
||||
req.IncludeCompleted = *input.IncludeCompleted
|
||||
}
|
||||
req.Keyword = strings.TrimSpace(input.Keyword)
|
||||
|
||||
before, err := parseTaskQueryBoundaryTime(input.DeadlineBefore, true)
|
||||
if err != nil {
|
||||
return TaskQueryRequest{}, err
|
||||
}
|
||||
after, err := parseTaskQueryBoundaryTime(input.DeadlineAfter, false)
|
||||
if err != nil {
|
||||
return TaskQueryRequest{}, err
|
||||
}
|
||||
req.DeadlineBefore = before
|
||||
req.DeadlineAfter = after
|
||||
if req.DeadlineBefore != nil && req.DeadlineAfter != nil && req.DeadlineAfter.After(*req.DeadlineBefore) {
|
||||
return TaskQueryRequest{}, errors.New("deadline_after 不能晚于 deadline_before")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// parseTaskQueryBoundaryTime 解析截止时间上下界。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. isUpper=true 时,纯日期补到当天 23:59:59。
|
||||
// 2. isUpper=false 时,纯日期补到当天 00:00:00。
|
||||
// 3. 不支持的格式直接返回错误,由调用方决定是否回退。
|
||||
func parseTaskQueryBoundaryTime(raw string, isUpper bool) (*time.Time, error) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
loc := time.Local
|
||||
for _, layout := range taskQueryTimeLayouts {
|
||||
var (
|
||||
parsed time.Time
|
||||
err error
|
||||
)
|
||||
if layout == time.RFC3339 {
|
||||
parsed, err = time.Parse(layout, text)
|
||||
if err == nil {
|
||||
parsed = parsed.In(loc)
|
||||
}
|
||||
} else {
|
||||
parsed, err = time.ParseInLocation(layout, text, loc)
|
||||
}
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if layout == "2006-01-02" {
|
||||
if isUpper {
|
||||
parsed = time.Date(parsed.Year(), parsed.Month(), parsed.Day(), 23, 59, 59, 0, loc)
|
||||
} else {
|
||||
parsed = time.Date(parsed.Year(), parsed.Month(), parsed.Day(), 0, 0, 0, 0, loc)
|
||||
}
|
||||
}
|
||||
return &parsed, nil
|
||||
}
|
||||
return nil, fmt.Errorf("时间格式不支持: %s", text)
|
||||
}
|
||||
|
||||
// formatTaskQueryTime 负责把内部时间格式化为给模型展示的分钟级文本。
|
||||
func formatTaskQueryTime(value *time.Time) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
return value.In(time.Local).Format("2006-01-02 15:04")
|
||||
}
|
||||
74
backend/agent/node/tool_common.go
Normal file
74
backend/agent/node/tool_common.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package agentnode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// collectToolInfos 负责批量提取工具元信息,供模型注册与工具索引复用。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只负责调用 tool.Info 并聚合返回结果。
|
||||
// 2. 不负责校验工具是否可执行,也不负责按名称检索工具。
|
||||
func collectToolInfos(ctx context.Context, tools []tool.BaseTool) ([]*schema.ToolInfo, error) {
|
||||
infos := make([]*schema.ToolInfo, 0, len(tools))
|
||||
for _, currentTool := range tools {
|
||||
info, err := currentTool.Info(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取工具信息失败: %w", err)
|
||||
}
|
||||
infos = append(infos, info)
|
||||
}
|
||||
return infos, nil
|
||||
}
|
||||
|
||||
// buildInvokableToolMap 负责把工具列表转换成“工具名 -> 可执行工具”的索引表。
|
||||
//
|
||||
// 步骤说明:
|
||||
// 1. 先校验 tools 与 infos 是否一一对应,避免后续按下标取值时出现错配。
|
||||
// 2. 再校验每个工具都带有合法名字,并且确实实现了 InvokableTool 接口。
|
||||
// 3. 任一步失败都立即返回错误,避免 graph 在运行期拿到半残缺的工具集合。
|
||||
func buildInvokableToolMap(tools []tool.BaseTool, infos []*schema.ToolInfo) (map[string]tool.InvokableTool, error) {
|
||||
if len(tools) == 0 || len(infos) == 0 {
|
||||
return nil, errors.New("tool bundle is empty")
|
||||
}
|
||||
if len(tools) != len(infos) {
|
||||
return nil, errors.New("tool bundle mismatch")
|
||||
}
|
||||
|
||||
result := make(map[string]tool.InvokableTool, len(tools))
|
||||
for idx, currentTool := range tools {
|
||||
info := infos[idx]
|
||||
if info == nil || strings.TrimSpace(info.Name) == "" {
|
||||
return nil, errors.New("tool info is invalid")
|
||||
}
|
||||
invokable, ok := currentTool.(tool.InvokableTool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tool %s is not invokable", info.Name)
|
||||
}
|
||||
result[info.Name] = invokable
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// getInvokableToolByName 负责从工具集合中提取指定名称的可执行工具。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责复用统一索引逻辑,避免各业务链路重复写名称查找代码。
|
||||
// 2. 不负责兜底选择其他工具;未命中时直接返回错误,由上层决定如何处理。
|
||||
func getInvokableToolByName(tools []tool.BaseTool, infos []*schema.ToolInfo, name string) (tool.InvokableTool, error) {
|
||||
invokableMap, err := buildInvokableToolMap(tools, infos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
invokable, ok := invokableMap[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tool %s not found", name)
|
||||
}
|
||||
return invokable, nil
|
||||
}
|
||||
@@ -1,35 +1,7 @@
|
||||
package quicknote
|
||||
package agentprompt
|
||||
|
||||
const (
|
||||
// QuickNoteRouteControlPrompt 用于“首段控制码分流”:
|
||||
// - 仅负责判断用户输入应走 quick_note 还是 chat;
|
||||
// - 不直接回答用户问题;
|
||||
// - 必须输出可机读控制码,便于后端无歧义解析。
|
||||
// 额外说明:
|
||||
// 1) 这里要求固定 XML 结构,是为了让后端做严格字符串/标签解析,而不是模糊关键词匹配;
|
||||
// 2) 增加 reason 标签,主要用于日志排障(看模型为何判到 quick_note/chat);
|
||||
// 3) 明确“禁止输出其他内容”,是为了减少模型附加寒暄导致解析失败。
|
||||
QuickNoteRouteControlPrompt = `你是 SmartFlow 的请求分流控制器。
|
||||
你的唯一任务是给后端返回可机读控制码,不要做用户可见回复,不要解释。
|
||||
|
||||
判定规则:
|
||||
1) 若用户表达“希望你在将来提醒/记录/安排某件事”,输出 quick_note。
|
||||
2) 其余情况输出 chat(包括闲聊、知识问答、纯讨论、观点交流)。
|
||||
3) 口语变体(如“d我/q我/戳我/到点喊我/记得提醒我”)也属于 quick_note。
|
||||
|
||||
输出格式必须严格如下(两行,大小写不敏感):
|
||||
<SMARTFLOW_ROUTE nonce="给定nonce" action="quick_note|chat"></SMARTFLOW_ROUTE>
|
||||
<SMARTFLOW_REASON>一句不超过30字的中文理由</SMARTFLOW_REASON>
|
||||
|
||||
禁止输出任何其他内容。`
|
||||
|
||||
// QuickNotePlanPrompt 用于“单请求聚合规划”:
|
||||
// - 在一次调用内完成标题抽取、时间归一化、紧急分界线评估、优先级评估、跟进句生成;
|
||||
// - 主要用于路由已明确命中 quick_note 的场景,以降低串行 LLM 调用次数。
|
||||
// 额外说明:
|
||||
// 1) 强制 JSON 输出,减少后端解析分支复杂度;
|
||||
// 2) deadline_at / urgency_threshold_at 统一分钟级,方便直接映射到数据库时间字段;
|
||||
// 3) banter 与事实分离,避免润色文案污染结构化字段。
|
||||
// QuickNotePlanPrompt 用于“单请求聚合规划”。
|
||||
QuickNotePlanPrompt = `你是 SmartFlow 的任务聚合规划器。
|
||||
你将基于用户输入,一次性输出任务规划结果,供后端直接写库。
|
||||
|
||||
@@ -48,11 +20,6 @@ const (
|
||||
- banter 不得新增或修改任务事实(任务名、时间、优先级)。`
|
||||
|
||||
// QuickNoteIntentPrompt 用于第一阶段:判断用户输入是否属于“随口记”。
|
||||
// 设计约束:
|
||||
// 1) 只做识别与抽取,不允许模型宣称“已写库”;
|
||||
// 2) 遇到相对时间必须先换算成绝对时间,减少后续工具层歧义;
|
||||
// 3) 若无时间信息必须返回空字符串,避免幻觉时间污染数据库。
|
||||
// 4) 把“当前时间”明确注入 prompt,保证相对时间换算有统一基准。
|
||||
QuickNoteIntentPrompt = `你是 SmartFlow 的“随口记分诊器”。
|
||||
请判断用户输入是否表达了“帮我记一个任务/日程”的需求。
|
||||
- 若是,请提取任务标题与时间线索。
|
||||
@@ -61,8 +28,6 @@ const (
|
||||
- 不要声称已经写入数据库。`
|
||||
|
||||
// QuickNotePriorityPrompt 用于第二阶段:将任务归类到四象限优先级,并评估紧急分界线。
|
||||
// 输出会直接映射到 tasks.priority(1~4),因此要求结果必须可解释。
|
||||
// 这里强调“理由必须可解释”,是为了后续日志复盘时能看懂模型为何这么判。
|
||||
QuickNotePriorityPrompt = `你是 SmartFlow 的任务优先级评估器。
|
||||
根据任务内容、时间约束和执行成本,输出优先级 priority_group:
|
||||
1=重要且紧急,2=重要不紧急,3=简单不重要,4=不简单不重要。
|
||||
@@ -70,11 +35,6 @@ const (
|
||||
若你认为该任务需要后续自动平移,请额外输出 urgency_threshold_at(绝对时间,yyyy-MM-dd HH:mm);否则输出空字符串。`
|
||||
|
||||
// QuickNoteReplyBanterPrompt 用于随口记成功后的“轻松跟进句”生成。
|
||||
// 约束重点:
|
||||
// 1) 只输出一句自然中文;
|
||||
// 2) 贴合用户原话题(例如吃早餐、开会、写报告);
|
||||
// 3) 禁止新增事实(尤其不能改时间、优先级、任务内容);
|
||||
// 4) 不要 markdown,不要列表,不要引号包裹。
|
||||
QuickNoteReplyBanterPrompt = `你是 SmartFlow 的中文口语化回复润色助手。
|
||||
请根据用户原话生成一句轻松自然的跟进话术,让回复更有温度。
|
||||
要求:
|
||||
24
backend/agent/prompt/route.go
Normal file
24
backend/agent/prompt/route.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package agentprompt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const routeSystemPrompt = `
|
||||
你是 SmartFlow 的一级路由助手。
|
||||
你的职责不是回答用户,而是判断这条消息更适合走哪条能力链路。
|
||||
|
||||
当前 Agent 仍在逐批迁移阶段,因此这里只先保留 prompt 落点与职责说明。
|
||||
真正迁移旧 route 提示词时,应把正式版本收敛到这里,而不是散落在 node 或 service 中。
|
||||
`
|
||||
|
||||
// BuildRouteSystemPrompt 返回一级路由系统提示词。
|
||||
func BuildRouteSystemPrompt() string {
|
||||
return strings.TrimSpace(routeSystemPrompt)
|
||||
}
|
||||
|
||||
// BuildRouteUserPrompt 构造一级路由用户提示词。
|
||||
func BuildRouteUserPrompt(userInput string) string {
|
||||
return fmt.Sprintf("用户输入:%s", strings.TrimSpace(userInput))
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package scheduleplan
|
||||
package agentprompt
|
||||
|
||||
const (
|
||||
// SchedulePlanIntentPrompt 用于 plan 节点:从用户输入提取排程意图与约束。
|
||||
@@ -7,12 +7,6 @@ const (
|
||||
// 1. 负责把自然语言转成结构化 JSON,供后端节点分流与执行;
|
||||
// 2. 负责抽取 task_class_ids / strategy / task_tags 等关键字段;
|
||||
// 3. 不负责做排程计算,不负责做工具调用。
|
||||
//
|
||||
// 输出约束:
|
||||
// 1. 必须只输出 JSON,禁止附加解释文本;
|
||||
// 2. task_class_ids 是主语义;
|
||||
// 3. task_class_id 仅作为兼容字段保留,便于老链路平滑过渡;
|
||||
// 4. 需要额外给出 restart + adjustment_scope,用于图分流。
|
||||
SchedulePlanIntentPrompt = `你是 SmartFlow 的排程意图分析器。
|
||||
请根据用户输入,提取排程意图与约束条件。
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package schedulerefine
|
||||
package agentprompt
|
||||
|
||||
const (
|
||||
// contractPrompt 负责把用户自然语言微调请求抽取为结构化契约。
|
||||
contractPrompt = `你是 SmartFlow 的排程微调契约分析器。
|
||||
// ScheduleRefineContractPrompt 负责把用户自然语言微调请求抽取为结构化契约。
|
||||
ScheduleRefineContractPrompt = `你是 SmartFlow 的排程微调契约分析器。
|
||||
你会收到:当前时间、用户请求、已有排程摘要。
|
||||
请只输出 JSON,不要 Markdown,不要解释,不要代码块:
|
||||
{
|
||||
@@ -31,8 +31,8 @@ const (
|
||||
4. hard_requirements 必须可验证,避免空泛描述。
|
||||
5. hard_assertions 必须尽量结构化,避免只给自然语言目标。`
|
||||
|
||||
// plannerPrompt 只负责生成“执行路径”,不直接执行动作。
|
||||
plannerPrompt = `你是 SmartFlow 的排程微调 Planner。
|
||||
// ScheduleRefinePlannerPrompt 只负责生成“执行路径”,不直接执行动作。
|
||||
ScheduleRefinePlannerPrompt = `你是 SmartFlow 的排程微调 Planner。
|
||||
你会收到:用户请求、契约、最近动作观察。
|
||||
请只输出 JSON,不要 Markdown,不要解释,不要代码块:
|
||||
{
|
||||
@@ -47,8 +47,8 @@ const (
|
||||
4. 若目标是“上下文切换最少/同科目连续”,steps 必须体现 MinContextSwitch 且包含“成功后才收口”的硬条件。
|
||||
5. 不要输出半截 JSON。`
|
||||
|
||||
// reactPrompt 用于“单任务微步 ReAct”执行器。
|
||||
reactPrompt = `你是 SmartFlow 的单任务微步 ReAct 执行器。
|
||||
// ScheduleRefineReactPrompt 用于“单任务微步 ReAct”执行器。
|
||||
ScheduleRefineReactPrompt = `你是 SmartFlow 的单任务微步 ReAct 执行器。
|
||||
当前只处理一个任务(CURRENT_TASK),不能发散到其它任务的主动改动。
|
||||
你每轮只能做两件事之一:
|
||||
1) 调用一个工具(基础工具或复合工具)
|
||||
@@ -121,8 +121,8 @@ const (
|
||||
17. 若 COMPOSITE_TOOLS_ALLOWED=false,禁止调用 SpreadEven/MinContextSwitch,只能使用基础工具逐步处理。
|
||||
18. 为保证解析稳定:goal_check<=50字,decision<=90字,summary<=60字。`
|
||||
|
||||
// postReflectPrompt 要求模型基于真实工具结果做复盘,不允许“脑补成功”。
|
||||
postReflectPrompt = `你是 SmartFlow 的 ReAct 复盘器。
|
||||
// ScheduleRefinePostReflectPrompt 要求模型基于真实工具结果做复盘,不允许“脑补成功”。
|
||||
ScheduleRefinePostReflectPrompt = `你是 SmartFlow 的 ReAct 复盘器。
|
||||
你会收到:本轮工具参数、后端真实执行结果、上一轮上下文。
|
||||
请只输出 JSON,不要 Markdown,不要解释:
|
||||
{
|
||||
@@ -136,8 +136,8 @@ const (
|
||||
2. 若 error_code 属于 ORDER_VIOLATION/SLOT_CONFLICT/REPEAT_FAILED_ACTION,next_strategy 必须给出规避方法。
|
||||
3. should_stop=true 仅用于“目标已满足”或“继续收益很低”。`
|
||||
|
||||
// reviewPrompt 用于终审语义校验。
|
||||
reviewPrompt = `你是 SmartFlow 的终审校验器。
|
||||
// ScheduleRefineReviewPrompt 用于终审语义校验。
|
||||
ScheduleRefineReviewPrompt = `你是 SmartFlow 的终审校验器。
|
||||
请判断“当前排程”是否满足“本轮用户微调请求 + 契约硬要求”。
|
||||
只输出 JSON:
|
||||
{
|
||||
@@ -150,16 +150,16 @@ const (
|
||||
1. pass=true 时 unmet 必须为空数组。
|
||||
2. pass=false 时 reason 必须给出核心差距。`
|
||||
|
||||
// summaryPrompt 用于最终面向用户的自然语言总结。
|
||||
summaryPrompt = `你是 SmartFlow 的排程结果解读助手。
|
||||
// ScheduleRefineSummaryPrompt 用于最终面向用户的自然语言总结。
|
||||
ScheduleRefineSummaryPrompt = `你是 SmartFlow 的排程结果解读助手。
|
||||
请基于输入输出 2~4 句中文总结:
|
||||
1) 先说明本轮改了什么;
|
||||
2) 再说明改动收益;
|
||||
3) 若终审未完全通过,明确还差什么。
|
||||
不要输出 JSON。`
|
||||
|
||||
// repairPrompt 用于终审失败后的单次修复动作。
|
||||
repairPrompt = `你是 SmartFlow 的修复执行器。
|
||||
// ScheduleRefineRepairPrompt 用于终审失败后的单次修复动作。
|
||||
ScheduleRefineRepairPrompt = `你是 SmartFlow 的修复执行器。
|
||||
当前方案未通过终审,请根据“未满足点”只做一次修复动作。
|
||||
只允许输出一个 tool_call(Move 或 Swap),不允许 done。
|
||||
|
||||
79
backend/agent/prompt/taskquery.go
Normal file
79
backend/agent/prompt/taskquery.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package agentprompt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const TaskQueryPlanPrompt = `你是 SmartFlow 的任务查询规划器。请根据用户原话,输出结构化查询计划 JSON,供后端直接执行。
|
||||
只允许输出 JSON,不要输出解释、代码块或多余文字。
|
||||
|
||||
输出字段:
|
||||
{
|
||||
"user_goal": "一句话总结用户诉求",
|
||||
"quadrants": [1,2,3,4],
|
||||
"sort_by": "deadline|priority|id",
|
||||
"order": "asc|desc",
|
||||
"limit": 1-20,
|
||||
"include_completed": false,
|
||||
"keyword": "可选关键词,或空字符串",
|
||||
"deadline_before": "yyyy-MM-dd HH:mm 或空字符串",
|
||||
"deadline_after": "yyyy-MM-dd HH:mm 或空字符串"
|
||||
}
|
||||
|
||||
规则:
|
||||
1. quadrants 为空数组表示“全部象限”。
|
||||
2. 用户未提排序时,默认 sort_by=deadline 且 order=asc。
|
||||
3. 用户未提数量时,limit 默认 5。
|
||||
4. 时间字段必须输出绝对时间或空字符串,不要输出“明天”“下周一”这类相对时间。
|
||||
5. 如果用户语义更偏向“我还有什么要做”“看看待办”,优先考虑 1、2 象限;如果 1、2 象限为空,再考虑 3、4 象限。
|
||||
6. 如果用户语义更偏向“来点事做做”“给我点轻松的任务”,优先考虑 3、4 象限。
|
||||
7. 允许多选象限。`
|
||||
|
||||
const TaskQueryReflectPrompt = `你是 SmartFlow 的任务查询结果审阅器。你会看到:用户原话、当前查询计划、查询结果摘要、当前重试次数。
|
||||
请只输出 JSON,不要输出解释、代码块或多余文字。
|
||||
|
||||
输出字段:
|
||||
{
|
||||
"satisfied": true,
|
||||
"need_retry": false,
|
||||
"reason": "一句话原因",
|
||||
"reply": "可直接给用户看的中文回复",
|
||||
"retry_patch": {
|
||||
"quadrants": [1,2,3,4],
|
||||
"sort_by": "deadline|priority|id",
|
||||
"order": "asc|desc",
|
||||
"limit": 1-20,
|
||||
"include_completed": true,
|
||||
"keyword": "可选关键词,或空字符串",
|
||||
"deadline_before": "yyyy-MM-dd HH:mm 或空字符串",
|
||||
"deadline_after": "yyyy-MM-dd HH:mm 或空字符串"
|
||||
}
|
||||
}
|
||||
|
||||
规则:
|
||||
1. 如果当前结果已经满足用户诉求,返回 satisfied=true 且 need_retry=false。
|
||||
2. 如果当前结果不满足,但仍值得再查一次,返回 need_retry=true,并尽量只给最小必要 patch。
|
||||
3. 如果不建议再试,返回 need_retry=false,并在 reply 里说明当前最接近的结果。
|
||||
4. reply 应该是自然中文,不要输出表格。`
|
||||
|
||||
func BuildTaskQueryPlanUserPrompt(nowText, userInput string) string {
|
||||
return fmt.Sprintf(
|
||||
"当前时间(北京时间,精确到分钟):%s\n用户输入:%s\n\n请输出任务查询计划 JSON。",
|
||||
strings.TrimSpace(nowText),
|
||||
strings.TrimSpace(userInput),
|
||||
)
|
||||
}
|
||||
|
||||
func BuildTaskQueryReflectUserPrompt(nowText, userInput, userGoal, planSummary string, retryCount, maxRetry int, resultSummary string) string {
|
||||
return fmt.Sprintf(
|
||||
"当前时间:%s\n用户原话:%s\n用户目标:%s\n当前查询计划:%s\n当前重试:%d/%d\n查询结果摘要:\n%s",
|
||||
strings.TrimSpace(nowText),
|
||||
strings.TrimSpace(userInput),
|
||||
strings.TrimSpace(userGoal),
|
||||
strings.TrimSpace(planSummary),
|
||||
retryCount,
|
||||
maxRetry,
|
||||
strings.TrimSpace(resultSummary),
|
||||
)
|
||||
}
|
||||
@@ -1,163 +0,0 @@
|
||||
package quicknote
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
const (
|
||||
// 图节点:意图识别(含聚合规划与时间校验)
|
||||
quickNoteGraphNodeIntent = "quick_note_intent"
|
||||
// 图节点:优先级评估(或本地兜底)
|
||||
quickNoteGraphNodeRank = "quick_note_priority"
|
||||
// 图节点:持久化(调用写库工具)
|
||||
quickNoteGraphNodePersist = "quick_note_persist"
|
||||
// 图节点:退出(用于非随口记/校验失败分支)
|
||||
quickNoteGraphNodeExit = "quick_note_exit"
|
||||
)
|
||||
|
||||
// QuickNoteGraphRunInput 是运行“随口记 graph”所需的输入依赖。
|
||||
// 说明:
|
||||
// 1) EmitStage 可选,用于把节点进度推送给外层(例如 SSE 状态块);
|
||||
// 2) 不传 EmitStage 时,图逻辑保持静默执行;
|
||||
// 3) SkipIntentVerification=true 时,表示上游路由已信任 quick_note,可跳过二次意图判定。
|
||||
type QuickNoteGraphRunInput struct {
|
||||
Model *ark.ChatModel
|
||||
State *QuickNoteState
|
||||
Deps QuickNoteToolDeps
|
||||
|
||||
SkipIntentVerification bool
|
||||
EmitStage func(stage, detail string)
|
||||
}
|
||||
|
||||
// RunQuickNoteGraph 执行“随口记”图编排。
|
||||
// 该文件只负责“连线与分支”,节点内部逻辑全部下沉到 nodes.go。
|
||||
func RunQuickNoteGraph(ctx context.Context, input QuickNoteGraphRunInput) (*QuickNoteState, error) {
|
||||
// 1. 启动前硬校验:模型、状态、依赖缺一不可。
|
||||
if input.Model == nil {
|
||||
return nil, errors.New("quick note graph: model is nil")
|
||||
}
|
||||
if input.State == nil {
|
||||
return nil, errors.New("quick note graph: state is nil")
|
||||
}
|
||||
if err := input.Deps.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 统一封装阶段推送函数,避免各节点反复判空。
|
||||
emitStage := func(stage, detail string) {
|
||||
if input.EmitStage != nil {
|
||||
input.EmitStage(stage, detail)
|
||||
}
|
||||
}
|
||||
|
||||
// 统一初始化“当前时间基准”,避免同一请求内相对时间口径漂移。
|
||||
// 2.1 若上游未设置 RequestNow,这里补齐。
|
||||
if input.State.RequestNow.IsZero() {
|
||||
input.State.RequestNow = quickNoteNowToMinute()
|
||||
}
|
||||
// 2.2 若上游未设置文本基准,这里按统一格式补齐。
|
||||
if strings.TrimSpace(input.State.RequestNowText) == "" {
|
||||
input.State.RequestNowText = formatQuickNoteTimeToMinute(input.State.RequestNow)
|
||||
}
|
||||
|
||||
// 3. 构建工具包并取出写库工具。
|
||||
// 这样 graph 运行时只关心“调用工具”,不关心工具如何注册。
|
||||
toolBundle, err := BuildQuickNoteToolBundle(ctx, input.Deps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
createTaskTool, err := getInvokableToolByName(toolBundle, ToolNameQuickNoteCreateTask)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. runner 负责把依赖收口,graph 只保留连线定义。
|
||||
runner := newQuickNoteRunner(input, createTaskTool, emitStage)
|
||||
|
||||
// 5. 创建状态图容器:输入/输出类型都为 *QuickNoteState。
|
||||
graph := compose.NewGraph[*QuickNoteState, *QuickNoteState]()
|
||||
|
||||
// 6. 注册节点(意图 -> 优先级 -> 持久化 -> 退出)。
|
||||
if err = graph.AddLambdaNode(quickNoteGraphNodeIntent, compose.InvokableLambda(runner.intentNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = graph.AddLambdaNode(quickNoteGraphNodeRank, compose.InvokableLambda(runner.priorityNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = graph.AddLambdaNode(quickNoteGraphNodePersist, compose.InvokableLambda(runner.persistNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = graph.AddLambdaNode(quickNoteGraphNodeExit, compose.InvokableLambda(runner.exitNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 连线:START -> intent
|
||||
// 7. 所有请求统一先过 intent 节点,确保意图和时间校验在前。
|
||||
if err = graph.AddEdge(compose.START, quickNoteGraphNodeIntent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 分支:intent 后决定去 priority 还是 exit。
|
||||
// 8. 非随口记或时间非法时直接 exit,避免进入后续写库路径。
|
||||
if err = graph.AddBranch(quickNoteGraphNodeIntent, compose.NewGraphBranch(
|
||||
runner.nextAfterIntent,
|
||||
map[string]bool{
|
||||
quickNoteGraphNodeRank: true,
|
||||
quickNoteGraphNodeExit: true,
|
||||
},
|
||||
)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// exit 直接结束。
|
||||
// 9. exit 是显式终点前节点,方便后续插入“统一收尾逻辑”。
|
||||
if err = graph.AddEdge(quickNoteGraphNodeExit, compose.END); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// priority -> persist。
|
||||
// 10. 通过优先级节点后,进入持久化节点。
|
||||
if err = graph.AddEdge(quickNoteGraphNodeRank, quickNoteGraphNodePersist); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// persist 后决定“重试 persist”还是结束。
|
||||
// 11. 重试策略由状态字段驱动,不在 graph 层写重试计数逻辑。
|
||||
if err = graph.AddBranch(quickNoteGraphNodePersist, compose.NewGraphBranch(
|
||||
runner.nextAfterPersist,
|
||||
map[string]bool{
|
||||
quickNoteGraphNodePersist: true,
|
||||
compose.END: true,
|
||||
},
|
||||
)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 12. 运行步数上限:至少 12 步,并根据 MaxToolRetry 预留重试步数。
|
||||
// 防止异常分支导致无限循环。
|
||||
maxSteps := input.State.MaxToolRetry + 10
|
||||
if maxSteps < 12 {
|
||||
maxSteps = 12
|
||||
}
|
||||
|
||||
// 13. 编译图得到可执行实例。
|
||||
runnable, err := graph.Compile(ctx,
|
||||
compose.WithGraphName("QuickNoteGraph"),
|
||||
compose.WithMaxRunSteps(maxSteps),
|
||||
compose.WithNodeTriggerMode(compose.AnyPredecessor),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 14. 执行图并返回最终状态。
|
||||
return runnable.Invoke(ctx, input.State)
|
||||
}
|
||||
@@ -1,670 +0,0 @@
|
||||
package quicknote
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
einoModel "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
type quickNoteIntentModelOutput struct {
|
||||
IsQuickNote bool `json:"is_quick_note"`
|
||||
Title string `json:"title"`
|
||||
DeadlineAt string `json:"deadline_at"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type quickNotePriorityModelOutput struct {
|
||||
PriorityGroup int `json:"priority_group"`
|
||||
Reason string `json:"reason"`
|
||||
UrgencyThresholdAt string `json:"urgency_threshold_at"`
|
||||
}
|
||||
|
||||
// quickNotePlanModelOutput 是“单请求聚合规划”节点的模型输出。
|
||||
type quickNotePlanModelOutput struct {
|
||||
Title string `json:"title"`
|
||||
DeadlineAt string `json:"deadline_at"`
|
||||
UrgencyThresholdAt string `json:"urgency_threshold_at"`
|
||||
PriorityGroup int `json:"priority_group"`
|
||||
PriorityReason string `json:"priority_reason"`
|
||||
Banter string `json:"banter"`
|
||||
}
|
||||
|
||||
// runQuickNoteIntentNode 负责“意图识别 + 聚合规划 + 时间校验”。
|
||||
// 说明:
|
||||
// 1) trustRoute 命中时,直接走单请求聚合规划,跳过二次意图识别;
|
||||
// 2) 无论是否走快路径,最终都要走本地时间硬校验,防止脏时间落库。
|
||||
func runQuickNoteIntentNode(ctx context.Context, st *QuickNoteState, input QuickNoteGraphRunInput, emitStage func(stage, detail string)) (*QuickNoteState, error) {
|
||||
// 0. 基础防御:state 为空直接返回错误,避免后续节点空指针。
|
||||
if st == nil {
|
||||
return nil, errors.New("quick note graph: nil state in intent node")
|
||||
}
|
||||
|
||||
// 1. 如果上游路由已高置信命中 quick_note,则走“单请求聚合快路径”。
|
||||
if input.SkipIntentVerification {
|
||||
emitStage("quick_note.intent.analyzing", "已由上游路由判定为任务请求,跳过二次意图判断。")
|
||||
st.IsQuickNoteIntent = true
|
||||
st.IntentJudgeReason = "上游路由已命中 quick_note,跳过二次意图判定"
|
||||
st.PlannedBySingleCall = true
|
||||
|
||||
// 1.1 一次调用里尽量拿齐 title/deadline/priority/banter,减少串行模型开销。
|
||||
emitStage("quick_note.plan.generating", "正在一次性生成时间归一化、优先级与回复润色。")
|
||||
plan, planErr := planQuickNoteInSingleCall(ctx, input.Model, st.RequestNowText, st.RequestNow, st.UserInput)
|
||||
if planErr != nil {
|
||||
// 1.2 聚合规划失败不终止链路,改为后续本地兜底。
|
||||
st.IntentJudgeReason += ";聚合规划失败,回退本地兜底"
|
||||
} else {
|
||||
// 1.3 仅在字段有效时回填,避免无效值污染状态。
|
||||
if strings.TrimSpace(plan.Title) != "" {
|
||||
st.ExtractedTitle = strings.TrimSpace(plan.Title)
|
||||
}
|
||||
if plan.Deadline != nil {
|
||||
st.ExtractedDeadline = plan.Deadline
|
||||
}
|
||||
st.ExtractedDeadlineText = strings.TrimSpace(plan.DeadlineText)
|
||||
if plan.UrgencyThreshold != nil {
|
||||
st.ExtractedUrgencyThreshold = normalizeUrgencyThreshold(plan.UrgencyThreshold, plan.Deadline)
|
||||
}
|
||||
if IsValidTaskPriority(plan.PriorityGroup) {
|
||||
st.ExtractedPriority = plan.PriorityGroup
|
||||
st.ExtractedPriorityReason = strings.TrimSpace(plan.PriorityReason)
|
||||
}
|
||||
st.ExtractedBanter = strings.TrimSpace(plan.Banter)
|
||||
}
|
||||
|
||||
// 1.4 如果模型没给标题,基于原句做本地标题提取兜底。
|
||||
if strings.TrimSpace(st.ExtractedTitle) == "" {
|
||||
st.ExtractedTitle = deriveQuickNoteTitleFromInput(st.UserInput)
|
||||
}
|
||||
|
||||
// 1.5 无论是否聚合成功,都要进行本地时间硬校验,防止脏时间写库。
|
||||
emitStage("quick_note.deadline.validating", "正在校验并归一化任务时间。")
|
||||
userDeadline, userHasTimeHint, userDeadlineErr := parseOptionalDeadlineFromUserInput(st.UserInput, st.RequestNow)
|
||||
if userHasTimeHint && userDeadlineErr != nil {
|
||||
st.DeadlineValidationError = userDeadlineErr.Error()
|
||||
st.AssistantReply = "我识别到你给了时间信息,但这个时间格式我没法准确解析,请改成例如:2026-03-20 18:30、明天下午3点、下周一上午9点。"
|
||||
emitStage("quick_note.failed", "时间校验失败,未执行写入。")
|
||||
return st, nil
|
||||
}
|
||||
if userDeadline != nil {
|
||||
// 用户原句能解析出时间时,以原句解析结果为准(更贴近真实输入)。
|
||||
st.ExtractedDeadline = userDeadline
|
||||
st.ExtractedDeadlineText = strings.TrimSpace(st.UserInput)
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 2. 常规路径:先让模型做意图识别 + 初步抽取。
|
||||
emitStage("quick_note.intent.analyzing", "正在分析用户输入是否属于任务安排请求。")
|
||||
prompt := fmt.Sprintf(`当前时间(北京时间,精确到分钟):%s
|
||||
用户输入:%s
|
||||
请仅输出 JSON(不要 markdown,不要解释),字段如下:
|
||||
{
|
||||
"is_quick_note": boolean,
|
||||
"title": string,
|
||||
"deadline_at": string,
|
||||
"reason": string
|
||||
}
|
||||
字段约束:
|
||||
1) deadline_at 只允许输出绝对时间,格式必须为 "yyyy-MM-dd HH:mm"。
|
||||
2) 如果用户说了“明天/后天/下周一/今晚”等相对时间,必须基于上面的当前时间换算成绝对时间。
|
||||
3) 如果用户没有提及时间,deadline_at 输出空字符串。`,
|
||||
st.RequestNowText,
|
||||
st.UserInput,
|
||||
)
|
||||
|
||||
// 2.1 模型调用失败时,保守回退普通聊天,避免误写任务。
|
||||
raw, callErr := callModelForJSON(ctx, input.Model, QuickNoteIntentPrompt, prompt)
|
||||
if callErr != nil {
|
||||
st.IsQuickNoteIntent = false
|
||||
st.IntentJudgeReason = "意图识别失败,回退普通聊天"
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 2.2 解析失败同样回退普通聊天,保证稳定性优先。
|
||||
parsed, parseErr := parseJSONPayload[quickNoteIntentModelOutput](raw)
|
||||
if parseErr != nil {
|
||||
st.IsQuickNoteIntent = false
|
||||
st.IntentJudgeReason = "意图识别结果不可解析,回退普通聊天"
|
||||
return st, nil
|
||||
}
|
||||
|
||||
st.IsQuickNoteIntent = parsed.IsQuickNote
|
||||
st.IntentJudgeReason = strings.TrimSpace(parsed.Reason)
|
||||
if !st.IsQuickNoteIntent {
|
||||
// 非随口记:后续通过分支直接退出 graph。
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 2.3 处理标题字段:为空时回退到用户原句。
|
||||
title := strings.TrimSpace(parsed.Title)
|
||||
if title == "" {
|
||||
title = strings.TrimSpace(st.UserInput)
|
||||
}
|
||||
st.ExtractedTitle = title
|
||||
|
||||
emitStage("quick_note.deadline.validating", "正在校验并归一化任务时间。")
|
||||
|
||||
// Step A:优先尝试解析模型抽取出来的 deadline。
|
||||
// 这样可利用模型“结构化理解”能力先拿一次候选时间。
|
||||
st.ExtractedDeadlineText = strings.TrimSpace(parsed.DeadlineAt)
|
||||
if st.ExtractedDeadlineText != "" {
|
||||
if deadline, deadlineErr := parseOptionalDeadlineWithNow(st.ExtractedDeadlineText, st.RequestNow); deadlineErr == nil {
|
||||
st.ExtractedDeadline = deadline
|
||||
}
|
||||
}
|
||||
|
||||
// Step B:基于用户原句执行“本地时间解析 + 合法性校验”。
|
||||
// 本地校验是最终硬门槛,确保“用户给错时间不会被静默写成 NULL”。
|
||||
userDeadline, userHasTimeHint, userDeadlineErr := parseOptionalDeadlineFromUserInput(st.UserInput, st.RequestNow)
|
||||
if userHasTimeHint && userDeadlineErr != nil {
|
||||
st.DeadlineValidationError = userDeadlineErr.Error()
|
||||
st.AssistantReply = "我识别到你给了时间信息,但这个时间格式我没法准确解析,请改成例如:2026-03-20 18:30、明天下午3点、下周一上午9点。"
|
||||
emitStage("quick_note.failed", "时间校验失败,未执行写入。")
|
||||
return st, nil
|
||||
}
|
||||
|
||||
if st.ExtractedDeadline == nil && userDeadline != nil {
|
||||
// 当模型未提取出时间,但原句能解析时,补写时间结果。
|
||||
st.ExtractedDeadline = userDeadline
|
||||
if st.ExtractedDeadlineText == "" {
|
||||
st.ExtractedDeadlineText = strings.TrimSpace(st.UserInput)
|
||||
}
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// runQuickNotePriorityNode 负责“优先级评估”。
|
||||
// 说明:
|
||||
// 1) 聚合规划已给出合法优先级时直接复用;
|
||||
// 2) 快路径下缺失优先级时直接本地兜底,避免额外模型调用;
|
||||
// 3) 其余场景走独立评估模型,失败再兜底。
|
||||
func runQuickNotePriorityNode(ctx context.Context, st *QuickNoteState, input QuickNoteGraphRunInput, emitStage func(stage, detail string)) (*QuickNoteState, error) {
|
||||
if st == nil {
|
||||
return nil, errors.New("quick note graph: nil state in priority node")
|
||||
}
|
||||
// 1. 非随口记或时间校验失败时,不做优先级评估。
|
||||
if !st.IsQuickNoteIntent || strings.TrimSpace(st.DeadlineValidationError) != "" {
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 2. 已有合法优先级则直接复用,避免重复调用模型。
|
||||
if IsValidTaskPriority(st.ExtractedPriority) {
|
||||
if strings.TrimSpace(st.ExtractedPriorityReason) == "" {
|
||||
st.ExtractedPriorityReason = "复用聚合规划优先级"
|
||||
}
|
||||
emitStage("quick_note.priority.evaluating", "已复用聚合规划结果中的优先级。")
|
||||
return st, nil
|
||||
}
|
||||
// 3. 快路径下若缺失优先级,直接本地兜底,追求低延迟。
|
||||
if input.SkipIntentVerification || st.PlannedBySingleCall {
|
||||
st.ExtractedPriority = fallbackPriority(st)
|
||||
st.ExtractedPriorityReason = "聚合规划未给出合法优先级,使用本地兜底"
|
||||
emitStage("quick_note.priority.evaluating", "聚合优先级缺失,已使用本地兜底。")
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 4. 常规路径才调用独立优先级模型。
|
||||
emitStage("quick_note.priority.evaluating", "正在评估任务优先级。")
|
||||
deadlineText := "无"
|
||||
if st.ExtractedDeadline != nil {
|
||||
deadlineText = formatQuickNoteTimeToMinute(*st.ExtractedDeadline)
|
||||
}
|
||||
deadlineClue := strings.TrimSpace(st.ExtractedDeadlineText)
|
||||
if deadlineClue == "" {
|
||||
deadlineClue = "无"
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`当前时间(北京时间,精确到分钟):%s
|
||||
请对以下任务评估优先级:
|
||||
- 任务标题:%s
|
||||
- 用户原始输入:%s
|
||||
- 时间线索原文:%s
|
||||
- 归一化截止时间:%s
|
||||
|
||||
请仅输出 JSON(不要 markdown,不要解释):
|
||||
{
|
||||
"priority_group": 1|2|3|4,
|
||||
"reason": "简短理由",
|
||||
"urgency_threshold_at": "yyyy-MM-dd HH:mm 或空字符串"
|
||||
}
|
||||
|
||||
额外约束:
|
||||
1) urgency_threshold_at 表示“何时从不紧急象限自动平移到紧急象限”;
|
||||
2) 若该任务不需要自动平移,可输出空字符串;
|
||||
3) 若任务已在紧急象限(priority_group=1 或 3),优先输出空字符串;
|
||||
4) 若输出非空时间,必须是绝对时间,且不晚于归一化截止时间(若有)。`,
|
||||
st.RequestNowText,
|
||||
st.ExtractedTitle,
|
||||
st.UserInput,
|
||||
deadlineClue,
|
||||
deadlineText,
|
||||
)
|
||||
|
||||
// 4.1 调用失败:使用本地兜底,不中断主链路。
|
||||
raw, callErr := callModelForJSON(ctx, input.Model, QuickNotePriorityPrompt, prompt)
|
||||
if callErr != nil {
|
||||
st.ExtractedPriority = fallbackPriority(st)
|
||||
st.ExtractedPriorityReason = "优先级评估失败,使用兜底策略"
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 4.2 解析失败或非法值:同样兜底。
|
||||
parsed, parseErr := parseJSONPayload[quickNotePriorityModelOutput](raw)
|
||||
if parseErr != nil || !IsValidTaskPriority(parsed.PriorityGroup) {
|
||||
st.ExtractedPriority = fallbackPriority(st)
|
||||
st.ExtractedPriorityReason = "优先级结果异常,使用兜底策略"
|
||||
return st, nil
|
||||
}
|
||||
|
||||
st.ExtractedPriority = parsed.PriorityGroup
|
||||
st.ExtractedPriorityReason = strings.TrimSpace(parsed.Reason)
|
||||
if strings.TrimSpace(parsed.UrgencyThresholdAt) != "" {
|
||||
urgencyThreshold, thresholdErr := parseOptionalDeadlineWithNow(strings.TrimSpace(parsed.UrgencyThresholdAt), st.RequestNow)
|
||||
if thresholdErr == nil {
|
||||
st.ExtractedUrgencyThreshold = normalizeUrgencyThreshold(urgencyThreshold, st.ExtractedDeadline)
|
||||
}
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// runQuickNotePersistNodeInternal 负责“写库工具调用 + 重试态回填”。
|
||||
func runQuickNotePersistNodeInternal(ctx context.Context, st *QuickNoteState, createTaskTool tool.InvokableTool, input QuickNoteGraphRunInput, emitStage func(stage, detail string)) (*QuickNoteState, error) {
|
||||
_ = input // 保留参数形状,后续若需要基于输入开关扩展可直接使用。
|
||||
|
||||
if st == nil {
|
||||
return nil, errors.New("quick note graph: nil state in persist node")
|
||||
}
|
||||
// 1. 非随口记或时间非法时不允许落库。
|
||||
if !st.IsQuickNoteIntent || strings.TrimSpace(st.DeadlineValidationError) != "" {
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 2. 准备工具入参:优先使用已评估优先级,缺失则兜底。
|
||||
emitStage("quick_note.persisting", "正在写入任务数据。")
|
||||
priority := st.ExtractedPriority
|
||||
if !IsValidTaskPriority(priority) {
|
||||
priority = fallbackPriority(st)
|
||||
st.ExtractedPriority = priority
|
||||
}
|
||||
|
||||
deadlineText := ""
|
||||
if st.ExtractedDeadline != nil {
|
||||
deadlineText = st.ExtractedDeadline.In(quickNoteLocation()).Format(time.RFC3339)
|
||||
}
|
||||
urgencyThresholdText := ""
|
||||
if st.ExtractedUrgencyThreshold != nil {
|
||||
urgencyThresholdText = st.ExtractedUrgencyThreshold.In(quickNoteLocation()).Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// 3. 工具参数序列化失败视作一次失败尝试,交由重试分支处理。
|
||||
toolInput := QuickNoteCreateTaskToolInput{
|
||||
Title: st.ExtractedTitle,
|
||||
PriorityGroup: priority,
|
||||
DeadlineAt: deadlineText,
|
||||
UrgencyThresholdAt: urgencyThresholdText,
|
||||
}
|
||||
rawInput, marshalErr := json.Marshal(toolInput)
|
||||
if marshalErr != nil {
|
||||
st.RecordToolError("构造工具参数失败: " + marshalErr.Error())
|
||||
if !st.CanRetryTool() {
|
||||
st.AssistantReply = "抱歉,记录任务时参数处理失败,请稍后重试。"
|
||||
emitStage("quick_note.failed", "参数构造失败,未完成写入。")
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 4. 调用写库工具。
|
||||
rawOutput, invokeErr := createTaskTool.InvokableRun(ctx, string(rawInput))
|
||||
if invokeErr != nil {
|
||||
st.RecordToolError(invokeErr.Error())
|
||||
if !st.CanRetryTool() {
|
||||
st.AssistantReply = "抱歉,我尝试了多次仍未能成功记录这条任务,请稍后再试。"
|
||||
emitStage("quick_note.failed", "多次重试后仍未完成写入。")
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 5. 工具返回解析失败同样按“可重试错误”处理。
|
||||
toolOutput, parseErr := parseJSONPayload[QuickNoteCreateTaskToolOutput](rawOutput)
|
||||
if parseErr != nil {
|
||||
st.RecordToolError("解析工具返回失败: " + parseErr.Error())
|
||||
if !st.CanRetryTool() {
|
||||
st.AssistantReply = "抱歉,我拿到了异常结果,没能确认任务是否记录成功,请稍后再试。"
|
||||
emitStage("quick_note.failed", "结果解析异常,无法确认写入结果。")
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 成功判定硬门槛:必须拿到有效 task_id,防止“假成功”。
|
||||
if toolOutput.TaskID <= 0 {
|
||||
st.RecordToolError(fmt.Sprintf("工具返回非法 task_id=%d", toolOutput.TaskID))
|
||||
if !st.CanRetryTool() {
|
||||
st.AssistantReply = "抱歉,这次我没能确认任务写入成功,请再发一次我立刻补上。"
|
||||
emitStage("quick_note.failed", "写入结果缺少有效 task_id,已终止成功回包。")
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 6. 写库成功后回填状态,并准备最终回复内容。
|
||||
st.RecordToolSuccess(toolOutput.TaskID)
|
||||
if strings.TrimSpace(toolOutput.Title) != "" {
|
||||
st.ExtractedTitle = strings.TrimSpace(toolOutput.Title)
|
||||
}
|
||||
if IsValidTaskPriority(toolOutput.PriorityGroup) {
|
||||
st.ExtractedPriority = toolOutput.PriorityGroup
|
||||
}
|
||||
|
||||
reply := strings.TrimSpace(toolOutput.Message)
|
||||
if reply == "" {
|
||||
reply = fmt.Sprintf("已为你记录:%s(%s)", st.ExtractedTitle, PriorityLabelCN(st.ExtractedPriority))
|
||||
}
|
||||
st.AssistantReply = reply
|
||||
emitStage("quick_note.persisted", "任务写入成功,正在组织回复内容。")
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// selectQuickNoteNextAfterIntent 根据意图与时间校验结果决定 intent 后分支。
|
||||
func selectQuickNoteNextAfterIntent(st *QuickNoteState) string {
|
||||
// 1) 非随口记 -> exit;
|
||||
// 2) 时间校验失败 -> exit;
|
||||
// 3) 其余 -> priority 节点。
|
||||
if st == nil || !st.IsQuickNoteIntent {
|
||||
return quickNoteGraphNodeExit
|
||||
}
|
||||
if strings.TrimSpace(st.DeadlineValidationError) != "" {
|
||||
return quickNoteGraphNodeExit
|
||||
}
|
||||
return quickNoteGraphNodeRank
|
||||
}
|
||||
|
||||
// selectQuickNoteNextAfterPersist 根据持久化状态决定 persist 后分支。
|
||||
func selectQuickNoteNextAfterPersist(st *QuickNoteState) string {
|
||||
// 分支规则:
|
||||
// 1) state=nil:防御式结束;
|
||||
// 2) 已持久化:结束;
|
||||
// 3) 可重试:回到 persist 重试;
|
||||
// 4) 不可重试:写失败文案并结束。
|
||||
if st == nil {
|
||||
return compose.END
|
||||
}
|
||||
if st.Persisted {
|
||||
return compose.END
|
||||
}
|
||||
if st.CanRetryTool() {
|
||||
return quickNoteGraphNodePersist
|
||||
}
|
||||
if strings.TrimSpace(st.AssistantReply) == "" {
|
||||
st.AssistantReply = "抱歉,我尝试了多次仍未能成功记录这条任务,请稍后再试。"
|
||||
}
|
||||
return compose.END
|
||||
}
|
||||
|
||||
func getInvokableToolByName(bundle *QuickNoteToolBundle, name string) (tool.InvokableTool, error) {
|
||||
// 1. 校验工具包有效性。
|
||||
if bundle == nil {
|
||||
return nil, errors.New("tool bundle is nil")
|
||||
}
|
||||
if len(bundle.Tools) == 0 || len(bundle.ToolInfos) == 0 {
|
||||
return nil, errors.New("tool bundle is empty")
|
||||
}
|
||||
// 2. 通过 ToolInfo 名称定位并拿到同索引的 Tool 实例。
|
||||
for idx, info := range bundle.ToolInfos {
|
||||
if info == nil || info.Name != name {
|
||||
continue
|
||||
}
|
||||
invokable, ok := bundle.Tools[idx].(tool.InvokableTool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tool %s is not invokable", name)
|
||||
}
|
||||
return invokable, nil
|
||||
}
|
||||
return nil, fmt.Errorf("tool %s not found", name)
|
||||
}
|
||||
|
||||
func callModelForJSON(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string) (string, error) {
|
||||
// 默认 JSON 输出场景 token 足够小,使用 256 作为保守上限。
|
||||
return callModelForJSONWithMaxTokens(ctx, chatModel, systemPrompt, userPrompt, 256)
|
||||
}
|
||||
|
||||
func callModelForJSONWithMaxTokens(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, maxTokens int) (string, error) {
|
||||
// 1. 构造 system + user 两段消息。
|
||||
messages := []*schema.Message{
|
||||
schema.SystemMessage(systemPrompt),
|
||||
schema.UserMessage(userPrompt),
|
||||
}
|
||||
// 2. 统一关闭 thinking,降低额外延迟,并用温度 0 提升结构化稳定性。
|
||||
opts := []einoModel.Option{
|
||||
ark.WithThinking(&arkModel.Thinking{Type: arkModel.ThinkingTypeDisabled}),
|
||||
einoModel.WithTemperature(0),
|
||||
}
|
||||
if maxTokens > 0 {
|
||||
opts = append(opts, einoModel.WithMaxTokens(maxTokens))
|
||||
}
|
||||
|
||||
// 3. 调模型并对空响应做防御校验。
|
||||
resp, err := chatModel.Generate(ctx, messages, opts...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp == nil {
|
||||
return "", errors.New("模型返回为空")
|
||||
}
|
||||
content := strings.TrimSpace(resp.Content)
|
||||
if content == "" {
|
||||
return "", errors.New("模型返回内容为空")
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
type quickNotePlannedResult struct {
|
||||
Title string
|
||||
Deadline *time.Time
|
||||
DeadlineText string
|
||||
UrgencyThreshold *time.Time
|
||||
UrgencyThresholdText string
|
||||
PriorityGroup int
|
||||
PriorityReason string
|
||||
Banter string
|
||||
}
|
||||
|
||||
// planQuickNoteInSingleCall 在一次模型调用里完成“时间/优先级/banter”聚合规划。
|
||||
func planQuickNoteInSingleCall(
|
||||
ctx context.Context,
|
||||
chatModel *ark.ChatModel,
|
||||
nowText string,
|
||||
now time.Time,
|
||||
userInput string,
|
||||
) (*quickNotePlannedResult, error) {
|
||||
// 1. 构造聚合 prompt:一次返回所有结构化字段,减少多次 LLM 往返。
|
||||
prompt := fmt.Sprintf(`当前时间(北京时间,精确到分钟):%s
|
||||
用户输入:%s
|
||||
|
||||
请仅输出 JSON(不要 markdown,不要解释),字段如下:
|
||||
{
|
||||
"title": string,
|
||||
"deadline_at": string,
|
||||
"urgency_threshold_at": string,
|
||||
"priority_group": 1|2|3|4,
|
||||
"priority_reason": string,
|
||||
"banter": string
|
||||
}
|
||||
|
||||
约束:
|
||||
1) deadline_at 只允许 "yyyy-MM-dd HH:mm" 或空字符串;
|
||||
2) urgency_threshold_at 只允许 "yyyy-MM-dd HH:mm" 或空字符串;
|
||||
3) 若用户给了相对时间(如明天/今晚/下周一),必须换算为绝对时间;
|
||||
4) 若任务不需要自动平移,可让 urgency_threshold_at 为空;
|
||||
5) banter 只允许一句中文,不超过30字,不得改动任务事实。`,
|
||||
nowText,
|
||||
strings.TrimSpace(userInput),
|
||||
)
|
||||
|
||||
// 2. 控制 maxTokens,避免模型冗长输出导致延迟上升。
|
||||
raw, err := callModelForJSONWithMaxTokens(ctx, chatModel, QuickNotePlanPrompt, prompt, 220)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 3. 解析模型输出 JSON。
|
||||
parsed, parseErr := parseJSONPayload[quickNotePlanModelOutput](raw)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
|
||||
result := &quickNotePlannedResult{
|
||||
Title: strings.TrimSpace(parsed.Title),
|
||||
DeadlineText: strings.TrimSpace(parsed.DeadlineAt),
|
||||
UrgencyThresholdText: strings.TrimSpace(parsed.UrgencyThresholdAt),
|
||||
PriorityGroup: parsed.PriorityGroup,
|
||||
PriorityReason: strings.TrimSpace(parsed.PriorityReason),
|
||||
Banter: strings.TrimSpace(parsed.Banter),
|
||||
}
|
||||
|
||||
// 4. banter 只保留首行,防止模型输出多行破坏最终回复风格。
|
||||
if result.Banter != "" {
|
||||
if idx := strings.Index(result.Banter, "\n"); idx >= 0 {
|
||||
result.Banter = strings.TrimSpace(result.Banter[:idx])
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 对 deadline 做本地二次校验,确保可落库。
|
||||
if result.DeadlineText != "" {
|
||||
if deadline, deadlineErr := parseOptionalDeadlineWithNow(result.DeadlineText, now); deadlineErr == nil {
|
||||
result.Deadline = deadline
|
||||
}
|
||||
}
|
||||
// 6. 对 urgency_threshold_at 做本地二次校验,并与 deadline 做上界约束。
|
||||
if result.UrgencyThresholdText != "" {
|
||||
if urgencyThreshold, thresholdErr := parseOptionalDeadlineWithNow(result.UrgencyThresholdText, now); thresholdErr == nil {
|
||||
result.UrgencyThreshold = normalizeUrgencyThreshold(urgencyThreshold, result.Deadline)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func parseJSONPayload[T any](raw string) (*T, error) {
|
||||
// 1. 空字符串直接失败。
|
||||
clean := strings.TrimSpace(raw)
|
||||
if clean == "" {
|
||||
return nil, errors.New("empty response")
|
||||
}
|
||||
|
||||
// 2. 兼容 ```json ... ``` 包裹输出。
|
||||
if strings.HasPrefix(clean, "```") {
|
||||
clean = strings.TrimPrefix(clean, "```json")
|
||||
clean = strings.TrimPrefix(clean, "```")
|
||||
clean = strings.TrimSuffix(clean, "```")
|
||||
clean = strings.TrimSpace(clean)
|
||||
}
|
||||
|
||||
// 3. 先尝试整体反序列化(最快路径)。
|
||||
var out T
|
||||
if err := json.Unmarshal([]byte(clean), &out); err == nil {
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
// 4. 若模型附带额外文本,则提取最外层 JSON 对象再解析。
|
||||
obj := extractJSONObject(clean)
|
||||
if obj == "" {
|
||||
return nil, fmt.Errorf("no json object found in: %s", clean)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(obj), &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func extractJSONObject(text string) string {
|
||||
// 简化提取策略:取首个“{”到最后“}”的片段。
|
||||
// 对当前 prompt 场景足够稳定,且实现成本低。
|
||||
start := strings.Index(text, "{")
|
||||
end := strings.LastIndex(text, "}")
|
||||
if start == -1 || end == -1 || end <= start {
|
||||
return ""
|
||||
}
|
||||
return text[start : end+1]
|
||||
}
|
||||
|
||||
// normalizeUrgencyThreshold 归一化“紧急分界线时间”。
|
||||
//
|
||||
// 规则:
|
||||
// 1. 分界线为空时直接返回空;
|
||||
// 2. 存在 deadline 且分界线晚于 deadline 时,收敛到 deadline;
|
||||
// 3. 其余情况保持原值。
|
||||
func normalizeUrgencyThreshold(threshold *time.Time, deadline *time.Time) *time.Time {
|
||||
if threshold == nil {
|
||||
return nil
|
||||
}
|
||||
if deadline == nil {
|
||||
return threshold
|
||||
}
|
||||
if threshold.After(*deadline) {
|
||||
normalized := *deadline
|
||||
return &normalized
|
||||
}
|
||||
return threshold
|
||||
}
|
||||
|
||||
func fallbackPriority(st *QuickNoteState) int {
|
||||
// 兜底规则:
|
||||
// 1) 有截止时间且 <=48h:重要且紧急;
|
||||
// 2) 有截止时间但较远:重要不紧急;
|
||||
// 3) 无截止时间:简单不重要。
|
||||
if st == nil {
|
||||
return QuickNotePrioritySimpleNotImportant
|
||||
}
|
||||
if st.ExtractedDeadline != nil {
|
||||
if time.Until(*st.ExtractedDeadline) <= 48*time.Hour {
|
||||
return QuickNotePriorityImportantUrgent
|
||||
}
|
||||
return QuickNotePriorityImportantNotUrgent
|
||||
}
|
||||
return QuickNotePrioritySimpleNotImportant
|
||||
}
|
||||
|
||||
// deriveQuickNoteTitleFromInput 在“跳过二次意图判定”场景下,从用户原句提取任务标题。
|
||||
func deriveQuickNoteTitleFromInput(userInput string) string {
|
||||
// 1. 先清理空白。
|
||||
text := strings.TrimSpace(userInput)
|
||||
if text == "" {
|
||||
return "这条任务"
|
||||
}
|
||||
|
||||
// 2. 去掉常见指令前缀,保留核心任务语义。
|
||||
prefixes := []string{
|
||||
"请帮我", "麻烦帮我", "麻烦你", "帮我", "提醒我", "请提醒我", "记一下", "记个", "帮我记一下",
|
||||
}
|
||||
for _, prefix := range prefixes {
|
||||
if strings.HasPrefix(text, prefix) {
|
||||
text = strings.TrimSpace(strings.TrimPrefix(text, prefix))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 截断“记得/到时候”等尾部提醒语,避免标题过长。
|
||||
suffixSeparators := []string{
|
||||
",记得", ",记得", ",到时候", ",到时候", " 到时候", ",别忘了", ",别忘了", "。记得",
|
||||
}
|
||||
for _, sep := range suffixSeparators {
|
||||
if idx := strings.Index(text, sep); idx > 0 {
|
||||
text = strings.TrimSpace(text[:idx])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 收尾清理标点;若清理后为空则回退原句。
|
||||
text = strings.Trim(text, ",,。.!!?;; ")
|
||||
if text == "" {
|
||||
return strings.TrimSpace(userInput)
|
||||
}
|
||||
return text
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
package quicknote
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDeriveQuickNoteTitleFromInput(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "保留核心事项并去掉尾部提醒口头语",
|
||||
input: "明天上午12点我要去取快递,到时候记得q我",
|
||||
want: "明天上午12点我要去取快递",
|
||||
},
|
||||
{
|
||||
name: "去掉常见前缀口头语",
|
||||
input: "提醒我周五下午三点交实验报告",
|
||||
want: "周五下午三点交实验报告",
|
||||
},
|
||||
{
|
||||
name: "空输入兜底",
|
||||
input: " ",
|
||||
want: "这条任务",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := deriveQuickNoteTitleFromInput(tc.input)
|
||||
if got != tc.want {
|
||||
t.Fatalf("title 提取不符合预期,got=%q want=%q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
package quicknote
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
)
|
||||
|
||||
// quickNoteRunner 是“单次图运行”的请求级依赖容器。
|
||||
// 设计目标:
|
||||
// 1) 把节点运行所需依赖(input/tool/emit)就近收口;
|
||||
// 2) 让 graph.go 只保留“节点连线”和“方法引用”,提升可读性;
|
||||
// 3) 避免在 graph.go 里重复出现内联闭包和参数透传。
|
||||
type quickNoteRunner struct {
|
||||
input QuickNoteGraphRunInput
|
||||
createTaskTool tool.InvokableTool
|
||||
emitStage func(stage, detail string)
|
||||
}
|
||||
|
||||
// newQuickNoteRunner 构造请求级 runner。
|
||||
// 说明:runner 生命周期仅限一次 graph invoke,不做跨请求复用。
|
||||
func newQuickNoteRunner(input QuickNoteGraphRunInput, createTaskTool tool.InvokableTool, emitStage func(stage, detail string)) *quickNoteRunner {
|
||||
return &quickNoteRunner{
|
||||
input: input,
|
||||
createTaskTool: createTaskTool,
|
||||
emitStage: emitStage,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *quickNoteRunner) intentNode(ctx context.Context, st *QuickNoteState) (*QuickNoteState, error) {
|
||||
// 方法引用适配层:把 runner 内部依赖透传到纯函数节点实现。
|
||||
return runQuickNoteIntentNode(ctx, st, r.input, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *quickNoteRunner) priorityNode(ctx context.Context, st *QuickNoteState) (*QuickNoteState, error) {
|
||||
// 方法引用适配层:让 graph.go 保持“只连线,不写业务细节”。
|
||||
return runQuickNotePriorityNode(ctx, st, r.input, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *quickNoteRunner) persistNode(ctx context.Context, st *QuickNoteState) (*QuickNoteState, error) {
|
||||
// 这里注入 createTaskTool,是为了让 persist 节点不直接依赖外部容器对象。
|
||||
return runQuickNotePersistNodeInternal(ctx, st, r.createTaskTool, r.input, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *quickNoteRunner) nextAfterIntent(ctx context.Context, st *QuickNoteState) (string, error) {
|
||||
// 当前分支决策是纯状态函数,不依赖 context,保留参数仅为适配 GraphBranch 签名。
|
||||
_ = ctx
|
||||
return selectQuickNoteNextAfterIntent(st), nil
|
||||
}
|
||||
|
||||
func (r *quickNoteRunner) nextAfterPersist(ctx context.Context, st *QuickNoteState) (string, error) {
|
||||
// 当前分支决策是纯状态函数,不依赖 context,保留参数仅为适配 GraphBranch 签名。
|
||||
_ = ctx
|
||||
return selectQuickNoteNextAfterPersist(st), nil
|
||||
}
|
||||
|
||||
func (r *quickNoteRunner) exitNode(ctx context.Context, st *QuickNoteState) (*QuickNoteState, error) {
|
||||
// exit 节点不做任何业务逻辑,仅把当前状态原样透传到 END。
|
||||
_ = ctx
|
||||
return st, nil
|
||||
}
|
||||
@@ -1,174 +0,0 @@
|
||||
package quicknote
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
// QuickNoteDatetimeMinuteLayout 是“随口记”链路内部统一的分钟级时间格式。
|
||||
// 说明:
|
||||
// 1) 用于把“当前时间基准”传给模型,避免模型在相对时间推断时出现秒级抖动。
|
||||
// 2) 用于日志和调试,读起来比 RFC3339 更直观。
|
||||
QuickNoteDatetimeMinuteLayout = "2006-01-02 15:04"
|
||||
|
||||
// quickNoteTimezoneName 是随口记链路默认业务时区。
|
||||
// 这里固定为东八区,避免容器运行在 UTC 时把“明天/今晚”解释偏移到错误日期。
|
||||
quickNoteTimezoneName = "Asia/Shanghai"
|
||||
|
||||
// QuickNotePriorityImportantUrgent 对应四象限里的“重要且紧急”。
|
||||
// 在当前 tasks 表中映射为 priority=1(数值越小优先级越高)。
|
||||
QuickNotePriorityImportantUrgent = 1
|
||||
// QuickNotePriorityImportantNotUrgent 对应“重要不紧急”。
|
||||
QuickNotePriorityImportantNotUrgent = 2
|
||||
// QuickNotePrioritySimpleNotImportant 对应“简单不重要”。
|
||||
QuickNotePrioritySimpleNotImportant = 3
|
||||
// QuickNotePriorityComplexNotImportant 对应“不简单不重要”。
|
||||
QuickNotePriorityComplexNotImportant = 4
|
||||
)
|
||||
|
||||
// IsValidTaskPriority 判断优先级是否合法。
|
||||
// 目前后端任务模型限定为 1~4。
|
||||
func IsValidTaskPriority(priority int) bool {
|
||||
return priority >= QuickNotePriorityImportantUrgent && priority <= QuickNotePriorityComplexNotImportant
|
||||
}
|
||||
|
||||
// PriorityLabelCN 把优先级数值转换为中文标签,便于拼接给用户的自然语言回复。
|
||||
func PriorityLabelCN(priority int) string {
|
||||
switch priority {
|
||||
case QuickNotePriorityImportantUrgent:
|
||||
return "重要且紧急"
|
||||
case QuickNotePriorityImportantNotUrgent:
|
||||
return "重要不紧急"
|
||||
case QuickNotePrioritySimpleNotImportant:
|
||||
return "简单不重要"
|
||||
case QuickNotePriorityComplexNotImportant:
|
||||
return "不简单不重要"
|
||||
default:
|
||||
return "未知优先级"
|
||||
}
|
||||
}
|
||||
|
||||
// QuickNoteState 是“AI随口记”链路在 graph 节点间传递的统一状态容器。
|
||||
// 设计目标:
|
||||
// 1) 把本次请求的上下文收拢到一个结构里,降低节点函数参数散落;
|
||||
// 2) 让“识别、评估、写库、重试、回复”每一步都可追踪;
|
||||
// 3) 便于后续扩展打点和可观测字段(例如时间解析失败原因)。
|
||||
type QuickNoteState struct {
|
||||
// 基础上下文:用于日志关联与用户隔离。
|
||||
TraceID string
|
||||
UserID int
|
||||
ConversationID string
|
||||
|
||||
// RequestNow 记录“请求进入随口记链路时”的时间基准(分钟级)。
|
||||
// 所有相对时间(明天/后天/下周一)都必须基于这个时间计算,
|
||||
// 这样同一次请求内不会因为时间流逝产生口径漂移。
|
||||
RequestNow time.Time
|
||||
// RequestNowText 是 RequestNow 的字符串形式,主要用于 prompt 注入。
|
||||
RequestNowText string
|
||||
|
||||
// 用户原始输入(例如:提醒我下周日之前完成大作业)。
|
||||
UserInput string
|
||||
|
||||
// 意图判定结果。
|
||||
IsQuickNoteIntent bool
|
||||
IntentJudgeReason string
|
||||
|
||||
// 结构化抽取结果:由“意图识别/信息抽取”节点写入。
|
||||
ExtractedTitle string
|
||||
ExtractedDeadline *time.Time
|
||||
ExtractedDeadlineText string
|
||||
// ExtractedUrgencyThreshold 表示“进入紧急象限的分界时间”。
|
||||
//
|
||||
// 语义说明:
|
||||
// 1. 该时间由模型规划后给出,并在后端做解析校验;
|
||||
// 2. 到达该时间后,任务可在“读时派生 + 异步落库”链路中被自动平移;
|
||||
// 3. 为空表示该任务不参与自动平移。
|
||||
ExtractedUrgencyThreshold *time.Time
|
||||
ExtractedPriority int
|
||||
// ExtractedBanter 是聚合规划阶段生成的“轻松跟进句”。
|
||||
// 该字段非空时,最终回复阶段可直接复用,避免再触发一次独立润色模型调用。
|
||||
ExtractedBanter string
|
||||
// PlannedBySingleCall 标记本次是否走了“单请求聚合规划”快路径。
|
||||
// 用于在后续节点做更激进的性能策略(例如缺失字段时直接本地兜底,避免再触发模型调用)。
|
||||
PlannedBySingleCall bool
|
||||
|
||||
// ExtractedPriorityReason 记录优先级评估理由,便于后续排查模型判断是否符合预期。
|
||||
ExtractedPriorityReason string
|
||||
// DeadlineValidationError 记录时间校验失败原因。
|
||||
// 只要该字段非空,就说明用户提供了无法解析的时间表达,本次请求不应落库。
|
||||
DeadlineValidationError string
|
||||
|
||||
// 工具调用过程状态:用于重试与故障回溯。
|
||||
ToolAttemptCount int
|
||||
MaxToolRetry int
|
||||
LastToolError string
|
||||
|
||||
// 最终持久化结果:由“写库工具”节点回填。
|
||||
PersistedTaskID int
|
||||
Persisted bool
|
||||
|
||||
// AssistantReply 是 graph 最终给用户的回复文案。
|
||||
AssistantReply string
|
||||
}
|
||||
|
||||
// NewQuickNoteState 创建随口记状态对象并初始化默认重试次数。
|
||||
func NewQuickNoteState(traceID string, userID int, conversationID, userInput string) *QuickNoteState {
|
||||
// 1. 在“进入链路”这一刻固化时间基准,后续所有相对时间都以它为准。
|
||||
requestNow := quickNoteNowToMinute()
|
||||
return &QuickNoteState{
|
||||
TraceID: traceID,
|
||||
UserID: userID,
|
||||
ConversationID: conversationID,
|
||||
RequestNow: requestNow,
|
||||
RequestNowText: formatQuickNoteTimeToMinute(requestNow),
|
||||
UserInput: userInput,
|
||||
MaxToolRetry: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// CanRetryTool 判断当前是否还能继续重试工具调用。
|
||||
func (s *QuickNoteState) CanRetryTool() bool {
|
||||
// 规则:已尝试次数 < 最大重试次数 才允许继续。
|
||||
// 这里不做 <=,是为了让“第 MaxToolRetry 次失败后”及时停机并给用户明确反馈。
|
||||
return s.ToolAttemptCount < s.MaxToolRetry
|
||||
}
|
||||
|
||||
// RecordToolError 记录一次工具调用失败。
|
||||
func (s *QuickNoteState) RecordToolError(errMsg string) {
|
||||
// 1. 每失败一次都要累加计数,供分支节点判断是否继续重试。
|
||||
s.ToolAttemptCount++
|
||||
// 2. 保留最后一次错误,便于日志与排障定位“最终失败原因”。
|
||||
s.LastToolError = errMsg
|
||||
}
|
||||
|
||||
// RecordToolSuccess 记录一次工具调用成功。
|
||||
func (s *QuickNoteState) RecordToolSuccess(taskID int) {
|
||||
// 1. 成功同样计入尝试次数,便于还原完整调用轨迹。
|
||||
s.ToolAttemptCount++
|
||||
// 2. 回填 task_id 和成功标志,供后续节点拼接成功回复。
|
||||
s.PersistedTaskID = taskID
|
||||
s.Persisted = true
|
||||
// 3. 成功后清空错误,避免后续误读历史失败信息。
|
||||
s.LastToolError = ""
|
||||
}
|
||||
|
||||
// quickNoteLocation 返回随口记链路使用的业务时区。
|
||||
func quickNoteLocation() *time.Location {
|
||||
// 1. 优先加载业务固定时区,保证“明天/今晚”等语义与用户预期一致。
|
||||
loc, err := time.LoadLocation(quickNoteTimezoneName)
|
||||
if err != nil {
|
||||
// 2. 极端情况下回退到系统本地时区,避免因时区加载失败导致链路整体不可用。
|
||||
return time.Local
|
||||
}
|
||||
return loc
|
||||
}
|
||||
|
||||
// quickNoteNowToMinute 返回当前时间并截断到分钟级。
|
||||
func quickNoteNowToMinute() time.Time {
|
||||
// 统一截断到分钟,避免秒级抖动导致“同一次请求前后解析口径不一致”。
|
||||
return time.Now().In(quickNoteLocation()).Truncate(time.Minute)
|
||||
}
|
||||
|
||||
// formatQuickNoteTimeToMinute 将时间格式化为分钟级字符串。
|
||||
func formatQuickNoteTimeToMinute(t time.Time) string {
|
||||
// 输出前统一转换到业务时区,避免日志和 prompt 出现跨时区混淆。
|
||||
return t.In(quickNoteLocation()).Format(QuickNoteDatetimeMinuteLayout)
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
package quicknote
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseOptionalDeadlineWithNow_Absolute(t *testing.T) {
|
||||
loc := quickNoteLocation()
|
||||
now := time.Date(2026, 3, 12, 10, 15, 0, 0, loc)
|
||||
|
||||
deadline, err := parseOptionalDeadlineWithNow("2026-03-20 18:30", now)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if deadline == nil {
|
||||
t.Fatalf("deadline should not be nil")
|
||||
}
|
||||
|
||||
want := time.Date(2026, 3, 20, 18, 30, 0, 0, loc)
|
||||
if !deadline.Equal(want) {
|
||||
t.Fatalf("unexpected deadline, got=%s want=%s", deadline.Format(time.RFC3339), want.Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOptionalDeadlineWithNow_RelativeTomorrowWithoutClock(t *testing.T) {
|
||||
loc := quickNoteLocation()
|
||||
now := time.Date(2026, 3, 12, 10, 15, 0, 0, loc)
|
||||
|
||||
deadline, err := parseOptionalDeadlineWithNow("明天交计网实验报告", now)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if deadline == nil {
|
||||
t.Fatalf("deadline should not be nil")
|
||||
}
|
||||
|
||||
want := time.Date(2026, 3, 13, 23, 59, 0, 0, loc)
|
||||
if !deadline.Equal(want) {
|
||||
t.Fatalf("unexpected deadline, got=%s want=%s", deadline.Format(time.RFC3339), want.Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOptionalDeadlineWithNow_RelativeTomorrowWithClock(t *testing.T) {
|
||||
loc := quickNoteLocation()
|
||||
now := time.Date(2026, 3, 12, 10, 15, 0, 0, loc)
|
||||
|
||||
deadline, err := parseOptionalDeadlineWithNow("明天下午3点交计网实验报告", now)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if deadline == nil {
|
||||
t.Fatalf("deadline should not be nil")
|
||||
}
|
||||
|
||||
want := time.Date(2026, 3, 13, 15, 0, 0, 0, loc)
|
||||
if !deadline.Equal(want) {
|
||||
t.Fatalf("unexpected deadline, got=%s want=%s", deadline.Format(time.RFC3339), want.Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOptionalDeadlineWithNow_RelativeWeekday(t *testing.T) {
|
||||
loc := quickNoteLocation()
|
||||
now := time.Date(2026, 3, 12, 10, 15, 0, 0, loc) // 周四
|
||||
|
||||
deadline, err := parseOptionalDeadlineWithNow("下周一上午9点开组会", now)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if deadline == nil {
|
||||
t.Fatalf("deadline should not be nil")
|
||||
}
|
||||
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, loc)
|
||||
if !deadline.Equal(want) {
|
||||
t.Fatalf("unexpected deadline, got=%s want=%s", deadline.Format(time.RFC3339), want.Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOptionalDeadlineFromUserInput_NoHint(t *testing.T) {
|
||||
loc := quickNoteLocation()
|
||||
now := time.Date(2026, 3, 12, 10, 15, 0, 0, loc)
|
||||
|
||||
deadline, hasHint, err := parseOptionalDeadlineFromUserInput("帮我记一下要复习计网", now)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if hasHint {
|
||||
t.Fatalf("expected no time hint")
|
||||
}
|
||||
if deadline != nil {
|
||||
t.Fatalf("deadline should be nil when no time hint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOptionalDeadlineFromUserInput_InvalidDate(t *testing.T) {
|
||||
loc := quickNoteLocation()
|
||||
now := time.Date(2026, 3, 12, 10, 15, 0, 0, loc)
|
||||
|
||||
deadline, hasHint, err := parseOptionalDeadlineFromUserInput("2026-13-45 25:99 交实验", now)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error but got nil")
|
||||
}
|
||||
if !hasHint {
|
||||
t.Fatalf("expected hasHint=true")
|
||||
}
|
||||
if deadline != nil {
|
||||
t.Fatalf("deadline should be nil for invalid date")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOptionalDeadlineWithNow_Invalid(t *testing.T) {
|
||||
loc := quickNoteLocation()
|
||||
now := time.Date(2026, 3, 12, 10, 15, 0, 0, loc)
|
||||
|
||||
deadline, err := parseOptionalDeadlineWithNow("记得尽快处理", now)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error but got nil")
|
||||
}
|
||||
if deadline != nil {
|
||||
t.Fatalf("deadline should be nil for invalid input")
|
||||
}
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
package route
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseRouteControlTag_SchedulePlanCreate(t *testing.T) {
|
||||
nonce := "nonce-create"
|
||||
raw := `<SMARTFLOW_ROUTE nonce="nonce-create" action="schedule_plan_create"></SMARTFLOW_ROUTE>
|
||||
<SMARTFLOW_REASON>新建排程</SMARTFLOW_REASON>`
|
||||
|
||||
decision, err := ParseRouteControlTag(raw, nonce)
|
||||
if err != nil {
|
||||
t.Fatalf("解析失败: %v", err)
|
||||
}
|
||||
if decision.Action != ActionSchedulePlanCreate {
|
||||
t.Fatalf("action 不匹配,期望=%s 实际=%s", ActionSchedulePlanCreate, decision.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRouteControlTag_SchedulePlanRefine(t *testing.T) {
|
||||
nonce := "nonce-refine"
|
||||
raw := `<SMARTFLOW_ROUTE nonce="nonce-refine" action="schedule_plan_refine"></SMARTFLOW_ROUTE>
|
||||
<SMARTFLOW_REASON>微调排程</SMARTFLOW_REASON>`
|
||||
|
||||
decision, err := ParseRouteControlTag(raw, nonce)
|
||||
if err != nil {
|
||||
t.Fatalf("解析失败: %v", err)
|
||||
}
|
||||
if decision.Action != ActionSchedulePlanRefine {
|
||||
t.Fatalf("action 不匹配,期望=%s 实际=%s", ActionSchedulePlanRefine, decision.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRouteControlTag_LegacySchedulePlan(t *testing.T) {
|
||||
nonce := "nonce-legacy"
|
||||
raw := `<SMARTFLOW_ROUTE nonce="nonce-legacy" action="schedule_plan"></SMARTFLOW_ROUTE>
|
||||
<SMARTFLOW_REASON>兼容旧动作</SMARTFLOW_REASON>`
|
||||
|
||||
decision, err := ParseRouteControlTag(raw, nonce)
|
||||
if err != nil {
|
||||
t.Fatalf("解析失败: %v", err)
|
||||
}
|
||||
if decision.Action != ActionSchedulePlanCreate {
|
||||
t.Fatalf("旧动作映射错误,期望=%s 实际=%s", ActionSchedulePlanCreate, decision.Action)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package route
|
||||
package agentrouter
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -8,11 +8,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
agentllm "github.com/LoveLosita/smartflow/backend/agent/llm"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
einoModel "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/google/uuid"
|
||||
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -61,7 +59,12 @@ const routeControlPrompt = `你是 SmartFlow 的请求分流控制器。
|
||||
|
||||
禁止输出任何其他内容。`
|
||||
|
||||
// Action 表示分流动作。
|
||||
// Action 是 Agent 路由层对业务动作的统一命名。
|
||||
//
|
||||
// 这里直接定义在 router 包,而不是复用旧 route 包:
|
||||
// 1. 当前这轮迁移要求只有 router 可以保留对旧链路的兼容语义;
|
||||
// 2. chat / quicknote 已经要完全切到 Agent,自然不该再依赖旧包常量;
|
||||
// 3. schedule/taskquery 尚未搬迁完成时,也能继续靠这些常量在 service 层做统一分发。
|
||||
type Action string
|
||||
|
||||
const (
|
||||
@@ -169,27 +172,17 @@ func routeByModelControlTag(ctx context.Context, selectedModel *ark.ChatModel, u
|
||||
nowText := time.Now().In(time.Local).Format("2006-01-02 15:04")
|
||||
userPrompt := fmt.Sprintf("nonce=%s\n当前时间=%s\n用户输入=%s", nonce, nowText, strings.TrimSpace(userMessage))
|
||||
|
||||
resp, err := selectedModel.Generate(routeCtx, []*schema.Message{
|
||||
schema.SystemMessage(routeControlPrompt),
|
||||
schema.UserMessage(userPrompt),
|
||||
},
|
||||
ark.WithThinking(&arkModel.Thinking{Type: arkModel.ThinkingTypeDisabled}),
|
||||
einoModel.WithTemperature(0),
|
||||
einoModel.WithMaxTokens(120),
|
||||
)
|
||||
// 1. 调用目的:路由场景只需要稳定、短文本、禁用 thinking 的结构化输出。
|
||||
// 2. 这里复用 Agent 公共 LLM 封装,删除与 quicknote 重复的 JSON/文本调用样板代码。
|
||||
resp, err := agentllm.CallArkText(routeCtx, selectedModel, routeControlPrompt, userPrompt, agentllm.ArkCallOptions{
|
||||
Temperature: 0,
|
||||
MaxTokens: 120,
|
||||
Thinking: agentllm.ThinkingModeDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("empty route response")
|
||||
}
|
||||
|
||||
raw := strings.TrimSpace(resp.Content)
|
||||
if raw == "" {
|
||||
return nil, fmt.Errorf("empty route content")
|
||||
}
|
||||
|
||||
return ParseRouteControlTag(raw, nonce)
|
||||
return ParseRouteControlTag(resp, nonce)
|
||||
}
|
||||
|
||||
// deriveRouteControlContext 为“控制码路由”创建子上下文。
|
||||
67
backend/agent/router/route.go
Normal file
67
backend/agent/router/route.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package agentrouter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Dispatcher 是 Agent 的统一分发器。
|
||||
type Dispatcher struct {
|
||||
resolver Resolver
|
||||
handlers map[Action]SkillHandler
|
||||
}
|
||||
|
||||
// NewDispatcher 创建统一分发器。
|
||||
func NewDispatcher(resolver Resolver) *Dispatcher {
|
||||
return &Dispatcher{
|
||||
resolver: resolver,
|
||||
handlers: make(map[Action]SkillHandler),
|
||||
}
|
||||
}
|
||||
|
||||
// Register 注册某个动作的处理函数。
|
||||
func (d *Dispatcher) Register(action Action, handler SkillHandler) error {
|
||||
if d == nil {
|
||||
return errors.New("dispatcher is nil")
|
||||
}
|
||||
if action == "" {
|
||||
return errors.New("route action is empty")
|
||||
}
|
||||
if handler == nil {
|
||||
return fmt.Errorf("handler for action %s is nil", action)
|
||||
}
|
||||
if _, exists := d.handlers[action]; exists {
|
||||
return fmt.Errorf("handler for action %s already registered", action)
|
||||
}
|
||||
d.handlers[action] = handler
|
||||
return nil
|
||||
}
|
||||
|
||||
// Dispatch 执行“分流 -> skill handler”完整入口。
|
||||
func (d *Dispatcher) Dispatch(ctx context.Context, req *AgentRequest) (*AgentResponse, error) {
|
||||
if d == nil || d.resolver == nil {
|
||||
return nil, errors.New("route dispatcher is not ready")
|
||||
}
|
||||
if req == nil {
|
||||
return nil, errors.New("agent request is nil")
|
||||
}
|
||||
|
||||
// 1. 调用目的:统一先走一级路由,让入口层只关心“请求来了”,
|
||||
// 不需要提前知道这是普通聊天、随口记、任务查询还是后续排程。
|
||||
decision, err := d.resolver.Resolve(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if decision == nil {
|
||||
return nil, errors.New("route decision is nil")
|
||||
}
|
||||
|
||||
// 2. 路由结果出来后,只根据 action 查找对应 handler。
|
||||
// 这里故意不做 skill 级 fallback,避免路由层和 skill 内部职责再次缠在一起。
|
||||
handler, exists := d.handlers[decision.Action]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no handler registered for action %s", decision.Action)
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
34
backend/agent/router/route_model.go
Normal file
34
backend/agent/router/route_model.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package agentrouter
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Resolver 定义一级路由器能力。
|
||||
type Resolver interface {
|
||||
Resolve(ctx context.Context, req *AgentRequest) (*RoutingDecision, error)
|
||||
}
|
||||
|
||||
// SkillHandler 是某个 skill 的统一执行入口。
|
||||
type SkillHandler func(ctx context.Context, req *AgentRequest) (*AgentResponse, error)
|
||||
|
||||
// AgentRequest 是 Agent 路由层可见的最小请求结构。
|
||||
//
|
||||
// 设计目的:
|
||||
// 1. 让 router 层只依赖自己真正关心的字段;
|
||||
// 2. 避免把整份 agentmodel 结构在迁移早期层层透传;
|
||||
// 3. 后续若总入口还要追加别的字段,只需要在入口层做一次映射。
|
||||
type AgentRequest struct {
|
||||
UserID int
|
||||
ConversationID string
|
||||
UserMessage string
|
||||
ModelName string
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
// AgentResponse 是路由分发器对 skill handler 的统一响应外壳。
|
||||
type AgentResponse struct {
|
||||
Action Action
|
||||
Reply string
|
||||
Meta map[string]any
|
||||
}
|
||||
@@ -1,315 +0,0 @@
|
||||
package scheduleplan
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
const (
|
||||
// dailyReactRoundTimeout 是日内单轮模型调用超时。
|
||||
// 日内节点走并发调用,超时要比周级更保守,避免占满资源。
|
||||
dailyReactRoundTimeout = 3 * time.Minute
|
||||
)
|
||||
|
||||
// runDailyRefineNode 负责“并发日内优化”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责按 DayGroup 并发调用单日 ReAct;
|
||||
// 2. 负责输出“按天开始/完成”的阶段状态块(不推 reasoning 细流);
|
||||
// 3. 负责把单日失败回退到原始数据,确保全链路可继续;
|
||||
// 4. 不负责跨天配平(交给 weekly_refine),不负责最终总结(交给 final_check)。
|
||||
func runDailyRefineNode(
|
||||
ctx context.Context,
|
||||
st *SchedulePlanState,
|
||||
chatModel *ark.ChatModel,
|
||||
dailyRefineConcurrency int,
|
||||
emitStage func(stage, detail string),
|
||||
) (*SchedulePlanState, error) {
|
||||
if st == nil || len(st.DailyGroups) == 0 {
|
||||
return st, nil
|
||||
}
|
||||
if chatModel == nil {
|
||||
return st, fmt.Errorf("schedule plan daily refine: model is nil")
|
||||
}
|
||||
|
||||
// 1. 并发度兜底:
|
||||
// 1.1 优先使用注入参数;
|
||||
// 1.2 若注入参数非法,则回退到 state 值;
|
||||
// 1.3 state 也非法时,回退到编译期默认值。
|
||||
if dailyRefineConcurrency <= 0 {
|
||||
dailyRefineConcurrency = st.DailyRefineConcurrency
|
||||
}
|
||||
if dailyRefineConcurrency <= 0 {
|
||||
dailyRefineConcurrency = schedulePlanDefaultDailyRefineConcurrency
|
||||
}
|
||||
|
||||
emitStage(
|
||||
"schedule_plan.daily_refine.start",
|
||||
fmt.Sprintf("正在并发优化各天日程,并发度=%d。", dailyRefineConcurrency),
|
||||
)
|
||||
|
||||
// 2. 拉平所有 DayGroup 并排序,确保日志与阶段输出稳定可读。
|
||||
allGroups := flattenAndSortDayGroups(st.DailyGroups)
|
||||
if len(allGroups) == 0 {
|
||||
st.DailyResults = make(map[int]map[int][]model.HybridScheduleEntry)
|
||||
emitStage("schedule_plan.daily_refine.done", "没有可优化的天,跳过日内优化。")
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 3. 并发执行:
|
||||
// 3.1 sem 控制并发上限;
|
||||
// 3.2 wg 等待全部 goroutine 完成;
|
||||
// 3.3 mu 保护 results/firstErr,避免竞态。
|
||||
sem := make(chan struct{}, dailyRefineConcurrency)
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
totalGroups := int32(len(allGroups))
|
||||
var finishedGroups int32
|
||||
|
||||
results := make(map[int]map[int][]model.HybridScheduleEntry)
|
||||
var firstErr error
|
||||
|
||||
for _, group := range allGroups {
|
||||
g := group
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// 3.4 先申请并发令牌;若 ctx 已取消,直接回退原始数据并结束。
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
defer func() { <-sem }()
|
||||
case <-ctx.Done():
|
||||
mu.Lock()
|
||||
if firstErr == nil {
|
||||
firstErr = ctx.Err()
|
||||
}
|
||||
ensureDayResult(results, g.Week, g.DayOfWeek, g.Entries)
|
||||
mu.Unlock()
|
||||
// 3.4.1 取消场景也要计入进度,避免前端看到“卡住不动”。
|
||||
done := atomic.AddInt32(&finishedGroups, 1)
|
||||
emitStage(
|
||||
"schedule_plan.daily_refine.day_done",
|
||||
fmt.Sprintf("W%dD%d 已取消并回退原方案。(进度 %d/%d)", g.Week, g.DayOfWeek, done, totalGroups),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
emitStage(
|
||||
"schedule_plan.daily_refine.day_start",
|
||||
fmt.Sprintf("正在安排 W%dD%d。(当前进度 %d/%d)", g.Week, g.DayOfWeek, atomic.LoadInt32(&finishedGroups), totalGroups),
|
||||
)
|
||||
|
||||
// 3.5 低收益天直接跳过模型调用,原样透传。
|
||||
if g.SkipRefine {
|
||||
mu.Lock()
|
||||
ensureDayResult(results, g.Week, g.DayOfWeek, g.Entries)
|
||||
mu.Unlock()
|
||||
done := atomic.AddInt32(&finishedGroups, 1)
|
||||
emitStage(
|
||||
"schedule_plan.daily_refine.day_done",
|
||||
fmt.Sprintf("W%dD%d suggested 较少,已跳过优化。(进度 %d/%d)", g.Week, g.DayOfWeek, done, totalGroups),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// 3.6 深拷贝输入,避免并发场景下意外修改共享切片。
|
||||
localEntries := deepCopyEntries(g.Entries)
|
||||
|
||||
// 3.7 动态轮次:
|
||||
// 3.7.1 suggested <= 4:1轮足够;
|
||||
// 3.7.2 suggested > 4:最多2轮,提升复杂天优化质量。
|
||||
maxRounds := 1
|
||||
if countSuggested(localEntries) > 4 {
|
||||
maxRounds = 2
|
||||
}
|
||||
|
||||
optimized, refineErr := runSingleDayReact(ctx, chatModel, localEntries, maxRounds, g.Week, g.DayOfWeek)
|
||||
if refineErr != nil {
|
||||
mu.Lock()
|
||||
if firstErr == nil {
|
||||
firstErr = refineErr
|
||||
}
|
||||
// 3.8 单天失败回退:
|
||||
// 3.8.1 保证失败只影响该天;
|
||||
// 3.8.2 保证总流程可继续推进到 merge/weekly/final。
|
||||
ensureDayResult(results, g.Week, g.DayOfWeek, g.Entries)
|
||||
mu.Unlock()
|
||||
done := atomic.AddInt32(&finishedGroups, 1)
|
||||
emitStage(
|
||||
"schedule_plan.daily_refine.day_done",
|
||||
fmt.Sprintf("W%dD%d 优化失败,已回退原方案。(进度 %d/%d)", g.Week, g.DayOfWeek, done, totalGroups),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
ensureDayResult(results, g.Week, g.DayOfWeek, optimized)
|
||||
mu.Unlock()
|
||||
done := atomic.AddInt32(&finishedGroups, 1)
|
||||
emitStage(
|
||||
"schedule_plan.daily_refine.day_done",
|
||||
fmt.Sprintf("W%dD%d 已安排完成。(进度 %d/%d)", g.Week, g.DayOfWeek, done, totalGroups),
|
||||
)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
st.DailyResults = results
|
||||
if firstErr != nil {
|
||||
emitStage("schedule_plan.daily_refine.partial_error", fmt.Sprintf("部分天优化失败,已自动回退。原因:%s", firstErr.Error()))
|
||||
}
|
||||
emitStage("schedule_plan.daily_refine.done", "日内优化阶段完成。")
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// runSingleDayReact 执行单天封闭式 ReAct 优化。
|
||||
//
|
||||
// 关键约束:
|
||||
// 1. prompt 只包含当天数据;
|
||||
// 2. 代码层再做“Move 不能跨天”硬校验;
|
||||
// 3. Thinking 默认关闭,优先降低日内并发阶段的长尾时延。
|
||||
func runSingleDayReact(
|
||||
ctx context.Context,
|
||||
chatModel *ark.ChatModel,
|
||||
entries []model.HybridScheduleEntry,
|
||||
maxRounds int,
|
||||
week int,
|
||||
dayOfWeek int,
|
||||
) ([]model.HybridScheduleEntry, error) {
|
||||
hybridJSON, err := json.Marshal(entries)
|
||||
if err != nil {
|
||||
return entries, err
|
||||
}
|
||||
|
||||
messages := []*schema.Message{
|
||||
schema.SystemMessage(SchedulePlanDailyReactPrompt),
|
||||
schema.UserMessage(fmt.Sprintf(
|
||||
"以下是今天的日程(JSON):\n%s\n\n仅优化这一天的数据,不要跨天移动。",
|
||||
string(hybridJSON),
|
||||
)),
|
||||
}
|
||||
|
||||
for round := 0; round < maxRounds; round++ {
|
||||
roundCtx, cancel := context.WithTimeout(ctx, dailyReactRoundTimeout)
|
||||
resp, generateErr := chatModel.Generate(
|
||||
roundCtx,
|
||||
messages,
|
||||
// 1. 日内优化只做“单天局部微调”,任务边界清晰,默认关闭 thinking 以降低时延。
|
||||
// 2. 周级全局配平仍保留 thinking(在 weekly_refine),这里不承担跨天复杂推理职责。
|
||||
// 3. 若后续观测到质量回退,可只在 suggested 很多时按条件重开 thinking,而不是全量开启。
|
||||
ark.WithThinking(&arkModel.Thinking{Type: arkModel.ThinkingTypeDisabled}),
|
||||
)
|
||||
cancel()
|
||||
if generateErr != nil {
|
||||
return entries, fmt.Errorf("日内 ReAct 第%d轮失败: %w", round+1, generateErr)
|
||||
}
|
||||
if resp == nil {
|
||||
return entries, fmt.Errorf("日内 ReAct 第%d轮返回为空", round+1)
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(resp.Content)
|
||||
parsed, parseErr := parseReactLLMOutput(content)
|
||||
if parseErr != nil {
|
||||
// 解析失败时回退当前轮,不把异常向上放大成整条链路失败。
|
||||
return entries, nil
|
||||
}
|
||||
if parsed.Done || len(parsed.ToolCalls) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// 1. 执行工具调用。
|
||||
// 1.1 每个调用都经过“日内策略约束”校验;
|
||||
// 1.2 任何单次调用失败都只返回 failed result,不中断整轮。
|
||||
results := make([]reactToolResult, 0, len(parsed.ToolCalls))
|
||||
for _, call := range parsed.ToolCalls {
|
||||
var result reactToolResult
|
||||
entries, result = dispatchDailyReactTool(entries, call, week, dayOfWeek)
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
// 2. 把“本轮模型输出 + 工具执行结果”拼入下一轮上下文。
|
||||
// 2.1 这样模型可以看到操作反馈,继续迭代;
|
||||
// 2.2 若下一轮仍无有效动作,会自然在 done/空 tool_calls 退出。
|
||||
messages = append(messages, schema.AssistantMessage(content, nil))
|
||||
resultJSON, _ := json.Marshal(results)
|
||||
messages = append(messages, schema.UserMessage(
|
||||
fmt.Sprintf("工具执行结果:\n%s\n\n请继续优化或输出 {\"done\":true,\"summary\":\"...\"}。", string(resultJSON)),
|
||||
))
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// dispatchDailyReactTool 在通用工具分发前增加“日内硬约束”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只负责校验 Move 的目标是否仍在当前天;
|
||||
// 2. 通过后复用 dispatchReactTool 执行;
|
||||
// 3. 不负责复杂冲突判定(冲突判定由底层工具函数处理)。
|
||||
func dispatchDailyReactTool(entries []model.HybridScheduleEntry, call reactToolCall, week int, dayOfWeek int) ([]model.HybridScheduleEntry, reactToolResult) {
|
||||
if call.Tool == "Move" {
|
||||
toWeek, weekOK := paramInt(call.Params, "to_week")
|
||||
toDay, dayOK := paramInt(call.Params, "to_day")
|
||||
if !weekOK || !dayOK {
|
||||
return entries, reactToolResult{
|
||||
Tool: "Move",
|
||||
Success: false,
|
||||
Result: "参数缺失:to_week/to_day",
|
||||
}
|
||||
}
|
||||
if toWeek != week || toDay != dayOfWeek {
|
||||
return entries, reactToolResult{
|
||||
Tool: "Move",
|
||||
Success: false,
|
||||
Result: fmt.Sprintf("日内优化禁止跨天移动:当前仅允许 W%dD%d", week, dayOfWeek),
|
||||
}
|
||||
}
|
||||
}
|
||||
return dispatchReactTool(entries, call)
|
||||
}
|
||||
|
||||
// flattenAndSortDayGroups 把 map 结构摊平成有序切片,便于稳定并发调度。
|
||||
func flattenAndSortDayGroups(groups map[int]map[int]*DayGroup) []*DayGroup {
|
||||
out := make([]*DayGroup, 0)
|
||||
for _, dayMap := range groups {
|
||||
for _, g := range dayMap {
|
||||
if g != nil {
|
||||
out = append(out, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool {
|
||||
if out[i].Week != out[j].Week {
|
||||
return out[i].Week < out[j].Week
|
||||
}
|
||||
return out[i].DayOfWeek < out[j].DayOfWeek
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
// ensureDayResult 确保 results[week][day] 存在并写入值。
|
||||
func ensureDayResult(results map[int]map[int][]model.HybridScheduleEntry, week int, day int, entries []model.HybridScheduleEntry) {
|
||||
if results[week] == nil {
|
||||
results[week] = make(map[int][]model.HybridScheduleEntry)
|
||||
}
|
||||
results[week][day] = entries
|
||||
}
|
||||
|
||||
// deepCopyEntries 深拷贝 HybridScheduleEntry 切片。
|
||||
func deepCopyEntries(src []model.HybridScheduleEntry) []model.HybridScheduleEntry {
|
||||
dst := make([]model.HybridScheduleEntry, len(src))
|
||||
copy(dst, src)
|
||||
return dst
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
package scheduleplan
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// runDailySplitNode 负责“按天拆分 + 标签注入 + 跳过判断”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责把全量 HybridEntries 拆成 DayGroup,供后续并发日内优化;
|
||||
// 2. 负责把 TaskTags(task_item_id -> tag) 注入到条目的 ContextTag;
|
||||
// 3. 负责识别“低收益天”(suggested<=2)并标记 SkipRefine;
|
||||
// 4. 不负责调用模型,不负责并发执行,不负责结果合并。
|
||||
func runDailySplitNode(
|
||||
ctx context.Context,
|
||||
st *SchedulePlanState,
|
||||
emitStage func(stage, detail string),
|
||||
) (*SchedulePlanState, error) {
|
||||
_ = ctx
|
||||
if st == nil || len(st.HybridEntries) == 0 {
|
||||
return st, nil
|
||||
}
|
||||
|
||||
emitStage("schedule_plan.daily_split.start", "正在按天拆分排程并标记优化单元。")
|
||||
|
||||
// 1. 初始化容器:
|
||||
// 1.1 groups 以 week/day 二级索引保存 DayGroup;
|
||||
// 1.2 这么做的目的是后续 daily_refine 可以直接并发遍历,不再重复分组。
|
||||
groups := make(map[int]map[int]*DayGroup)
|
||||
|
||||
// 2. 遍历混合条目,执行“标签注入 + 分组”。
|
||||
for i := range st.HybridEntries {
|
||||
entry := &st.HybridEntries[i]
|
||||
|
||||
// 2.1 仅对 suggested 条目注入 ContextTag。
|
||||
// 2.1.1 existing 条目是固定课表/已落库任务,不参与认知标签优化。
|
||||
// 2.1.2 注入失败时兜底 General,避免后续 prompt 出现空标签。
|
||||
if entry.Status == "suggested" && entry.TaskItemID > 0 {
|
||||
if tag, ok := st.TaskTags[entry.TaskItemID]; ok {
|
||||
entry.ContextTag = normalizeContextTag(tag)
|
||||
} else {
|
||||
entry.ContextTag = "General"
|
||||
}
|
||||
}
|
||||
|
||||
// 2.2 建立分组索引。
|
||||
if groups[entry.Week] == nil {
|
||||
groups[entry.Week] = make(map[int]*DayGroup)
|
||||
}
|
||||
if groups[entry.Week][entry.DayOfWeek] == nil {
|
||||
groups[entry.Week][entry.DayOfWeek] = &DayGroup{
|
||||
Week: entry.Week,
|
||||
DayOfWeek: entry.DayOfWeek,
|
||||
}
|
||||
}
|
||||
groups[entry.Week][entry.DayOfWeek].Entries = append(groups[entry.Week][entry.DayOfWeek].Entries, *entry)
|
||||
}
|
||||
|
||||
// 3. 逐天计算 suggested 数量,标记是否跳过日内优化。
|
||||
//
|
||||
// 3.1 为什么阈值设为 <=2:
|
||||
// 3.1.1 suggested 很少时,模型优化收益通常不足以覆盖请求成本;
|
||||
// 3.1.2 直接跳过可减少无效模型调用和阶段等待。
|
||||
// 3.2 失败策略:
|
||||
// 3.2.1 这里只做内存标记,不会失败;
|
||||
// 3.2.2 即使阈值判断不完美,也只影响优化深度,不影响功能正确性。
|
||||
totalDays := 0
|
||||
skipDays := 0
|
||||
for _, dayMap := range groups {
|
||||
for _, dayGroup := range dayMap {
|
||||
totalDays++
|
||||
suggestedCount := 0
|
||||
for _, e := range dayGroup.Entries {
|
||||
if e.Status == "suggested" {
|
||||
suggestedCount++
|
||||
}
|
||||
}
|
||||
if suggestedCount <= 2 {
|
||||
dayGroup.SkipRefine = true
|
||||
skipDays++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 回填状态,交给后续节点使用。
|
||||
st.DailyGroups = groups
|
||||
emitStage(
|
||||
"schedule_plan.daily_split.done",
|
||||
fmt.Sprintf("已拆分为 %d 天,其中 %d 天跳过日内优化。", totalDays, skipDays),
|
||||
)
|
||||
return st, nil
|
||||
}
|
||||
@@ -1,171 +0,0 @@
|
||||
package scheduleplan
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
einoModel "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
// runFinalCheckNode 负责“终审校验 + 总结生成”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责执行物理校验(冲突、节次越界、数量核对);
|
||||
// 2. 负责在校验失败时回退到 MergeSnapshot;
|
||||
// 3. 负责生成最终给用户看的自然语言总结;
|
||||
// 4. 不负责写库(本期只做预览)。
|
||||
func runFinalCheckNode(
|
||||
ctx context.Context,
|
||||
st *SchedulePlanState,
|
||||
chatModel *ark.ChatModel,
|
||||
emitStage func(stage, detail string),
|
||||
) (*SchedulePlanState, error) {
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("schedule plan final check: nil state")
|
||||
}
|
||||
|
||||
emitStage("schedule_plan.final_check.start", "正在进行终审校验。")
|
||||
|
||||
// 1. 先做物理校验。
|
||||
issues := physicsCheck(st)
|
||||
if len(issues) > 0 {
|
||||
emitStage("schedule_plan.final_check.issues", fmt.Sprintf("发现 %d 个问题,已回退到日内优化结果。", len(issues)))
|
||||
// 1.1 回退策略:
|
||||
// 1.1.1 优先回退到 merge 快照(已经过冲突校验);
|
||||
// 1.1.2 若快照为空,保留当前结果继续走总结,保证可返回。
|
||||
if len(st.MergeSnapshot) > 0 {
|
||||
st.HybridEntries = deepCopyEntries(st.MergeSnapshot)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 生成人性化总结。
|
||||
//
|
||||
// 2.1 总结失败不影响主流程;
|
||||
// 2.2 失败时使用兜底文案,保证前端始终有可展示文本。
|
||||
summary, err := generateHumanSummary(ctx, chatModel, st.HybridEntries, st.Constraints, st.WeeklyActionLogs)
|
||||
if err != nil || strings.TrimSpace(summary) == "" {
|
||||
st.FinalSummary = fmt.Sprintf("排程优化完成,共安排了 %d 个任务。", countSuggested(st.HybridEntries))
|
||||
} else {
|
||||
st.FinalSummary = strings.TrimSpace(summary)
|
||||
}
|
||||
|
||||
emitStage("schedule_plan.final_check.done", "终审校验完成。")
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// physicsCheck 执行物理层面校验。
|
||||
//
|
||||
// 校验项:
|
||||
// 1. 时间冲突:同一 slot 不允许多任务占用;
|
||||
// 2. 节次越界:section 必须落在 1..12 且 from<=to;
|
||||
// 3. 数量核对:suggested 数量应与原始 AllocatedItems 数量一致。
|
||||
func physicsCheck(st *SchedulePlanState) []string {
|
||||
issues := make([]string, 0)
|
||||
if st == nil {
|
||||
return append(issues, "state 为空")
|
||||
}
|
||||
|
||||
// 1. 时间冲突校验。
|
||||
if conflict := detectConflicts(st.HybridEntries); conflict != "" {
|
||||
issues = append(issues, "时间冲突:"+conflict)
|
||||
}
|
||||
|
||||
// 2. 节次越界校验。
|
||||
for _, entry := range st.HybridEntries {
|
||||
if entry.SectionFrom < 1 || entry.SectionTo > 12 || entry.SectionFrom > entry.SectionTo {
|
||||
issues = append(
|
||||
issues,
|
||||
fmt.Sprintf("节次越界:[%s] W%dD%d 第%d-%d节", entry.Name, entry.Week, entry.DayOfWeek, entry.SectionFrom, entry.SectionTo),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 数量一致性校验。
|
||||
// 3.1 判断依据:suggested 表示“待应用任务块”,应与 allocatedItems 数量匹配;
|
||||
// 3.2 若不匹配,可能表示工具调用丢失或重复覆盖。
|
||||
suggestedCount := countSuggested(st.HybridEntries)
|
||||
if suggestedCount != len(st.AllocatedItems) {
|
||||
issues = append(
|
||||
issues,
|
||||
fmt.Sprintf("任务数量不匹配:suggested=%d,原始分配=%d", suggestedCount, len(st.AllocatedItems)),
|
||||
)
|
||||
}
|
||||
|
||||
return issues
|
||||
}
|
||||
|
||||
// countSuggested 统计 suggested 条目数量。
|
||||
func countSuggested(entries []model.HybridScheduleEntry) int {
|
||||
count := 0
|
||||
for _, entry := range entries {
|
||||
if entry.Status == "suggested" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// generateHumanSummary 调用模型生成“用户可读”的总结文案。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只做读模型,不修改任何 state;
|
||||
// 2. 输出纯文本;
|
||||
// 3. 失败时把错误返回给上层,由上层决定兜底文案。
|
||||
func generateHumanSummary(
|
||||
ctx context.Context,
|
||||
chatModel *ark.ChatModel,
|
||||
entries []model.HybridScheduleEntry,
|
||||
constraints []string,
|
||||
actionLogs []string,
|
||||
) (string, error) {
|
||||
if chatModel == nil {
|
||||
return "", fmt.Errorf("final summary model is nil")
|
||||
}
|
||||
entriesJSON, _ := json.Marshal(entries)
|
||||
constraintText := "无"
|
||||
if len(constraints) > 0 {
|
||||
constraintText = strings.Join(constraints, "、")
|
||||
}
|
||||
actionLogText := "无"
|
||||
if len(actionLogs) > 0 {
|
||||
// 1. 只取最后 30 条动作日志,避免上下文无限膨胀。
|
||||
// 2. 周级优化是“渐进式动作链”,取尾部更能体现最终收敛过程。
|
||||
// 3. 这里仅做展示收敛,不改原日志,保证调试信息完整保留在 state 中。
|
||||
start := 0
|
||||
if len(actionLogs) > 30 {
|
||||
start = len(actionLogs) - 30
|
||||
}
|
||||
actionLogText = strings.Join(actionLogs[start:], "\n")
|
||||
}
|
||||
|
||||
userPrompt := fmt.Sprintf(
|
||||
"以下是最终排程方案(JSON):\n%s\n\n用户约束:%s\n\n以下是本次周级优化动作日志(按时间顺序):\n%s\n\n请基于“结果+过程”输出2-3句自然中文总结,重点说明本方案的优点和改进点。",
|
||||
string(entriesJSON),
|
||||
constraintText,
|
||||
actionLogText,
|
||||
)
|
||||
|
||||
resp, err := chatModel.Generate(
|
||||
ctx,
|
||||
[]*schema.Message{
|
||||
schema.SystemMessage(SchedulePlanFinalCheckPrompt),
|
||||
schema.UserMessage(userPrompt),
|
||||
},
|
||||
ark.WithThinking(&arkModel.Thinking{Type: arkModel.ThinkingTypeDisabled}),
|
||||
einoModel.WithTemperature(0.4),
|
||||
einoModel.WithMaxTokens(256),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp == nil {
|
||||
return "", fmt.Errorf("final summary response is nil")
|
||||
}
|
||||
return strings.TrimSpace(resp.Content), nil
|
||||
}
|
||||
@@ -1,210 +0,0 @@
|
||||
package scheduleplan
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
const (
|
||||
// 图节点:意图识别与约束提取
|
||||
schedulePlanGraphNodePlan = "schedule_plan_plan"
|
||||
// 图节点:粗排构建(替代旧 preview + hybridBuild)
|
||||
schedulePlanGraphNodeRoughBuild = "schedule_plan_rough_build"
|
||||
// 图节点:提前退出
|
||||
schedulePlanGraphNodeExit = "schedule_plan_exit"
|
||||
// 图节点:按天拆分并注入上下文标签
|
||||
schedulePlanGraphNodeDailySplit = "schedule_plan_daily_split"
|
||||
// 图节点:小改动快速微调(用于 small scope)
|
||||
schedulePlanGraphNodeQuickRefine = "schedule_plan_quick_refine"
|
||||
// 图节点:并发日内优化
|
||||
schedulePlanGraphNodeDailyRefine = "schedule_plan_daily_refine"
|
||||
// 图节点:合并日内优化结果
|
||||
schedulePlanGraphNodeMerge = "schedule_plan_merge"
|
||||
// 图节点:周级配平优化(单步动作模式,输出阶段状态)
|
||||
schedulePlanGraphNodeWeeklyRefine = "schedule_plan_weekly_refine"
|
||||
// 图节点:终审校验
|
||||
schedulePlanGraphNodeFinalCheck = "schedule_plan_final_check"
|
||||
// 图节点:返回预览结果(不落库)
|
||||
schedulePlanGraphNodeReturnPreview = "schedule_plan_return_preview"
|
||||
)
|
||||
|
||||
// SchedulePlanGraphRunInput 是执行“智能排程 graph”所需输入。
|
||||
//
|
||||
// 字段说明:
|
||||
// 1. Extra:前端附加参数(重点是 task_class_ids);
|
||||
// 2. ChatHistory:支持连续对话微调;
|
||||
// 3. OutChan/ModelName:保留兼容字段(当前 weekly refine 主要输出阶段状态);
|
||||
// 4. DailyRefineConcurrency/WeeklyAdjustBudget:可选运行参数覆盖。
|
||||
type SchedulePlanGraphRunInput struct {
|
||||
Model *ark.ChatModel
|
||||
State *SchedulePlanState
|
||||
Deps SchedulePlanToolDeps
|
||||
UserMessage string
|
||||
Extra map[string]any
|
||||
ChatHistory []*schema.Message
|
||||
EmitStage func(stage, detail string)
|
||||
|
||||
OutChan chan<- string
|
||||
ModelName string
|
||||
|
||||
DailyRefineConcurrency int
|
||||
WeeklyAdjustBudget int
|
||||
}
|
||||
|
||||
// RunSchedulePlanGraph 执行“智能排程”图编排。
|
||||
//
|
||||
// 当前链路:
|
||||
// START
|
||||
// -> plan
|
||||
// -> roughBuild
|
||||
// -> (len(task_class_ids)>=2 ? dailySplit -> dailyRefine -> merge : weeklyRefine)
|
||||
// -> finalCheck
|
||||
// -> returnPreview
|
||||
// -> END
|
||||
//
|
||||
// 说明:
|
||||
// 1. exit 分支可从 plan/roughBuild 直接提前终止;
|
||||
// 2. 本文件只负责“连线与分支”,节点内业务都在 nodes/daily/weekly 文件中。
|
||||
func RunSchedulePlanGraph(ctx context.Context, input SchedulePlanGraphRunInput) (*SchedulePlanState, error) {
|
||||
// 1. 启动前硬校验。
|
||||
if input.Model == nil {
|
||||
return nil, errors.New("schedule plan graph: model is nil")
|
||||
}
|
||||
if input.State == nil {
|
||||
return nil, errors.New("schedule plan graph: state is nil")
|
||||
}
|
||||
if err := input.Deps.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 注入运行时配置(可选覆盖)。
|
||||
if input.DailyRefineConcurrency > 0 {
|
||||
input.State.DailyRefineConcurrency = input.DailyRefineConcurrency
|
||||
}
|
||||
if input.WeeklyAdjustBudget > 0 {
|
||||
input.State.WeeklyAdjustBudget = input.WeeklyAdjustBudget
|
||||
}
|
||||
|
||||
emitStage := func(stage, detail string) {
|
||||
if input.EmitStage != nil {
|
||||
input.EmitStage(stage, detail)
|
||||
}
|
||||
}
|
||||
|
||||
runner := newSchedulePlanRunner(
|
||||
input.Model,
|
||||
input.Deps,
|
||||
emitStage,
|
||||
input.UserMessage,
|
||||
input.Extra,
|
||||
input.ChatHistory,
|
||||
input.OutChan,
|
||||
input.ModelName,
|
||||
input.State.DailyRefineConcurrency,
|
||||
)
|
||||
|
||||
graph := compose.NewGraph[*SchedulePlanState, *SchedulePlanState]()
|
||||
|
||||
// 3. 注册节点。
|
||||
if err := graph.AddLambdaNode(schedulePlanGraphNodePlan, compose.InvokableLambda(runner.planNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(schedulePlanGraphNodeRoughBuild, compose.InvokableLambda(runner.roughBuildNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(schedulePlanGraphNodeExit, compose.InvokableLambda(runner.exitNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(schedulePlanGraphNodeDailySplit, compose.InvokableLambda(runner.dailySplitNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(schedulePlanGraphNodeQuickRefine, compose.InvokableLambda(runner.quickRefineNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(schedulePlanGraphNodeDailyRefine, compose.InvokableLambda(runner.dailyRefineNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(schedulePlanGraphNodeMerge, compose.InvokableLambda(runner.mergeNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(schedulePlanGraphNodeWeeklyRefine, compose.InvokableLambda(runner.weeklyRefineNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(schedulePlanGraphNodeFinalCheck, compose.InvokableLambda(runner.finalCheckNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(schedulePlanGraphNodeReturnPreview, compose.InvokableLambda(runner.returnPreviewNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. 连线:START -> plan
|
||||
if err := graph.AddEdge(compose.START, schedulePlanGraphNodePlan); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 5. plan 分支:roughBuild | exit
|
||||
if err := graph.AddBranch(schedulePlanGraphNodePlan, compose.NewGraphBranch(
|
||||
runner.nextAfterPlan,
|
||||
map[string]bool{
|
||||
schedulePlanGraphNodeRoughBuild: true,
|
||||
schedulePlanGraphNodeExit: true,
|
||||
},
|
||||
)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 6. roughBuild 分支:dailySplit | weeklyRefine | exit
|
||||
if err := graph.AddBranch(schedulePlanGraphNodeRoughBuild, compose.NewGraphBranch(
|
||||
runner.nextAfterRoughBuild,
|
||||
map[string]bool{
|
||||
schedulePlanGraphNodeDailySplit: true,
|
||||
schedulePlanGraphNodeQuickRefine: true,
|
||||
schedulePlanGraphNodeWeeklyRefine: true,
|
||||
schedulePlanGraphNodeExit: true,
|
||||
},
|
||||
)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 7. 固定边:quickRefine -> weeklyRefine;dailySplit -> dailyRefine -> merge -> weeklyRefine -> finalCheck -> returnPreview -> END
|
||||
if err := graph.AddEdge(schedulePlanGraphNodeQuickRefine, schedulePlanGraphNodeWeeklyRefine); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(schedulePlanGraphNodeDailySplit, schedulePlanGraphNodeDailyRefine); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(schedulePlanGraphNodeDailyRefine, schedulePlanGraphNodeMerge); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(schedulePlanGraphNodeMerge, schedulePlanGraphNodeWeeklyRefine); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(schedulePlanGraphNodeWeeklyRefine, schedulePlanGraphNodeFinalCheck); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(schedulePlanGraphNodeFinalCheck, schedulePlanGraphNodeReturnPreview); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(schedulePlanGraphNodeReturnPreview, compose.END); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(schedulePlanGraphNodeExit, compose.END); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 8. 编译并执行。
|
||||
// 路径最多约 8~9 个节点,保守预留 20 步避免误判。
|
||||
runnable, err := graph.Compile(ctx,
|
||||
compose.WithGraphName("SchedulePlanGraph"),
|
||||
compose.WithMaxRunSteps(20),
|
||||
compose.WithNodeTriggerMode(compose.AnyPredecessor),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return runnable.Invoke(ctx, input.State)
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
package scheduleplan
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
)
|
||||
|
||||
// runMergeNode 负责“合并日内结果 + 冲突校验 + 回退快照”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责把 DailyResults 合并回全量 HybridEntries;
|
||||
// 2. 负责执行时间冲突检测;
|
||||
// 3. 负责在冲突时回退原始数据;
|
||||
// 4. 负责产出 MergeSnapshot,供 final_check 失败时回退。
|
||||
func runMergeNode(
|
||||
ctx context.Context,
|
||||
st *SchedulePlanState,
|
||||
emitStage func(stage, detail string),
|
||||
) (*SchedulePlanState, error) {
|
||||
_ = ctx
|
||||
if st == nil || len(st.DailyResults) == 0 {
|
||||
return st, nil
|
||||
}
|
||||
|
||||
emitStage("schedule_plan.merge.start", "正在合并日内优化结果。")
|
||||
|
||||
// 1. 先保存 merge 前原始数据,作为冲突时的第一层回退兜底。
|
||||
originalEntries := deepCopyEntries(st.HybridEntries)
|
||||
|
||||
// 2. 展平 daily results。
|
||||
merged := make([]model.HybridScheduleEntry, 0)
|
||||
for _, dayMap := range st.DailyResults {
|
||||
for _, dayEntries := range dayMap {
|
||||
merged = append(merged, dayEntries...)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 冲突校验。
|
||||
//
|
||||
// 3.1 判断依据:同一 (week, day, section) 只能有一个条目占用;
|
||||
// 3.2 失败处理:一旦冲突,整批回退到 merge 前原始结果;
|
||||
// 3.3 回退策略:回退后仍继续链路,避免请求直接失败。
|
||||
if conflict := detectConflicts(merged); conflict != "" {
|
||||
st.HybridEntries = originalEntries
|
||||
emitStage("schedule_plan.merge.conflict", fmt.Sprintf("检测到冲突并回退:%s", conflict))
|
||||
} else {
|
||||
st.HybridEntries = merged
|
||||
emitStage("schedule_plan.merge.done", fmt.Sprintf("合并完成,共 %d 个条目。", len(merged)))
|
||||
}
|
||||
|
||||
// 4. 无论是否冲突,都生成“可回退快照”。
|
||||
st.MergeSnapshot = deepCopyEntries(st.HybridEntries)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// detectConflicts 检测条目是否存在时间冲突。
|
||||
//
|
||||
// 返回语义:
|
||||
// 1. 返回空字符串:无冲突;
|
||||
// 2. 返回非空字符串:冲突描述,可直接用于日志/阶段提示。
|
||||
func detectConflicts(entries []model.HybridScheduleEntry) string {
|
||||
type slotKey struct {
|
||||
week, day, section int
|
||||
}
|
||||
occupied := make(map[slotKey]string)
|
||||
for _, entry := range entries {
|
||||
// 1. 仅“阻塞建议任务”的条目参与冲突校验。
|
||||
// 2. 可嵌入且当前未占用的课程槽位不应被判定为冲突。
|
||||
if !entryBlocksSuggested(entry) {
|
||||
continue
|
||||
}
|
||||
for section := entry.SectionFrom; section <= entry.SectionTo; section++ {
|
||||
key := slotKey{week: entry.Week, day: entry.DayOfWeek, section: section}
|
||||
if prevName, exists := occupied[key]; exists {
|
||||
return fmt.Sprintf(
|
||||
"W%dD%d 第%d节 冲突:[%s] 与 [%s]",
|
||||
entry.Week, entry.DayOfWeek, section, prevName, entry.Name,
|
||||
)
|
||||
}
|
||||
occupied[key] = entry.Name
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,767 +0,0 @@
|
||||
package scheduleplan
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
einoModel "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
// schedulePlanIntentOutput 是 plan 节点要求模型返回的结构化结果。
|
||||
//
|
||||
// 兼容说明:
|
||||
// 1. 新主语义是 task_class_ids(数组);
|
||||
// 2. 为兼容旧 prompt/旧缓存输出,保留 task_class_id(单值)兜底解析;
|
||||
// 3. TaskTags 的 key 兼容两种写法:
|
||||
// 3.1 推荐:task_item_id(例如 "12");
|
||||
// 3.2 兼容:任务名称(例如 "高数复习")。
|
||||
type schedulePlanIntentOutput struct {
|
||||
Intent string `json:"intent"`
|
||||
Constraints []string `json:"constraints"`
|
||||
TaskClassIDs []int `json:"task_class_ids"`
|
||||
TaskClassID int `json:"task_class_id"`
|
||||
Strategy string `json:"strategy"`
|
||||
TaskTags map[string]string `json:"task_tags"`
|
||||
Restart bool `json:"restart"`
|
||||
AdjustmentScope string `json:"adjustment_scope"`
|
||||
Reason string `json:"reason"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
}
|
||||
|
||||
// runPlanNode 负责“识别排程意图 + 提取约束 + 收敛任务类 ID”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责把用户自然语言和 extra 参数收敛为统一状态;
|
||||
// 2. 负责输出后续节点需要的最小上下文(TaskClassIDs/约束/策略/标签);
|
||||
// 3. 不负责调用粗排算法,不负责写库。
|
||||
func runPlanNode(
|
||||
ctx context.Context,
|
||||
st *SchedulePlanState,
|
||||
chatModel *ark.ChatModel,
|
||||
userMessage string,
|
||||
extra map[string]any,
|
||||
chatHistory []*schema.Message,
|
||||
emitStage func(stage, detail string),
|
||||
) (*SchedulePlanState, error) {
|
||||
if st == nil {
|
||||
return nil, errors.New("schedule plan graph: nil state in plan node")
|
||||
}
|
||||
st.RestartRequested = false
|
||||
st.AdjustmentReason = ""
|
||||
st.AdjustmentConfidence = 0
|
||||
st.AdjustmentScope = schedulePlanAdjustmentScopeLarge
|
||||
|
||||
emitStage("schedule_plan.plan.analyzing", "正在分析你的排程需求。")
|
||||
|
||||
// 1. 先收敛 extra 中显式传入的任务类 ID(优先级高于模型推断)。
|
||||
// 1.1 先读 task_class_ids 数组;
|
||||
// 1.2 再兼容读取单值 task_class_id;
|
||||
// 1.3 最后统一做过滤 + 去重,防止非法值或重复值污染状态机。
|
||||
if extra != nil {
|
||||
mergedIDs := make([]int, 0, len(st.TaskClassIDs)+2)
|
||||
mergedIDs = append(mergedIDs, st.TaskClassIDs...)
|
||||
if tcIDs, ok := ExtraIntSlice(extra, "task_class_ids"); ok {
|
||||
mergedIDs = append(mergedIDs, tcIDs...)
|
||||
}
|
||||
if tcID, ok := ExtraInt(extra, "task_class_id"); ok && tcID > 0 {
|
||||
mergedIDs = append(mergedIDs, tcID)
|
||||
}
|
||||
st.TaskClassIDs = normalizeTaskClassIDs(mergedIDs)
|
||||
}
|
||||
// 1.4 若本轮请求没带 task_class_ids,但会话里存在上一次排程快照,则用快照中的任务类兜底。
|
||||
// 1.4.1 这样用户可以直接说“把周三晚上的高数挪到周五”,无需每轮都重复传任务类集合;
|
||||
// 1.4.2 失败兜底:若快照也没有任务类,后续按原逻辑处理(可能提前退出并提示补参)。
|
||||
if len(st.TaskClassIDs) == 0 && len(st.PreviousTaskClassIDs) > 0 {
|
||||
st.TaskClassIDs = normalizeTaskClassIDs(append([]int(nil), st.PreviousTaskClassIDs...))
|
||||
}
|
||||
|
||||
// 2. 识别“是否为连续对话微调”场景。
|
||||
// 2.1 只做历史探测,不做历史改写;
|
||||
// 2.2 探测失败不影响主链路,只是少一个 prompt hint。
|
||||
if st.HasPreviousPreview && len(st.PreviousHybridEntries) > 0 {
|
||||
st.IsAdjustment = true
|
||||
st.AdjustmentScope = schedulePlanAdjustmentScopeMedium
|
||||
}
|
||||
previousPlan := extractPreviousPlanFromHistory(chatHistory)
|
||||
if previousPlan != "" {
|
||||
st.PreviousPlanJSON = previousPlan
|
||||
st.IsAdjustment = true
|
||||
st.AdjustmentScope = schedulePlanAdjustmentScopeMedium
|
||||
}
|
||||
|
||||
// 3. 组装模型提示词。
|
||||
adjustmentHint := ""
|
||||
if st.IsAdjustment {
|
||||
adjustmentHint = "\n注意:这是对已有排程的微调请求,请重点抽取本次新增或变更的约束。"
|
||||
}
|
||||
prompt := fmt.Sprintf(
|
||||
"当前时间(北京时间):%s\n用户输入:%s%s\n\n请提取排程意图与约束。",
|
||||
st.RequestNowText,
|
||||
strings.TrimSpace(userMessage),
|
||||
adjustmentHint,
|
||||
)
|
||||
|
||||
// 4. 调模型拿结构化输出。
|
||||
// 4.1 如果失败但已经有 TaskClassIDs,则降级继续;
|
||||
// 4.2 如果失败且没有任务类 ID,直接给出可执行错误提示。
|
||||
raw, callErr := callScheduleModelForJSON(ctx, chatModel, SchedulePlanIntentPrompt, prompt, 256)
|
||||
if callErr != nil {
|
||||
if len(st.TaskClassIDs) > 0 {
|
||||
st.UserIntent = strings.TrimSpace(userMessage)
|
||||
emitStage("schedule_plan.plan.fallback", "意图识别失败,已使用请求参数兜底继续。")
|
||||
return st, nil
|
||||
}
|
||||
st.FinalSummary = "抱歉,我没拿到有效的任务类信息。请在请求中传入 task_class_ids。"
|
||||
return st, nil
|
||||
}
|
||||
|
||||
parsed, parseErr := parseScheduleJSON[schedulePlanIntentOutput](raw)
|
||||
if parseErr != nil {
|
||||
if len(st.TaskClassIDs) > 0 {
|
||||
st.UserIntent = strings.TrimSpace(userMessage)
|
||||
emitStage("schedule_plan.plan.fallback", "模型返回解析失败,已使用请求参数兜底继续。")
|
||||
return st, nil
|
||||
}
|
||||
st.FinalSummary = "抱歉,我没能解析排程意图。请重试,或直接传入 task_class_ids。"
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 5. 回填基础字段。
|
||||
st.UserIntent = strings.TrimSpace(parsed.Intent)
|
||||
if st.UserIntent == "" {
|
||||
st.UserIntent = strings.TrimSpace(userMessage)
|
||||
}
|
||||
if len(parsed.Constraints) > 0 {
|
||||
st.Constraints = parsed.Constraints
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(parsed.Strategy), "rapid") {
|
||||
st.Strategy = "rapid"
|
||||
}
|
||||
st.RestartRequested = parsed.Restart
|
||||
st.AdjustmentScope = normalizeAdjustmentScope(parsed.AdjustmentScope)
|
||||
st.AdjustmentReason = strings.TrimSpace(parsed.Reason)
|
||||
st.AdjustmentConfidence = clampAdjustmentConfidence(parsed.Confidence)
|
||||
|
||||
// 5.1 分级语义兜底:
|
||||
// 5.1.1 非微调请求不走 small/medium,强制按 large 进入完整排程;
|
||||
// 5.1.2 微调请求默认至少走 medium,避免 scope 缺失时误判;
|
||||
// 5.1.3 restart=true 时强制重排并清空历史快照承接。
|
||||
if !st.IsAdjustment {
|
||||
st.AdjustmentScope = schedulePlanAdjustmentScopeLarge
|
||||
} else if st.AdjustmentScope == "" {
|
||||
st.AdjustmentScope = schedulePlanAdjustmentScopeMedium
|
||||
}
|
||||
if st.RestartRequested {
|
||||
st.IsAdjustment = false
|
||||
st.AdjustmentScope = schedulePlanAdjustmentScopeLarge
|
||||
st.clearPreviousPreviewContext()
|
||||
}
|
||||
|
||||
// 6. 合并任务类 ID(新字段 + 旧字段双兼容)。
|
||||
// 6.1 先拼接已有值与模型输出;
|
||||
// 6.2 再统一清洗,保证后续节点使用稳定语义。
|
||||
mergedIDs := make([]int, 0, len(st.TaskClassIDs)+len(parsed.TaskClassIDs)+1)
|
||||
mergedIDs = append(mergedIDs, st.TaskClassIDs...)
|
||||
mergedIDs = append(mergedIDs, parsed.TaskClassIDs...)
|
||||
if parsed.TaskClassID > 0 {
|
||||
mergedIDs = append(mergedIDs, parsed.TaskClassID)
|
||||
}
|
||||
st.TaskClassIDs = normalizeTaskClassIDs(mergedIDs)
|
||||
|
||||
// 7. 回填任务标签映射(给 daily_split 注入 context_tag 用)。
|
||||
// 7.1 TaskTags(按 task_item_id)优先;
|
||||
// 7.2 无法转成 ID 的 key 先存到 TaskTagHintsByName,等 roughBuild 阶段再映射;
|
||||
// 7.3 单条标签解析失败不影响主流程。
|
||||
if st.TaskTags == nil {
|
||||
st.TaskTags = make(map[int]string)
|
||||
}
|
||||
if st.TaskTagHintsByName == nil {
|
||||
st.TaskTagHintsByName = make(map[string]string)
|
||||
}
|
||||
for rawKey, rawTag := range parsed.TaskTags {
|
||||
tag := normalizeContextTag(rawTag)
|
||||
key := strings.TrimSpace(rawKey)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if id, convErr := strconv.Atoi(key); convErr == nil && id > 0 {
|
||||
st.TaskTags[id] = tag
|
||||
continue
|
||||
}
|
||||
st.TaskTagHintsByName[key] = tag
|
||||
}
|
||||
|
||||
emitStage(
|
||||
"schedule_plan.plan.done",
|
||||
fmt.Sprintf(
|
||||
"已识别排程意图,任务类数量=%d,微调=%t,力度=%s,重排=%t。",
|
||||
len(st.TaskClassIDs),
|
||||
st.IsAdjustment,
|
||||
st.AdjustmentScope,
|
||||
st.RestartRequested,
|
||||
),
|
||||
)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// selectNextAfterPlan 根据 plan 节点结果决定下一步。
|
||||
//
|
||||
// 分支规则:
|
||||
// 1. 如果 FinalSummary 已经有内容,说明已确定要提前退出 -> exit;
|
||||
// 2. 如果任务类为空,说明无法继续构建方案 -> exit;
|
||||
// 3. 其余情况 -> roughBuild。
|
||||
func selectNextAfterPlan(st *SchedulePlanState) string {
|
||||
if st == nil {
|
||||
return schedulePlanGraphNodeExit
|
||||
}
|
||||
if strings.TrimSpace(st.FinalSummary) != "" {
|
||||
return schedulePlanGraphNodeExit
|
||||
}
|
||||
if len(st.TaskClassIDs) == 0 {
|
||||
return schedulePlanGraphNodeExit
|
||||
}
|
||||
return schedulePlanGraphNodeRoughBuild
|
||||
}
|
||||
|
||||
// runRoughBuildNode 负责“一次性完成粗排结果构建”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 调用多任务类混排能力,生成 HybridEntries + AllocatedItems;
|
||||
// 2. 把 HybridEntries 转成 CandidatePlans,便于后续预览输出;
|
||||
// 3. 不做 daily/weekly 优化本身,只提供下游输入。
|
||||
func runRoughBuildNode(
|
||||
ctx context.Context,
|
||||
st *SchedulePlanState,
|
||||
deps SchedulePlanToolDeps,
|
||||
emitStage func(stage, detail string),
|
||||
) (*SchedulePlanState, error) {
|
||||
if st == nil {
|
||||
return nil, errors.New("schedule plan graph: nil state in roughBuild node")
|
||||
}
|
||||
if deps.HybridScheduleWithPlanMulti == nil {
|
||||
return nil, errors.New("schedule plan graph: HybridScheduleWithPlanMulti dependency not injected")
|
||||
}
|
||||
|
||||
// 1. 清洗并校验任务类 ID。
|
||||
// 1.1 统一在节点入口做一次最终收敛,避免上游遗漏导致语义漂移;
|
||||
// 1.2 若最终仍为空,直接结束,避免无意义调用下游服务。
|
||||
taskClassIDs := normalizeTaskClassIDs(st.TaskClassIDs)
|
||||
// 1.3 连续对话兜底:若本轮任务类为空且命中历史快照,则回退到上轮任务类集合。
|
||||
if len(taskClassIDs) == 0 && st.IsAdjustment && len(st.PreviousTaskClassIDs) > 0 {
|
||||
taskClassIDs = normalizeTaskClassIDs(append([]int(nil), st.PreviousTaskClassIDs...))
|
||||
}
|
||||
if len(taskClassIDs) == 0 {
|
||||
st.FinalSummary = "缺少有效的任务类 ID,无法生成排程方案。请传入 task_class_ids。"
|
||||
return st, nil
|
||||
}
|
||||
st.TaskClassIDs = taskClassIDs
|
||||
|
||||
// 2. 连续对话微调优先复用上一版混合日程作为起点,避免“每轮都重新粗排”。
|
||||
// 2.1 触发条件:IsAdjustment=true 且 PreviousHybridEntries 非空;
|
||||
// 2.2 失败兜底:若快照不完整(例如 AllocatedItems 为空),会构造最小占位任务块,保持下游校验可运行;
|
||||
// 2.3 回退策略:若没有可复用快照,再走全量粗排构建路径。
|
||||
canReusePreviousPlan := st.IsAdjustment &&
|
||||
!st.RestartRequested &&
|
||||
len(st.PreviousHybridEntries) > 0 &&
|
||||
sameTaskClassSet(taskClassIDs, st.PreviousTaskClassIDs)
|
||||
if canReusePreviousPlan {
|
||||
emitStage("schedule_plan.rough_build.reuse_previous", "检测到连续对话微调,复用上一版排程作为优化起点。")
|
||||
st.HybridEntries = deepCopyEntries(st.PreviousHybridEntries)
|
||||
st.CandidatePlans = deepCopyWeekSchedules(st.PreviousCandidatePlans)
|
||||
if len(st.CandidatePlans) == 0 {
|
||||
st.CandidatePlans = hybridEntriesToWeekSchedules(st.HybridEntries)
|
||||
}
|
||||
st.AllocatedItems = deepCopyTaskClassItems(st.PreviousAllocatedItems)
|
||||
if len(st.AllocatedItems) == 0 {
|
||||
st.AllocatedItems = buildAllocatedItemsFromHybridEntries(st.HybridEntries)
|
||||
}
|
||||
|
||||
// 2.2 复用模式下同样尝试解析窗口边界,保证周级 Move 约束仍然有效。
|
||||
if deps.ResolvePlanningWindow != nil {
|
||||
startWeek, startDay, endWeek, endDay, windowErr := deps.ResolvePlanningWindow(ctx, st.UserID, taskClassIDs)
|
||||
if windowErr != nil {
|
||||
st.FinalSummary = fmt.Sprintf("解析排程窗口失败:%s。", windowErr.Error())
|
||||
return st, nil
|
||||
}
|
||||
st.HasPlanningWindow = true
|
||||
st.PlanStartWeek = startWeek
|
||||
st.PlanStartDay = startDay
|
||||
st.PlanEndWeek = endWeek
|
||||
st.PlanEndDay = endDay
|
||||
}
|
||||
|
||||
st.MergeSnapshot = deepCopyEntries(st.HybridEntries)
|
||||
suggestedCount := 0
|
||||
for _, e := range st.HybridEntries {
|
||||
if e.Status == "suggested" {
|
||||
suggestedCount++
|
||||
}
|
||||
}
|
||||
emitStage(
|
||||
"schedule_plan.rough_build.done",
|
||||
fmt.Sprintf("已复用历史方案,条目总数=%d,可优化条目=%d。", len(st.HybridEntries), suggestedCount),
|
||||
)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
emitStage("schedule_plan.rough_build.building", "正在构建粗排候选方案。")
|
||||
|
||||
// 3. 调用服务层统一能力构建混合日程。
|
||||
// 3.1 该能力内部会完成“多任务类粗排 + 既有日程合并”;
|
||||
// 3.2 这里不再拆成 preview/hybrid 两段,避免跨节点重复计算。
|
||||
entries, allocatedItems, err := deps.HybridScheduleWithPlanMulti(ctx, st.UserID, taskClassIDs)
|
||||
if err != nil {
|
||||
st.FinalSummary = fmt.Sprintf("构建粗排方案失败:%s。", err.Error())
|
||||
return st, nil
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
st.FinalSummary = "没有生成可优化的排程条目,请检查任务类时间范围或课表占用。"
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 4. 回填状态。
|
||||
st.HybridEntries = entries
|
||||
st.AllocatedItems = allocatedItems
|
||||
st.CandidatePlans = hybridEntriesToWeekSchedules(entries)
|
||||
|
||||
// 4.1 解析全局排程窗口(可选依赖)。
|
||||
// 4.1.1 目的:给周级 Move 增加“首尾不足一周”的硬边界校验;
|
||||
// 4.1.2 失败策略:若依赖已注入但解析失败,直接结束本次排程,避免带着错误窗口继续优化。
|
||||
if deps.ResolvePlanningWindow != nil {
|
||||
startWeek, startDay, endWeek, endDay, windowErr := deps.ResolvePlanningWindow(ctx, st.UserID, taskClassIDs)
|
||||
if windowErr != nil {
|
||||
st.FinalSummary = fmt.Sprintf("解析排程窗口失败:%s。", windowErr.Error())
|
||||
return st, nil
|
||||
}
|
||||
st.HasPlanningWindow = true
|
||||
st.PlanStartWeek = startWeek
|
||||
st.PlanStartDay = startDay
|
||||
st.PlanEndWeek = endWeek
|
||||
st.PlanEndDay = endDay
|
||||
}
|
||||
|
||||
// 4.2 记录 merge 快照:
|
||||
// 4.2.1 单任务类路径可直接作为 final_check 回退基线;
|
||||
// 4.2.2 多任务类路径后续 merge 节点会覆盖成“日内优化后快照”。
|
||||
st.MergeSnapshot = deepCopyEntries(entries)
|
||||
|
||||
// 5. 把“按名称提示的标签”尽可能映射到 task_item_id。
|
||||
// 5.1 目的:后续 daily_split 统一按 task_item_id 维度写入 context_tag;
|
||||
// 5.2 失败策略:映射不上不报错,后续默认走 General 标签。
|
||||
if st.TaskTags == nil {
|
||||
st.TaskTags = make(map[int]string)
|
||||
}
|
||||
if len(st.TaskTagHintsByName) > 0 {
|
||||
for i := range st.HybridEntries {
|
||||
entry := &st.HybridEntries[i]
|
||||
if entry.Status != "suggested" || entry.TaskItemID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, exists := st.TaskTags[entry.TaskItemID]; exists {
|
||||
continue
|
||||
}
|
||||
if tag, ok := st.TaskTagHintsByName[entry.Name]; ok {
|
||||
st.TaskTags[entry.TaskItemID] = normalizeContextTag(tag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
suggestedCount := 0
|
||||
for _, e := range entries {
|
||||
if e.Status == "suggested" {
|
||||
suggestedCount++
|
||||
}
|
||||
}
|
||||
emitStage(
|
||||
"schedule_plan.rough_build.done",
|
||||
fmt.Sprintf("粗排构建完成,条目总数=%d,可优化条目=%d。", len(entries), suggestedCount),
|
||||
)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// callScheduleModelForJSON 调用模型并要求返回 JSON。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 仅负责模型调用参数装配,不做业务字段解释;
|
||||
// 2. 统一关闭 thinking,减少路由/抽取场景的延迟和 token 开销。
|
||||
func callScheduleModelForJSON(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, maxTokens int) (string, error) {
|
||||
if chatModel == nil {
|
||||
return "", errors.New("schedule plan: model is nil")
|
||||
}
|
||||
|
||||
messages := []*schema.Message{
|
||||
schema.SystemMessage(systemPrompt),
|
||||
schema.UserMessage(userPrompt),
|
||||
}
|
||||
opts := []einoModel.Option{
|
||||
ark.WithThinking(&arkModel.Thinking{Type: arkModel.ThinkingTypeDisabled}),
|
||||
einoModel.WithTemperature(0),
|
||||
}
|
||||
if maxTokens > 0 {
|
||||
opts = append(opts, einoModel.WithMaxTokens(maxTokens))
|
||||
}
|
||||
|
||||
resp, err := chatModel.Generate(ctx, messages, opts...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp == nil {
|
||||
return "", errors.New("模型返回为空")
|
||||
}
|
||||
content := strings.TrimSpace(resp.Content)
|
||||
if content == "" {
|
||||
return "", errors.New("模型返回内容为空")
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// parseScheduleJSON 解析模型返回的 JSON 内容。
|
||||
//
|
||||
// 兼容策略:
|
||||
// 1. 兼容 ```json ... ``` 包裹;
|
||||
// 2. 兼容模型在 JSON 前后带解释文本(提取最外层对象)。
|
||||
func parseScheduleJSON[T any](raw string) (*T, error) {
|
||||
clean := strings.TrimSpace(raw)
|
||||
if clean == "" {
|
||||
return nil, errors.New("empty response")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(clean, "```") {
|
||||
clean = strings.TrimPrefix(clean, "```json")
|
||||
clean = strings.TrimPrefix(clean, "```")
|
||||
clean = strings.TrimSuffix(clean, "```")
|
||||
clean = strings.TrimSpace(clean)
|
||||
}
|
||||
|
||||
var out T
|
||||
if err := json.Unmarshal([]byte(clean), &out); err == nil {
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
start := strings.Index(clean, "{")
|
||||
end := strings.LastIndex(clean, "}")
|
||||
if start == -1 || end == -1 || end <= start {
|
||||
return nil, fmt.Errorf("no json object found in: %s", clean)
|
||||
}
|
||||
obj := clean[start : end+1]
|
||||
if err := json.Unmarshal([]byte(obj), &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
// extractPreviousPlanFromHistory 从对话历史中提取最近一次排程结果文本。
|
||||
func extractPreviousPlanFromHistory(history []*schema.Message) string {
|
||||
if len(history) == 0 {
|
||||
return ""
|
||||
}
|
||||
for i := len(history) - 1; i >= 0; i-- {
|
||||
msg := history[i]
|
||||
if msg == nil || msg.Role != schema.Assistant {
|
||||
continue
|
||||
}
|
||||
content := strings.TrimSpace(msg.Content)
|
||||
if strings.Contains(content, "排程完成") || strings.Contains(content, "已成功安排") {
|
||||
return content
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// runReturnPreviewNode 负责把优化后的 HybridEntries 转成“前端可直接展示”的预览结构。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 把 suggested 结果回填到 AllocatedItems,便于后续确认后直接落库;
|
||||
// 2. 生成 CandidatePlans;
|
||||
// 3. 生成最终文案;
|
||||
// 4. 不执行实际写库。
|
||||
func runReturnPreviewNode(
|
||||
ctx context.Context,
|
||||
st *SchedulePlanState,
|
||||
emitStage func(stage, detail string),
|
||||
) (*SchedulePlanState, error) {
|
||||
_ = ctx
|
||||
if st == nil {
|
||||
return nil, errors.New("schedule plan graph: nil state in returnPreview node")
|
||||
}
|
||||
|
||||
emitStage("schedule_plan.preview_return.building", "正在生成优化后的排程预览。")
|
||||
|
||||
// 1. 把 HybridEntries 中 suggested 的最终位置回填到 AllocatedItems。
|
||||
suggestedMap := make(map[int]*model.HybridScheduleEntry)
|
||||
for i := range st.HybridEntries {
|
||||
e := &st.HybridEntries[i]
|
||||
if e.Status == "suggested" && e.TaskItemID > 0 {
|
||||
suggestedMap[e.TaskItemID] = e
|
||||
}
|
||||
}
|
||||
for i := range st.AllocatedItems {
|
||||
item := &st.AllocatedItems[i]
|
||||
if entry, ok := suggestedMap[item.ID]; ok && item.EmbeddedTime != nil {
|
||||
item.EmbeddedTime.Week = entry.Week
|
||||
item.EmbeddedTime.DayOfWeek = entry.DayOfWeek
|
||||
item.EmbeddedTime.SectionFrom = entry.SectionFrom
|
||||
item.EmbeddedTime.SectionTo = entry.SectionTo
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 生成前端预览结构。
|
||||
st.CandidatePlans = hybridEntriesToWeekSchedules(st.HybridEntries)
|
||||
|
||||
// 3. 生成最终摘要:
|
||||
// 3.1 优先保留 final_check 的输出;
|
||||
// 3.2 若没有 final_check 输出,则回退 weekly refine 摘要;
|
||||
// 3.3 都没有时给兜底文案。
|
||||
if strings.TrimSpace(st.FinalSummary) == "" {
|
||||
if strings.TrimSpace(st.ReactSummary) != "" {
|
||||
st.FinalSummary = st.ReactSummary
|
||||
} else {
|
||||
st.FinalSummary = fmt.Sprintf("排程优化完成,共 %d 个任务已安排,请确认后应用。", len(suggestedMap))
|
||||
}
|
||||
}
|
||||
st.Completed = true
|
||||
|
||||
emitStage("schedule_plan.preview_return.done", "排程预览已生成,等待你确认。")
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// buildAllocatedItemsFromHybridEntries 根据 suggested 条目构造最小可用的任务块快照。
|
||||
//
|
||||
// 设计目的:
|
||||
// 1. 连续微调复用历史方案时,若缓存里没有 AllocatedItems,仍然保证 final_check 的数量核对可运行;
|
||||
// 2. return_preview 仍可依据 TaskItemID 回填最终 embedded_time;
|
||||
// 3. 该函数只做“兜底构造”,不替代真实粗排输出。
|
||||
func buildAllocatedItemsFromHybridEntries(entries []model.HybridScheduleEntry) []model.TaskClassItem {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
items := make([]model.TaskClassItem, 0)
|
||||
for _, entry := range entries {
|
||||
if entry.Status != "suggested" {
|
||||
continue
|
||||
}
|
||||
embedded := &model.TargetTime{
|
||||
Week: entry.Week,
|
||||
DayOfWeek: entry.DayOfWeek,
|
||||
SectionFrom: entry.SectionFrom,
|
||||
SectionTo: entry.SectionTo,
|
||||
}
|
||||
taskID := entry.TaskItemID
|
||||
items = append(items, model.TaskClassItem{
|
||||
ID: taskID,
|
||||
EmbeddedTime: embedded,
|
||||
})
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
// deepCopyTaskClassItems 深拷贝任务块切片(包含指针字段),避免跨节点共享引用。
|
||||
func deepCopyTaskClassItems(src []model.TaskClassItem) []model.TaskClassItem {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
dst := make([]model.TaskClassItem, 0, len(src))
|
||||
for _, item := range src {
|
||||
copied := item
|
||||
if item.CategoryID != nil {
|
||||
v := *item.CategoryID
|
||||
copied.CategoryID = &v
|
||||
}
|
||||
if item.Order != nil {
|
||||
v := *item.Order
|
||||
copied.Order = &v
|
||||
}
|
||||
if item.Content != nil {
|
||||
v := *item.Content
|
||||
copied.Content = &v
|
||||
}
|
||||
if item.Status != nil {
|
||||
v := *item.Status
|
||||
copied.Status = &v
|
||||
}
|
||||
if item.EmbeddedTime != nil {
|
||||
t := *item.EmbeddedTime
|
||||
copied.EmbeddedTime = &t
|
||||
}
|
||||
dst = append(dst, copied)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// normalizeContextTag 归一化任务标签。
|
||||
//
|
||||
// 失败兜底:
|
||||
// 1. 未识别/空值统一回落到 General;
|
||||
// 2. 保证后续 prompt 构造不会出现空标签。
|
||||
func normalizeContextTag(raw string) string {
|
||||
tag := strings.TrimSpace(raw)
|
||||
if tag == "" {
|
||||
return "General"
|
||||
}
|
||||
switch strings.ToLower(tag) {
|
||||
case "high-logic", "high_logic", "logic":
|
||||
return "High-Logic"
|
||||
case "memory":
|
||||
return "Memory"
|
||||
case "review":
|
||||
return "Review"
|
||||
case "general":
|
||||
return "General"
|
||||
default:
|
||||
return "General"
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeTaskClassIDs 清洗 task_class_ids(去重 + 过滤非法值)。
|
||||
func normalizeTaskClassIDs(ids []int) []int {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[int]struct{}, len(ids))
|
||||
out := make([]int, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[id]; exists {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// clearPreviousPreviewContext 清空会话承接快照字段。
|
||||
//
|
||||
// 触发场景:
|
||||
// 1. 用户明确要求 restart(重新排);
|
||||
// 2. 需要强制断开“沿用历史方案”的路径,避免脏状态渗透到新方案。
|
||||
func (st *SchedulePlanState) clearPreviousPreviewContext() {
|
||||
if st == nil {
|
||||
return
|
||||
}
|
||||
st.HasPreviousPreview = false
|
||||
st.PreviousTaskClassIDs = nil
|
||||
st.PreviousHybridEntries = nil
|
||||
st.PreviousAllocatedItems = nil
|
||||
st.PreviousCandidatePlans = nil
|
||||
st.PreviousPlanJSON = ""
|
||||
}
|
||||
|
||||
// clampAdjustmentConfidence 约束置信度字段到 [0,1]。
|
||||
func clampAdjustmentConfidence(v float64) float64 {
|
||||
if v < 0 {
|
||||
return 0
|
||||
}
|
||||
if v > 1 {
|
||||
return 1
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// deepCopyWeekSchedules 深拷贝周视图方案切片,避免跨节点共享引用。
|
||||
func deepCopyWeekSchedules(src []model.UserWeekSchedule) []model.UserWeekSchedule {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
dst := make([]model.UserWeekSchedule, 0, len(src))
|
||||
for _, week := range src {
|
||||
eventsCopy := make([]model.WeeklyEventBrief, len(week.Events))
|
||||
copy(eventsCopy, week.Events)
|
||||
dst = append(dst, model.UserWeekSchedule{
|
||||
Week: week.Week,
|
||||
Events: eventsCopy,
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// sameTaskClassSet 判断两组 task_class_ids 是否表示同一集合(忽略顺序,忽略重复)。
|
||||
//
|
||||
// 语义:
|
||||
// 1. 两边经清洗后都为空,返回 false(空集合不作为“可复用历史方案”的依据);
|
||||
// 2. 元素集合完全一致返回 true;
|
||||
// 3. 任一元素差异返回 false。
|
||||
func sameTaskClassSet(left []int, right []int) bool {
|
||||
l := normalizeTaskClassIDs(left)
|
||||
r := normalizeTaskClassIDs(right)
|
||||
if len(l) == 0 || len(r) == 0 {
|
||||
return false
|
||||
}
|
||||
if len(l) != len(r) {
|
||||
return false
|
||||
}
|
||||
seen := make(map[int]struct{}, len(l))
|
||||
for _, id := range l {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for _, id := range r {
|
||||
if _, ok := seen[id]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// hybridEntriesToWeekSchedules 把内存中的混合条目转换成前端周视图格式。
|
||||
func hybridEntriesToWeekSchedules(entries []model.HybridScheduleEntry) []model.UserWeekSchedule {
|
||||
sectionTimeMap := map[int][2]string{
|
||||
1: {"08:00", "08:45"}, 2: {"08:55", "09:40"},
|
||||
3: {"10:15", "11:00"}, 4: {"11:10", "11:55"},
|
||||
5: {"14:00", "14:45"}, 6: {"14:55", "15:40"},
|
||||
7: {"16:15", "17:00"}, 8: {"17:10", "17:55"},
|
||||
9: {"19:00", "19:45"}, 10: {"19:55", "20:40"},
|
||||
11: {"20:50", "21:35"}, 12: {"21:45", "22:30"},
|
||||
}
|
||||
|
||||
weekMap := make(map[int][]model.WeeklyEventBrief)
|
||||
for _, e := range entries {
|
||||
startTime := ""
|
||||
endTime := ""
|
||||
if t, ok := sectionTimeMap[e.SectionFrom]; ok {
|
||||
startTime = t[0]
|
||||
}
|
||||
if t, ok := sectionTimeMap[e.SectionTo]; ok {
|
||||
endTime = t[1]
|
||||
}
|
||||
|
||||
brief := model.WeeklyEventBrief{
|
||||
DayOfWeek: e.DayOfWeek,
|
||||
Name: e.Name,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
Type: e.Type,
|
||||
Span: e.SectionTo - e.SectionFrom + 1,
|
||||
Status: e.Status,
|
||||
}
|
||||
if e.EventID > 0 {
|
||||
brief.ID = e.EventID
|
||||
}
|
||||
weekMap[e.Week] = append(weekMap[e.Week], brief)
|
||||
}
|
||||
|
||||
result := make([]model.UserWeekSchedule, 0, len(weekMap))
|
||||
for week, events := range weekMap {
|
||||
result = append(result, model.UserWeekSchedule{
|
||||
Week: week,
|
||||
Events: events,
|
||||
})
|
||||
}
|
||||
for i := 0; i < len(result); i++ {
|
||||
for j := i + 1; j < len(result); j++ {
|
||||
if result[j].Week < result[i].Week {
|
||||
result[i], result[j] = result[j], result[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
package scheduleplan
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// runQuickRefineNode 是 small 微调分支的“轻量预算收缩节点”。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责在进入 weekly_refine 前收缩预算与并发,避免小改动走重链路;
|
||||
// 2. 负责保留“可回退”的最低预算,避免直接压成 0 导致无动作可执行;
|
||||
// 3. 不负责执行任何 Move/Swap(真正动作仍由 weekly_refine 完成)。
|
||||
func runQuickRefineNode(
|
||||
ctx context.Context,
|
||||
st *SchedulePlanState,
|
||||
emitStage func(stage, detail string),
|
||||
) (*SchedulePlanState, error) {
|
||||
_ = ctx
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("schedule plan quick refine: nil state")
|
||||
}
|
||||
|
||||
emitStage("schedule_plan.quick_refine.start", "检测到小幅微调,正在切换到快速优化路径。")
|
||||
|
||||
// 1. 预算收缩策略:
|
||||
// 1.1 small 场景目标是“快速响应 + 可解释改动”,不追求大规模重排;
|
||||
// 1.2 因此把总预算压到最多 2 次尝试、有效预算压到最多 1 次成功动作;
|
||||
// 1.3 如果上游已配置更小预算,则尊重更小值,不做反向放大。
|
||||
if st.WeeklyTotalBudget <= 0 {
|
||||
st.WeeklyTotalBudget = schedulePlanDefaultWeeklyTotalBudget
|
||||
}
|
||||
if st.WeeklyAdjustBudget <= 0 {
|
||||
st.WeeklyAdjustBudget = schedulePlanDefaultWeeklyAdjustBudget
|
||||
}
|
||||
st.WeeklyTotalBudget = clampBudgetUpper(st.WeeklyTotalBudget, 2)
|
||||
st.WeeklyAdjustBudget = clampBudgetUpper(st.WeeklyAdjustBudget, 1)
|
||||
|
||||
// 2. 预算一致性兜底:
|
||||
// 2.1 总预算至少为 1(否则 weekly worker 无法执行);
|
||||
// 2.2 有效预算至少为 1(否则所有成功动作都不被允许);
|
||||
// 2.3 有效预算永远不能超过总预算。
|
||||
if st.WeeklyTotalBudget < 1 {
|
||||
st.WeeklyTotalBudget = 1
|
||||
}
|
||||
if st.WeeklyAdjustBudget < 1 {
|
||||
st.WeeklyAdjustBudget = 1
|
||||
}
|
||||
if st.WeeklyAdjustBudget > st.WeeklyTotalBudget {
|
||||
st.WeeklyAdjustBudget = st.WeeklyTotalBudget
|
||||
}
|
||||
|
||||
// 3. 小改动路径把周级并发收敛到 1,优先保证稳定与可观察性。
|
||||
st.WeeklyRefineConcurrency = 1
|
||||
|
||||
emitStage(
|
||||
"schedule_plan.quick_refine.done",
|
||||
fmt.Sprintf(
|
||||
"快速微调预算已生效:总预算=%d,有效预算=%d,并发=%d。",
|
||||
st.WeeklyTotalBudget,
|
||||
st.WeeklyAdjustBudget,
|
||||
st.WeeklyRefineConcurrency,
|
||||
),
|
||||
)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// clampBudgetUpper 把预算裁剪到“非负且不超过上限”。
|
||||
func clampBudgetUpper(current int, upper int) int {
|
||||
if current < 0 {
|
||||
return 0
|
||||
}
|
||||
if current > upper {
|
||||
return upper
|
||||
}
|
||||
return current
|
||||
}
|
||||
@@ -1,847 +0,0 @@
|
||||
package scheduleplan
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
einoModel "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
const (
|
||||
// weeklyReactRoundTimeout 是周级“单步动作”单轮超时时间。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 当前周级策略是“每轮只做一个动作”,单轮输入较短,超时可比旧版更保守;
|
||||
// 2. 过长超时会放大长尾等待,影响并发周优化的整体收口速度。
|
||||
weeklyReactRoundTimeout = 4 * time.Minute
|
||||
)
|
||||
|
||||
// weeklyRefineWorkerResult 是“单周 worker”输出。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 记录该周优化后的 entries;
|
||||
// 2. 记录预算消耗(总动作/有效动作);
|
||||
// 3. 记录动作日志,供 final_check 生成“过程可解释”总结;
|
||||
// 4. 记录该周摘要,便于最终汇总。
|
||||
type weeklyRefineWorkerResult struct {
|
||||
Week int
|
||||
Entries []model.HybridScheduleEntry
|
||||
TotalUsed int
|
||||
EffectiveUsed int
|
||||
Summary string
|
||||
ActionLogs []string
|
||||
}
|
||||
|
||||
// runWeeklyRefineNode 执行“周级单步优化”。
|
||||
//
|
||||
// 新链路目标:
|
||||
// 1. 把全量周数据拆成“按周并发”执行,降低单次模型输入规模;
|
||||
// 2. 每轮只允许一个动作(Move/Swap)或 done,减少模型犹豫;
|
||||
// 3. 使用“双预算”约束迭代:
|
||||
// 3.1 总动作预算:成功/失败都扣减;
|
||||
// 3.2 有效动作预算:仅成功动作扣减;
|
||||
// 4. 不在该阶段输出 reasoning 文本,改为阶段状态 + 动作结果,避免刷屏。
|
||||
func runWeeklyRefineNode(
|
||||
ctx context.Context,
|
||||
st *SchedulePlanState,
|
||||
chatModel *ark.ChatModel,
|
||||
outChan chan<- string,
|
||||
modelName string,
|
||||
emitStage func(stage, detail string),
|
||||
) (*SchedulePlanState, error) {
|
||||
_ = outChan
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("schedule plan weekly refine: nil state")
|
||||
}
|
||||
if chatModel == nil {
|
||||
return nil, fmt.Errorf("schedule plan weekly refine: model is nil")
|
||||
}
|
||||
if len(st.HybridEntries) == 0 {
|
||||
st.ReactDone = true
|
||||
st.ReactSummary = "无可优化的排程条目。"
|
||||
return st, nil
|
||||
}
|
||||
if strings.TrimSpace(modelName) == "" {
|
||||
modelName = "worker"
|
||||
}
|
||||
|
||||
// 1. 预算与并发兜底。
|
||||
// 1.1 有效预算(旧字段)<=0 时回退默认值;
|
||||
// 1.2 总预算 <=0 时回退默认值;
|
||||
// 1.3 为避免“有效预算 > 总预算”的反直觉状态,做一次归一化修正;
|
||||
// 1.4 周级并发度默认不高于周数,避免空并发浪费。
|
||||
if st.WeeklyAdjustBudget <= 0 {
|
||||
st.WeeklyAdjustBudget = schedulePlanDefaultWeeklyAdjustBudget
|
||||
}
|
||||
if st.WeeklyTotalBudget <= 0 {
|
||||
st.WeeklyTotalBudget = schedulePlanDefaultWeeklyTotalBudget
|
||||
}
|
||||
if st.WeeklyAdjustBudget > st.WeeklyTotalBudget {
|
||||
st.WeeklyAdjustBudget = st.WeeklyTotalBudget
|
||||
}
|
||||
if st.WeeklyRefineConcurrency <= 0 {
|
||||
st.WeeklyRefineConcurrency = schedulePlanDefaultWeeklyRefineConcurrency
|
||||
}
|
||||
|
||||
// 2. 按周拆分输入。
|
||||
weekOrder, weekEntries := splitHybridEntriesByWeek(st.HybridEntries)
|
||||
if len(weekOrder) == 0 {
|
||||
st.ReactDone = true
|
||||
st.ReactSummary = "无可优化的排程条目。"
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 3. 只对“包含 suggested 的周”分配预算,其余周直接透传。
|
||||
activeWeeks := make([]int, 0, len(weekOrder))
|
||||
for _, week := range weekOrder {
|
||||
if countSuggested(weekEntries[week]) > 0 {
|
||||
activeWeeks = append(activeWeeks, week)
|
||||
}
|
||||
}
|
||||
if len(activeWeeks) == 0 {
|
||||
st.ReactDone = true
|
||||
st.ReactSummary = "当前方案中没有可调整的 suggested 任务,已跳过周级优化。"
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 3.1 强制“每个有效周至少 1 个总预算 + 1 个有效预算”。
|
||||
// 3.1.1 判断依据:任何有效周都必须有机会进入优化,避免出现 0 预算跳过。
|
||||
// 3.1.2 实现方式:当全局预算不足时,自动抬升到 activeWeeks 数量。
|
||||
// 3.1.3 失败/兜底:该步骤仅做内存字段修正,不依赖外部资源,不会新增失败点。
|
||||
minBudgetRequired := len(activeWeeks)
|
||||
if st.WeeklyTotalBudget < minBudgetRequired {
|
||||
st.WeeklyTotalBudget = minBudgetRequired
|
||||
}
|
||||
if st.WeeklyAdjustBudget < minBudgetRequired {
|
||||
st.WeeklyAdjustBudget = minBudgetRequired
|
||||
}
|
||||
if st.WeeklyAdjustBudget > st.WeeklyTotalBudget {
|
||||
st.WeeklyAdjustBudget = st.WeeklyTotalBudget
|
||||
}
|
||||
|
||||
totalBudgetByWeek, effectiveBudgetByWeek, weeklyLoads, coveredWeeks := splitWeeklyBudgetsByLoad(
|
||||
activeWeeks,
|
||||
weekEntries,
|
||||
st.WeeklyTotalBudget,
|
||||
st.WeeklyAdjustBudget,
|
||||
)
|
||||
budgetIndexByWeek := make(map[int]int, len(activeWeeks))
|
||||
for idx, week := range activeWeeks {
|
||||
budgetIndexByWeek[week] = idx
|
||||
}
|
||||
if coveredWeeks < len(activeWeeks) {
|
||||
emitStage(
|
||||
"schedule_plan.weekly_refine.budget_fallback",
|
||||
fmt.Sprintf(
|
||||
"周级预算不足以覆盖全部有效周(有效周=%d,至少需预算=%d;当前总预算=%d,有效预算=%d)。已按周负载优先覆盖 %d 个周,其余周预算置 0 并透传原方案。",
|
||||
len(activeWeeks),
|
||||
len(activeWeeks),
|
||||
st.WeeklyTotalBudget,
|
||||
st.WeeklyAdjustBudget,
|
||||
coveredWeeks,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
workerConcurrency := st.WeeklyRefineConcurrency
|
||||
if workerConcurrency > len(activeWeeks) {
|
||||
workerConcurrency = len(activeWeeks)
|
||||
}
|
||||
if workerConcurrency <= 0 {
|
||||
workerConcurrency = 1
|
||||
}
|
||||
|
||||
emitStage(
|
||||
"schedule_plan.weekly_refine.start",
|
||||
fmt.Sprintf(
|
||||
"周级单步优化开始:周数=%d(可优化=%d),并发度=%d,总动作预算=%d,有效动作预算=%d,覆盖周=%d/%d,周负载=%v。",
|
||||
len(weekOrder),
|
||||
len(activeWeeks),
|
||||
workerConcurrency,
|
||||
st.WeeklyTotalBudget,
|
||||
st.WeeklyAdjustBudget,
|
||||
coveredWeeks,
|
||||
len(activeWeeks),
|
||||
weeklyLoads,
|
||||
),
|
||||
)
|
||||
|
||||
// 4. 并发执行“单周 worker”。
|
||||
sem := make(chan struct{}, workerConcurrency)
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
|
||||
workerResults := make(map[int]weeklyRefineWorkerResult, len(weekOrder))
|
||||
var firstErr error
|
||||
completedWeeks := 0
|
||||
|
||||
for _, week := range weekOrder {
|
||||
week := week
|
||||
entries := deepCopyEntries(weekEntries[week])
|
||||
|
||||
// 4.1 没有 suggested 的周直接透传,不占模型调用预算。
|
||||
if countSuggested(entries) == 0 {
|
||||
workerResults[week] = weeklyRefineWorkerResult{
|
||||
Week: week,
|
||||
Entries: entries,
|
||||
Summary: fmt.Sprintf("W%d 无 suggested 任务,跳过周级优化。", week),
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
defer func() { <-sem }()
|
||||
case <-ctx.Done():
|
||||
mu.Lock()
|
||||
if firstErr == nil {
|
||||
firstErr = ctx.Err()
|
||||
}
|
||||
completedWeeks++
|
||||
workerResults[week] = weeklyRefineWorkerResult{
|
||||
Week: week,
|
||||
Entries: entries,
|
||||
Summary: fmt.Sprintf("W%d 优化取消,已保留原方案。", week),
|
||||
}
|
||||
emitStage(
|
||||
"schedule_plan.weekly_refine.week_done",
|
||||
fmt.Sprintf("W%d 已取消并回退原方案。(进度 %d/%d)", week, completedWeeks, len(activeWeeks)),
|
||||
)
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
idx := budgetIndexByWeek[week]
|
||||
weekTotalBudget := totalBudgetByWeek[idx]
|
||||
weekEffectiveBudget := effectiveBudgetByWeek[idx]
|
||||
emitStage(
|
||||
"schedule_plan.weekly_refine.week_start",
|
||||
fmt.Sprintf(
|
||||
"W%d 开始周级单步优化:总预算=%d,有效预算=%d。",
|
||||
week,
|
||||
weekTotalBudget,
|
||||
weekEffectiveBudget,
|
||||
),
|
||||
)
|
||||
|
||||
result, workerErr := runSingleWeekRefineWorker(
|
||||
ctx,
|
||||
chatModel,
|
||||
modelName,
|
||||
week,
|
||||
entries,
|
||||
st.Constraints,
|
||||
weeklyPlanningWindow{
|
||||
Enabled: st.HasPlanningWindow,
|
||||
StartWeek: st.PlanStartWeek,
|
||||
StartDay: st.PlanStartDay,
|
||||
EndWeek: st.PlanEndWeek,
|
||||
EndDay: st.PlanEndDay,
|
||||
},
|
||||
weekTotalBudget,
|
||||
weekEffectiveBudget,
|
||||
emitStage,
|
||||
)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if workerErr != nil && firstErr == nil {
|
||||
firstErr = workerErr
|
||||
}
|
||||
completedWeeks++
|
||||
workerResults[week] = result
|
||||
emitStage(
|
||||
"schedule_plan.weekly_refine.week_done",
|
||||
fmt.Sprintf(
|
||||
"W%d 周级优化完成(总已用=%d/%d,有效已用=%d/%d)。(进度 %d/%d)",
|
||||
week,
|
||||
result.TotalUsed,
|
||||
weekTotalBudget,
|
||||
result.EffectiveUsed,
|
||||
weekEffectiveBudget,
|
||||
completedWeeks,
|
||||
len(activeWeeks),
|
||||
),
|
||||
)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// 5. 汇总 worker 结果,重建全量 HybridEntries。
|
||||
mergedEntries := make([]model.HybridScheduleEntry, 0, len(st.HybridEntries))
|
||||
st.WeeklyTotalUsed = 0
|
||||
st.WeeklyAdjustUsed = 0
|
||||
st.WeeklyActionLogs = st.WeeklyActionLogs[:0]
|
||||
weekSummaries := make([]string, 0, len(weekOrder))
|
||||
|
||||
for _, week := range weekOrder {
|
||||
result, exists := workerResults[week]
|
||||
if !exists {
|
||||
// 理论上不会发生;兜底透传该周原始条目。
|
||||
result = weeklyRefineWorkerResult{
|
||||
Week: week,
|
||||
Entries: deepCopyEntries(weekEntries[week]),
|
||||
Summary: fmt.Sprintf("W%d 未拿到 worker 结果,已保留原方案。", week),
|
||||
}
|
||||
}
|
||||
mergedEntries = append(mergedEntries, result.Entries...)
|
||||
st.WeeklyTotalUsed += result.TotalUsed
|
||||
st.WeeklyAdjustUsed += result.EffectiveUsed
|
||||
st.WeeklyActionLogs = append(st.WeeklyActionLogs, result.ActionLogs...)
|
||||
if strings.TrimSpace(result.Summary) != "" {
|
||||
weekSummaries = append(weekSummaries, result.Summary)
|
||||
}
|
||||
}
|
||||
sortHybridEntries(mergedEntries)
|
||||
st.HybridEntries = mergedEntries
|
||||
|
||||
// 6. 生成阶段摘要并收口状态。
|
||||
st.ReactDone = true
|
||||
st.ReactRound = st.WeeklyTotalUsed
|
||||
if len(weekSummaries) == 0 {
|
||||
st.ReactSummary = fmt.Sprintf(
|
||||
"周级优化完成:总动作已用 %d/%d,有效动作已用 %d/%d。",
|
||||
st.WeeklyTotalUsed, st.WeeklyTotalBudget, st.WeeklyAdjustUsed, st.WeeklyAdjustBudget,
|
||||
)
|
||||
} else {
|
||||
st.ReactSummary = strings.Join(weekSummaries, ";")
|
||||
}
|
||||
if firstErr != nil {
|
||||
emitStage("schedule_plan.weekly_refine.partial_error", fmt.Sprintf("周级并发优化部分失败,已自动保留失败周原方案。原因:%s", firstErr.Error()))
|
||||
}
|
||||
emitStage(
|
||||
"schedule_plan.weekly_refine.done",
|
||||
fmt.Sprintf(
|
||||
"周级单步优化结束:总动作已用 %d/%d,有效动作已用 %d/%d。",
|
||||
st.WeeklyTotalUsed, st.WeeklyTotalBudget, st.WeeklyAdjustUsed, st.WeeklyAdjustBudget,
|
||||
),
|
||||
)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// runSingleWeekRefineWorker 执行“单周 + 单步动作”循环。
|
||||
//
|
||||
// 流程说明:
|
||||
// 1. 每轮只允许 1 个工具调用或 done;
|
||||
// 2. 每次工具调用都扣“总预算”;
|
||||
// 3. 仅成功调用再扣“有效预算”;
|
||||
// 4. 工具结果会回灌到下一轮上下文,驱动“走一步看一步”。
|
||||
func runSingleWeekRefineWorker(
|
||||
ctx context.Context,
|
||||
chatModel *ark.ChatModel,
|
||||
modelName string,
|
||||
week int,
|
||||
entries []model.HybridScheduleEntry,
|
||||
constraints []string,
|
||||
window weeklyPlanningWindow,
|
||||
totalBudget int,
|
||||
effectiveBudget int,
|
||||
emitStage func(stage, detail string),
|
||||
) (weeklyRefineWorkerResult, error) {
|
||||
result := weeklyRefineWorkerResult{
|
||||
Week: week,
|
||||
Entries: deepCopyEntries(entries),
|
||||
}
|
||||
|
||||
if totalBudget <= 0 || effectiveBudget <= 0 {
|
||||
result.Summary = fmt.Sprintf("W%d 预算为 0,跳过周级优化。", week)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
hybridJSON, err := json.Marshal(result.Entries)
|
||||
if err != nil {
|
||||
result.Summary = fmt.Sprintf("W%d 序列化失败,已保留原方案。", week)
|
||||
return result, fmt.Errorf("周级 worker 序列化失败 week=%d: %w", week, err)
|
||||
}
|
||||
constraintsText := "无"
|
||||
if len(constraints) > 0 {
|
||||
constraintsText = strings.Join(constraints, "、")
|
||||
}
|
||||
|
||||
messages := []*schema.Message{
|
||||
schema.SystemMessage(
|
||||
renderWeeklyPromptWithBudget(
|
||||
effectiveBudget-result.EffectiveUsed,
|
||||
effectiveBudget,
|
||||
result.EffectiveUsed,
|
||||
totalBudget-result.TotalUsed,
|
||||
totalBudget,
|
||||
result.TotalUsed,
|
||||
),
|
||||
),
|
||||
schema.UserMessage(fmt.Sprintf(
|
||||
"当前处理周次:W%d\n以下是当前周混合日程(JSON):\n%s\n\n用户约束:%s\n\n注意:本 worker 仅允许优化 W%d 内的任务。",
|
||||
week,
|
||||
string(hybridJSON),
|
||||
constraintsText,
|
||||
week,
|
||||
)),
|
||||
}
|
||||
|
||||
for result.TotalUsed < totalBudget && result.EffectiveUsed < effectiveBudget {
|
||||
remainingTotal := totalBudget - result.TotalUsed
|
||||
remainingEffective := effectiveBudget - result.EffectiveUsed
|
||||
emitStage(
|
||||
"schedule_plan.weekly_refine.round",
|
||||
fmt.Sprintf(
|
||||
"W%d 新一轮决策:总预算剩余=%d/%d,有效预算剩余=%d/%d。",
|
||||
week,
|
||||
remainingTotal,
|
||||
totalBudget,
|
||||
remainingEffective,
|
||||
effectiveBudget,
|
||||
),
|
||||
)
|
||||
|
||||
// 1. 每轮更新系统提示中的预算占位符。
|
||||
messages[0] = schema.SystemMessage(
|
||||
renderWeeklyPromptWithBudget(
|
||||
remainingEffective,
|
||||
effectiveBudget,
|
||||
result.EffectiveUsed,
|
||||
remainingTotal,
|
||||
totalBudget,
|
||||
result.TotalUsed,
|
||||
),
|
||||
)
|
||||
|
||||
roundCtx, cancel := context.WithTimeout(ctx, weeklyReactRoundTimeout)
|
||||
content, genErr := generateWeeklyRefineRound(roundCtx, chatModel, messages)
|
||||
cancel()
|
||||
if genErr != nil {
|
||||
result.Summary = fmt.Sprintf("W%d 模型调用失败,已保留当前结果。", week)
|
||||
return result, fmt.Errorf("周级 worker 调用失败 week=%d: %w", week, genErr)
|
||||
}
|
||||
|
||||
parsed, parseErr := parseReactLLMOutput(content)
|
||||
if parseErr != nil {
|
||||
result.Summary = fmt.Sprintf("W%d 输出格式异常,已保留当前结果。", week)
|
||||
return result, fmt.Errorf("周级 worker 解析失败 week=%d: %w", week, parseErr)
|
||||
}
|
||||
|
||||
// 2. done=true 直接正常结束,不再消耗预算。
|
||||
if parsed.Done {
|
||||
summary := strings.TrimSpace(parsed.Summary)
|
||||
if summary == "" {
|
||||
summary = fmt.Sprintf(
|
||||
"W%d 优化结束(总动作已用 %d/%d,有效动作已用 %d/%d)。",
|
||||
week,
|
||||
result.TotalUsed, totalBudget,
|
||||
result.EffectiveUsed, effectiveBudget,
|
||||
)
|
||||
}
|
||||
result.Summary = summary
|
||||
break
|
||||
}
|
||||
|
||||
// 3. 只取一个工具调用,强制单步。
|
||||
call, warn := pickSingleToolCall(parsed.ToolCalls)
|
||||
if call == nil {
|
||||
result.Summary = fmt.Sprintf(
|
||||
"W%d 无可执行动作,提前结束(总动作已用 %d/%d,有效动作已用 %d/%d)。",
|
||||
week,
|
||||
result.TotalUsed, totalBudget,
|
||||
result.EffectiveUsed, effectiveBudget,
|
||||
)
|
||||
break
|
||||
}
|
||||
if warn != "" {
|
||||
result.ActionLogs = append(result.ActionLogs, fmt.Sprintf("W%d 警告:%s", week, warn))
|
||||
}
|
||||
|
||||
// 4. 执行工具:总预算总是扣减;有效预算仅成功时扣减。
|
||||
result.TotalUsed++
|
||||
nextEntries, toolResult := dispatchWeeklySingleActionTool(result.Entries, *call, week, window)
|
||||
if toolResult.Success {
|
||||
result.EffectiveUsed++
|
||||
result.Entries = nextEntries
|
||||
}
|
||||
|
||||
logLine := fmt.Sprintf(
|
||||
"W%d 动作[%s] 结果=%t,总预算=%d/%d,有效预算=%d/%d,详情=%s",
|
||||
week,
|
||||
toolResult.Tool,
|
||||
toolResult.Success,
|
||||
result.TotalUsed,
|
||||
totalBudget,
|
||||
result.EffectiveUsed,
|
||||
effectiveBudget,
|
||||
toolResult.Result,
|
||||
)
|
||||
result.ActionLogs = append(result.ActionLogs, logLine)
|
||||
statusMark := "FAIL"
|
||||
if toolResult.Success {
|
||||
statusMark = "OK"
|
||||
}
|
||||
emitStage("schedule_plan.weekly_refine.tool_call", fmt.Sprintf("[%s] %s", statusMark, logLine))
|
||||
|
||||
// 5. 把“本轮输出 + 工具结果”拼回下一轮上下文,驱动增量推理。
|
||||
messages = append(messages, schema.AssistantMessage(content, nil))
|
||||
toolResultJSON, _ := json.Marshal([]reactToolResult{toolResult})
|
||||
messages = append(messages, schema.UserMessage(
|
||||
fmt.Sprintf(
|
||||
"上一轮工具结果:%s\n当前预算:总剩余=%d,有效剩余=%d\n请继续按“单步动作”规则决策(仅一个工具调用或 done)。",
|
||||
string(toolResultJSON),
|
||||
totalBudget-result.TotalUsed,
|
||||
effectiveBudget-result.EffectiveUsed,
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(result.Summary) == "" {
|
||||
result.Summary = fmt.Sprintf(
|
||||
"W%d 预算耗尽停止(总动作已用 %d/%d,有效动作已用 %d/%d)。",
|
||||
week,
|
||||
result.TotalUsed, totalBudget,
|
||||
result.EffectiveUsed, effectiveBudget,
|
||||
)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// generateWeeklyRefineRound 调用模型生成“单周单步”决策输出。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 周级仍保留 thinking(提高复杂排程准确率);
|
||||
// 2. 但不把 reasoning 实时透传给前端,避免刷屏;
|
||||
// 3. 仅返回最终 content,交给 JSON 解析器处理。
|
||||
func generateWeeklyRefineRound(
|
||||
ctx context.Context,
|
||||
chatModel *ark.ChatModel,
|
||||
messages []*schema.Message,
|
||||
) (string, error) {
|
||||
resp, err := chatModel.Generate(
|
||||
ctx,
|
||||
messages,
|
||||
ark.WithThinking(&arkModel.Thinking{Type: arkModel.ThinkingTypeEnabled}),
|
||||
einoModel.WithTemperature(0.2),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp == nil {
|
||||
return "", fmt.Errorf("周级单步调用返回为空")
|
||||
}
|
||||
content := strings.TrimSpace(resp.Content)
|
||||
if content == "" {
|
||||
return "", fmt.Errorf("周级单步调用返回内容为空")
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// renderWeeklyPromptWithBudget 渲染周级单步优化的预算占位符。
|
||||
//
|
||||
// 1. 保留旧占位符 {{budget*}} 兼容历史模板;
|
||||
// 2. 新增 action_total/action_effective 占位符表达双预算语义;
|
||||
// 3. 所有负值都会在这里兜底归零,避免传给模型异常预算。
|
||||
func renderWeeklyPromptWithBudget(
|
||||
remainingEffective int,
|
||||
effectiveBudget int,
|
||||
usedEffective int,
|
||||
remainingTotal int,
|
||||
totalBudget int,
|
||||
usedTotal int,
|
||||
) string {
|
||||
if effectiveBudget <= 0 {
|
||||
effectiveBudget = schedulePlanDefaultWeeklyAdjustBudget
|
||||
}
|
||||
if totalBudget <= 0 {
|
||||
totalBudget = schedulePlanDefaultWeeklyTotalBudget
|
||||
}
|
||||
if remainingEffective < 0 {
|
||||
remainingEffective = 0
|
||||
}
|
||||
if remainingTotal < 0 {
|
||||
remainingTotal = 0
|
||||
}
|
||||
if usedEffective < 0 {
|
||||
usedEffective = 0
|
||||
}
|
||||
if usedTotal < 0 {
|
||||
usedTotal = 0
|
||||
}
|
||||
if usedEffective > effectiveBudget {
|
||||
usedEffective = effectiveBudget
|
||||
}
|
||||
if usedTotal > totalBudget {
|
||||
usedTotal = totalBudget
|
||||
}
|
||||
|
||||
prompt := SchedulePlanWeeklyReactPrompt
|
||||
prompt = strings.ReplaceAll(prompt, "{{action_total_remaining}}", fmt.Sprintf("%d", remainingTotal))
|
||||
prompt = strings.ReplaceAll(prompt, "{{action_total_budget}}", fmt.Sprintf("%d", totalBudget))
|
||||
prompt = strings.ReplaceAll(prompt, "{{action_total_used}}", fmt.Sprintf("%d", usedTotal))
|
||||
prompt = strings.ReplaceAll(prompt, "{{action_effective_remaining}}", fmt.Sprintf("%d", remainingEffective))
|
||||
prompt = strings.ReplaceAll(prompt, "{{action_effective_budget}}", fmt.Sprintf("%d", effectiveBudget))
|
||||
prompt = strings.ReplaceAll(prompt, "{{action_effective_used}}", fmt.Sprintf("%d", usedEffective))
|
||||
|
||||
// 兼容旧模板占位符,避免历史 prompt 残留时出现未替换文本。
|
||||
prompt = strings.ReplaceAll(prompt, "{{budget_remaining}}", fmt.Sprintf("%d", remainingEffective))
|
||||
prompt = strings.ReplaceAll(prompt, "{{budget_total}}", fmt.Sprintf("%d", effectiveBudget))
|
||||
prompt = strings.ReplaceAll(prompt, "{{budget_used}}", fmt.Sprintf("%d", usedEffective))
|
||||
prompt = strings.ReplaceAll(prompt, "{{budget}}", fmt.Sprintf("%d(总额度 %d,已用 %d)", remainingEffective, effectiveBudget, usedEffective))
|
||||
return prompt
|
||||
}
|
||||
|
||||
// pickSingleToolCall 在“单步动作模式”下选择一个工具调用。
|
||||
//
|
||||
// 返回语义:
|
||||
// 1. call=nil:没有可执行工具;
|
||||
// 2. warn 非空:模型返回了多个工具,本轮仅执行第一个。
|
||||
func pickSingleToolCall(toolCalls []reactToolCall) (*reactToolCall, string) {
|
||||
if len(toolCalls) == 0 {
|
||||
return nil, ""
|
||||
}
|
||||
call := toolCalls[0]
|
||||
if len(toolCalls) == 1 {
|
||||
return &call, ""
|
||||
}
|
||||
return &call, fmt.Sprintf("模型返回了 %d 个工具调用,单步模式仅执行第一个:%s", len(toolCalls), call.Tool)
|
||||
}
|
||||
|
||||
// splitHybridEntriesByWeek 按 week 对混合条目分组并返回稳定周序。
|
||||
func splitHybridEntriesByWeek(entries []model.HybridScheduleEntry) ([]int, map[int][]model.HybridScheduleEntry) {
|
||||
byWeek := make(map[int][]model.HybridScheduleEntry)
|
||||
for _, entry := range entries {
|
||||
byWeek[entry.Week] = append(byWeek[entry.Week], entry)
|
||||
}
|
||||
weeks := make([]int, 0, len(byWeek))
|
||||
for week := range byWeek {
|
||||
weeks = append(weeks, week)
|
||||
}
|
||||
sort.Ints(weeks)
|
||||
return weeks, byWeek
|
||||
}
|
||||
|
||||
type weightedBudgetRemainder struct {
|
||||
Index int
|
||||
Remainder int
|
||||
Load int
|
||||
}
|
||||
|
||||
// splitWeeklyBudgetsByLoad 根据“有效周保底 + 周负载加权”拆分预算。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责:返回与 activeWeeks 同索引对齐的总预算/有效预算;
|
||||
// 2. 负责:在预算不足时按负载优先覆盖高负载周;
|
||||
// 3. 不负责:执行周级动作与状态落盘(由 runSingleWeekRefineWorker / runWeeklyRefineNode 负责)。
|
||||
//
|
||||
// 输入输出语义:
|
||||
// 1. coveredWeeks 表示“同时拿到 >=1 总预算和 >=1 有效预算”的周数;
|
||||
// 2. 当任一全局预算 <=0 时,返回全 0;上游将据此跳过对应周优化;
|
||||
// 3. 返回的 weeklyLoads 仅用于可观测性,不参与后续状态持久化。
|
||||
func splitWeeklyBudgetsByLoad(
|
||||
activeWeeks []int,
|
||||
weekEntries map[int][]model.HybridScheduleEntry,
|
||||
totalBudget int,
|
||||
effectiveBudget int,
|
||||
) (totalByWeek []int, effectiveByWeek []int, weeklyLoads []int, coveredWeeks int) {
|
||||
weekCount := len(activeWeeks)
|
||||
if weekCount == 0 {
|
||||
return nil, nil, nil, 0
|
||||
}
|
||||
|
||||
if totalBudget < 0 {
|
||||
totalBudget = 0
|
||||
}
|
||||
if effectiveBudget < 0 {
|
||||
effectiveBudget = 0
|
||||
}
|
||||
|
||||
weeklyLoads = buildWeeklyLoadScores(activeWeeks, weekEntries)
|
||||
totalByWeek = make([]int, weekCount)
|
||||
effectiveByWeek = make([]int, weekCount)
|
||||
if totalBudget == 0 || effectiveBudget == 0 {
|
||||
return totalByWeek, effectiveByWeek, weeklyLoads, 0
|
||||
}
|
||||
|
||||
// 1. 先计算“可保底覆盖周数”。
|
||||
// 1.1 目标是每个有效周至少 1 个总预算 + 1 个有效预算;
|
||||
// 1.2 失败场景:当预算小于有效周数量时,不可能全覆盖;
|
||||
// 1.3 兜底策略:只覆盖高负载周,避免把预算分散到无法执行的周。
|
||||
coveredWeeks = weekCount
|
||||
if totalBudget < coveredWeeks {
|
||||
coveredWeeks = totalBudget
|
||||
}
|
||||
if effectiveBudget < coveredWeeks {
|
||||
coveredWeeks = effectiveBudget
|
||||
}
|
||||
if coveredWeeks <= 0 {
|
||||
return totalByWeek, effectiveByWeek, weeklyLoads, 0
|
||||
}
|
||||
|
||||
coveredIndexes := pickTopLoadWeekIndexes(weeklyLoads, coveredWeeks)
|
||||
for _, idx := range coveredIndexes {
|
||||
totalByWeek[idx]++
|
||||
effectiveByWeek[idx]++
|
||||
}
|
||||
|
||||
// 2. 再把剩余预算按周负载加权分配。
|
||||
// 2.1 判断依据:负载越高,给到的额外预算越多,优先解决高密度周;
|
||||
// 2.2 失败场景:负载异常(<=0)会导致权重失真;
|
||||
// 2.3 兜底策略:权重最小按 1 处理,保证分配可持续、不会 panic。
|
||||
addWeightedBudget(totalByWeek, weeklyLoads, coveredIndexes, totalBudget-coveredWeeks)
|
||||
addWeightedBudget(effectiveByWeek, weeklyLoads, coveredIndexes, effectiveBudget-coveredWeeks)
|
||||
return totalByWeek, effectiveByWeek, weeklyLoads, coveredWeeks
|
||||
}
|
||||
|
||||
// buildWeeklyLoadScores 计算每个有效周的负载评分。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责:以 suggested 任务的节次跨度作为周负载;
|
||||
// 2. 不负责:预算分配策略与排序决策(由 splitWeeklyBudgetsByLoad/pickTopLoadWeekIndexes 负责)。
|
||||
func buildWeeklyLoadScores(
|
||||
activeWeeks []int,
|
||||
weekEntries map[int][]model.HybridScheduleEntry,
|
||||
) []int {
|
||||
loads := make([]int, len(activeWeeks))
|
||||
for idx, week := range activeWeeks {
|
||||
load := 0
|
||||
for _, entry := range weekEntries[week] {
|
||||
if entry.Status != "suggested" {
|
||||
continue
|
||||
}
|
||||
span := entry.SectionTo - entry.SectionFrom + 1
|
||||
if span <= 0 {
|
||||
span = 1
|
||||
}
|
||||
load += span
|
||||
}
|
||||
if load <= 0 {
|
||||
// 兜底:脏数据或异常节次下仍给该周最小权重,避免被完全饿死。
|
||||
load = 1
|
||||
}
|
||||
loads[idx] = load
|
||||
}
|
||||
return loads
|
||||
}
|
||||
|
||||
// pickTopLoadWeekIndexes 选择负载最高的 topN 个周索引。
|
||||
func pickTopLoadWeekIndexes(loads []int, topN int) []int {
|
||||
if topN <= 0 || len(loads) == 0 {
|
||||
return nil
|
||||
}
|
||||
indexes := make([]int, len(loads))
|
||||
for i := range loads {
|
||||
indexes[i] = i
|
||||
}
|
||||
sort.SliceStable(indexes, func(i, j int) bool {
|
||||
left := loads[indexes[i]]
|
||||
right := loads[indexes[j]]
|
||||
if left != right {
|
||||
return left > right
|
||||
}
|
||||
return indexes[i] < indexes[j]
|
||||
})
|
||||
if topN > len(indexes) {
|
||||
topN = len(indexes)
|
||||
}
|
||||
selected := append([]int(nil), indexes[:topN]...)
|
||||
sort.Ints(selected)
|
||||
return selected
|
||||
}
|
||||
|
||||
// addWeightedBudget 把剩余预算按权重分配到目标周。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 先按整数份额分配;
|
||||
// 2. 再按“最大余数法”分发尾差,保证总和严格守恒;
|
||||
// 3. 余数相同时优先高负载周,再按索引稳定排序,避免结果抖动。
|
||||
func addWeightedBudget(
|
||||
budgets []int,
|
||||
loads []int,
|
||||
targetIndexes []int,
|
||||
remainingBudget int,
|
||||
) {
|
||||
if remainingBudget <= 0 || len(targetIndexes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
totalLoad := 0
|
||||
normalizedLoadByIndex := make(map[int]int, len(targetIndexes))
|
||||
for _, idx := range targetIndexes {
|
||||
load := 1
|
||||
if idx >= 0 && idx < len(loads) && loads[idx] > 0 {
|
||||
load = loads[idx]
|
||||
}
|
||||
normalizedLoadByIndex[idx] = load
|
||||
totalLoad += load
|
||||
}
|
||||
if totalLoad <= 0 {
|
||||
// 理论上不会出现;兜底均匀轮询分配,保证不会丢预算。
|
||||
for i := 0; i < remainingBudget; i++ {
|
||||
budgets[targetIndexes[i%len(targetIndexes)]]++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
allocated := 0
|
||||
remainders := make([]weightedBudgetRemainder, 0, len(targetIndexes))
|
||||
for _, idx := range targetIndexes {
|
||||
load := normalizedLoadByIndex[idx]
|
||||
shareProduct := remainingBudget * load
|
||||
share := shareProduct / totalLoad
|
||||
budgets[idx] += share
|
||||
allocated += share
|
||||
remainders = append(remainders, weightedBudgetRemainder{
|
||||
Index: idx,
|
||||
Remainder: shareProduct % totalLoad,
|
||||
Load: load,
|
||||
})
|
||||
}
|
||||
|
||||
left := remainingBudget - allocated
|
||||
if left <= 0 {
|
||||
return
|
||||
}
|
||||
sort.SliceStable(remainders, func(i, j int) bool {
|
||||
if remainders[i].Remainder != remainders[j].Remainder {
|
||||
return remainders[i].Remainder > remainders[j].Remainder
|
||||
}
|
||||
if remainders[i].Load != remainders[j].Load {
|
||||
return remainders[i].Load > remainders[j].Load
|
||||
}
|
||||
return remainders[i].Index < remainders[j].Index
|
||||
})
|
||||
for i := 0; i < left; i++ {
|
||||
budgets[remainders[i%len(remainders)].Index]++
|
||||
}
|
||||
}
|
||||
|
||||
// sortHybridEntries 对条目做稳定排序,确保后续预览输出稳定。
|
||||
func sortHybridEntries(entries []model.HybridScheduleEntry) {
|
||||
sort.SliceStable(entries, func(i, j int) bool {
|
||||
left := entries[i]
|
||||
right := entries[j]
|
||||
if left.Week != right.Week {
|
||||
return left.Week < right.Week
|
||||
}
|
||||
if left.DayOfWeek != right.DayOfWeek {
|
||||
return left.DayOfWeek < right.DayOfWeek
|
||||
}
|
||||
if left.SectionFrom != right.SectionFrom {
|
||||
return left.SectionFrom < right.SectionFrom
|
||||
}
|
||||
if left.SectionTo != right.SectionTo {
|
||||
return left.SectionTo < right.SectionTo
|
||||
}
|
||||
if left.Status != right.Status {
|
||||
// existing 放前,suggested 放后,便于观察课表底板与建议层。
|
||||
return left.Status < right.Status
|
||||
}
|
||||
return left.Name < right.Name
|
||||
})
|
||||
}
|
||||
@@ -1,128 +0,0 @@
|
||||
package scheduleplan
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// schedulePlanRunner 是“单次图执行”的请求级依赖容器。
|
||||
//
|
||||
// 设计目标:
|
||||
// 1. 把节点运行所需依赖(model/deps/emit/extra/history)就近收口;
|
||||
// 2. 让 graph.go 只保留“节点连线与分支决策”,提升可读性;
|
||||
// 3. 避免在 graph.go 里重复出现大量闭包和参数透传。
|
||||
type schedulePlanRunner struct {
|
||||
chatModel *ark.ChatModel
|
||||
deps SchedulePlanToolDeps
|
||||
emitStage func(stage, detail string)
|
||||
userMessage string
|
||||
extra map[string]any
|
||||
chatHistory []*schema.Message
|
||||
|
||||
// weekly refine 需要的上下文
|
||||
outChan chan<- string
|
||||
modelName string
|
||||
|
||||
// daily refine 并发度
|
||||
dailyRefineConcurrency int
|
||||
}
|
||||
|
||||
func newSchedulePlanRunner(
|
||||
chatModel *ark.ChatModel,
|
||||
deps SchedulePlanToolDeps,
|
||||
emitStage func(stage, detail string),
|
||||
userMessage string,
|
||||
extra map[string]any,
|
||||
chatHistory []*schema.Message,
|
||||
outChan chan<- string,
|
||||
modelName string,
|
||||
dailyRefineConcurrency int,
|
||||
) *schedulePlanRunner {
|
||||
return &schedulePlanRunner{
|
||||
chatModel: chatModel,
|
||||
deps: deps,
|
||||
emitStage: emitStage,
|
||||
userMessage: userMessage,
|
||||
extra: extra,
|
||||
chatHistory: chatHistory,
|
||||
outChan: outChan,
|
||||
modelName: modelName,
|
||||
dailyRefineConcurrency: dailyRefineConcurrency,
|
||||
}
|
||||
}
|
||||
|
||||
// 节点方法适配层
|
||||
|
||||
func (r *schedulePlanRunner) planNode(ctx context.Context, st *SchedulePlanState) (*SchedulePlanState, error) {
|
||||
return runPlanNode(ctx, st, r.chatModel, r.userMessage, r.extra, r.chatHistory, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *schedulePlanRunner) roughBuildNode(ctx context.Context, st *SchedulePlanState) (*SchedulePlanState, error) {
|
||||
return runRoughBuildNode(ctx, st, r.deps, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *schedulePlanRunner) dailySplitNode(ctx context.Context, st *SchedulePlanState) (*SchedulePlanState, error) {
|
||||
return runDailySplitNode(ctx, st, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *schedulePlanRunner) quickRefineNode(ctx context.Context, st *SchedulePlanState) (*SchedulePlanState, error) {
|
||||
return runQuickRefineNode(ctx, st, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *schedulePlanRunner) dailyRefineNode(ctx context.Context, st *SchedulePlanState) (*SchedulePlanState, error) {
|
||||
return runDailyRefineNode(ctx, st, r.chatModel, r.dailyRefineConcurrency, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *schedulePlanRunner) mergeNode(ctx context.Context, st *SchedulePlanState) (*SchedulePlanState, error) {
|
||||
return runMergeNode(ctx, st, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *schedulePlanRunner) weeklyRefineNode(ctx context.Context, st *SchedulePlanState) (*SchedulePlanState, error) {
|
||||
return runWeeklyRefineNode(ctx, st, r.chatModel, r.outChan, r.modelName, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *schedulePlanRunner) finalCheckNode(ctx context.Context, st *SchedulePlanState) (*SchedulePlanState, error) {
|
||||
return runFinalCheckNode(ctx, st, r.chatModel, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *schedulePlanRunner) returnPreviewNode(ctx context.Context, st *SchedulePlanState) (*SchedulePlanState, error) {
|
||||
return runReturnPreviewNode(ctx, st, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *schedulePlanRunner) exitNode(_ context.Context, st *SchedulePlanState) (*SchedulePlanState, error) {
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 分支决策适配层
|
||||
|
||||
func (r *schedulePlanRunner) nextAfterPlan(_ context.Context, st *SchedulePlanState) (string, error) {
|
||||
return selectNextAfterPlan(st), nil
|
||||
}
|
||||
|
||||
// nextAfterRoughBuild 根据粗排构建结果决定后续路径。
|
||||
//
|
||||
// 规则:
|
||||
// 1. 没有可优化条目 -> exit;
|
||||
// 2. task_class_ids >= 2 -> dailySplit(多任务类混排,先做日内并发);
|
||||
// 3. task_class_ids == 1 -> weeklyRefine(单任务类直接周级配平)。
|
||||
func (r *schedulePlanRunner) nextAfterRoughBuild(_ context.Context, st *SchedulePlanState) (string, error) {
|
||||
if st == nil || len(st.HybridEntries) == 0 {
|
||||
return schedulePlanGraphNodeExit, nil
|
||||
}
|
||||
|
||||
// 1. 连续微调且判定为 small:先走快速微调节点,收缩预算后再进 weekly。
|
||||
if st.IsAdjustment && st.AdjustmentScope == schedulePlanAdjustmentScopeSmall {
|
||||
return schedulePlanGraphNodeQuickRefine, nil
|
||||
}
|
||||
// 2. 连续微调且判定为 medium:直接走 weekly,跳过 daily。
|
||||
if st.IsAdjustment && st.AdjustmentScope == schedulePlanAdjustmentScopeMedium {
|
||||
return schedulePlanGraphNodeWeeklyRefine, nil
|
||||
}
|
||||
// 3. large 或非微调:保持原有逻辑,多任务类走 daily,单任务类直达 weekly。
|
||||
if len(st.TaskClassIDs) >= 2 {
|
||||
return schedulePlanGraphNodeDailySplit, nil
|
||||
}
|
||||
return schedulePlanGraphNodeWeeklyRefine, nil
|
||||
}
|
||||
@@ -1,287 +0,0 @@
|
||||
package scheduleplan
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
)
|
||||
|
||||
const (
|
||||
// schedulePlanTimezoneName 是排程链路默认业务时区。
|
||||
// 与随口记保持一致,固定东八区,避免容器运行在 UTC 导致"明天/今晚"偏移。
|
||||
schedulePlanTimezoneName = "Asia/Shanghai"
|
||||
|
||||
// schedulePlanDatetimeLayout 是排程链路内部统一的分钟级时间格式。
|
||||
schedulePlanDatetimeLayout = "2006-01-02 15:04"
|
||||
|
||||
// schedulePlanDefaultDailyRefineConcurrency 是日内并发优化默认并发度。
|
||||
// 这里给一个保守默认值,避免未配置时直接把模型并发打满导致限流。
|
||||
schedulePlanDefaultDailyRefineConcurrency = 3
|
||||
|
||||
// schedulePlanDefaultWeeklyAdjustBudget 是周级配平默认调整额度。
|
||||
// 额度存在的目的:
|
||||
// 1. 防止周级 ReAct 过度调整导致震荡;
|
||||
// 2. 控制 token 与时延成本;
|
||||
// 3. 让方案改动更可解释。
|
||||
schedulePlanDefaultWeeklyAdjustBudget = 5
|
||||
|
||||
// schedulePlanDefaultWeeklyTotalBudget 是周级“总尝试次数”默认预算。
|
||||
//
|
||||
// 设计意图:
|
||||
// 1. 总预算统计“动作尝试次数”(成功/失败都记一次);
|
||||
// 2. 有效预算统计“成功动作次数”(仅成功时记一次);
|
||||
// 3. 通过双预算把“探索次数”和“有效改动次数”分离,降低模型无效空转成本。
|
||||
schedulePlanDefaultWeeklyTotalBudget = 8
|
||||
|
||||
// schedulePlanDefaultWeeklyRefineConcurrency 是周级“按周并发”默认并发度。
|
||||
// 说明:
|
||||
// 1. 周级输入规模通常比单天更大,默认并发度不宜过高,避免触发模型侧限流;
|
||||
// 2. 可在运行时按请求状态覆盖。
|
||||
schedulePlanDefaultWeeklyRefineConcurrency = 2
|
||||
|
||||
// schedulePlanAdjustmentScopeSmall 表示“小改动微调”。
|
||||
// 语义:优先走快速路径,只做轻量周级调整。
|
||||
schedulePlanAdjustmentScopeSmall = "small"
|
||||
// schedulePlanAdjustmentScopeMedium 表示“中等改动微调”。
|
||||
// 语义:跳过日内拆分,直接进入周级配平。
|
||||
schedulePlanAdjustmentScopeMedium = "medium"
|
||||
// schedulePlanAdjustmentScopeLarge 表示“大改动重排”。
|
||||
// 语义:必要时重新走全量路径(日内并发 + 周级配平)。
|
||||
schedulePlanAdjustmentScopeLarge = "large"
|
||||
)
|
||||
|
||||
// DayGroup 是“按天拆分后”的最小优化单元。
|
||||
//
|
||||
// 设计目的:
|
||||
// 1. 把全量周视角数据拆成“单天小包”,降低日内 ReAct 输入规模;
|
||||
// 2. 支持并发优化不同天的数据,缩短整体等待;
|
||||
// 3. 通过 SkipRefine 让低收益天数直接跳过,节省模型调用成本。
|
||||
type DayGroup struct {
|
||||
Week int
|
||||
DayOfWeek int
|
||||
Entries []model.HybridScheduleEntry
|
||||
SkipRefine bool
|
||||
}
|
||||
|
||||
// SchedulePlanState 是“智能排程”链路在 graph 节点间传递的统一状态容器。
|
||||
//
|
||||
// 设计目标:
|
||||
// 1) 收拢排程请求全生命周期的上下文,降低节点间参数散落;
|
||||
// 2) 支持“粗排 -> 日内并发优化 -> 周级配平 -> 终审校验”的完整链路追踪;
|
||||
// 3) 支持连续对话微调:保留上版方案 + 本次约束变更,便于增量重排。
|
||||
type SchedulePlanState struct {
|
||||
// ── 基础上下文 ──
|
||||
TraceID string
|
||||
UserID int
|
||||
ConversationID string
|
||||
RequestNow time.Time
|
||||
RequestNowText string
|
||||
|
||||
// ── plan 节点输出 ──
|
||||
|
||||
// UserIntent 是模型对用户排程意图的结构化摘要(如"帮我安排高数复习计划")。
|
||||
UserIntent string
|
||||
// Constraints 是用户提出的硬约束列表(如 ["早八不排", "周末休息"])。
|
||||
Constraints []string
|
||||
// TaskClassIDs 是本次请求携带的任务类集合(统一主语义)。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 这里明确不再维护单值 task_class_id,避免“单值和切片同时存在”导致语义漂移;
|
||||
// 2. 分流依据统一为 len(TaskClassIDs):
|
||||
// 2.1 len==1:跳过 daily 并发,直接进入 weekly refine;
|
||||
// 2.2 len>=2:进入 daily 并发后再 weekly refine;
|
||||
// 3. 输入清洗(去重、过滤非法值)由 plan 节点完成,这里只承载最终状态。
|
||||
TaskClassIDs []int
|
||||
// Strategy 是排程策略(steady/rapid),默认 steady。
|
||||
Strategy string
|
||||
// TaskTags 是“任务项 ID -> 认知类型标签”的映射。
|
||||
// 使用 ID 而不是名称,目的是规避“同名任务”带来的映射冲突。
|
||||
TaskTags map[int]string
|
||||
// TaskTagHintsByName 是“任务名称 -> 认知类型标签”的临时映射。
|
||||
// 该字段只作为 plan 输出兼容层:
|
||||
// 1. 若模型暂时给不出 task_item_id,只给名称;
|
||||
// 2. 后续在 hybridBuild/dailySplit 阶段再转换为 TaskTags(ID 维度)。
|
||||
TaskTagHintsByName map[string]string
|
||||
|
||||
// ── preview 节点输出 ──
|
||||
|
||||
// CandidatePlans 是粗排算法生成的候选方案(展示型结构,供后续节点做预览与总结)。
|
||||
CandidatePlans []model.UserWeekSchedule
|
||||
// AllocatedItems 是粗排算法已分配的任务项(EmbeddedTime 已回填),供 ReAct 精排使用。
|
||||
AllocatedItems []model.TaskClassItem
|
||||
// HasPlanningWindow 标记是否成功解析出“任务类时间窗”的相对周/天边界。
|
||||
//
|
||||
// 语义:
|
||||
// 1. true:PlanStart*/PlanEnd* 字段可用于 Move 工具的硬边界校验;
|
||||
// 2. false:表示当前运行未拿到窗口信息(例如依赖未注入),工具层将仅做基础校验。
|
||||
HasPlanningWindow bool
|
||||
// PlanStartWeek / PlanStartDay 表示全局排程窗口起点(相对周/天)。
|
||||
PlanStartWeek int
|
||||
PlanStartDay int
|
||||
// PlanEndWeek / PlanEndDay 表示全局排程窗口终点(相对周/天)。
|
||||
PlanEndWeek int
|
||||
PlanEndDay int
|
||||
|
||||
// ── 日内并发优化阶段 ──
|
||||
|
||||
// DailyGroups 是按 (week, day) 拆分后的单日优化输入。
|
||||
// 结构:week -> day -> DayGroup。
|
||||
DailyGroups map[int]map[int]*DayGroup
|
||||
// DailyResults 是单日优化输出。
|
||||
// 结构:week -> day -> []HybridScheduleEntry。
|
||||
DailyResults map[int]map[int][]model.HybridScheduleEntry
|
||||
// DailyRefineConcurrency 是日内并发优化的并发度。
|
||||
// 说明:该值由配置注入,可按环境调节。
|
||||
DailyRefineConcurrency int
|
||||
|
||||
// ── 周级 ReAct 精排阶段 ──
|
||||
|
||||
// HybridEntries 是混合日程条目列表,包含既有日程(existing)和粗排建议(suggested)。
|
||||
// 周级 ReAct 工具直接在此切片上操作(内存修改,不涉及 DB)。
|
||||
HybridEntries []model.HybridScheduleEntry
|
||||
// MergeSnapshot 是 merge 后快照。
|
||||
// 终审失败时回退到该快照,确保至少保留“日内优化成果”。
|
||||
MergeSnapshot []model.HybridScheduleEntry
|
||||
// ReactRound 当前周级 ReAct 循环轮次。
|
||||
ReactRound int
|
||||
// ReactMaxRound 周级 ReAct 最大循环轮次。
|
||||
ReactMaxRound int
|
||||
// ReactSummary 周级 ReAct 输出的优化摘要。
|
||||
ReactSummary string
|
||||
// ReactDone 标记周级 ReAct 是否已完成。
|
||||
ReactDone bool
|
||||
// WeeklyAdjustBudget 是周级跨天调整额度上限。
|
||||
// 语义:有效动作预算(仅工具调用成功时扣减)。
|
||||
WeeklyAdjustBudget int
|
||||
// WeeklyAdjustUsed 是周级跨天调整已使用额度。
|
||||
// 语义:有效动作已使用次数(仅成功调用时递增)。
|
||||
WeeklyAdjustUsed int
|
||||
// WeeklyTotalBudget 是周级总动作预算。
|
||||
// 语义:总尝试次数预算(成功/失败都扣减)。
|
||||
WeeklyTotalBudget int
|
||||
// WeeklyTotalUsed 是周级总动作已使用次数。
|
||||
// 语义:成功/失败每执行一次工具调用都递增。
|
||||
WeeklyTotalUsed int
|
||||
// WeeklyRefineConcurrency 是周级“按周并发”并发度。
|
||||
WeeklyRefineConcurrency int
|
||||
// WeeklyActionLogs 记录周级优化阶段的关键动作流水。
|
||||
//
|
||||
// 设计目的:
|
||||
// 1. 供 final_check 的总结模型理解“优化过程”,而非只看最终静态结果;
|
||||
// 2. 供调试排查时快速回放“每轮做了什么动作、是否成功、为何失败”。
|
||||
WeeklyActionLogs []string
|
||||
|
||||
// ── 连续对话微调 ──
|
||||
|
||||
// PreviousPlanJSON 是上一版已落库方案的 JSON 序列化,用于增量微调。
|
||||
// 从对话历史中提取,不做持久化。
|
||||
PreviousPlanJSON string
|
||||
// IsAdjustment 标记本次是否为微调请求(而非全新排程)。
|
||||
IsAdjustment bool
|
||||
// RestartRequested 标记本轮是否要求“放弃历史快照并重新排程”。
|
||||
//
|
||||
// 语义:
|
||||
// 1. true:强制清空 Previous* 并走全新构建;
|
||||
// 2. false:允许按同会话历史快照做增量微调。
|
||||
RestartRequested bool
|
||||
// AdjustmentScope 表示本轮改动力度分级(small/medium/large)。
|
||||
//
|
||||
// 分流语义:
|
||||
// 1. small:走快速微调节点,再进入周级优化;
|
||||
// 2. medium:跳过 daily,直接周级优化;
|
||||
// 3. large:优先走全量路径(多任务类时会经过 daily 并发)。
|
||||
AdjustmentScope string
|
||||
// AdjustmentReason 是模型给出的力度判定理由,用于日志排障与 review。
|
||||
AdjustmentReason string
|
||||
// AdjustmentConfidence 是模型给出的力度判定置信度(0-1)。
|
||||
AdjustmentConfidence float64
|
||||
// HasPreviousPreview 标记是否命中“同会话上一次排程预览快照”。
|
||||
//
|
||||
// 语义:
|
||||
// 1. true:可以尝试复用上次 HybridEntries 作为本轮优化起点;
|
||||
// 2. false:按全新排程路径构建粗排底板。
|
||||
HasPreviousPreview bool
|
||||
// PreviousTaskClassIDs 是上一次预览对应的任务类集合。
|
||||
//
|
||||
// 用途:
|
||||
// 1. 本轮未显式传 task_class_ids 时作为兜底;
|
||||
// 2. 仅会话内承接,不改动数据库。
|
||||
PreviousTaskClassIDs []int
|
||||
// PreviousHybridEntries 是上一次预览保存的混合日程条目。
|
||||
//
|
||||
// 用途:
|
||||
// 1. 连续对话微调时直接复用,避免重新粗排;
|
||||
// 2. 若为空则回退到粗排构建路径。
|
||||
PreviousHybridEntries []model.HybridScheduleEntry
|
||||
// PreviousAllocatedItems 是上一次预览保存的任务块分配结果。
|
||||
//
|
||||
// 用途:
|
||||
// 1. 保持 final_check 的数量核对口径稳定;
|
||||
// 2. return_preview 阶段可继续回填 embedded_time。
|
||||
PreviousAllocatedItems []model.TaskClassItem
|
||||
// PreviousCandidatePlans 是上一版预览保存的周视图结构化结果。
|
||||
//
|
||||
// 用途:
|
||||
// 1. 连续微调时可直接复用,避免重复转换;
|
||||
// 2. 兜底展示层(即使本轮未走全量粗排,仍可给前端稳定结构)。
|
||||
PreviousCandidatePlans []model.UserWeekSchedule
|
||||
|
||||
// ── 最终输出 ──
|
||||
|
||||
// FinalSummary 是 graph 最终给用户的回复文案。
|
||||
FinalSummary string
|
||||
// Completed 标记整个排程链路是否成功完成。
|
||||
Completed bool
|
||||
}
|
||||
|
||||
// NewSchedulePlanState 创建排程状态对象并初始化默认值。
|
||||
func NewSchedulePlanState(traceID string, userID int, conversationID string) *SchedulePlanState {
|
||||
now := schedulePlanNowToMinute()
|
||||
return &SchedulePlanState{
|
||||
TraceID: traceID,
|
||||
UserID: userID,
|
||||
ConversationID: conversationID,
|
||||
RequestNow: now,
|
||||
RequestNowText: now.In(schedulePlanLocation()).Format(schedulePlanDatetimeLayout),
|
||||
Strategy: "steady",
|
||||
TaskTags: make(map[int]string),
|
||||
TaskTagHintsByName: make(map[string]string),
|
||||
DailyRefineConcurrency: schedulePlanDefaultDailyRefineConcurrency,
|
||||
WeeklyRefineConcurrency: schedulePlanDefaultWeeklyRefineConcurrency,
|
||||
AdjustmentScope: schedulePlanAdjustmentScopeLarge,
|
||||
ReactMaxRound: 2,
|
||||
WeeklyAdjustBudget: schedulePlanDefaultWeeklyAdjustBudget,
|
||||
WeeklyTotalBudget: schedulePlanDefaultWeeklyTotalBudget,
|
||||
}
|
||||
}
|
||||
|
||||
// schedulePlanLocation 返回排程链路使用的业务时区。
|
||||
func schedulePlanLocation() *time.Location {
|
||||
loc, err := time.LoadLocation(schedulePlanTimezoneName)
|
||||
if err != nil {
|
||||
return time.Local
|
||||
}
|
||||
return loc
|
||||
}
|
||||
|
||||
// schedulePlanNowToMinute 返回当前时间并截断到分钟级。
|
||||
func schedulePlanNowToMinute() time.Time {
|
||||
return time.Now().In(schedulePlanLocation()).Truncate(time.Minute)
|
||||
}
|
||||
|
||||
// normalizeAdjustmentScope 归一化排程微调力度字段。
|
||||
//
|
||||
// 兜底策略:
|
||||
// 1. 只接受 small/medium/large;
|
||||
// 2. 任何未知值都回退为 large,保证不会误走“过轻”路径。
|
||||
func normalizeAdjustmentScope(raw string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case schedulePlanAdjustmentScopeSmall:
|
||||
return schedulePlanAdjustmentScopeSmall
|
||||
case schedulePlanAdjustmentScopeMedium:
|
||||
return schedulePlanAdjustmentScopeMedium
|
||||
default:
|
||||
return schedulePlanAdjustmentScopeLarge
|
||||
}
|
||||
}
|
||||
@@ -1,147 +0,0 @@
|
||||
package scheduleplan
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
)
|
||||
|
||||
// SchedulePlanToolDeps 描述“智能排程 graph”运行所需的外部业务依赖。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只负责声明“需要哪些能力”,不负责具体实现(实现由 service 层注入)。
|
||||
// 2. 只收口函数签名,不承载业务状态,避免跨请求共享可变数据。
|
||||
// 3. 当前统一采用 task_class_ids 语义,不再依赖单 task_class_id 主路径。
|
||||
type SchedulePlanToolDeps struct {
|
||||
// SmartPlanningMultiRaw 是可选依赖:
|
||||
// 1) 用于需要单独输出“粗排预览”时复用;
|
||||
// 2) 当前主链路已由 HybridScheduleWithPlanMulti 覆盖,可不注入。
|
||||
SmartPlanningMultiRaw func(ctx context.Context, userID int, taskClassIDs []int) ([]model.UserWeekSchedule, []model.TaskClassItem, error)
|
||||
|
||||
// HybridScheduleWithPlanMulti 把“既有日程 + 粗排结果”合并成统一的 HybridScheduleEntry 切片,
|
||||
// 供 daily/weekly ReAct 节点在内存中继续优化。
|
||||
HybridScheduleWithPlanMulti func(ctx context.Context, userID int, taskClassIDs []int) ([]model.HybridScheduleEntry, []model.TaskClassItem, error)
|
||||
|
||||
// ResolvePlanningWindow 根据 task_class_ids 解析“全局排程窗口”的相对周/天边界。
|
||||
//
|
||||
// 返回语义:
|
||||
// 1. startWeek/startDay:窗口起点(含);
|
||||
// 2. endWeek/endDay:窗口终点(含);
|
||||
// 3. error:解析失败(如任务类不存在、日期非法)。
|
||||
//
|
||||
// 用途:
|
||||
// 1. 给周级 Move 工具加硬边界,避免把任务移动到窗口外的天数;
|
||||
// 2. 解决“首尾不足一周”场景下的周内越界问题。
|
||||
ResolvePlanningWindow func(ctx context.Context, userID int, taskClassIDs []int) (startWeek, startDay, endWeek, endDay int, err error)
|
||||
}
|
||||
|
||||
// validate 校验依赖完整性。
|
||||
//
|
||||
// 失败处理:
|
||||
// 1. 任意依赖缺失都直接返回错误,避免 graph 运行到中途才 panic。
|
||||
// 2. 调用方(runSchedulePlanFlow)收到错误后会走回退链路,不影响普通聊天可用性。
|
||||
func (d SchedulePlanToolDeps) validate() error {
|
||||
if d.HybridScheduleWithPlanMulti == nil {
|
||||
return errors.New("schedule plan tool deps: HybridScheduleWithPlanMulti is nil")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtraInt 从 extra map 中安全提取整数值。
|
||||
//
|
||||
// 兼容策略:
|
||||
// 1) JSON 数字默认解析为 float64,做 int 转换;
|
||||
// 2) 兼容字符串形式(如 "42"),用 Atoi 解析;
|
||||
// 3) 其余类型返回 false,由调用方决定后续处理。
|
||||
func ExtraInt(extra map[string]any, key string) (int, bool) {
|
||||
v, ok := extra[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n), true
|
||||
case int:
|
||||
return n, true
|
||||
case string:
|
||||
i, err := strconv.Atoi(n)
|
||||
return i, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ExtraIntSlice 从 extra map 中安全提取整数切片。
|
||||
//
|
||||
// 兼容输入:
|
||||
// 1) []any(JSON 数组反序列化后的常见类型);
|
||||
// 2) []int;
|
||||
// 3) []float64;
|
||||
// 4) 逗号分隔字符串(例如 "1,2,3")。
|
||||
//
|
||||
// 返回语义:
|
||||
// 1) ok=true:至少成功解析出一个整数;
|
||||
// 2) ok=false:字段不存在或全部解析失败。
|
||||
func ExtraIntSlice(extra map[string]any, key string) ([]int, bool) {
|
||||
v, exists := extra[key]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
parseOne := func(raw any) (int, error) {
|
||||
switch n := raw.(type) {
|
||||
case int:
|
||||
return n, nil
|
||||
case float64:
|
||||
return int(n), nil
|
||||
case string:
|
||||
i, err := strconv.Atoi(n)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return i, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported type: %T", raw)
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]int, 0)
|
||||
switch arr := v.(type) {
|
||||
case []int:
|
||||
for _, item := range arr {
|
||||
out = append(out, item)
|
||||
}
|
||||
case []float64:
|
||||
for _, item := range arr {
|
||||
out = append(out, int(item))
|
||||
}
|
||||
case []any:
|
||||
for _, item := range arr {
|
||||
if parsed, err := parseOne(item); err == nil {
|
||||
out = append(out, parsed)
|
||||
}
|
||||
}
|
||||
case string:
|
||||
parts := strings.Split(arr, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if parsed, err := strconv.Atoi(part); err == nil {
|
||||
out = append(out, parsed)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if len(out) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package schedulerefine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
)
|
||||
|
||||
func TestRefineToolSpreadEvenRespectsCanonicalRouteFilters(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 1, Name: "任务1", Type: "task", Status: "suggested", Week: 16, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2, ContextTag: "A"},
|
||||
// 1. 这里放一个更早周次的 existing 条目,用来把可查询窗口拉到 W11;
|
||||
// 2. 若复合工具内部丢了 week_filter/day_of_week,就会优先落到更早的 W11D1,而不是目标 W12D3。
|
||||
{TaskItemID: 99, Name: "课程", Type: "course", Status: "existing", Week: 11, DayOfWeek: 5, SectionFrom: 11, SectionTo: 12, BlockForSuggested: true},
|
||||
}
|
||||
params := map[string]any{
|
||||
"task_item_ids": []int{1},
|
||||
"week_filter": []int{12},
|
||||
"day_of_week": []int{3},
|
||||
"allow_embed": false,
|
||||
}
|
||||
|
||||
nextEntries, result := refineToolSpreadEven(entries, params, planningWindow{Enabled: false}, refineToolPolicy{
|
||||
OriginOrderMap: map[int]int{1: 1},
|
||||
})
|
||||
if !result.Success {
|
||||
t.Fatalf("SpreadEven 执行失败: %s", result.Result)
|
||||
}
|
||||
|
||||
idx := findSuggestedByID(nextEntries, 1)
|
||||
if idx < 0 {
|
||||
t.Fatalf("未找到 task_item_id=1")
|
||||
}
|
||||
got := nextEntries[idx]
|
||||
if got.Week != 12 || got.DayOfWeek != 3 {
|
||||
t.Fatalf("期望复合工具严格遵守 week_filter/day_of_week,实际落点=W%dD%d", got.Week, got.DayOfWeek)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCompositeRouteNodeAllowsHandoffWithoutDeterministicObjective(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 11, Name: "任务11", Type: "task", Status: "suggested", Week: 16, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2, ContextTag: "数学"},
|
||||
{TaskItemID: 12, Name: "任务12", Type: "task", Status: "suggested", Week: 16, DayOfWeek: 1, SectionFrom: 3, SectionTo: 4, ContextTag: "算法"},
|
||||
{TaskItemID: 13, Name: "任务13", Type: "task", Status: "suggested", Week: 16, DayOfWeek: 1, SectionFrom: 5, SectionTo: 6, ContextTag: "数学"},
|
||||
}
|
||||
st := &ScheduleRefineState{
|
||||
UserMessage: "把这些任务按最少上下文切换整理一下",
|
||||
HybridEntries: cloneHybridEntries(entries),
|
||||
InitialHybridEntries: cloneHybridEntries(entries),
|
||||
WorksetTaskIDs: []int{11, 12, 13},
|
||||
RequiredCompositeTool: "MinContextSwitch",
|
||||
CompositeRetryMax: 0,
|
||||
ExecuteMax: 4,
|
||||
OriginOrderMap: map[int]int{11: 1, 12: 2, 13: 3},
|
||||
CompositeToolCalled: map[string]bool{
|
||||
"SpreadEven": false,
|
||||
"MinContextSwitch": false,
|
||||
},
|
||||
CompositeToolSuccess: map[string]bool{
|
||||
"SpreadEven": false,
|
||||
"MinContextSwitch": false,
|
||||
},
|
||||
}
|
||||
|
||||
stageLogs := make([]string, 0, 8)
|
||||
nextState, err := runCompositeRouteNode(context.Background(), st, func(stage, detail string) {
|
||||
stageLogs = append(stageLogs, stage+"|"+detail)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runCompositeRouteNode 返回错误: %v", err)
|
||||
}
|
||||
if nextState == nil {
|
||||
t.Fatalf("runCompositeRouteNode 返回 nil state")
|
||||
}
|
||||
if !nextState.CompositeRouteSucceeded {
|
||||
t.Fatalf("期望复合分支在缺少 deterministic objective 时直接出站,实际 CompositeRouteSucceeded=false, stages=%v, action_logs=%v", stageLogs, nextState.ActionLogs)
|
||||
}
|
||||
if nextState.DisableCompositeTools {
|
||||
t.Fatalf("期望复合分支直接进入终审,不应降级为禁复合 ReAct")
|
||||
}
|
||||
if !nextState.CompositeToolSuccess["MinContextSwitch"] {
|
||||
t.Fatalf("期望 MinContextSwitch 成功状态被记录")
|
||||
}
|
||||
}
|
||||
@@ -1,179 +0,0 @@
|
||||
package schedulerefine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
)
|
||||
|
||||
func TestRefineToolSpreadEvenSuccess(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 1, Name: "任务1", Type: "task", Status: "suggested", Week: 16, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2, ContextTag: "A"},
|
||||
{TaskItemID: 2, Name: "任务2", Type: "task", Status: "suggested", Week: 16, DayOfWeek: 1, SectionFrom: 3, SectionTo: 4, ContextTag: "B"},
|
||||
{TaskItemID: 99, Name: "课程", Type: "course", Status: "existing", Week: 12, DayOfWeek: 1, SectionFrom: 5, SectionTo: 6, BlockForSuggested: true},
|
||||
}
|
||||
params := map[string]any{
|
||||
"task_item_ids": []any{1.0, 2.0},
|
||||
"week": 12,
|
||||
"day_of_week": []any{1.0, 2.0, 3.0},
|
||||
"allow_embed": false,
|
||||
}
|
||||
policy := refineToolPolicy{OriginOrderMap: map[int]int{1: 1, 2: 2}}
|
||||
|
||||
nextEntries, result := refineToolSpreadEven(entries, params, planningWindow{Enabled: false}, policy)
|
||||
if !result.Success {
|
||||
t.Fatalf("SpreadEven 执行失败: %s", result.Result)
|
||||
}
|
||||
if result.Tool != "SpreadEven" {
|
||||
t.Fatalf("工具名错误,期望 SpreadEven,实际=%s", result.Tool)
|
||||
}
|
||||
|
||||
idx1 := findSuggestedByID(nextEntries, 1)
|
||||
idx2 := findSuggestedByID(nextEntries, 2)
|
||||
if idx1 < 0 || idx2 < 0 {
|
||||
t.Fatalf("移动后未找到目标任务: idx1=%d idx2=%d", idx1, idx2)
|
||||
}
|
||||
task1 := nextEntries[idx1]
|
||||
task2 := nextEntries[idx2]
|
||||
if task1.Week != 12 || task2.Week != 12 {
|
||||
t.Fatalf("期望任务被移动到 W12,实际 task1=%d task2=%d", task1.Week, task2.Week)
|
||||
}
|
||||
if task1.DayOfWeek < 1 || task1.DayOfWeek > 3 || task2.DayOfWeek < 1 || task2.DayOfWeek > 3 {
|
||||
t.Fatalf("期望任务被移动到周一到周三,实际 task1=%d task2=%d", task1.DayOfWeek, task2.DayOfWeek)
|
||||
}
|
||||
if task1.DayOfWeek == task2.DayOfWeek && sectionsOverlap(task1.SectionFrom, task1.SectionTo, task2.SectionFrom, task2.SectionTo) {
|
||||
t.Fatalf("复合工具不应产出重叠坑位: task1=%+v task2=%+v", task1, task2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefineToolMinContextSwitchGroupsContext(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 11, Name: "任务11", Type: "task", Status: "suggested", Week: 16, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2, ContextTag: "数学"},
|
||||
{TaskItemID: 12, Name: "任务12", Type: "task", Status: "suggested", Week: 16, DayOfWeek: 1, SectionFrom: 3, SectionTo: 4, ContextTag: "算法"},
|
||||
{TaskItemID: 13, Name: "任务13", Type: "task", Status: "suggested", Week: 16, DayOfWeek: 1, SectionFrom: 5, SectionTo: 6, ContextTag: "数学"},
|
||||
{TaskItemID: 99, Name: "课程", Type: "course", Status: "existing", Week: 12, DayOfWeek: 1, SectionFrom: 11, SectionTo: 12, BlockForSuggested: true},
|
||||
}
|
||||
params := map[string]any{
|
||||
"task_item_ids": []any{11.0, 12.0, 13.0},
|
||||
"week": 12,
|
||||
"day_of_week": []any{1.0},
|
||||
}
|
||||
policy := refineToolPolicy{OriginOrderMap: map[int]int{11: 1, 12: 2, 13: 3}}
|
||||
|
||||
nextEntries, result := refineToolMinContextSwitch(entries, params, planningWindow{Enabled: false}, policy)
|
||||
if !result.Success {
|
||||
t.Fatalf("MinContextSwitch 执行失败: %s", result.Result)
|
||||
}
|
||||
if result.Tool != "MinContextSwitch" {
|
||||
t.Fatalf("工具名错误,期望 MinContextSwitch,实际=%s", result.Tool)
|
||||
}
|
||||
|
||||
selected := make([]model.HybridScheduleEntry, 0, 3)
|
||||
for _, id := range []int{11, 12, 13} {
|
||||
idx := findSuggestedByID(nextEntries, id)
|
||||
if idx < 0 {
|
||||
t.Fatalf("未找到任务 id=%d", id)
|
||||
}
|
||||
selected = append(selected, nextEntries[idx])
|
||||
}
|
||||
sort.SliceStable(selected, func(i, j int) bool {
|
||||
if selected[i].Week != selected[j].Week {
|
||||
return selected[i].Week < selected[j].Week
|
||||
}
|
||||
if selected[i].DayOfWeek != selected[j].DayOfWeek {
|
||||
return selected[i].DayOfWeek < selected[j].DayOfWeek
|
||||
}
|
||||
return selected[i].SectionFrom < selected[j].SectionFrom
|
||||
})
|
||||
|
||||
switches := 0
|
||||
for i := 1; i < len(selected); i++ {
|
||||
if selected[i].ContextTag != selected[i-1].ContextTag {
|
||||
switches++
|
||||
}
|
||||
}
|
||||
if switches > 1 {
|
||||
t.Fatalf("期望最少上下文切换(<=1),实际 switches=%d, tasks=%+v", switches, selected)
|
||||
}
|
||||
if selected[0].TaskItemID != 11 || selected[1].TaskItemID != 13 || selected[2].TaskItemID != 12 {
|
||||
t.Fatalf("期望在原坑位集合内重排为 11,13,12,实际=%+v", selected)
|
||||
}
|
||||
for _, task := range selected {
|
||||
if task.Week != 16 || task.DayOfWeek != 1 {
|
||||
t.Fatalf("MinContextSwitch 不应跳出原坑位集合,实际 task=%+v", task)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefineToolMinContextSwitchKeepsCurrentSlotSet(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 21, Name: "随机事件与概率基础概念复习", Type: "task", Status: "suggested", Week: 14, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2, ContextTag: "General"},
|
||||
{TaskItemID: 22, Name: "数制、码制与逻辑代数基础", Type: "task", Status: "suggested", Week: 14, DayOfWeek: 1, SectionFrom: 11, SectionTo: 12, ContextTag: "General"},
|
||||
{TaskItemID: 23, Name: "第二章 条件概率与全概率公式", Type: "task", Status: "suggested", Week: 14, DayOfWeek: 3, SectionFrom: 3, SectionTo: 4, ContextTag: "General"},
|
||||
}
|
||||
params := map[string]any{
|
||||
"task_item_ids": []any{21.0, 22.0, 23.0},
|
||||
"week": 14,
|
||||
"limit": 48,
|
||||
"allow_embed": true,
|
||||
}
|
||||
policy := refineToolPolicy{OriginOrderMap: map[int]int{21: 1, 22: 2, 23: 3}}
|
||||
|
||||
nextEntries, result := refineToolMinContextSwitch(entries, params, planningWindow{Enabled: false}, policy)
|
||||
if !result.Success {
|
||||
t.Fatalf("MinContextSwitch 执行失败: %s", result.Result)
|
||||
}
|
||||
|
||||
selected := make([]model.HybridScheduleEntry, 0, 3)
|
||||
for _, id := range []int{21, 22, 23} {
|
||||
idx := findSuggestedByID(nextEntries, id)
|
||||
if idx < 0 {
|
||||
t.Fatalf("未找到任务 id=%d", id)
|
||||
}
|
||||
selected = append(selected, nextEntries[idx])
|
||||
}
|
||||
sort.SliceStable(selected, func(i, j int) bool {
|
||||
if selected[i].Week != selected[j].Week {
|
||||
return selected[i].Week < selected[j].Week
|
||||
}
|
||||
if selected[i].DayOfWeek != selected[j].DayOfWeek {
|
||||
return selected[i].DayOfWeek < selected[j].DayOfWeek
|
||||
}
|
||||
return selected[i].SectionFrom < selected[j].SectionFrom
|
||||
})
|
||||
|
||||
if selected[0].TaskItemID != 21 || selected[1].TaskItemID != 23 || selected[2].TaskItemID != 22 {
|
||||
t.Fatalf("期望按原坑位集合重排为概率, 概率, 数电,实际=%+v", selected)
|
||||
}
|
||||
expectedSlots := map[int]string{
|
||||
21: "14-1-1-2",
|
||||
23: "14-1-11-12",
|
||||
22: "14-3-3-4",
|
||||
}
|
||||
for _, task := range selected {
|
||||
got := fmt.Sprintf("%d-%d-%d-%d", task.Week, task.DayOfWeek, task.SectionFrom, task.SectionTo)
|
||||
if got != expectedSlots[task.TaskItemID] {
|
||||
t.Fatalf("任务 id=%d 应仅在原坑位集合内换位,期望=%s 实际=%s", task.TaskItemID, expectedSlots[task.TaskItemID], got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListTaskIDsFromToolCallComposite(t *testing.T) {
|
||||
call := reactToolCall{
|
||||
Tool: "SpreadEven",
|
||||
Params: map[string]any{
|
||||
"task_item_ids": []any{1.0, 2.0, 2.0},
|
||||
"task_item_id": 3,
|
||||
},
|
||||
}
|
||||
ids := listTaskIDsFromToolCall(call)
|
||||
if len(ids) != 3 {
|
||||
t.Fatalf("期望提取 3 个去重 ID,实际=%v", ids)
|
||||
}
|
||||
sort.Ints(ids)
|
||||
if ids[0] != 1 || ids[1] != 2 || ids[2] != 3 {
|
||||
t.Fatalf("提取结果错误,实际=%v", ids)
|
||||
}
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
package schedulerefine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
const (
|
||||
graphNodeContract = "schedule_refine_contract"
|
||||
graphNodePlan = "schedule_refine_plan"
|
||||
graphNodeSlice = "schedule_refine_slice"
|
||||
graphNodeRoute = "schedule_refine_route"
|
||||
graphNodeReact = "schedule_refine_react"
|
||||
graphNodeHardCheck = "schedule_refine_hard_check"
|
||||
graphNodeSummary = "schedule_refine_summary"
|
||||
)
|
||||
|
||||
// ScheduleRefineGraphRunInput 是“连续微调图”运行参数。
|
||||
//
|
||||
// 字段语义:
|
||||
// 1. Model:本轮图运行使用的聊天模型。
|
||||
// 2. State:预先注入的微调状态(通常来自上一版预览快照)。
|
||||
// 3. EmitStage:SSE 阶段回调,允许服务层把阶段进度透传给前端。
|
||||
type ScheduleRefineGraphRunInput struct {
|
||||
Model *ark.ChatModel
|
||||
State *ScheduleRefineState
|
||||
EmitStage func(stage, detail string)
|
||||
}
|
||||
|
||||
// RunScheduleRefineGraph 执行“连续微调”独立图链路。
|
||||
//
|
||||
// 链路顺序:
|
||||
// START -> contract -> plan -> slice -> route -> react -> hard_check -> summary -> END
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 当前链路采用线性图,确保可读性优先;
|
||||
// 2. “终审失败后单次修复”在 hard_check 节点内部闭环处理,避免图连线分叉过多;
|
||||
// 3. 若后续需要引入多分支策略(例如大改动转重排),可在 contract 后追加 branch 节点。
|
||||
func RunScheduleRefineGraph(ctx context.Context, input ScheduleRefineGraphRunInput) (*ScheduleRefineState, error) {
|
||||
if input.Model == nil {
|
||||
return nil, fmt.Errorf("schedule refine graph: model is nil")
|
||||
}
|
||||
if input.State == nil {
|
||||
return nil, fmt.Errorf("schedule refine graph: state is nil")
|
||||
}
|
||||
|
||||
emitStage := func(stage, detail string) {
|
||||
if input.EmitStage != nil {
|
||||
input.EmitStage(stage, detail)
|
||||
}
|
||||
}
|
||||
runner := newScheduleRefineRunner(input.Model, emitStage)
|
||||
|
||||
graph := compose.NewGraph[*ScheduleRefineState, *ScheduleRefineState]()
|
||||
if err := graph.AddLambdaNode(graphNodeContract, compose.InvokableLambda(runner.contractNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(graphNodePlan, compose.InvokableLambda(runner.planNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(graphNodeSlice, compose.InvokableLambda(runner.sliceNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(graphNodeRoute, compose.InvokableLambda(runner.routeNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(graphNodeReact, compose.InvokableLambda(runner.reactNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(graphNodeHardCheck, compose.InvokableLambda(runner.hardCheckNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddLambdaNode(graphNodeSummary, compose.InvokableLambda(runner.summaryNode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := graph.AddEdge(compose.START, graphNodeContract); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(graphNodeContract, graphNodePlan); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(graphNodePlan, graphNodeSlice); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(graphNodeSlice, graphNodeRoute); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(graphNodeRoute, graphNodeReact); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(graphNodeReact, graphNodeHardCheck); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(graphNodeHardCheck, graphNodeSummary); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := graph.AddEdge(graphNodeSummary, compose.END); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
runnable, err := graph.Compile(ctx,
|
||||
compose.WithGraphName("ScheduleRefineGraph"),
|
||||
compose.WithMaxRunSteps(20),
|
||||
compose.WithNodeTriggerMode(compose.AnyPredecessor),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return runnable.Invoke(ctx, input.State)
|
||||
}
|
||||
@@ -1,637 +0,0 @@
|
||||
package schedulerefine
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/model"
|
||||
)
|
||||
|
||||
func TestQueryTargetTasksWeekFilterAndTaskID(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 1, Name: "task-w12", Week: 12, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2, Status: "suggested", Type: "task"},
|
||||
{TaskItemID: 2, Name: "task-w13", Week: 13, DayOfWeek: 1, SectionFrom: 3, SectionTo: 4, Status: "suggested", Type: "task"},
|
||||
{TaskItemID: 3, Name: "task-w14", Week: 14, DayOfWeek: 1, SectionFrom: 5, SectionTo: 6, Status: "suggested", Type: "task"},
|
||||
}
|
||||
policy := refineToolPolicy{OriginOrderMap: map[int]int{1: 1, 2: 2, 3: 3}}
|
||||
|
||||
paramsWeek := map[string]any{
|
||||
"week_filter": []any{13.0, 14.0},
|
||||
}
|
||||
_, resultWeek := refineToolQueryTargetTasks(entries, paramsWeek, policy)
|
||||
if !resultWeek.Success {
|
||||
t.Fatalf("week_filter 查询失败: %s", resultWeek.Result)
|
||||
}
|
||||
var payloadWeek struct {
|
||||
Count int `json:"count"`
|
||||
Items []struct {
|
||||
TaskItemID int `json:"task_item_id"`
|
||||
Week int `json:"week"`
|
||||
} `json:"items"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(resultWeek.Result), &payloadWeek); err != nil {
|
||||
t.Fatalf("解析 week_filter 结果失败: %v", err)
|
||||
}
|
||||
if payloadWeek.Count != 2 {
|
||||
t.Fatalf("week_filter 期望返回 2 条,实际=%d", payloadWeek.Count)
|
||||
}
|
||||
for _, item := range payloadWeek.Items {
|
||||
if item.Week != 13 && item.Week != 14 {
|
||||
t.Fatalf("week_filter 过滤失败,出现非法周次=%d", item.Week)
|
||||
}
|
||||
}
|
||||
|
||||
paramsTaskID := map[string]any{
|
||||
"week_filter": []any{13.0, 14.0},
|
||||
"task_item_id": 2,
|
||||
}
|
||||
_, resultTaskID := refineToolQueryTargetTasks(entries, paramsTaskID, policy)
|
||||
if !resultTaskID.Success {
|
||||
t.Fatalf("task_item_id 查询失败: %s", resultTaskID.Result)
|
||||
}
|
||||
var payloadTaskID struct {
|
||||
Count int `json:"count"`
|
||||
Items []struct {
|
||||
TaskItemID int `json:"task_item_id"`
|
||||
Week int `json:"week"`
|
||||
} `json:"items"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(resultTaskID.Result), &payloadTaskID); err != nil {
|
||||
t.Fatalf("解析 task_item_id 结果失败: %v", err)
|
||||
}
|
||||
if payloadTaskID.Count != 1 {
|
||||
t.Fatalf("task_item_id 期望返回 1 条,实际=%d", payloadTaskID.Count)
|
||||
}
|
||||
if payloadTaskID.Items[0].TaskItemID != 2 || payloadTaskID.Items[0].Week != 13 {
|
||||
t.Fatalf("task_item_id 过滤错误: %+v", payloadTaskID.Items[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryAvailableSlotsExactSectionAlias(t *testing.T) {
|
||||
params := map[string]any{
|
||||
"week": 13,
|
||||
"section_duration": 2,
|
||||
"section_from": 1,
|
||||
"section_to": 2,
|
||||
"limit": 5,
|
||||
}
|
||||
_, result := refineToolQueryAvailableSlots(nil, params, planningWindow{Enabled: false})
|
||||
if !result.Success {
|
||||
t.Fatalf("QueryAvailableSlots 失败: %s", result.Result)
|
||||
}
|
||||
var payload struct {
|
||||
Count int `json:"count"`
|
||||
Slots []struct {
|
||||
Week int `json:"week"`
|
||||
SectionFrom int `json:"section_from"`
|
||||
SectionTo int `json:"section_to"`
|
||||
} `json:"slots"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(result.Result), &payload); err != nil {
|
||||
t.Fatalf("解析 QueryAvailableSlots 结果失败: %v", err)
|
||||
}
|
||||
if payload.Count == 0 {
|
||||
t.Fatalf("期望至少返回一个可用时段,实际=0")
|
||||
}
|
||||
for _, slot := range payload.Slots {
|
||||
if slot.Week != 13 {
|
||||
t.Fatalf("返回了错误周次: %+v", slot)
|
||||
}
|
||||
if slot.SectionFrom != 1 || slot.SectionTo != 2 {
|
||||
t.Fatalf("精确节次过滤失败: %+v", slot)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryAvailableSlotsWeekFilterDayFilterAlias(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 1, Name: "task-w12", Week: 12, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2, Status: "suggested", Type: "task"},
|
||||
{TaskItemID: 2, Name: "task-w17", Week: 17, DayOfWeek: 4, SectionFrom: 3, SectionTo: 4, Status: "suggested", Type: "task"},
|
||||
}
|
||||
params := map[string]any{
|
||||
"week_filter": []any{17.0},
|
||||
"day_filter": []any{1.0, 2.0, 3.0},
|
||||
"limit": 20,
|
||||
}
|
||||
|
||||
_, result := refineToolQueryAvailableSlots(entries, params, planningWindow{Enabled: false})
|
||||
if !result.Success {
|
||||
t.Fatalf("QueryAvailableSlots 别名查询失败: %s", result.Result)
|
||||
}
|
||||
var payload struct {
|
||||
Count int `json:"count"`
|
||||
Slots []struct {
|
||||
Week int `json:"week"`
|
||||
DayOfWeek int `json:"day_of_week"`
|
||||
} `json:"slots"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(result.Result), &payload); err != nil {
|
||||
t.Fatalf("解析 week/day 过滤结果失败: %v", err)
|
||||
}
|
||||
if payload.Count == 0 {
|
||||
t.Fatalf("week_filter/day_filter 查询应返回 W17 周一到周三空位,实际为空")
|
||||
}
|
||||
for _, slot := range payload.Slots {
|
||||
if slot.Week != 17 {
|
||||
t.Fatalf("week_filter 失效,出现 week=%d", slot.Week)
|
||||
}
|
||||
if slot.DayOfWeek < 1 || slot.DayOfWeek > 3 {
|
||||
t.Fatalf("day_filter 失效,出现 day_of_week=%d", slot.DayOfWeek)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectWorksetTaskIDsSourceWeekOnly(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 1, Week: 12, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2, Status: "suggested", Type: "task"},
|
||||
{TaskItemID: 2, Week: 14, DayOfWeek: 1, SectionFrom: 3, SectionTo: 4, Status: "suggested", Type: "task"},
|
||||
{TaskItemID: 3, Week: 13, DayOfWeek: 1, SectionFrom: 5, SectionTo: 6, Status: "suggested", Type: "task"},
|
||||
{TaskItemID: 4, Week: 14, DayOfWeek: 2, SectionFrom: 7, SectionTo: 8, Status: "suggested", Type: "task"},
|
||||
}
|
||||
slice := RefineSlicePlan{WeekFilter: []int{14, 13}}
|
||||
originOrder := map[int]int{1: 1, 2: 2, 3: 3, 4: 4}
|
||||
|
||||
got := collectWorksetTaskIDs(entries, slice, originOrder)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("来源周收敛失败,期望 2 条,实际=%d, got=%v", len(got), got)
|
||||
}
|
||||
if got[0] != 2 || got[1] != 4 {
|
||||
t.Fatalf("来源周结果错误,期望 [2 4],实际=%v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSlicePlanDirectionalSourceTarget(t *testing.T) {
|
||||
st := &ScheduleRefineState{
|
||||
UserMessage: "帮我把第17周周四到周五的任务都收敛到17周的周一到周三,优先放空位,空位不够了再嵌入",
|
||||
}
|
||||
plan := buildSlicePlan(st)
|
||||
if len(plan.WeekFilter) == 0 || plan.WeekFilter[0] != 17 {
|
||||
t.Fatalf("week_filter 解析错误: %+v", plan.WeekFilter)
|
||||
}
|
||||
expectSource := []int{4, 5}
|
||||
expectTarget := []int{1, 2, 3}
|
||||
if len(plan.SourceDays) != len(expectSource) {
|
||||
t.Fatalf("source_days 长度错误: got=%v", plan.SourceDays)
|
||||
}
|
||||
for i := range expectSource {
|
||||
if plan.SourceDays[i] != expectSource[i] {
|
||||
t.Fatalf("source_days 错误: got=%v", plan.SourceDays)
|
||||
}
|
||||
}
|
||||
if len(plan.TargetDays) != len(expectTarget) {
|
||||
t.Fatalf("target_days 长度错误: got=%v", plan.TargetDays)
|
||||
}
|
||||
for i := range expectTarget {
|
||||
if plan.TargetDays[i] != expectTarget[i] {
|
||||
t.Fatalf("target_days 错误: got=%v", plan.TargetDays)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyTaskCoordinateMismatch(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 28, Name: "task-w17-d4", Week: 17, DayOfWeek: 4, SectionFrom: 5, SectionTo: 6, Status: "suggested", Type: "task"},
|
||||
}
|
||||
policy := refineToolPolicy{OriginOrderMap: map[int]int{28: 1}}
|
||||
params := map[string]any{
|
||||
"task_item_id": 28,
|
||||
"week": 17,
|
||||
"day_of_week": 1,
|
||||
"section_from": 1,
|
||||
"section_to": 2,
|
||||
}
|
||||
|
||||
_, result := refineToolVerify(entries, params, policy)
|
||||
if result.Success {
|
||||
t.Fatalf("期望 Verify 在任务坐标不匹配时失败,实际 success=true, result=%s", result.Result)
|
||||
}
|
||||
if result.ErrorCode != "VERIFY_FAILED" {
|
||||
t.Fatalf("期望错误码 VERIFY_FAILED,实际=%s", result.ErrorCode)
|
||||
}
|
||||
if !strings.Contains(result.Result, "不匹配") {
|
||||
t.Fatalf("期望结果包含“不匹配”提示,实际=%s", result.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveRejectsSuggestedCourseEntry(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{
|
||||
TaskItemID: 39,
|
||||
Name: "面向对象程序设计-C++",
|
||||
Type: "course",
|
||||
Status: "suggested",
|
||||
Week: 17,
|
||||
DayOfWeek: 4,
|
||||
SectionFrom: 7,
|
||||
SectionTo: 8,
|
||||
},
|
||||
}
|
||||
params := map[string]any{
|
||||
"task_item_id": 39,
|
||||
"to_week": 17,
|
||||
"to_day": 1,
|
||||
"to_section_from": 7,
|
||||
"to_section_to": 8,
|
||||
}
|
||||
_, result := refineToolMove(entries, params, planningWindow{Enabled: false}, refineToolPolicy{OriginOrderMap: map[int]int{39: 1}})
|
||||
if result.Success {
|
||||
t.Fatalf("期望 course 类型的 suggested 条目不可移动,实际 success=true, result=%s", result.Result)
|
||||
}
|
||||
if !strings.Contains(result.Result, "可移动 suggested 任务") {
|
||||
t.Fatalf("期望返回不可移动提示,实际=%s", result.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryAvailableSlotsSlotTypePureDisablesEmbed(t *testing.T) {
|
||||
entries := []model.HybridScheduleEntry{
|
||||
{
|
||||
Name: "可嵌入课程",
|
||||
Type: "course",
|
||||
Status: "existing",
|
||||
Week: 17,
|
||||
DayOfWeek: 1,
|
||||
SectionFrom: 1,
|
||||
SectionTo: 2,
|
||||
BlockForSuggested: false,
|
||||
},
|
||||
}
|
||||
|
||||
pureParams := map[string]any{
|
||||
"week": 17,
|
||||
"day_of_week": 1,
|
||||
"section_from": 1,
|
||||
"section_to": 2,
|
||||
"slot_type": "pure",
|
||||
}
|
||||
_, pureResult := refineToolQueryAvailableSlots(entries, pureParams, planningWindow{Enabled: false})
|
||||
if !pureResult.Success {
|
||||
t.Fatalf("pure 查询失败: %s", pureResult.Result)
|
||||
}
|
||||
var purePayload struct {
|
||||
Count int `json:"count"`
|
||||
EmbeddedCount int `json:"embedded_count"`
|
||||
FallbackUsed bool `json:"fallback_used"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(pureResult.Result), &purePayload); err != nil {
|
||||
t.Fatalf("解析 pure 查询结果失败: %v", err)
|
||||
}
|
||||
if purePayload.Count != 0 || purePayload.EmbeddedCount != 0 || purePayload.FallbackUsed {
|
||||
t.Fatalf("slot_type=pure 应禁用嵌入兜底,实际 payload=%+v", purePayload)
|
||||
}
|
||||
|
||||
defaultParams := map[string]any{
|
||||
"week": 17,
|
||||
"day_of_week": 1,
|
||||
"section_from": 1,
|
||||
"section_to": 2,
|
||||
}
|
||||
_, defaultResult := refineToolQueryAvailableSlots(entries, defaultParams, planningWindow{Enabled: false})
|
||||
if !defaultResult.Success {
|
||||
t.Fatalf("default 查询失败: %s", defaultResult.Result)
|
||||
}
|
||||
var defaultPayload struct {
|
||||
Count int `json:"count"`
|
||||
EmbeddedCount int `json:"embedded_count"`
|
||||
FallbackUsed bool `json:"fallback_used"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(defaultResult.Result), &defaultPayload); err != nil {
|
||||
t.Fatalf("解析 default 查询结果失败: %v", err)
|
||||
}
|
||||
if defaultPayload.Count == 0 || defaultPayload.EmbeddedCount == 0 || !defaultPayload.FallbackUsed {
|
||||
t.Fatalf("默认查询应允许嵌入候选,实际 payload=%+v", defaultPayload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileObjectiveAndEvaluateMoveAllPass(t *testing.T) {
|
||||
initial := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 39, Name: "任务39", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 4, SectionFrom: 7, SectionTo: 8},
|
||||
{TaskItemID: 51, Name: "任务51", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 5, SectionFrom: 9, SectionTo: 10},
|
||||
}
|
||||
final := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 39, Name: "任务39", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 1, SectionFrom: 7, SectionTo: 8},
|
||||
{TaskItemID: 51, Name: "任务51", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 2, SectionFrom: 9, SectionTo: 10},
|
||||
}
|
||||
st := &ScheduleRefineState{
|
||||
UserMessage: "把17周周四到周五任务收敛到周一到周三",
|
||||
InitialHybridEntries: initial,
|
||||
HybridEntries: final,
|
||||
SlicePlan: RefineSlicePlan{
|
||||
WeekFilter: []int{17},
|
||||
SourceDays: []int{4, 5},
|
||||
TargetDays: []int{1, 2, 3},
|
||||
},
|
||||
}
|
||||
st.Objective = compileRefineObjective(st, st.SlicePlan)
|
||||
if st.Objective.Mode != "move_all" {
|
||||
t.Fatalf("期望目标模式 move_all,实际=%s", st.Objective.Mode)
|
||||
}
|
||||
|
||||
pass, _, unmet, applied := evaluateObjectiveDeterministic(st)
|
||||
if !applied {
|
||||
t.Fatalf("期望命中确定性终审")
|
||||
}
|
||||
if !pass {
|
||||
t.Fatalf("期望确定性终审通过,unmet=%v", unmet)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileObjectiveAndEvaluateMoveAllFail(t *testing.T) {
|
||||
initial := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 26, Name: "任务26", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 5, SectionFrom: 7, SectionTo: 8},
|
||||
}
|
||||
final := []model.HybridScheduleEntry{
|
||||
{TaskItemID: 26, Name: "任务26", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 5, SectionFrom: 7, SectionTo: 8},
|
||||
}
|
||||
st := &ScheduleRefineState{
|
||||
UserMessage: "把17周周四到周五任务收敛到周一到周三",
|
||||
InitialHybridEntries: initial,
|
||||
HybridEntries: final,
|
||||
SlicePlan: RefineSlicePlan{
|
||||
WeekFilter: []int{17},
|
||||
SourceDays: []int{4, 5},
|
||||
TargetDays: []int{1, 2, 3},
|
||||
},
|
||||
}
|
||||
st.Objective = compileRefineObjective(st, st.SlicePlan)
|
||||
|
||||
pass, _, unmet, applied := evaluateObjectiveDeterministic(st)
|
||||
if !applied {
|
||||
t.Fatalf("期望命中确定性终审")
|
||||
}
|
||||
if pass {
|
||||
t.Fatalf("期望确定性终审失败")
|
||||
}
|
||||
if len(unmet) == 0 {
|
||||
t.Fatalf("期望返回未满足项")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileObjectiveMoveRatioFromContractAndEvaluatePass(t *testing.T) {
|
||||
initial, final := buildHalfTransferEntries(10, 5)
|
||||
st := &ScheduleRefineState{
|
||||
UserMessage: "17周任务太多,帮我调整到16周",
|
||||
InitialHybridEntries: initial,
|
||||
HybridEntries: final,
|
||||
SlicePlan: RefineSlicePlan{
|
||||
WeekFilter: []int{17, 16},
|
||||
},
|
||||
Contract: RefineContract{
|
||||
Intent: "将第17周任务匀一半到第16周",
|
||||
HardRequirements: []string{"原第17周任务数调整为原来的一半", "调整到第16周的任务数为原第17周任务数的一半"},
|
||||
},
|
||||
}
|
||||
st.Objective = compileRefineObjective(st, st.SlicePlan)
|
||||
if st.Objective.Mode != "move_ratio" {
|
||||
t.Fatalf("期望目标模式 move_ratio,实际=%s", st.Objective.Mode)
|
||||
}
|
||||
if st.Objective.RequiredMoveMin != 5 || st.Objective.RequiredMoveMax != 5 {
|
||||
t.Fatalf("半数迁移阈值错误: min=%d max=%d", st.Objective.RequiredMoveMin, st.Objective.RequiredMoveMax)
|
||||
}
|
||||
|
||||
pass, _, unmet, applied := evaluateObjectiveDeterministic(st)
|
||||
if !applied {
|
||||
t.Fatalf("期望命中确定性终审")
|
||||
}
|
||||
if !pass {
|
||||
t.Fatalf("期望半数迁移通过,unmet=%v", unmet)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileObjectiveMoveRatioFromContractAndEvaluateFail(t *testing.T) {
|
||||
initial, final := buildHalfTransferEntries(10, 4)
|
||||
st := &ScheduleRefineState{
|
||||
UserMessage: "17周任务太多,帮我调整到16周",
|
||||
InitialHybridEntries: initial,
|
||||
HybridEntries: final,
|
||||
SlicePlan: RefineSlicePlan{
|
||||
WeekFilter: []int{17, 16},
|
||||
},
|
||||
Contract: RefineContract{
|
||||
Intent: "将第17周任务匀一半到第16周",
|
||||
HardRequirements: []string{"原第17周任务数调整为原来的一半", "调整到第16周的任务数为原第17周任务数的一半"},
|
||||
},
|
||||
}
|
||||
st.Objective = compileRefineObjective(st, st.SlicePlan)
|
||||
|
||||
pass, _, unmet, applied := evaluateObjectiveDeterministic(st)
|
||||
if !applied {
|
||||
t.Fatalf("期望命中确定性终审")
|
||||
}
|
||||
if pass {
|
||||
t.Fatalf("期望半数迁移失败")
|
||||
}
|
||||
if len(unmet) == 0 {
|
||||
t.Fatalf("期望返回未满足项")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileObjectiveMoveRatioFromStructuredAssertion(t *testing.T) {
|
||||
initial, final := buildHalfTransferEntries(10, 5)
|
||||
st := &ScheduleRefineState{
|
||||
UserMessage: "请把任务重新分配",
|
||||
InitialHybridEntries: initial,
|
||||
HybridEntries: final,
|
||||
SlicePlan: RefineSlicePlan{
|
||||
WeekFilter: []int{17, 16},
|
||||
},
|
||||
Contract: RefineContract{
|
||||
Intent: "任务重新分配",
|
||||
HardAssertions: []RefineAssertion{
|
||||
{
|
||||
Metric: "source_move_ratio_percent",
|
||||
Operator: "==",
|
||||
Value: 50,
|
||||
Week: 17,
|
||||
TargetWeek: 16,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
st.Objective = compileRefineObjective(st, st.SlicePlan)
|
||||
if st.Objective.Mode != "move_ratio" {
|
||||
t.Fatalf("结构化断言未生效,期望 move_ratio,实际=%s", st.Objective.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
func buildHalfTransferEntries(total int, moved int) ([]model.HybridScheduleEntry, []model.HybridScheduleEntry) {
|
||||
initial := make([]model.HybridScheduleEntry, 0, total)
|
||||
final := make([]model.HybridScheduleEntry, 0, total)
|
||||
for i := 1; i <= total; i++ {
|
||||
initial = append(initial, model.HybridScheduleEntry{
|
||||
TaskItemID: i,
|
||||
Name: "task",
|
||||
Type: "task",
|
||||
Status: "suggested",
|
||||
Week: 17,
|
||||
DayOfWeek: 1,
|
||||
SectionFrom: 1,
|
||||
SectionTo: 2,
|
||||
})
|
||||
week := 17
|
||||
if i <= moved {
|
||||
week = 16
|
||||
}
|
||||
final = append(final, model.HybridScheduleEntry{
|
||||
TaskItemID: i,
|
||||
Name: "task",
|
||||
Type: "task",
|
||||
Status: "suggested",
|
||||
Week: week,
|
||||
DayOfWeek: 1,
|
||||
SectionFrom: 1,
|
||||
SectionTo: 2,
|
||||
})
|
||||
}
|
||||
return initial, final
|
||||
}
|
||||
|
||||
func TestNormalizeMovableTaskOrderByOrigin(t *testing.T) {
|
||||
st := &ScheduleRefineState{
|
||||
OriginOrderMap: map[int]int{
|
||||
101: 1,
|
||||
202: 2,
|
||||
},
|
||||
HybridEntries: []model.HybridScheduleEntry{
|
||||
{TaskItemID: 202, Name: "task-202", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2},
|
||||
{TaskItemID: 101, Name: "task-101", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 3, SectionFrom: 1, SectionTo: 2},
|
||||
},
|
||||
}
|
||||
changed := normalizeMovableTaskOrderByOrigin(st)
|
||||
if !changed {
|
||||
t.Fatalf("期望发生顺序归位")
|
||||
}
|
||||
sortHybridEntries(st.HybridEntries)
|
||||
if st.HybridEntries[0].TaskItemID != 101 || st.HybridEntries[1].TaskItemID != 202 {
|
||||
t.Fatalf("顺序归位失败: %+v", st.HybridEntries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryNormalizeMovableTaskOrderByOriginSkipsAfterMinContextSwitch(t *testing.T) {
|
||||
st := &ScheduleRefineState{
|
||||
OriginOrderMap: map[int]int{
|
||||
101: 1,
|
||||
202: 2,
|
||||
},
|
||||
CompositeToolSuccess: map[string]bool{
|
||||
"SpreadEven": false,
|
||||
"MinContextSwitch": true,
|
||||
},
|
||||
HybridEntries: []model.HybridScheduleEntry{
|
||||
{TaskItemID: 202, Name: "task-202", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2},
|
||||
{TaskItemID: 101, Name: "task-101", Type: "task", Status: "suggested", Week: 17, DayOfWeek: 3, SectionFrom: 1, SectionTo: 2},
|
||||
},
|
||||
}
|
||||
changed, skipped := tryNormalizeMovableTaskOrderByOrigin(st)
|
||||
if !skipped {
|
||||
t.Fatalf("期望 MinContextSwitch 成功后跳过顺序归位")
|
||||
}
|
||||
if changed {
|
||||
t.Fatalf("跳过顺序归位时不应报告 changed=true")
|
||||
}
|
||||
if st.HybridEntries[0].TaskItemID != 202 || st.HybridEntries[1].TaskItemID != 101 {
|
||||
t.Fatalf("跳过顺序归位后不应改写任务顺序: %+v", st.HybridEntries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvaluateHardChecksSkipsOrderConstraintAfterMinContextSwitch(t *testing.T) {
|
||||
st := &ScheduleRefineState{
|
||||
UserMessage: "减少第15周科目切换",
|
||||
OriginOrderMap: map[int]int{
|
||||
101: 1,
|
||||
202: 2,
|
||||
},
|
||||
CompositeToolSuccess: map[string]bool{
|
||||
"SpreadEven": false,
|
||||
"MinContextSwitch": true,
|
||||
},
|
||||
InitialHybridEntries: []model.HybridScheduleEntry{
|
||||
{TaskItemID: 101, Name: "概率任务", Type: "task", Status: "suggested", Week: 15, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2},
|
||||
{TaskItemID: 202, Name: "数电任务", Type: "task", Status: "suggested", Week: 15, DayOfWeek: 1, SectionFrom: 3, SectionTo: 4},
|
||||
},
|
||||
HybridEntries: []model.HybridScheduleEntry{
|
||||
{TaskItemID: 202, Name: "数电任务", Type: "task", Status: "suggested", Week: 15, DayOfWeek: 1, SectionFrom: 1, SectionTo: 2},
|
||||
{TaskItemID: 101, Name: "概率任务", Type: "task", Status: "suggested", Week: 15, DayOfWeek: 1, SectionFrom: 3, SectionTo: 4},
|
||||
},
|
||||
Objective: RefineObjective{
|
||||
Mode: "move_all",
|
||||
SourceWeeks: []int{15},
|
||||
TargetWeeks: []int{15},
|
||||
BaselineSourceTaskCount: 2,
|
||||
RequiredMoveMin: 2,
|
||||
RequiredMoveMax: 2,
|
||||
},
|
||||
SlicePlan: RefineSlicePlan{
|
||||
WeekFilter: []int{15},
|
||||
},
|
||||
}
|
||||
report := evaluateHardChecks(nil, nil, st, nil)
|
||||
if !report.OrderPassed {
|
||||
t.Fatalf("期望 MinContextSwitch 成功后跳过顺序终审,实际 issues=%v", report.OrderIssues)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrecheckToolCallPolicyRejectsRedundantSlotQuery(t *testing.T) {
|
||||
st := &ScheduleRefineState{
|
||||
SeenSlotQueries: make(map[string]struct{}),
|
||||
EntriesVersion: 0,
|
||||
}
|
||||
call := reactToolCall{
|
||||
Tool: "QueryAvailableSlots",
|
||||
Params: map[string]any{
|
||||
"week": 16,
|
||||
"day_of_week": 1,
|
||||
},
|
||||
}
|
||||
|
||||
if blockedResult, blocked := precheckToolCallPolicy(st, call, nil); blocked {
|
||||
t.Fatalf("首次查询不应被拒绝: %+v", blockedResult)
|
||||
}
|
||||
if blockedResult, blocked := precheckToolCallPolicy(st, call, nil); !blocked {
|
||||
t.Fatalf("重复查询应被拒绝")
|
||||
} else if blockedResult.ErrorCode != "QUERY_REDUNDANT" {
|
||||
t.Fatalf("错误码不符合预期: %+v", blockedResult)
|
||||
}
|
||||
st.EntriesVersion++
|
||||
if blockedResult, blocked := precheckToolCallPolicy(st, call, nil); blocked {
|
||||
t.Fatalf("排程版本变化后应允许再次查询: %+v", blockedResult)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanonicalizeMoveParamsFromRepairAliases(t *testing.T) {
|
||||
call := reactToolCall{
|
||||
Tool: "Move",
|
||||
Params: map[string]any{
|
||||
"task_item_id": 16,
|
||||
"new_week": 16,
|
||||
"day_of_week": 1,
|
||||
"section_from": 1,
|
||||
"section_to": 2,
|
||||
},
|
||||
}
|
||||
normalized := canonicalizeToolCall(call)
|
||||
if _, ok := paramIntAny(normalized.Params, "to_week"); !ok {
|
||||
t.Fatalf("to_week 规范化失败: %+v", normalized.Params)
|
||||
}
|
||||
if _, ok := paramIntAny(normalized.Params, "to_day"); !ok {
|
||||
t.Fatalf("to_day 规范化失败: %+v", normalized.Params)
|
||||
}
|
||||
if _, ok := paramIntAny(normalized.Params, "to_section_from"); !ok {
|
||||
t.Fatalf("to_section_from 规范化失败: %+v", normalized.Params)
|
||||
}
|
||||
if _, ok := paramIntAny(normalized.Params, "to_section_to"); !ok {
|
||||
t.Fatalf("to_section_to 规范化失败: %+v", normalized.Params)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectOrderIntentDefaultsToKeep(t *testing.T) {
|
||||
if !detectOrderIntent("16周总体任务太多了,帮我移动一半到12周") {
|
||||
t.Fatalf("未显式放宽顺序时,默认应保持顺序")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectOrderIntentExplicitAllowReorder(t *testing.T) {
|
||||
if detectOrderIntent("这次顺序无所谓,可以打乱顺序") {
|
||||
t.Fatalf("用户明确允许乱序时,应关闭顺序约束")
|
||||
}
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
package schedulerefine
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
)
|
||||
|
||||
// scheduleRefineRunner 是“单次图运行”的请求级依赖容器。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责收口模型与阶段回调,避免 graph.go 出现大量闭包;
|
||||
// 2. 负责把节点函数适配为统一签名;
|
||||
// 3. 不负责分支决策(当前链路为线性图)。
|
||||
type scheduleRefineRunner struct {
|
||||
chatModel *ark.ChatModel
|
||||
emitStage func(stage, detail string)
|
||||
}
|
||||
|
||||
func newScheduleRefineRunner(chatModel *ark.ChatModel, emitStage func(stage, detail string)) *scheduleRefineRunner {
|
||||
return &scheduleRefineRunner{
|
||||
chatModel: chatModel,
|
||||
emitStage: emitStage,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *scheduleRefineRunner) contractNode(ctx context.Context, st *ScheduleRefineState) (*ScheduleRefineState, error) {
|
||||
return runContractNode(ctx, r.chatModel, st, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *scheduleRefineRunner) planNode(ctx context.Context, st *ScheduleRefineState) (*ScheduleRefineState, error) {
|
||||
return runPlanNode(ctx, r.chatModel, st, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *scheduleRefineRunner) sliceNode(ctx context.Context, st *ScheduleRefineState) (*ScheduleRefineState, error) {
|
||||
return runSliceNode(ctx, st, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *scheduleRefineRunner) routeNode(ctx context.Context, st *ScheduleRefineState) (*ScheduleRefineState, error) {
|
||||
return runCompositeRouteNode(ctx, st, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *scheduleRefineRunner) reactNode(ctx context.Context, st *ScheduleRefineState) (*ScheduleRefineState, error) {
|
||||
return runReactLoopNode(ctx, r.chatModel, st, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *scheduleRefineRunner) hardCheckNode(ctx context.Context, st *ScheduleRefineState) (*ScheduleRefineState, error) {
|
||||
return runHardCheckNode(ctx, r.chatModel, st, r.emitStage)
|
||||
}
|
||||
|
||||
func (r *scheduleRefineRunner) summaryNode(ctx context.Context, st *ScheduleRefineState) (*ScheduleRefineState, error) {
|
||||
return runSummaryNode(ctx, r.chatModel, st, r.emitStage)
|
||||
}
|
||||
95
backend/agent/shared/clone.go
Normal file
95
backend/agent/shared/clone.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package agentshared
|
||||
|
||||
import "github.com/LoveLosita/smartflow/backend/model"
|
||||
|
||||
// CloneWeekSchedules 深拷贝周视图排程结果。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责断开 []UserWeekSchedule 与内部 Events 切片的引用共享;
|
||||
// 2. 负责服务于“缓存 DTO / graph state / API 响应”之间的安全复制;
|
||||
// 3. 不负责业务过滤,不负责排序。
|
||||
func CloneWeekSchedules(src []model.UserWeekSchedule) []model.UserWeekSchedule {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
dst := make([]model.UserWeekSchedule, 0, len(src))
|
||||
for _, week := range src {
|
||||
eventsCopy := make([]model.WeeklyEventBrief, len(week.Events))
|
||||
copy(eventsCopy, week.Events)
|
||||
dst = append(dst, model.UserWeekSchedule{
|
||||
Week: week.Week,
|
||||
Events: eventsCopy,
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// CloneHybridEntries 深拷贝混合排程条目切片。
|
||||
func CloneHybridEntries(src []model.HybridScheduleEntry) []model.HybridScheduleEntry {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
dst := make([]model.HybridScheduleEntry, len(src))
|
||||
copy(dst, src)
|
||||
return dst
|
||||
}
|
||||
|
||||
// CloneTaskClassItems 深拷贝任务块切片。
|
||||
//
|
||||
// 这里不能直接 copy:
|
||||
// 1. 因为 TaskClassItem 内部带若干指针字段;
|
||||
// 2. 如果浅拷贝,后续某一步修改 EmbeddedTime / Status,会污染原状态;
|
||||
// 3. 排程 graph 连续微调时,这种共享引用会非常难查,所以必须在公共层兜住。
|
||||
func CloneTaskClassItems(src []model.TaskClassItem) []model.TaskClassItem {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
dst := make([]model.TaskClassItem, 0, len(src))
|
||||
for _, item := range src {
|
||||
copied := item
|
||||
if item.CategoryID != nil {
|
||||
v := *item.CategoryID
|
||||
copied.CategoryID = &v
|
||||
}
|
||||
if item.Order != nil {
|
||||
v := *item.Order
|
||||
copied.Order = &v
|
||||
}
|
||||
if item.Content != nil {
|
||||
v := *item.Content
|
||||
copied.Content = &v
|
||||
}
|
||||
if item.Status != nil {
|
||||
v := *item.Status
|
||||
copied.Status = &v
|
||||
}
|
||||
if item.EmbeddedTime != nil {
|
||||
t := *item.EmbeddedTime
|
||||
copied.EmbeddedTime = &t
|
||||
}
|
||||
dst = append(dst, copied)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// CloneInts 深拷贝 int 切片。
|
||||
func CloneInts(src []int) []int {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
dst := make([]int, len(src))
|
||||
copy(dst, src)
|
||||
return dst
|
||||
}
|
||||
|
||||
// CloneStrings 深拷贝 string 切片。
|
||||
func CloneStrings(src []string) []string {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
dst := make([]string, len(src))
|
||||
copy(dst, src)
|
||||
return dst
|
||||
}
|
||||
85
backend/agent/shared/retry.go
Normal file
85
backend/agent/shared/retry.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package agentshared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RetryOptions 描述公共重试策略。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 这里只定义“是否重试、最多几次、间隔多久”;
|
||||
// 2. 不关心具体业务是工具调用失败、模型 JSON 失败还是 DB 暂时不可用;
|
||||
// 3. 真正的业务兜底文案仍应由上层 node 决定。
|
||||
type RetryOptions struct {
|
||||
MaxAttempts int
|
||||
Interval time.Duration
|
||||
ShouldRetry func(err error) bool
|
||||
OnRetry func(attempt int, err error)
|
||||
}
|
||||
|
||||
// Do 执行一个只返回 error 的重试任务。
|
||||
//
|
||||
// 执行规则:
|
||||
// 1. 第一次执行也算一次 attempt;
|
||||
// 2. 任意一次成功即立即返回;
|
||||
// 3. 上下文取消、达到最大次数、或 ShouldRetry=false 时立即停止。
|
||||
func Do(ctx context.Context, options RetryOptions, fn func(attempt int) error) error {
|
||||
_, err := DoValue[struct{}](ctx, options, func(attempt int) (struct{}, error) {
|
||||
return struct{}{}, fn(attempt)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// DoValue 执行一个带返回值的通用重试任务。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 旧 agent 里后续很多地方都会出现“失败重试 2~3 次”的模式;
|
||||
// 2. 这里先把循环骨架统一,避免每个 skill 自己写 for + sleep + ctx.Done;
|
||||
// 3. 上层只需关心“本轮失败要不要继续”,而不是重复造轮子。
|
||||
func DoValue[T any](ctx context.Context, options RetryOptions, fn func(attempt int) (T, error)) (T, error) {
|
||||
var zero T
|
||||
|
||||
maxAttempts := options.MaxAttempts
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = 1
|
||||
}
|
||||
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
value, err := fn(attempt)
|
||||
if err == nil {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// 1. 到最后一次了,直接返回原错误,避免无意义等待。
|
||||
if attempt >= maxAttempts {
|
||||
return zero, err
|
||||
}
|
||||
// 2. 业务显式声明“不值得重试”时,立刻停止。
|
||||
if options.ShouldRetry != nil && !options.ShouldRetry(err) {
|
||||
return zero, err
|
||||
}
|
||||
// 3. 把重试钩子留给上层,用于打点或阶段提示。
|
||||
if options.OnRetry != nil {
|
||||
options.OnRetry(attempt, err)
|
||||
}
|
||||
// 4. 没有配置间隔则马上下一轮;配置了则等待,同时尊重 ctx 取消。
|
||||
if options.Interval <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
timer := time.NewTimer(options.Interval)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return zero, ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
|
||||
return zero, nil
|
||||
}
|
||||
49
backend/agent/shared/time.go
Normal file
49
backend/agent/shared/time.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package agentshared
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// MinuteLayout 是 Agent 内部统一的分钟级时间文本格式。
|
||||
//
|
||||
// 设计原因:
|
||||
// 1. agent 里大量场景只需要精确到分钟;
|
||||
// 2. 秒级精度会增加提示词噪声,也容易让“同一请求内的当前时间”出现抖动;
|
||||
// 3. 先统一成一份常量,后续 quicknote / schedule 都直接复用。
|
||||
MinuteLayout = "2006-01-02 15:04"
|
||||
)
|
||||
|
||||
var (
|
||||
shanghaiLocOnce sync.Once
|
||||
shanghaiLoc *time.Location
|
||||
)
|
||||
|
||||
// ShanghaiLocation 返回 Agent 内部统一使用的东八区时区。
|
||||
func ShanghaiLocation() *time.Location {
|
||||
shanghaiLocOnce.Do(func() {
|
||||
loc, err := time.LoadLocation("Asia/Shanghai")
|
||||
if err != nil {
|
||||
// 兜底使用固定东八区,避免极端环境下因为系统时区文件缺失导致整个链路失败。
|
||||
loc = time.FixedZone("CST", 8*3600)
|
||||
}
|
||||
shanghaiLoc = loc
|
||||
})
|
||||
return shanghaiLoc
|
||||
}
|
||||
|
||||
// NowToMinute 返回当前北京时间,并截断到分钟级。
|
||||
func NowToMinute() time.Time {
|
||||
return time.Now().In(ShanghaiLocation()).Truncate(time.Minute)
|
||||
}
|
||||
|
||||
// NormalizeToMinute 把任意时间统一到北京时间分钟粒度。
|
||||
func NormalizeToMinute(t time.Time) time.Time {
|
||||
return t.In(ShanghaiLocation()).Truncate(time.Minute)
|
||||
}
|
||||
|
||||
// FormatMinute 把时间格式化为统一分钟级文本。
|
||||
func FormatMinute(t time.Time) string {
|
||||
return NormalizeToMinute(t).Format(MinuteLayout)
|
||||
}
|
||||
115
backend/agent/stream/emitter.go
Normal file
115
backend/agent/stream/emitter.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package agentstream
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PayloadEmitter 是真正向外层 SSE 管道写 chunk 的最小接口。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 这里刻意不用 chan/string 绑死实现;
|
||||
// 2. 上层既可以传“写 channel”的函数,也可以传“写 gin stream”的函数;
|
||||
// 3. 只要签名是 `func(string) error`,都能接进来。
|
||||
type PayloadEmitter func(payload string) error
|
||||
|
||||
// StageEmitter 是 graph/node 对“当前阶段”进行推送的最小接口。
|
||||
type StageEmitter func(stage, detail string)
|
||||
|
||||
// NoopPayloadEmitter 返回一个空实现,便于骨架期安全占位。
|
||||
func NoopPayloadEmitter() PayloadEmitter {
|
||||
return func(string) error { return nil }
|
||||
}
|
||||
|
||||
// NoopStageEmitter 返回一个空实现,避免 graph 在没有接前端时处处判空。
|
||||
func NoopStageEmitter() StageEmitter {
|
||||
return func(stage, detail string) {}
|
||||
}
|
||||
|
||||
// WrapStageEmitter 把可空函数包装成稳定的 StageEmitter。
|
||||
func WrapStageEmitter(fn func(stage, detail string)) StageEmitter {
|
||||
if fn == nil {
|
||||
return NoopStageEmitter()
|
||||
}
|
||||
return fn
|
||||
}
|
||||
|
||||
// EmitStageAsReasoning 把“阶段提示”伪装成 reasoning chunk 推给前端。
|
||||
//
|
||||
// 设计背景:
|
||||
// 1. 你当前 Apifox 只认思考块和正文块,因此阶段提示需要先借 reasoning_content 走通;
|
||||
// 2. 这样后续真正前端上线时,只需要在这一层换协议,而不必回到各 skill 重改 graph;
|
||||
// 3. 这里不拼花哨格式,只给出稳定、可读、可 grep 的文本。
|
||||
func EmitStageAsReasoning(emit PayloadEmitter, requestID, modelName string, created int64, stage, detail string, includeRole bool) error {
|
||||
if emit == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
text := BuildStageReasoningText(stage, detail)
|
||||
payload, err := ToOpenAIReasoningChunk(requestID, modelName, created, text, includeRole)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if payload == "" {
|
||||
return nil
|
||||
}
|
||||
return emit(payload)
|
||||
}
|
||||
|
||||
// EmitAssistantReply 把一段完整正文作为 assistant chunk 推出。
|
||||
//
|
||||
// 注意:
|
||||
// 1. 这里是“整段发”,不是把文本强行拆碎;
|
||||
// 2. 这样后续如果某条链路不需要真流式,也可以复用统一出口;
|
||||
// 3. 真正按 token/chunk 细粒度流式输出,应由 llm.Stream + 上层循环处理。
|
||||
func EmitAssistantReply(emit PayloadEmitter, requestID, modelName string, created int64, content string, includeRole bool) error {
|
||||
if emit == nil {
|
||||
return nil
|
||||
}
|
||||
payload, err := ToOpenAIAssistantChunk(requestID, modelName, created, content, includeRole)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if payload == "" {
|
||||
return nil
|
||||
}
|
||||
return emit(payload)
|
||||
}
|
||||
|
||||
// EmitFinish 统一输出 stop 结束块。
|
||||
func EmitFinish(emit PayloadEmitter, requestID, modelName string, created int64) error {
|
||||
if emit == nil {
|
||||
return nil
|
||||
}
|
||||
payload, err := ToOpenAIFinishStream(requestID, modelName, created)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if payload == "" {
|
||||
return nil
|
||||
}
|
||||
return emit(payload)
|
||||
}
|
||||
|
||||
// EmitDone 统一输出 OpenAI 兼容流式结束标记。
|
||||
func EmitDone(emit PayloadEmitter) error {
|
||||
if emit == nil {
|
||||
return nil
|
||||
}
|
||||
return emit("[DONE]")
|
||||
}
|
||||
|
||||
// BuildStageReasoningText 生成统一阶段提示文本。
|
||||
func BuildStageReasoningText(stage, detail string) string {
|
||||
stage = strings.TrimSpace(stage)
|
||||
detail = strings.TrimSpace(detail)
|
||||
|
||||
switch {
|
||||
case stage != "" && detail != "":
|
||||
return fmt.Sprintf("阶段:%s\n%s", stage, detail)
|
||||
case stage != "":
|
||||
return fmt.Sprintf("阶段:%s", stage)
|
||||
default:
|
||||
return detail
|
||||
}
|
||||
}
|
||||
102
backend/agent/stream/openai.go
Normal file
102
backend/agent/stream/openai.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package agentstream
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// OpenAIChunkResponse 是 OpenAI 兼容的流式 chunk DTO。
|
||||
//
|
||||
// 之所以单独放到 Agent/stream:
|
||||
// 1. 未来无论 quicknote、taskquery 还是 schedule,只要需要 SSE 都会复用这套协议壳;
|
||||
// 2. 这样 node/graph 层只关注“我要推什么内容”,不再自己拼 JSON;
|
||||
// 3. 后续如果前端协议升级,也能在这里集中改。
|
||||
type OpenAIChunkResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []OpenAIChunkChoice `json:"choices"`
|
||||
}
|
||||
|
||||
// OpenAIChunkChoice 对应 OpenAI choices[0]。
|
||||
type OpenAIChunkChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta OpenAIChunkDelta `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// OpenAIChunkDelta 是真正承载 role/content/reasoning 的位置。
|
||||
type OpenAIChunkDelta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
}
|
||||
|
||||
// ToOpenAIStream 把 Eino message 转成 OpenAI 兼容 chunk。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责把 chunk.Content / chunk.ReasoningContent 映射到协议字段;
|
||||
// 2. 负责按 includeRole 决定是否在首块带上 assistant 角色;
|
||||
// 3. 不负责发送,也不负责决定“这个 chunk 该不该推”。
|
||||
func ToOpenAIStream(chunk *schema.Message, requestID, modelName string, created int64, includeRole bool) (string, error) {
|
||||
delta := OpenAIChunkDelta{}
|
||||
if includeRole {
|
||||
delta.Role = "assistant"
|
||||
}
|
||||
if chunk != nil {
|
||||
delta.Content = chunk.Content
|
||||
delta.ReasoningContent = chunk.ReasoningContent
|
||||
}
|
||||
return buildOpenAIChunkPayload(requestID, modelName, created, delta, nil)
|
||||
}
|
||||
|
||||
// ToOpenAIReasoningChunk 直接构造一个 reasoning chunk。
|
||||
func ToOpenAIReasoningChunk(requestID, modelName string, created int64, reasoning string, includeRole bool) (string, error) {
|
||||
delta := OpenAIChunkDelta{ReasoningContent: reasoning}
|
||||
if includeRole {
|
||||
delta.Role = "assistant"
|
||||
}
|
||||
return buildOpenAIChunkPayload(requestID, modelName, created, delta, nil)
|
||||
}
|
||||
|
||||
// ToOpenAIAssistantChunk 直接构造一个正文 chunk。
|
||||
func ToOpenAIAssistantChunk(requestID, modelName string, created int64, content string, includeRole bool) (string, error) {
|
||||
delta := OpenAIChunkDelta{Content: content}
|
||||
if includeRole {
|
||||
delta.Role = "assistant"
|
||||
}
|
||||
return buildOpenAIChunkPayload(requestID, modelName, created, delta, nil)
|
||||
}
|
||||
|
||||
// ToOpenAIFinishStream 生成流式结束 chunk(finish_reason=stop)。
|
||||
func ToOpenAIFinishStream(requestID, modelName string, created int64) (string, error) {
|
||||
stop := "stop"
|
||||
return buildOpenAIChunkPayload(requestID, modelName, created, OpenAIChunkDelta{}, &stop)
|
||||
}
|
||||
|
||||
func buildOpenAIChunkPayload(requestID, modelName string, created int64, delta OpenAIChunkDelta, finishReason *string) (string, error) {
|
||||
// 1. 若既没有 role,也没有正文/思考,也没有 finish_reason,则视为“空块”,直接跳过。
|
||||
// 2. 这样可以避免上层每次都自己写一遍空块判断。
|
||||
if delta.Role == "" && delta.Content == "" && delta.ReasoningContent == "" && finishReason == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
dto := OpenAIChunkResponse{
|
||||
ID: requestID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: created,
|
||||
Model: modelName,
|
||||
Choices: []OpenAIChunkChoice{{
|
||||
Index: 0,
|
||||
Delta: delta,
|
||||
FinishReason: finishReason,
|
||||
}},
|
||||
}
|
||||
data, err := json.Marshal(dto)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
@@ -1,183 +0,0 @@
|
||||
package taskquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
const (
|
||||
// 图节点:意图规划(一次模型调用,产出结构化查询计划)
|
||||
taskQueryGraphNodePlan = "task_query_plan"
|
||||
// 图节点:象限归一化(不调模型,只做参数规整)
|
||||
taskQueryGraphNodeQuadrant = "task_query_quadrant"
|
||||
// 图节点:时间锚定(不调模型,锁定绝对时间边界)
|
||||
taskQueryGraphNodeTime = "task_query_time_anchor"
|
||||
// 图节点:工具查询(调用 query_tasks 工具)
|
||||
taskQueryGraphNodeQuery = "task_query_tool_query"
|
||||
// 图节点:结果反思与回复(模型判断是否满足并产出回复/重试补丁)
|
||||
taskQueryGraphNodeReflect = "task_query_reflect"
|
||||
)
|
||||
|
||||
// QueryGraphRunInput 是任务查询图运行输入。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. Model/Deps 提供图运行依赖;
|
||||
// 2. UserMessage/RequestNowText 提供本次请求上下文;
|
||||
// 3. MaxReflectRetry 控制“反思重试”上限;
|
||||
// 4. EmitStage 是可选阶段推送钩子,不影响主链路成功与否。
|
||||
type QueryGraphRunInput struct {
|
||||
Model *ark.ChatModel
|
||||
UserMessage string
|
||||
RequestNowText string
|
||||
Deps TaskQueryToolDeps
|
||||
MaxReflectRetry int
|
||||
EmitStage func(stage, detail string)
|
||||
}
|
||||
|
||||
// RunTaskQueryGraph 执行“任务查询图编排”。
|
||||
//
|
||||
// 关键策略:
|
||||
// 1. 规划节点只调用一次模型,统一产出查询计划;
|
||||
// 2. 查询节点优先按计划查,若为空先自动放宽一次(无额外模型调用);
|
||||
// 3. 反思节点最多重试 2 次,每次决定“是否满足、是否继续、如何补丁”。
|
||||
func RunTaskQueryGraph(ctx context.Context, input QueryGraphRunInput) (string, error) {
|
||||
// 1. 启动前硬校验。
|
||||
if input.Model == nil {
|
||||
return "", errors.New("task query graph: model is nil")
|
||||
}
|
||||
if err := input.Deps.validate(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 2. 构建工具包,并拿到 query_tasks 可执行工具。
|
||||
toolBundle, err := BuildTaskQueryToolBundle(ctx, input.Deps)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
toolMap, err := buildInvokableToolMap(toolBundle)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
queryTool, exists := toolMap[ToolNameTaskQueryTasks]
|
||||
if !exists {
|
||||
return "", fmt.Errorf("task query graph: tool %s not found", ToolNameTaskQueryTasks)
|
||||
}
|
||||
|
||||
// 3. 初始化状态:请求时间为空时做本地兜底。
|
||||
requestNow := strings.TrimSpace(input.RequestNowText)
|
||||
if requestNow == "" {
|
||||
requestNow = time.Now().In(time.Local).Format("2006-01-02 15:04")
|
||||
}
|
||||
state := NewTaskQueryState(strings.TrimSpace(input.UserMessage), requestNow, input.MaxReflectRetry)
|
||||
|
||||
// 4. 封装 runner,把“依赖注入”和“节点逻辑”解耦。
|
||||
runner := newTaskQueryGraphRunner(input, queryTool)
|
||||
|
||||
// 5. 只在本次请求内构图并执行,避免跨请求共享状态。
|
||||
graph := compose.NewGraph[*TaskQueryState, *TaskQueryState]()
|
||||
|
||||
if err = graph.AddLambdaNode(taskQueryGraphNodePlan, compose.InvokableLambda(runner.planNode)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddLambdaNode(taskQueryGraphNodeQuadrant, compose.InvokableLambda(runner.quadrantNode)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddLambdaNode(taskQueryGraphNodeTime, compose.InvokableLambda(runner.timeAnchorNode)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddLambdaNode(taskQueryGraphNodeQuery, compose.InvokableLambda(runner.queryNode)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddLambdaNode(taskQueryGraphNodeReflect, compose.InvokableLambda(runner.reflectNode)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 连线:START -> plan -> quadrant -> time -> query -> reflect
|
||||
if err = graph.AddEdge(compose.START, taskQueryGraphNodePlan); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddEdge(taskQueryGraphNodePlan, taskQueryGraphNodeQuadrant); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddEdge(taskQueryGraphNodeQuadrant, taskQueryGraphNodeTime); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddEdge(taskQueryGraphNodeTime, taskQueryGraphNodeQuery); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err = graph.AddEdge(taskQueryGraphNodeQuery, taskQueryGraphNodeReflect); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 分支:reflect 后要么结束,要么回到 query 重试。
|
||||
if err = graph.AddBranch(taskQueryGraphNodeReflect, compose.NewGraphBranch(
|
||||
runner.nextAfterReflect,
|
||||
map[string]bool{
|
||||
taskQueryGraphNodeQuery: true,
|
||||
compose.END: true,
|
||||
},
|
||||
)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
maxRunSteps := 24 + state.MaxReflectRetry*4
|
||||
if maxRunSteps < 24 {
|
||||
maxRunSteps = 24
|
||||
}
|
||||
runnable, err := graph.Compile(ctx,
|
||||
compose.WithGraphName("TaskQueryGraph"),
|
||||
compose.WithMaxRunSteps(maxRunSteps),
|
||||
compose.WithNodeTriggerMode(compose.AnyPredecessor),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
finalState, err := runnable.Invoke(ctx, state)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if finalState == nil {
|
||||
return "", errors.New("task query graph: final state is nil")
|
||||
}
|
||||
|
||||
reply := strings.TrimSpace(finalState.FinalReply)
|
||||
if reply == "" {
|
||||
reply = buildTaskQueryFallbackReply(finalState.LastQueryItems)
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
type taskQueryGraphRunner struct {
|
||||
input QueryGraphRunInput
|
||||
queryTool tool.InvokableTool
|
||||
}
|
||||
|
||||
func newTaskQueryGraphRunner(input QueryGraphRunInput, queryTool tool.InvokableTool) *taskQueryGraphRunner {
|
||||
return &taskQueryGraphRunner{
|
||||
input: input,
|
||||
queryTool: queryTool,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *taskQueryGraphRunner) emit(stage, detail string) {
|
||||
if r.input.EmitStage == nil {
|
||||
return
|
||||
}
|
||||
r.input.EmitStage(stage, detail)
|
||||
}
|
||||
|
||||
func (r *taskQueryGraphRunner) nextAfterReflect(ctx context.Context, st *TaskQueryState) (string, error) {
|
||||
_ = ctx
|
||||
if st != nil && st.NeedRetry {
|
||||
return taskQueryGraphNodeQuery, nil
|
||||
}
|
||||
return compose.END, nil
|
||||
}
|
||||
@@ -1,839 +0,0 @@
|
||||
package taskquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
einoModel "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
type taskQueryPlanOutput struct {
|
||||
UserGoal string `json:"user_goal"`
|
||||
Quadrants []int `json:"quadrants"`
|
||||
SortBy string `json:"sort_by"`
|
||||
Order string `json:"order"`
|
||||
Limit int `json:"limit"`
|
||||
IncludeCompleted *bool `json:"include_completed"`
|
||||
Keyword string `json:"keyword"`
|
||||
DeadlineBefore string `json:"deadline_before"`
|
||||
DeadlineAfter string `json:"deadline_after"`
|
||||
}
|
||||
|
||||
type taskQueryReflectOutput struct {
|
||||
Satisfied bool `json:"satisfied"`
|
||||
NeedRetry bool `json:"need_retry"`
|
||||
Reason string `json:"reason"`
|
||||
Reply string `json:"reply"`
|
||||
RetryPatch taskQueryRetryPatch `json:"retry_patch"`
|
||||
}
|
||||
|
||||
type taskQueryRetryPatch struct {
|
||||
Quadrants *[]int `json:"quadrants,omitempty"`
|
||||
SortBy *string `json:"sort_by,omitempty"`
|
||||
Order *string `json:"order,omitempty"`
|
||||
Limit *int `json:"limit,omitempty"`
|
||||
IncludeCompleted *bool `json:"include_completed,omitempty"`
|
||||
Keyword *string `json:"keyword,omitempty"`
|
||||
DeadlineBefore *string `json:"deadline_before,omitempty"`
|
||||
DeadlineAfter *string `json:"deadline_after,omitempty"`
|
||||
}
|
||||
|
||||
var (
|
||||
// explicitLimitPatterns 用于从用户原话提取“显式数量要求”。
|
||||
//
|
||||
// 例子:
|
||||
// 1. 前3个任务
|
||||
// 2. 给我5条
|
||||
// 3. top 10
|
||||
explicitLimitPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)\btop\s*(\d{1,2})\b`),
|
||||
regexp.MustCompile(`前\s*(\d{1,2})\s*(个|条|项)?`),
|
||||
regexp.MustCompile(`(\d{1,2})\s*(个|条|项)\s*任务?`),
|
||||
regexp.MustCompile(`给我\s*(\d{1,2})\s*(个|条|项)?`),
|
||||
}
|
||||
// chineseDigitMap 支持常见中文数字(用于“前五个”“来三个”这类口语)。
|
||||
chineseDigitMap = map[rune]int{
|
||||
'一': 1, '二': 2, '两': 2, '三': 3, '四': 4, '五': 5,
|
||||
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10,
|
||||
}
|
||||
)
|
||||
|
||||
func (r *taskQueryGraphRunner) planNode(ctx context.Context, st *TaskQueryState) (*TaskQueryState, error) {
|
||||
// 1. 防御校验:state 为空时直接返回,避免后续节点空指针。
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("task query graph: nil state in plan node")
|
||||
}
|
||||
|
||||
// 2. 规划节点只调用一次模型,把查询意图打包成结构化计划。
|
||||
r.emit("task_query.plan.generating", "正在一次性规划查询范围、排序和时间条件。")
|
||||
prompt := fmt.Sprintf(`当前时间(北京时间,精确到分钟):%s
|
||||
用户输入:%s
|
||||
|
||||
请输出任务查询计划 JSON。`, st.RequestNowText, st.UserMessage)
|
||||
|
||||
raw, err := callTaskQueryModelForJSON(ctx, r.input.Model, TaskQueryPlanPrompt, prompt, 260)
|
||||
if err != nil {
|
||||
// 3. 模型失败时不直接终止:回退到默认计划,保证可用性。
|
||||
st.UserGoal = "查询任务"
|
||||
st.Plan = defaultTaskQueryPlan()
|
||||
return st, nil
|
||||
}
|
||||
|
||||
planned, parseErr := parseTaskQueryJSON[taskQueryPlanOutput](raw)
|
||||
if parseErr != nil {
|
||||
// 4. JSON 异常同样回退默认计划,避免用户请求直接失败。
|
||||
st.UserGoal = "查询任务"
|
||||
st.Plan = defaultTaskQueryPlan()
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 5. 规划结果统一规范化,保证后续节点拿到稳定参数。
|
||||
st.UserGoal = strings.TrimSpace(planned.UserGoal)
|
||||
if st.UserGoal == "" {
|
||||
st.UserGoal = "查询任务"
|
||||
}
|
||||
st.Plan = normalizePlan(taskQueryPlanOutput{
|
||||
UserGoal: planned.UserGoal,
|
||||
Quadrants: planned.Quadrants,
|
||||
SortBy: planned.SortBy,
|
||||
Order: planned.Order,
|
||||
Limit: planned.Limit,
|
||||
IncludeCompleted: planned.IncludeCompleted,
|
||||
Keyword: planned.Keyword,
|
||||
DeadlineBefore: planned.DeadlineBefore,
|
||||
DeadlineAfter: planned.DeadlineAfter,
|
||||
})
|
||||
|
||||
// 6. 若用户原话里有明确数量要求(例如“给我3个”),强制覆盖 plan.limit。
|
||||
// 这样即使规划模型漏掉 limit,也不会影响最终返回条数预期。
|
||||
if explicitLimit, found := extractExplicitLimitFromUser(st.UserMessage); found {
|
||||
st.ExplicitLimit = explicitLimit
|
||||
st.Plan.Limit = explicitLimit
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
func (r *taskQueryGraphRunner) quadrantNode(ctx context.Context, st *TaskQueryState) (*TaskQueryState, error) {
|
||||
_ = ctx
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("task query graph: nil state in quadrant node")
|
||||
}
|
||||
|
||||
// 1. 象限节点不调用模型,只做“象限参数兜底与去重”。
|
||||
// 2. 为空表示全象限,非空表示指定象限。
|
||||
r.emit("task_query.quadrant.routing", "正在归一化象限筛选范围。")
|
||||
st.Plan.Quadrants = normalizeQuadrants(st.Plan.Quadrants)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
func (r *taskQueryGraphRunner) timeAnchorNode(ctx context.Context, st *TaskQueryState) (*TaskQueryState, error) {
|
||||
_ = ctx
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("task query graph: nil state in time anchor node")
|
||||
}
|
||||
|
||||
// 1. 时间节点不再调用模型,只负责把规划中的时间文本解析为绝对时间对象。
|
||||
// 2. 解析失败时清空该边界,避免非法时间导致整条查询失败。
|
||||
r.emit("task_query.time.anchoring", "正在锁定时间过滤边界。")
|
||||
applyTimeAnchorOnPlan(&st.Plan)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
func (r *taskQueryGraphRunner) queryNode(ctx context.Context, st *TaskQueryState) (*TaskQueryState, error) {
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("task query graph: nil state in query node")
|
||||
}
|
||||
|
||||
// 1. 按当前计划执行工具查询。
|
||||
r.emit("task_query.tool.querying", "正在查询任务数据。")
|
||||
items, err := r.executePlanByTool(ctx, st.Plan)
|
||||
if err != nil {
|
||||
// 查询失败不抛出硬错误,交给反思节点决定如何回复用户。
|
||||
st.LastQueryItems = make([]TaskQueryToolRecord, 0)
|
||||
st.LastQueryTotal = 0
|
||||
st.ReflectReason = "查询工具执行失败"
|
||||
return st, nil
|
||||
}
|
||||
st.LastQueryItems = items
|
||||
st.LastQueryTotal = len(items)
|
||||
|
||||
// 2. 额外优化:若结果为空且还没自动放宽过,则先放宽一次再查询(无额外模型调用)。
|
||||
if st.LastQueryTotal == 0 && !st.AutoBroadenApplied {
|
||||
plan, broadened := autoBroadenPlan(st.Plan)
|
||||
if broadened {
|
||||
st.AutoBroadenApplied = true
|
||||
st.Plan = plan
|
||||
r.emit("task_query.tool.broadened", "首次查询为空,已自动放宽条件再试一次。")
|
||||
retryItems, retryErr := r.executePlanByTool(ctx, st.Plan)
|
||||
if retryErr == nil {
|
||||
st.LastQueryItems = retryItems
|
||||
st.LastQueryTotal = len(retryItems)
|
||||
}
|
||||
}
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
func (r *taskQueryGraphRunner) reflectNode(ctx context.Context, st *TaskQueryState) (*TaskQueryState, error) {
|
||||
if st == nil {
|
||||
return nil, fmt.Errorf("task query graph: nil state in reflect node")
|
||||
}
|
||||
|
||||
// 1. 反思节点负责三件事:
|
||||
// 1.1 判断当前结果是否满足用户诉求;
|
||||
// 1.2 需要重试时给出最小 patch;
|
||||
// 1.3 同时给出可直接返回用户的中文回复。
|
||||
r.emit("task_query.reflecting", "正在判断结果是否贴合你的需求。")
|
||||
reflectPrompt := buildReflectUserPrompt(st)
|
||||
raw, err := callTaskQueryModelForJSON(ctx, r.input.Model, TaskQueryReflectPrompt, reflectPrompt, 380)
|
||||
if err != nil {
|
||||
// 2. 反思调用失败时直接收束,避免无限等待。
|
||||
st.NeedRetry = false
|
||||
st.FinalReply = buildTaskQueryFallbackReply(st.LastQueryItems)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
reflectResult, parseErr := parseTaskQueryJSON[taskQueryReflectOutput](raw)
|
||||
if parseErr != nil {
|
||||
st.NeedRetry = false
|
||||
st.FinalReply = buildTaskQueryFallbackReply(st.LastQueryItems)
|
||||
return st, nil
|
||||
}
|
||||
|
||||
st.ReflectReason = strings.TrimSpace(reflectResult.Reason)
|
||||
|
||||
// 3. 满足需求时直接结束。
|
||||
if reflectResult.Satisfied {
|
||||
st.NeedRetry = false
|
||||
st.FinalReply = buildTaskQueryFinalReply(st.LastQueryItems, st.Plan, strings.TrimSpace(reflectResult.Reply))
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 4. 不满足且允许重试时,应用 patch 并回到查询节点。
|
||||
if reflectResult.NeedRetry && st.RetryCount < st.MaxReflectRetry {
|
||||
st.Plan = applyRetryPatch(st.Plan, reflectResult.RetryPatch, st.ExplicitLimit)
|
||||
st.RetryCount++
|
||||
st.NeedRetry = true
|
||||
if strings.TrimSpace(reflectResult.Reply) != "" {
|
||||
// 4.1 这里先缓存中间回复,最终是否使用取决于后续是否成功命中。
|
||||
st.FinalReply = strings.TrimSpace(reflectResult.Reply)
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// 5. 不再重试:输出最终回复并结束。
|
||||
st.NeedRetry = false
|
||||
st.FinalReply = buildTaskQueryFinalReply(st.LastQueryItems, st.Plan, strings.TrimSpace(reflectResult.Reply))
|
||||
return st, nil
|
||||
}
|
||||
|
||||
func (r *taskQueryGraphRunner) executePlanByTool(ctx context.Context, plan QueryPlan) ([]TaskQueryToolRecord, error) {
|
||||
// 1. 这里强制通过工具执行查询,而不是直接读 DAO。
|
||||
// 目的:保持“工具边界”一致,后续迁移多工具编排时可复用同一协议。
|
||||
if r.queryTool == nil {
|
||||
return nil, fmt.Errorf("task query tool is nil")
|
||||
}
|
||||
|
||||
merged := make([]TaskQueryToolRecord, 0, plan.Limit)
|
||||
seen := make(map[int]struct{}, plan.Limit*2)
|
||||
|
||||
runOne := func(quadrant *int) error {
|
||||
input := TaskQueryToolInput{
|
||||
Quadrant: quadrant,
|
||||
SortBy: plan.SortBy,
|
||||
Order: plan.Order,
|
||||
Limit: plan.Limit,
|
||||
Keyword: plan.Keyword,
|
||||
DeadlineBefore: plan.DeadlineBeforeText,
|
||||
DeadlineAfter: plan.DeadlineAfterText,
|
||||
}
|
||||
includeCompleted := plan.IncludeCompleted
|
||||
input.IncludeCompleted = &includeCompleted
|
||||
|
||||
rawInput, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rawOutput, err := r.queryTool.InvokableRun(ctx, string(rawInput))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
parsed, err := parseTaskQueryJSON[TaskQueryToolOutput](rawOutput)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, item := range parsed.Items {
|
||||
if _, exists := seen[item.ID]; exists {
|
||||
continue
|
||||
}
|
||||
seen[item.ID] = struct{}{}
|
||||
merged = append(merged, item)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 2. Quadrants 为空表示全象限,执行一次无象限过滤查询。
|
||||
if len(plan.Quadrants) == 0 {
|
||||
if err := runOne(nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// 3. 指定象限时逐个调用工具并合并去重。
|
||||
for _, quadrant := range plan.Quadrants {
|
||||
q := quadrant
|
||||
if err := runOne(&q); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 合并后再按计划统一排序,保证跨象限结果顺序稳定。
|
||||
sortTaskQueryToolRecords(merged, plan)
|
||||
if len(merged) > plan.Limit {
|
||||
merged = merged[:plan.Limit]
|
||||
}
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
func normalizePlan(raw taskQueryPlanOutput) QueryPlan {
|
||||
plan := defaultTaskQueryPlan()
|
||||
plan.Quadrants = normalizeQuadrants(raw.Quadrants)
|
||||
|
||||
sortBy := strings.ToLower(strings.TrimSpace(raw.SortBy))
|
||||
switch sortBy {
|
||||
case "deadline", "priority", "id":
|
||||
plan.SortBy = sortBy
|
||||
}
|
||||
|
||||
order := strings.ToLower(strings.TrimSpace(raw.Order))
|
||||
switch order {
|
||||
case "asc", "desc":
|
||||
plan.Order = order
|
||||
}
|
||||
|
||||
if raw.Limit > 0 {
|
||||
plan.Limit = raw.Limit
|
||||
}
|
||||
if plan.Limit > MaxTaskQueryLimit {
|
||||
plan.Limit = MaxTaskQueryLimit
|
||||
}
|
||||
if plan.Limit <= 0 {
|
||||
plan.Limit = DefaultTaskQueryLimit
|
||||
}
|
||||
|
||||
if raw.IncludeCompleted != nil {
|
||||
plan.IncludeCompleted = *raw.IncludeCompleted
|
||||
}
|
||||
plan.Keyword = strings.TrimSpace(raw.Keyword)
|
||||
plan.DeadlineBeforeText = strings.TrimSpace(raw.DeadlineBefore)
|
||||
plan.DeadlineAfterText = strings.TrimSpace(raw.DeadlineAfter)
|
||||
applyTimeAnchorOnPlan(&plan)
|
||||
return plan
|
||||
}
|
||||
|
||||
func defaultTaskQueryPlan() QueryPlan {
|
||||
return QueryPlan{
|
||||
Quadrants: nil,
|
||||
SortBy: "deadline",
|
||||
Order: "asc",
|
||||
Limit: DefaultTaskQueryLimit,
|
||||
IncludeCompleted: false,
|
||||
Keyword: "",
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeQuadrants(quadrants []int) []int {
|
||||
if len(quadrants) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[int]struct{}, len(quadrants))
|
||||
result := make([]int, 0, len(quadrants))
|
||||
for _, q := range quadrants {
|
||||
if q < 1 || q > 4 {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[q]; exists {
|
||||
continue
|
||||
}
|
||||
seen[q] = struct{}{}
|
||||
result = append(result, q)
|
||||
}
|
||||
sort.Ints(result)
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(result) == 4 {
|
||||
// 指定了全部象限时与“空=全象限”等价,统一归一化为 nil。
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func applyTimeAnchorOnPlan(plan *QueryPlan) {
|
||||
if plan == nil {
|
||||
return
|
||||
}
|
||||
before, errBefore := parseOptionalBoundaryTime(plan.DeadlineBeforeText, true)
|
||||
after, errAfter := parseOptionalBoundaryTime(plan.DeadlineAfterText, false)
|
||||
|
||||
if errBefore != nil {
|
||||
plan.DeadlineBefore = nil
|
||||
plan.DeadlineBeforeText = ""
|
||||
} else {
|
||||
plan.DeadlineBefore = before
|
||||
}
|
||||
if errAfter != nil {
|
||||
plan.DeadlineAfter = nil
|
||||
plan.DeadlineAfterText = ""
|
||||
} else {
|
||||
plan.DeadlineAfter = after
|
||||
}
|
||||
|
||||
// 边界冲突时清空,防止构造出“必为空结果”的死条件。
|
||||
if plan.DeadlineBefore != nil && plan.DeadlineAfter != nil && plan.DeadlineAfter.After(*plan.DeadlineBefore) {
|
||||
plan.DeadlineBefore = nil
|
||||
plan.DeadlineAfter = nil
|
||||
plan.DeadlineBeforeText = ""
|
||||
plan.DeadlineAfterText = ""
|
||||
}
|
||||
}
|
||||
|
||||
func autoBroadenPlan(plan QueryPlan) (QueryPlan, bool) {
|
||||
// 1. 仅允许自动放宽一次,且放宽必须“可解释”:
|
||||
// 1.1 清空关键词;
|
||||
// 1.2 放开完成状态;
|
||||
// 1.3 清空时间边界;
|
||||
// 1.4 不主动改象限和 limit,避免语义漂移(例如“简单任务”被放宽成全象限)。
|
||||
changed := false
|
||||
broadened := plan
|
||||
|
||||
if strings.TrimSpace(broadened.Keyword) != "" {
|
||||
broadened.Keyword = ""
|
||||
changed = true
|
||||
}
|
||||
if !broadened.IncludeCompleted {
|
||||
broadened.IncludeCompleted = true
|
||||
changed = true
|
||||
}
|
||||
if broadened.DeadlineBefore != nil || broadened.DeadlineAfter != nil ||
|
||||
broadened.DeadlineBeforeText != "" || broadened.DeadlineAfterText != "" {
|
||||
broadened.DeadlineBefore = nil
|
||||
broadened.DeadlineAfter = nil
|
||||
broadened.DeadlineBeforeText = ""
|
||||
broadened.DeadlineAfterText = ""
|
||||
changed = true
|
||||
}
|
||||
return broadened, changed
|
||||
}
|
||||
|
||||
func applyRetryPatch(plan QueryPlan, patch taskQueryRetryPatch, explicitLimit int) QueryPlan {
|
||||
next := plan
|
||||
changed := false
|
||||
|
||||
if patch.Quadrants != nil {
|
||||
next.Quadrants = normalizeQuadrants(*patch.Quadrants)
|
||||
changed = true
|
||||
}
|
||||
if patch.SortBy != nil {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(*patch.SortBy))
|
||||
if sortBy == "deadline" || sortBy == "priority" || sortBy == "id" {
|
||||
next.SortBy = sortBy
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if patch.Order != nil {
|
||||
order := strings.ToLower(strings.TrimSpace(*patch.Order))
|
||||
if order == "asc" || order == "desc" {
|
||||
next.Order = order
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if patch.Limit != nil {
|
||||
// 用户显式指定数量时,锁定 limit,不允许反思补丁改写。
|
||||
if explicitLimit <= 0 {
|
||||
limit := *patch.Limit
|
||||
if limit <= 0 {
|
||||
limit = DefaultTaskQueryLimit
|
||||
}
|
||||
if limit > MaxTaskQueryLimit {
|
||||
limit = MaxTaskQueryLimit
|
||||
}
|
||||
next.Limit = limit
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if patch.IncludeCompleted != nil {
|
||||
next.IncludeCompleted = *patch.IncludeCompleted
|
||||
changed = true
|
||||
}
|
||||
if patch.Keyword != nil {
|
||||
next.Keyword = strings.TrimSpace(*patch.Keyword)
|
||||
changed = true
|
||||
}
|
||||
if patch.DeadlineBefore != nil {
|
||||
next.DeadlineBeforeText = strings.TrimSpace(*patch.DeadlineBefore)
|
||||
changed = true
|
||||
}
|
||||
if patch.DeadlineAfter != nil {
|
||||
next.DeadlineAfterText = strings.TrimSpace(*patch.DeadlineAfter)
|
||||
changed = true
|
||||
}
|
||||
|
||||
if changed {
|
||||
applyTimeAnchorOnPlan(&next)
|
||||
}
|
||||
// 双保险:显式数量存在时再次锁定,避免其他路径误改。
|
||||
if explicitLimit > 0 {
|
||||
next.Limit = explicitLimit
|
||||
}
|
||||
return next
|
||||
}
|
||||
|
||||
func buildReflectUserPrompt(st *TaskQueryState) string {
|
||||
planSummary := summarizePlan(st.Plan)
|
||||
resultSummary := summarizeQueryItems(st.LastQueryItems, 6)
|
||||
return fmt.Sprintf(`当前时间:%s
|
||||
用户原话:%s
|
||||
用户目标:%s
|
||||
当前查询计划:%s
|
||||
当前重试:%d/%d
|
||||
查询结果摘要:
|
||||
%s`,
|
||||
st.RequestNowText,
|
||||
st.UserMessage,
|
||||
st.UserGoal,
|
||||
planSummary,
|
||||
st.RetryCount,
|
||||
st.MaxReflectRetry,
|
||||
resultSummary,
|
||||
)
|
||||
}
|
||||
|
||||
func summarizePlan(plan QueryPlan) string {
|
||||
quadrants := "全部象限"
|
||||
if len(plan.Quadrants) > 0 {
|
||||
parts := make([]string, 0, len(plan.Quadrants))
|
||||
for _, q := range plan.Quadrants {
|
||||
parts = append(parts, strconv.Itoa(q))
|
||||
}
|
||||
quadrants = strings.Join(parts, ",")
|
||||
}
|
||||
return fmt.Sprintf("quadrants=%s sort=%s/%s limit=%d include_completed=%t keyword=%s before=%s after=%s",
|
||||
quadrants, plan.SortBy, plan.Order, plan.Limit, plan.IncludeCompleted,
|
||||
emptyToDash(plan.Keyword), emptyToDash(plan.DeadlineBeforeText), emptyToDash(plan.DeadlineAfterText))
|
||||
}
|
||||
|
||||
func summarizeQueryItems(items []TaskQueryToolRecord, max int) string {
|
||||
if len(items) == 0 {
|
||||
return "无结果"
|
||||
}
|
||||
if max <= 0 {
|
||||
max = 5
|
||||
}
|
||||
if len(items) > max {
|
||||
items = items[:max]
|
||||
}
|
||||
lines := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
line := fmt.Sprintf("- #%d %s | 象限=%d | 完成=%t | 截止=%s",
|
||||
item.ID, item.Title, item.PriorityGroup, item.IsCompleted, emptyToDash(item.DeadlineAt))
|
||||
lines = append(lines, line)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func buildTaskQueryFallbackReply(items []TaskQueryToolRecord) string {
|
||||
if len(items) == 0 {
|
||||
return "我这边暂时没找到匹配的任务。你可以再补一句,比如“按截止时间最早的前3个”或“只看简单不重要”。"
|
||||
}
|
||||
// 1. 用最多 3 条摘要拼一个稳态回复,避免模型异常时空白返回。
|
||||
preview := items
|
||||
if len(preview) > 3 {
|
||||
preview = preview[:3]
|
||||
}
|
||||
lines := make([]string, 0, len(preview))
|
||||
for _, item := range preview {
|
||||
lines = append(lines, fmt.Sprintf("%s(%s)", item.Title, item.PriorityLabel))
|
||||
}
|
||||
return fmt.Sprintf("我先给你筛到这些:%s。要不要我再按“更紧急”或“更简单”继续细化?", strings.Join(lines, "、"))
|
||||
}
|
||||
|
||||
// buildTaskQueryFinalReply 构建“确定性条数”的最终回复。
|
||||
//
|
||||
// 设计目的:
|
||||
// 1. 让返回条数严格受 plan.limit 约束,避免 LLM 自由发挥导致“只说1条”;
|
||||
// 2. 仍可保留 LLM 的语气前缀,但清单主体由后端稳定渲染;
|
||||
// 3. 无结果时统一走兜底文案。
|
||||
func buildTaskQueryFinalReply(items []TaskQueryToolRecord, plan QueryPlan, llmReply string) string {
|
||||
if len(items) == 0 {
|
||||
base := buildTaskQueryFallbackReply(items)
|
||||
if strings.TrimSpace(llmReply) == "" {
|
||||
return base
|
||||
}
|
||||
return strings.TrimSpace(llmReply) + "\n" + base
|
||||
}
|
||||
|
||||
desired := plan.Limit
|
||||
if desired <= 0 {
|
||||
desired = DefaultTaskQueryLimit
|
||||
}
|
||||
if desired > MaxTaskQueryLimit {
|
||||
desired = MaxTaskQueryLimit
|
||||
}
|
||||
showCount := desired
|
||||
if len(items) < showCount {
|
||||
showCount = len(items)
|
||||
}
|
||||
|
||||
preview := items[:showCount]
|
||||
lines := make([]string, 0, len(preview))
|
||||
for idx, item := range preview {
|
||||
deadline := strings.TrimSpace(item.DeadlineAt)
|
||||
if deadline == "" {
|
||||
deadline = "无明确截止时间"
|
||||
}
|
||||
status := "未完成"
|
||||
if item.IsCompleted {
|
||||
status = "已完成"
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%d. %s(%s,%s,截止:%s)",
|
||||
idx+1, item.Title, item.PriorityLabel, status, deadline))
|
||||
}
|
||||
|
||||
header := fmt.Sprintf("给你整理了 %d 条任务:", showCount)
|
||||
if lead := extractSafeReplyLead(llmReply); lead != "" {
|
||||
header = lead + "\n" + header
|
||||
}
|
||||
|
||||
reply := header + "\n" + strings.Join(lines, "\n")
|
||||
if len(items) > showCount {
|
||||
reply += fmt.Sprintf("\n另外还有 %d 条匹配任务,要不要我继续往下列?", len(items)-showCount)
|
||||
}
|
||||
return reply
|
||||
}
|
||||
|
||||
// extractSafeReplyLead 从 LLM 回复中提取“安全前缀句”。
|
||||
//
|
||||
// 目的:
|
||||
// 1. 防止 LLM 已经输出一整段列表时再次和后端列表拼接,造成双重输出;
|
||||
// 2. 仅保留单行短句语气前缀,正文列表始终以后端确定性渲染为准。
|
||||
func extractSafeReplyLead(llmReply string) string {
|
||||
text := strings.TrimSpace(llmReply)
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
// 有明显列表迹象时直接丢弃,避免重复列举。
|
||||
lower := strings.ToLower(text)
|
||||
if strings.Contains(text, "\n") || strings.Contains(text, "#") ||
|
||||
strings.Contains(lower, "1.") || strings.Contains(text, "1、") || strings.Contains(text, "以下是") {
|
||||
return ""
|
||||
}
|
||||
// 太长也不保留,避免把冗长模型输出混进最终回复。
|
||||
if len([]rune(text)) > 30 {
|
||||
return ""
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func sortTaskQueryToolRecords(items []TaskQueryToolRecord, plan QueryPlan) {
|
||||
if len(items) <= 1 {
|
||||
return
|
||||
}
|
||||
sortBy := strings.ToLower(strings.TrimSpace(plan.SortBy))
|
||||
order := strings.ToLower(strings.TrimSpace(plan.Order))
|
||||
if order != "desc" {
|
||||
order = "asc"
|
||||
}
|
||||
|
||||
sort.SliceStable(items, func(i, j int) bool {
|
||||
left := items[i]
|
||||
right := items[j]
|
||||
switch sortBy {
|
||||
case "priority":
|
||||
if left.PriorityGroup != right.PriorityGroup {
|
||||
if order == "desc" {
|
||||
return left.PriorityGroup > right.PriorityGroup
|
||||
}
|
||||
return left.PriorityGroup < right.PriorityGroup
|
||||
}
|
||||
return left.ID > right.ID
|
||||
case "id":
|
||||
if order == "desc" {
|
||||
return left.ID > right.ID
|
||||
}
|
||||
return left.ID < right.ID
|
||||
default:
|
||||
lTime, lOK := parseRecordDeadline(left.DeadlineAt)
|
||||
rTime, rOK := parseRecordDeadline(right.DeadlineAt)
|
||||
if lOK && rOK {
|
||||
if !lTime.Equal(rTime) {
|
||||
if order == "desc" {
|
||||
return lTime.After(rTime)
|
||||
}
|
||||
return lTime.Before(rTime)
|
||||
}
|
||||
return left.ID > right.ID
|
||||
}
|
||||
if lOK && !rOK {
|
||||
return true
|
||||
}
|
||||
if !lOK && rOK {
|
||||
return false
|
||||
}
|
||||
return left.ID > right.ID
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func parseRecordDeadline(raw string) (time.Time, bool) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
t, err := time.ParseInLocation("2006-01-02 15:04", text, time.Local)
|
||||
if err != nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return t, true
|
||||
}
|
||||
|
||||
func emptyToDash(text string) string {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return "-"
|
||||
}
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
// extractExplicitLimitFromUser 从用户原话提取显式数量诉求。
|
||||
//
|
||||
// 解析策略:
|
||||
// 1. 先匹配阿拉伯数字(前3个/top 5/给我2条);
|
||||
// 2. 再匹配常见中文数字(前五个/来三个);
|
||||
// 3. 统一限制在 1~20 之间。
|
||||
func extractExplicitLimitFromUser(userMessage string) (int, bool) {
|
||||
text := strings.TrimSpace(userMessage)
|
||||
if text == "" {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
for _, pattern := range explicitLimitPatterns {
|
||||
matches := pattern.FindStringSubmatch(text)
|
||||
if len(matches) < 2 {
|
||||
continue
|
||||
}
|
||||
number, err := strconv.Atoi(strings.TrimSpace(matches[1]))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
return normalizeExplicitLimit(number)
|
||||
}
|
||||
|
||||
// 中文数字兜底:覆盖高频口语模式。
|
||||
chinesePatterns := []string{"前", "来", "给我"}
|
||||
for _, prefix := range chinesePatterns {
|
||||
for digitRune, number := range chineseDigitMap {
|
||||
token := prefix + string(digitRune)
|
||||
if strings.Contains(text, token) {
|
||||
return normalizeExplicitLimit(number)
|
||||
}
|
||||
// “前五个”“来三个”这类再补一个“个/条/项”尾缀判断。
|
||||
for _, suffix := range []string{"个", "条", "项"} {
|
||||
if strings.Contains(text, token+suffix) {
|
||||
return normalizeExplicitLimit(number)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func normalizeExplicitLimit(number int) (int, bool) {
|
||||
if number <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
if number > MaxTaskQueryLimit {
|
||||
number = MaxTaskQueryLimit
|
||||
}
|
||||
return number, true
|
||||
}
|
||||
|
||||
func callTaskQueryModelForJSON(
|
||||
ctx context.Context,
|
||||
model *ark.ChatModel,
|
||||
systemPrompt string,
|
||||
userPrompt string,
|
||||
maxTokens int,
|
||||
) (string, error) {
|
||||
if model == nil {
|
||||
return "", fmt.Errorf("task query model is nil")
|
||||
}
|
||||
messages := []*schema.Message{
|
||||
schema.SystemMessage(systemPrompt),
|
||||
schema.UserMessage(userPrompt),
|
||||
}
|
||||
|
||||
opts := []einoModel.Option{
|
||||
ark.WithThinking(&arkModel.Thinking{Type: arkModel.ThinkingTypeDisabled}),
|
||||
einoModel.WithTemperature(0),
|
||||
}
|
||||
if maxTokens > 0 {
|
||||
opts = append(opts, einoModel.WithMaxTokens(maxTokens))
|
||||
}
|
||||
|
||||
resp, err := model.Generate(ctx, messages, opts...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp == nil {
|
||||
return "", fmt.Errorf("task query model returned nil")
|
||||
}
|
||||
text := strings.TrimSpace(resp.Content)
|
||||
if text == "" {
|
||||
return "", fmt.Errorf("task query model returned empty content")
|
||||
}
|
||||
return text, nil
|
||||
}
|
||||
|
||||
func parseTaskQueryJSON[T any](raw string) (*T, error) {
|
||||
clean := strings.TrimSpace(raw)
|
||||
if clean == "" {
|
||||
return nil, fmt.Errorf("empty response")
|
||||
}
|
||||
|
||||
// 1. 兼容 ```json 包裹格式。
|
||||
if strings.HasPrefix(clean, "```") {
|
||||
clean = strings.TrimPrefix(clean, "```json")
|
||||
clean = strings.TrimPrefix(clean, "```")
|
||||
clean = strings.TrimSuffix(clean, "```")
|
||||
clean = strings.TrimSpace(clean)
|
||||
}
|
||||
|
||||
// 2. 先尝试整体解析。
|
||||
var out T
|
||||
if err := json.Unmarshal([]byte(clean), &out); err == nil {
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
// 3. 若模型前后带了额外文本,则提取最外层对象再解析。
|
||||
start := strings.Index(clean, "{")
|
||||
end := strings.LastIndex(clean, "}")
|
||||
if start == -1 || end == -1 || end <= start {
|
||||
return nil, fmt.Errorf("no json object found")
|
||||
}
|
||||
obj := clean[start : end+1]
|
||||
if err := json.Unmarshal([]byte(obj), &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
package taskquery
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestExtractExplicitLimitFromUser_Number
|
||||
// 目的:验证用户原话里的阿拉伯数字数量诉求可以被正确提取。
|
||||
func TestExtractExplicitLimitFromUser_Number(t *testing.T) {
|
||||
limit, ok := extractExplicitLimitFromUser("给我3个优先级低的任务")
|
||||
if !ok {
|
||||
t.Fatalf("期望识别到显式数量")
|
||||
}
|
||||
if limit != 3 {
|
||||
t.Fatalf("数量识别错误,期望=3 实际=%d", limit)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractExplicitLimitFromUser_ChineseNumber
|
||||
// 目的:验证常见中文数字(如“前五个”)也能识别数量。
|
||||
func TestExtractExplicitLimitFromUser_ChineseNumber(t *testing.T) {
|
||||
limit, ok := extractExplicitLimitFromUser("前五个简单任务给我看看")
|
||||
if !ok {
|
||||
t.Fatalf("期望识别到中文数量")
|
||||
}
|
||||
if limit != 5 {
|
||||
t.Fatalf("数量识别错误,期望=5 实际=%d", limit)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractExplicitLimitFromUser_LaiYiGe
|
||||
// 目的:验证“来一个...”这种口语数量表达也能识别为 1。
|
||||
func TestExtractExplicitLimitFromUser_LaiYiGe(t *testing.T) {
|
||||
limit, ok := extractExplicitLimitFromUser("来一个我的简单任务")
|
||||
if !ok {
|
||||
t.Fatalf("期望识别到“来一个”的显式数量")
|
||||
}
|
||||
if limit != 1 {
|
||||
t.Fatalf("数量识别错误,期望=1 实际=%d", limit)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildTaskQueryFinalReply_RespectsLimit
|
||||
// 目的:验证最终回复会按 plan.limit 输出对应条数,而不是由 LLM 自由决定条数。
|
||||
func TestBuildTaskQueryFinalReply_RespectsLimit(t *testing.T) {
|
||||
items := []TaskQueryToolRecord{
|
||||
{ID: 1, Title: "任务1", PriorityLabel: "简单不重要", DeadlineAt: "2026-03-16 10:00"},
|
||||
{ID: 2, Title: "任务2", PriorityLabel: "简单不重要", DeadlineAt: "2026-03-17 10:00"},
|
||||
{ID: 3, Title: "任务3", PriorityLabel: "简单不重要", DeadlineAt: "2026-03-18 10:00"},
|
||||
}
|
||||
reply := buildTaskQueryFinalReply(items, QueryPlan{Limit: 2}, "好的")
|
||||
if !strings.Contains(reply, "整理了 2 条任务") {
|
||||
t.Fatalf("回复未体现 limit=2,reply=%s", reply)
|
||||
}
|
||||
if strings.Contains(reply, "3. ") {
|
||||
t.Fatalf("回复不应出现第3条,reply=%s", reply)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildTaskQueryFinalReply_NoDuplicateList
|
||||
// 目的:验证当 llmReply 已带列表内容时,不会和后端确定性列表重复拼接。
|
||||
func TestBuildTaskQueryFinalReply_NoDuplicateList(t *testing.T) {
|
||||
items := []TaskQueryToolRecord{
|
||||
{ID: 1, Title: "任务1", PriorityLabel: "简单不重要", DeadlineAt: "2026-03-16 10:00"},
|
||||
}
|
||||
llmReply := "以下是你的任务:\n#1 任务1"
|
||||
reply := buildTaskQueryFinalReply(items, QueryPlan{Limit: 1}, llmReply)
|
||||
if strings.Contains(reply, "以下是你的任务") {
|
||||
t.Fatalf("不应保留 llm 列表头,reply=%s", reply)
|
||||
}
|
||||
if !strings.Contains(reply, "整理了 1 条任务") {
|
||||
t.Fatalf("应保留后端确定性列表头,reply=%s", reply)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyRetryPatch_RespectExplicitLimit
|
||||
// 目的:验证用户显式数量存在时,反思补丁不能改写 limit。
|
||||
func TestApplyRetryPatch_RespectExplicitLimit(t *testing.T) {
|
||||
plan := QueryPlan{Limit: 1, SortBy: "deadline", Order: "asc"}
|
||||
limit := 10
|
||||
next := applyRetryPatch(plan, taskQueryRetryPatch{Limit: &limit}, 1)
|
||||
if next.Limit != 1 {
|
||||
t.Fatalf("显式数量锁应生效,期望=1 实际=%d", next.Limit)
|
||||
}
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
package taskquery
|
||||
|
||||
const (
|
||||
// TaskQueryAssistantPrompt 是“任务查询”分支的系统提示词。
|
||||
//
|
||||
// 设计目标:
|
||||
// 1. 把“先查工具再回答”的约束写死,减少模型直接编造任务的风险;
|
||||
// 2. 约束输出风格:简洁、可执行、可追问;
|
||||
// 3. 当用户需求不完整时,引导模型先做合理默认,再补充可选澄清。
|
||||
TaskQueryAssistantPrompt = `你是 SmartFlow 的任务查询助手。
|
||||
你的职责是:根据用户的问题,从任务工具中检索真实任务,再给出中文回复。
|
||||
|
||||
强约束:
|
||||
1) 只要用户在“查任务/筛任务/排序任务/找任务”,必须优先调用 query_tasks 工具,不要凭空回答。
|
||||
2) 工具返回为空时,直接说明“当前没有匹配任务”,并给一个简短下一步建议。
|
||||
3) 结果较多时,默认展示前 3~5 条关键信息(标题、象限、截止时间、完成状态)。
|
||||
4) 用户指令不完整时可先用默认参数查一次,再补一句澄清建议,不要反复追问。
|
||||
5) 回复必须自然口语化,禁止输出 markdown 表格。`
|
||||
|
||||
// TaskQueryPlanPrompt 是“任务查询规划节点”的系统提示词。
|
||||
//
|
||||
// 设计目标:
|
||||
// 1. 只调用一次模型,把“象限选择 + 排序 + 时间过滤 + 结果规模”统一规划出来;
|
||||
// 2. 输出强约束 JSON,便于后端节点稳定解析;
|
||||
// 3. 不要求模型直接生成最终回复,避免规划阶段混入废话。
|
||||
TaskQueryPlanPrompt = `你是 SmartFlow 的任务查询规划器。
|
||||
请根据用户原话,输出“结构化查询计划”JSON,供后端直接执行。
|
||||
|
||||
输出字段(只允许 JSON,不要解释):
|
||||
{
|
||||
"user_goal": "一句话总结用户诉求",
|
||||
"quadrants": [1,2,3,4],
|
||||
"sort_by": "deadline|priority|id",
|
||||
"order": "asc|desc",
|
||||
"limit": 1-20,
|
||||
"include_completed": false,
|
||||
"keyword": "可选关键词,或空字符串",
|
||||
"deadline_before": "yyyy-MM-dd HH:mm 或空字符串",
|
||||
"deadline_after": "yyyy-MM-dd HH:mm 或空字符串"
|
||||
}
|
||||
|
||||
规则:
|
||||
1) quadrants 为空数组表示“全部象限”。
|
||||
2) 若用户没提排序,默认 deadline + asc。
|
||||
3) 若用户没提数量,limit 默认 5。
|
||||
4) 时间字段必须是绝对时间或空字符串,不得输出相对时间。
|
||||
5) 只有用户的语义偏向"我还有啥事要做",即了解自己待办的请求,才优先1,2象限,即重要并紧急或者重要不紧急,若1,2象限没任务,则自动退至3,4象限;如果用户语义偏向"来点事情做做",那就说明用户需要无关紧要的事情做做,则优先3,4象限,即简单不重要或者不简单不重要。
|
||||
6) 允许多选象限。`
|
||||
|
||||
// TaskQueryReflectPrompt 是“查询结果反思节点”的系统提示词。
|
||||
//
|
||||
// 设计目标:
|
||||
// 1. 让模型判断“当前结果是否满足用户诉求”;
|
||||
// 2. 若不满足,给出可执行的轻量 patch(最多改几个关键条件);
|
||||
// 3. 同时输出可直接返回给用户的 reply,减少额外生成调用。
|
||||
TaskQueryReflectPrompt = `你是 SmartFlow 的任务查询结果审阅器。
|
||||
你会看到:用户原话、当前查询计划、查询结果摘要、当前重试次数。
|
||||
|
||||
请仅输出 JSON:
|
||||
{
|
||||
"satisfied": true/false,
|
||||
"need_retry": true/false,
|
||||
"reason": "一句话原因",
|
||||
"reply": "可直接给用户看的中文回复",
|
||||
"retry_patch": {
|
||||
"quadrants": [1,2,3,4],
|
||||
"sort_by": "deadline|priority|id",
|
||||
"order": "asc|desc",
|
||||
"limit": 1-20,
|
||||
"include_completed": true/false,
|
||||
"keyword": "字符串",
|
||||
"deadline_before": "yyyy-MM-dd HH:mm 或空字符串",
|
||||
"deadline_after": "yyyy-MM-dd HH:mm 或空字符串"
|
||||
}
|
||||
}
|
||||
|
||||
规则:
|
||||
1) 若结果已满足,satisfied=true 且 need_retry=false。
|
||||
2) 若结果不满足且仍可尝试,need_retry=true,并给最小必要 patch。
|
||||
3) 若不建议再试,need_retry=false,并在 reply 中说明当前最接近结果。`
|
||||
)
|
||||
@@ -1,88 +0,0 @@
|
||||
package taskquery
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
// DefaultTaskQueryLimit 是任务查询默认返回条数。
|
||||
DefaultTaskQueryLimit = 5
|
||||
// MaxTaskQueryLimit 是任务查询最大返回条数。
|
||||
MaxTaskQueryLimit = 20
|
||||
// DefaultReflectRetryMax 是反思重试默认上限。
|
||||
DefaultReflectRetryMax = 2
|
||||
)
|
||||
|
||||
// TaskQueryState 是任务查询图在节点间传递的统一状态容器。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 保存“规划参数、查询结果、反思决策、最终回复”;
|
||||
// 2. 控制“是否重试 + 已重试次数”状态机;
|
||||
// 3. 不负责真正查库,查库由工具执行。
|
||||
type TaskQueryState struct {
|
||||
// 请求上下文
|
||||
UserMessage string
|
||||
RequestNowText string
|
||||
|
||||
// 规划结果
|
||||
UserGoal string
|
||||
Plan QueryPlan
|
||||
// ExplicitLimit 表示“用户原话中明确指定的数量”。
|
||||
//
|
||||
// 语义说明:
|
||||
// 1. 0 代表未显式指定;
|
||||
// 2. >0 时应锁定该数量,不允许反思补丁或自动放宽改写。
|
||||
ExplicitLimit int
|
||||
|
||||
// 上一轮查询结果
|
||||
LastQueryItems []TaskQueryToolRecord
|
||||
LastQueryTotal int
|
||||
|
||||
// 自动放宽状态
|
||||
AutoBroadenApplied bool
|
||||
|
||||
// 反思状态
|
||||
RetryCount int
|
||||
MaxReflectRetry int
|
||||
NeedRetry bool
|
||||
ReflectReason string
|
||||
|
||||
// 最终输出
|
||||
FinalReply string
|
||||
}
|
||||
|
||||
// QueryPlan 是“任务查询计划”的统一结构。
|
||||
//
|
||||
// 语义说明:
|
||||
// 1. Quadrants 为空表示“查全部象限”;非空表示“只查这些象限”;
|
||||
// 2. DeadlineBefore/AfterText 保留原始文本,方便日志和反思 prompt;
|
||||
// 3. DeadlineBefore/After 是解析后的时间对象,供工具调用使用。
|
||||
type QueryPlan struct {
|
||||
Quadrants []int
|
||||
|
||||
SortBy string
|
||||
Order string
|
||||
Limit int
|
||||
|
||||
IncludeCompleted bool
|
||||
Keyword string
|
||||
|
||||
DeadlineBeforeText string
|
||||
DeadlineAfterText string
|
||||
DeadlineBefore *time.Time
|
||||
DeadlineAfter *time.Time
|
||||
}
|
||||
|
||||
// NewTaskQueryState 创建任务查询初始状态。
|
||||
func NewTaskQueryState(userMessage, requestNowText string, maxReflectRetry int) *TaskQueryState {
|
||||
if maxReflectRetry <= 0 {
|
||||
maxReflectRetry = DefaultReflectRetryMax
|
||||
}
|
||||
return &TaskQueryState{
|
||||
UserMessage: userMessage,
|
||||
RequestNowText: requestNowText,
|
||||
MaxReflectRetry: maxReflectRetry,
|
||||
LastQueryItems: make([]TaskQueryToolRecord, 0),
|
||||
NeedRetry: false,
|
||||
ReflectReason: "",
|
||||
AutoBroadenApplied: false,
|
||||
}
|
||||
}
|
||||
@@ -1,344 +0,0 @@
|
||||
package taskquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
toolutils "github.com/cloudwego/eino/components/tool/utils"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
const (
|
||||
// ToolNameTaskQueryTasks 是“任务查询工具”对模型暴露的标准名称。
|
||||
ToolNameTaskQueryTasks = "query_tasks"
|
||||
// ToolDescTaskQueryTasks 是工具职责说明,给模型理解参数语义。
|
||||
ToolDescTaskQueryTasks = "按象限/关键字/截止时间筛选并排序任务,返回结构化任务列表"
|
||||
)
|
||||
|
||||
var (
|
||||
// taskQueryTimeLayouts 是任务查询工具允许的时间输入格式白名单。
|
||||
taskQueryTimeLayouts = []string{
|
||||
time.RFC3339,
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02 15:04",
|
||||
"2006-01-02",
|
||||
}
|
||||
)
|
||||
|
||||
// TaskQueryToolDeps 描述任务查询工具依赖的外部能力。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. QueryTasks 负责真实数据读取;
|
||||
// 2. 工具层只负责参数校验与结果封装,不直接耦合 DAO 实现。
|
||||
type TaskQueryToolDeps struct {
|
||||
QueryTasks func(ctx context.Context, req TaskQueryRequest) ([]TaskRecord, error)
|
||||
}
|
||||
|
||||
func (d TaskQueryToolDeps) validate() error {
|
||||
// 1. 工具没有 QueryTasks 依赖就无法提供任何真实结果,启动时直接失败。
|
||||
if d.QueryTasks == nil {
|
||||
return errors.New("task query tool deps: QueryTasks is nil")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TaskQueryToolBundle 是任务查询工具包输出。
|
||||
//
|
||||
// 说明:
|
||||
// 1. Tools 用于实际执行;
|
||||
// 2. ToolInfos 用于模型注册工具 schema。
|
||||
type TaskQueryToolBundle struct {
|
||||
Tools []tool.BaseTool
|
||||
ToolInfos []*schema.ToolInfo
|
||||
}
|
||||
|
||||
// TaskQueryRequest 是工具层到业务层的内部查询请求。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只承载“查询条件”,不承载数据库/缓存实现细节;
|
||||
// 2. UserID 不由模型提供,必须由服务层上下文注入。
|
||||
type TaskQueryRequest struct {
|
||||
UserID int
|
||||
Quadrant *int
|
||||
SortBy string
|
||||
Order string
|
||||
Limit int
|
||||
IncludeCompleted bool
|
||||
Keyword string
|
||||
DeadlineBefore *time.Time
|
||||
DeadlineAfter *time.Time
|
||||
}
|
||||
|
||||
// TaskRecord 是业务层返回给工具层的任务记录。
|
||||
type TaskRecord struct {
|
||||
ID int
|
||||
Title string
|
||||
PriorityGroup int
|
||||
IsCompleted bool
|
||||
DeadlineAt *time.Time
|
||||
UrgencyThresholdAt *time.Time
|
||||
}
|
||||
|
||||
// TaskQueryToolInput 是对模型暴露的工具输入结构。
|
||||
//
|
||||
// 参数语义:
|
||||
// 1. quadrant 可选:1~4;
|
||||
// 2. sort_by 可选:deadline/priority/id;
|
||||
// 3. order 可选:asc/desc;
|
||||
// 4. limit 可选:默认 5,上限 20;
|
||||
// 5. include_completed 可选:默认 false。
|
||||
type TaskQueryToolInput struct {
|
||||
Quadrant *int `json:"quadrant,omitempty" jsonschema:"description=可选象限(1~4)"`
|
||||
SortBy string `json:"sort_by,omitempty" jsonschema:"description=排序字段(deadline|priority|id)"`
|
||||
Order string `json:"order,omitempty" jsonschema:"description=排序方向(asc|desc)"`
|
||||
Limit int `json:"limit,omitempty" jsonschema:"description=返回条数,默认5,上限20"`
|
||||
IncludeCompleted *bool `json:"include_completed,omitempty" jsonschema:"description=是否包含已完成任务,默认false"`
|
||||
Keyword string `json:"keyword,omitempty" jsonschema:"description=可选标题关键词,模糊匹配"`
|
||||
DeadlineBefore string `json:"deadline_before,omitempty" jsonschema:"description=可选截止上界,支持RFC3339或yyyy-MM-dd HH:mm"`
|
||||
DeadlineAfter string `json:"deadline_after,omitempty" jsonschema:"description=可选截止下界,支持RFC3339或yyyy-MM-dd HH:mm"`
|
||||
}
|
||||
|
||||
// TaskQueryToolOutput 是返回给模型的结构化结果。
|
||||
type TaskQueryToolOutput struct {
|
||||
Total int `json:"total"`
|
||||
Items []TaskQueryToolRecord `json:"items"`
|
||||
}
|
||||
|
||||
// TaskQueryToolRecord 是单条任务输出结构。
|
||||
type TaskQueryToolRecord struct {
|
||||
ID int `json:"id"`
|
||||
Title string `json:"title"`
|
||||
PriorityGroup int `json:"priority_group"`
|
||||
PriorityLabel string `json:"priority_label"`
|
||||
IsCompleted bool `json:"is_completed"`
|
||||
DeadlineAt string `json:"deadline_at,omitempty"`
|
||||
UrgencyThresholdAt string `json:"urgency_threshold_at,omitempty"`
|
||||
}
|
||||
|
||||
// BuildTaskQueryToolBundle 构建任务查询工具包。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先校验依赖,确保工具具备真实查询能力;
|
||||
// 2. 通过 InferTool 声明工具 schema,并在闭包内做全部参数校验;
|
||||
// 3. 输出 Tools + ToolInfos,供模型与执行器分别使用。
|
||||
func BuildTaskQueryToolBundle(ctx context.Context, deps TaskQueryToolDeps) (*TaskQueryToolBundle, error) {
|
||||
if err := deps.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
queryTool, err := toolutils.InferTool(
|
||||
ToolNameTaskQueryTasks,
|
||||
ToolDescTaskQueryTasks,
|
||||
func(ctx context.Context, input *TaskQueryToolInput) (*TaskQueryToolOutput, error) {
|
||||
// 1. 允许 input 为空,统一按默认参数执行一次查询。
|
||||
normalized, normalizeErr := normalizeToolInput(input)
|
||||
if normalizeErr != nil {
|
||||
return nil, normalizeErr
|
||||
}
|
||||
|
||||
// 2. 执行真实查询。
|
||||
records, queryErr := deps.QueryTasks(ctx, normalized)
|
||||
if queryErr != nil {
|
||||
return nil, queryErr
|
||||
}
|
||||
|
||||
// 3. 把业务记录映射成模型友好的结构化输出。
|
||||
items := make([]TaskQueryToolRecord, 0, len(records))
|
||||
for _, record := range records {
|
||||
items = append(items, TaskQueryToolRecord{
|
||||
ID: record.ID,
|
||||
Title: record.Title,
|
||||
PriorityGroup: record.PriorityGroup,
|
||||
PriorityLabel: priorityLabelCN(record.PriorityGroup),
|
||||
IsCompleted: record.IsCompleted,
|
||||
DeadlineAt: formatOptionalTime(record.DeadlineAt),
|
||||
UrgencyThresholdAt: formatOptionalTime(record.UrgencyThresholdAt),
|
||||
})
|
||||
}
|
||||
|
||||
return &TaskQueryToolOutput{
|
||||
Total: len(items),
|
||||
Items: items,
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("构建任务查询工具失败: %w", err)
|
||||
}
|
||||
|
||||
tools := []tool.BaseTool{queryTool}
|
||||
infos, err := collectToolInfos(ctx, tools)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TaskQueryToolBundle{
|
||||
Tools: tools,
|
||||
ToolInfos: infos,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// normalizeToolInput 负责参数清洗、默认值填充与合法性校验。
|
||||
//
|
||||
// 失败策略:
|
||||
// 1. 参数非法直接返回 error,阻止错误查询落到数据层;
|
||||
// 2. 参数缺失走默认值,优先保证“可用”。
|
||||
func normalizeToolInput(input *TaskQueryToolInput) (TaskQueryRequest, error) {
|
||||
// 1. 先准备默认值,保证“空参数”也能查到结果。
|
||||
req := TaskQueryRequest{
|
||||
SortBy: "deadline",
|
||||
Order: "asc",
|
||||
Limit: 5,
|
||||
IncludeCompleted: false,
|
||||
}
|
||||
if input == nil {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// 2. 象限校验:若提供则必须在 1~4。
|
||||
if input.Quadrant != nil {
|
||||
if *input.Quadrant < 1 || *input.Quadrant > 4 {
|
||||
return TaskQueryRequest{}, fmt.Errorf("quadrant=%d 非法,必须在 1~4", *input.Quadrant)
|
||||
}
|
||||
quadrant := *input.Quadrant
|
||||
req.Quadrant = &quadrant
|
||||
}
|
||||
|
||||
// 3. 排序字段校验。
|
||||
if strings.TrimSpace(input.SortBy) != "" {
|
||||
req.SortBy = strings.ToLower(strings.TrimSpace(input.SortBy))
|
||||
}
|
||||
switch req.SortBy {
|
||||
case "deadline", "priority", "id":
|
||||
// 允许字段。
|
||||
default:
|
||||
return TaskQueryRequest{}, fmt.Errorf("sort_by=%s 非法,仅支持 deadline|priority|id", req.SortBy)
|
||||
}
|
||||
|
||||
// 4. 排序方向校验。
|
||||
if strings.TrimSpace(input.Order) != "" {
|
||||
req.Order = strings.ToLower(strings.TrimSpace(input.Order))
|
||||
}
|
||||
switch req.Order {
|
||||
case "asc", "desc":
|
||||
// 允许方向。
|
||||
default:
|
||||
return TaskQueryRequest{}, fmt.Errorf("order=%s 非法,仅支持 asc|desc", req.Order)
|
||||
}
|
||||
|
||||
// 5. limit 校验与上限保护。
|
||||
if input.Limit > 0 {
|
||||
req.Limit = input.Limit
|
||||
}
|
||||
if req.Limit > 20 {
|
||||
req.Limit = 20
|
||||
}
|
||||
if req.Limit <= 0 {
|
||||
req.Limit = 5
|
||||
}
|
||||
|
||||
// 6. include_completed 默认 false;明确传入时才覆盖。
|
||||
if input.IncludeCompleted != nil {
|
||||
req.IncludeCompleted = *input.IncludeCompleted
|
||||
}
|
||||
|
||||
// 7. keyword 清洗:去首尾空格,空串视为未设置。
|
||||
req.Keyword = strings.TrimSpace(input.Keyword)
|
||||
|
||||
// 8. 截止时间上下界解析。
|
||||
before, err := parseOptionalBoundaryTime(input.DeadlineBefore, true)
|
||||
if err != nil {
|
||||
return TaskQueryRequest{}, err
|
||||
}
|
||||
after, err := parseOptionalBoundaryTime(input.DeadlineAfter, false)
|
||||
if err != nil {
|
||||
return TaskQueryRequest{}, err
|
||||
}
|
||||
req.DeadlineBefore = before
|
||||
req.DeadlineAfter = after
|
||||
|
||||
// 9. 上下界合法性检查:after 不能晚于 before。
|
||||
if req.DeadlineBefore != nil && req.DeadlineAfter != nil && req.DeadlineAfter.After(*req.DeadlineBefore) {
|
||||
return TaskQueryRequest{}, errors.New("deadline_after 不能晚于 deadline_before")
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func collectToolInfos(ctx context.Context, tools []tool.BaseTool) ([]*schema.ToolInfo, error) {
|
||||
infos := make([]*schema.ToolInfo, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
info, err := t.Info(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取工具信息失败: %w", err)
|
||||
}
|
||||
infos = append(infos, info)
|
||||
}
|
||||
return infos, nil
|
||||
}
|
||||
|
||||
// parseOptionalBoundaryTime 解析时间上下界。
|
||||
//
|
||||
// 参数语义:
|
||||
// 1. isUpper=true:按“上界”解析,若输入仅日期则补到 23:59;
|
||||
// 2. isUpper=false:按“下界”解析,若输入仅日期则补到 00:00。
|
||||
func parseOptionalBoundaryTime(raw string, isUpper bool) (*time.Time, error) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
loc := time.Local
|
||||
for _, layout := range taskQueryTimeLayouts {
|
||||
var (
|
||||
t time.Time
|
||||
err error
|
||||
)
|
||||
if layout == time.RFC3339 {
|
||||
t, err = time.Parse(layout, text)
|
||||
if err == nil {
|
||||
t = t.In(loc)
|
||||
}
|
||||
} else {
|
||||
t, err = time.ParseInLocation(layout, text, loc)
|
||||
}
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 仅日期输入时,按上下界补齐时分。
|
||||
if layout == "2006-01-02" {
|
||||
if isUpper {
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 0, loc)
|
||||
} else {
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
|
||||
}
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
return nil, fmt.Errorf("时间格式不支持: %s", text)
|
||||
}
|
||||
|
||||
func priorityLabelCN(priority int) string {
|
||||
switch priority {
|
||||
case 1:
|
||||
return "重要且紧急"
|
||||
case 2:
|
||||
return "重要不紧急"
|
||||
case 3:
|
||||
return "简单不重要"
|
||||
case 4:
|
||||
return "不简单不重要"
|
||||
default:
|
||||
return "未知优先级"
|
||||
}
|
||||
}
|
||||
|
||||
func formatOptionalTime(t *time.Time) string {
|
||||
if t == nil {
|
||||
return ""
|
||||
}
|
||||
return t.In(time.Local).Format("2006-01-02 15:04")
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package taskquery
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
)
|
||||
|
||||
// buildInvokableToolMap 把工具包转换成“工具名 -> 可执行工具”映射。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只做工具元数据到执行器的映射,不做业务逻辑;
|
||||
// 2. 若工具包结构异常(数量不一致/信息缺失)直接返回 error;
|
||||
// 3. 供图节点在运行时快速按工具名取执行器。
|
||||
func buildInvokableToolMap(bundle *TaskQueryToolBundle) (map[string]tool.InvokableTool, error) {
|
||||
if bundle == nil || len(bundle.Tools) == 0 || len(bundle.ToolInfos) == 0 {
|
||||
return nil, fmt.Errorf("task query tool bundle is empty")
|
||||
}
|
||||
if len(bundle.Tools) != len(bundle.ToolInfos) {
|
||||
return nil, fmt.Errorf("task query tool bundle mismatch")
|
||||
}
|
||||
|
||||
result := make(map[string]tool.InvokableTool, len(bundle.Tools))
|
||||
for idx, baseTool := range bundle.Tools {
|
||||
info := bundle.ToolInfos[idx]
|
||||
if info == nil || strings.TrimSpace(info.Name) == "" {
|
||||
return nil, fmt.Errorf("task query tool info is invalid")
|
||||
}
|
||||
invokableTool, ok := baseTool.(tool.InvokableTool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("task query tool %s is not invokable", info.Name)
|
||||
}
|
||||
result[info.Name] = invokableTool
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
package taskquery
|
||||
|
||||
import "testing"
|
||||
|
||||
// TestNormalizeToolInput_Default
|
||||
// 目的:验证空入参会回填默认查询参数,保证工具在“参数缺失”场景仍可执行。
|
||||
func TestNormalizeToolInput_Default(t *testing.T) {
|
||||
req, err := normalizeToolInput(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("不应报错: %v", err)
|
||||
}
|
||||
if req.SortBy != "deadline" || req.Order != "asc" || req.Limit != 5 || req.IncludeCompleted {
|
||||
t.Fatalf("默认值异常: %+v", req)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNormalizeToolInput_InvalidQuadrant
|
||||
// 目的:验证 quadrant 越界时会被拦截,避免无效过滤条件进入业务层。
|
||||
func TestNormalizeToolInput_InvalidQuadrant(t *testing.T) {
|
||||
invalid := 6
|
||||
_, err := normalizeToolInput(&TaskQueryToolInput{
|
||||
Quadrant: &invalid,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("期望 quadrant 越界时报错")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNormalizeToolInput_DateRange
|
||||
// 目的:验证时间上下界可解析并正确落入请求结构。
|
||||
func TestNormalizeToolInput_DateRange(t *testing.T) {
|
||||
req, err := normalizeToolInput(&TaskQueryToolInput{
|
||||
DeadlineAfter: "2026-03-01 08:00",
|
||||
DeadlineBefore: "2026-03-31",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("不应报错: %v", err)
|
||||
}
|
||||
if req.DeadlineAfter == nil || req.DeadlineBefore == nil {
|
||||
t.Fatalf("时间上下界不应为空: %+v", req)
|
||||
}
|
||||
if req.DeadlineAfter.After(*req.DeadlineBefore) {
|
||||
t.Fatalf("时间上下界关系异常: after=%v before=%v", req.DeadlineAfter, req.DeadlineBefore)
|
||||
}
|
||||
}
|
||||
176
backend/agent/通用能力接入文档.md
Normal file
176
backend/agent/通用能力接入文档.md
Normal file
@@ -0,0 +1,176 @@
|
||||
# agent 通用能力接入文档
|
||||
|
||||
## 1. 文档目的
|
||||
|
||||
本文用于说明 `backend/agent` 目录下“通用能力”的职责边界、放置位置和接入约束,避免后续继续出现“同一类能力复制三份、四份”的情况。
|
||||
|
||||
这里的“通用能力”特指:
|
||||
|
||||
1. 会被两个及以上能力域复用,或者已经明确会继续扩散的基础能力。
|
||||
2. 与具体业务语义弱耦合,抽出来后不会把某个 skill 的 prompt、状态字段、业务规则污染到其它模块。
|
||||
3. 抽出后能显著减少样板代码、降低迁移成本,或者统一链路行为。
|
||||
|
||||
本文不负责描述某个具体 skill 的业务流程。业务流程、状态机、prompt 细节,仍应放在对应能力域自己的文件中。
|
||||
|
||||
## 2. 当前目录分层
|
||||
|
||||
```text
|
||||
backend/agent/
|
||||
entrance.go
|
||||
chat/
|
||||
graph/
|
||||
llm/
|
||||
model/
|
||||
node/
|
||||
prompt/
|
||||
router/
|
||||
shared/
|
||||
stream/
|
||||
```
|
||||
|
||||
### 2.1 `entrance.go`
|
||||
|
||||
职责:
|
||||
|
||||
1. 作为 `agent` 模块对上层 service 的统一入口。
|
||||
2. 负责装配路由器与各能力 handler。
|
||||
3. 不负责具体 graph 逻辑、不负责直接调模型、不负责工具执行。
|
||||
|
||||
### 2.2 `router/`
|
||||
|
||||
职责:
|
||||
|
||||
1. 负责一级分流,把请求映射到具体能力链路。
|
||||
2. 维护统一的请求/响应结构和 action 定义。
|
||||
3. 不承载具体 skill 的业务判断细节。
|
||||
|
||||
适合放入这里的能力:
|
||||
|
||||
1. 路由请求结构。
|
||||
2. action 解析与分发。
|
||||
3. 对上层稳定暴露的最小门面。
|
||||
|
||||
### 2.3 `graph/`
|
||||
|
||||
职责:
|
||||
|
||||
1. 只负责组图、连线和节点编排。
|
||||
2. 文件里应尽量只出现节点挂载、分支和边定义。
|
||||
3. 不直接写复杂业务逻辑、不直接调 DAO、不直接拼 prompt。
|
||||
|
||||
### 2.4 `node/`
|
||||
|
||||
职责:
|
||||
|
||||
1. 承接能力域的核心业务节点实现。
|
||||
2. 按“节点逻辑文件 + 工具文件”的双文件格局组织复杂能力域。
|
||||
3. 在确实存在多节点复用时,可下沉少量带业务语义的 node 内部公共 helper。
|
||||
|
||||
当前约定:
|
||||
|
||||
1. `schedule_plan.go` / `schedule_plan_tool.go` 为一组。
|
||||
2. `schedule_refine.go` / `schedule_refine_tool.go` 为一组。
|
||||
3. `quicknote.go` / `quicknote_tool.go`、`taskquery.go` / `taskquery_tool.go` 同理。
|
||||
|
||||
补充说明:
|
||||
|
||||
1. `node/tool_common.go` 是 node 层内部通用工具聚合点。
|
||||
2. 这里只放“被两个及以上节点复用、但仍带一点节点上下文语义”的 helper。
|
||||
3. 如果某个能力已经弱化到与业务无关,应继续下沉到 `shared/`,而不是长期堆在 `tool_common.go`。
|
||||
|
||||
### 2.5 `llm/`
|
||||
|
||||
职责:
|
||||
|
||||
1. 统一封装模型调用、JSON 解析、推理参数和模型侧协议。
|
||||
2. 让上层节点尽量只关心“要什么结果”,不重复实现 SDK 样板代码。
|
||||
3. 不承载具体业务状态流转。
|
||||
|
||||
### 2.6 `model/`
|
||||
|
||||
职责:
|
||||
|
||||
1. 统一放置 agent 内部状态结构、输入输出 DTO、默认预算等模型无关定义。
|
||||
2. 不在这里写业务执行逻辑。
|
||||
|
||||
### 2.7 `prompt/`
|
||||
|
||||
职责:
|
||||
|
||||
1. 维护系统提示词、结构化输出模板、路由提示词等文本资产。
|
||||
2. 不在 prompt 文件中写节点控制流和工具编排。
|
||||
|
||||
### 2.8 `stream/`
|
||||
|
||||
职责:
|
||||
|
||||
1. 统一承接 SSE chunk 包装、阶段推送、OpenAI/Ark 流式适配。
|
||||
2. 保证上层 service 不需要重复拼装流协议。
|
||||
|
||||
### 2.9 `shared/`
|
||||
|
||||
职责:
|
||||
|
||||
1. 放置跨能力域复用的纯工具能力,例如时间、重试、深拷贝等。
|
||||
2. 要求业务语义尽量弱、依赖尽量少。
|
||||
3. 一旦某类逻辑已经被第二处复用,必须优先评估是否放到这里。
|
||||
|
||||
## 3. 什么该抽成通用能力
|
||||
|
||||
满足以下任一条件时,必须优先评估抽公共层:
|
||||
|
||||
1. 同类逻辑已经出现第二份实现。
|
||||
2. 不同 skill 的实现只有参数不同,控制流基本一致。
|
||||
3. 上层 service 已经开始出现重复胶水代码。
|
||||
4. 继续复制会增加迁移、测试或回归排查成本。
|
||||
|
||||
常见应优先考虑抽取的方向:
|
||||
|
||||
1. 模型调用门面。
|
||||
2. JSON 容错解析。
|
||||
3. SSE 阶段推送与 chunk 包装。
|
||||
4. 深拷贝与快照转换。
|
||||
5. 缓存快照读写辅助逻辑。
|
||||
|
||||
## 4. 什么不该抽成通用能力
|
||||
|
||||
以下内容默认不应抽到公共层:
|
||||
|
||||
1. 某个 skill 独有的 prompt 片段。
|
||||
2. 只服务单一业务的状态字段映射。
|
||||
3. 带强业务语义的 ReAct 决策规则。
|
||||
4. 只在一个节点里短期使用、且没有第二处复用证据的 helper。
|
||||
|
||||
判断原则:
|
||||
|
||||
1. 若抽出来后名字仍然需要带明显业务词,通常说明它还不够通用。
|
||||
2. 若抽出来会让其它模块被迫理解某个 skill 的内部规则,说明抽取层级过早。
|
||||
|
||||
## 5. 新增能力时的落点规则
|
||||
|
||||
1. 纯工具、弱业务语义、跨域复用:优先放 `shared/`。
|
||||
2. 只在路由阶段复用:放 `router/`。
|
||||
3. 只与模型协议相关:放 `llm/`。
|
||||
4. 只与流式输出相关:放 `stream/`。
|
||||
5. 只在 node 层内被多个节点复用,且带少量业务上下文:放 `node/tool_common.go` 或同层 helper。
|
||||
6. 仍然明显属于某个能力域:留在对应 `node/`、`prompt/`、`model/` 文件中,不要硬抽。
|
||||
|
||||
## 6. 变更要求
|
||||
|
||||
后续若在 `backend/agent` 中新增、下沉、替换任何通用能力,必须同步完成以下动作:
|
||||
|
||||
1. 更新本文档,说明新能力放在哪一层、为什么放这里。
|
||||
2. 说明是否替代了旧实现,旧实现是否已经删除。
|
||||
3. 检查是否还残留第三份及以上重复实现。
|
||||
4. 若本轮只是暂时无法抽公共层,必须在代码注释或文档里写明原因。
|
||||
|
||||
## 7. 当前结构结论
|
||||
|
||||
截至当前版本,`backend/agent` 已是唯一正式实现目录,`backend/service/agentsvc` 也已与历史旧路径完全解耦。
|
||||
|
||||
后续重构优先级建议:
|
||||
|
||||
1. 继续收口 node 层内部重复的查询/校验/移动辅助逻辑。
|
||||
2. 持续把 service 层里可复用的旁路读写逻辑下沉到更稳定的公共层。
|
||||
3. 保持 graph 只做编排、node 只做业务、shared 只做弱语义公共能力,避免重新堆回大杂烩结构。
|
||||
|
||||
Reference in New Issue
Block a user