diff --git a/backend/agent/README.md b/backend/agent/README.md new file mode 100644 index 0000000..f6e28a1 --- /dev/null +++ b/backend/agent/README.md @@ -0,0 +1,31 @@ +# backend/agent 目录说明 + +该目录当前按“聊天流式输出能力”和“可编排的随口记能力”拆分: + +1. `graph.go` +- 仅负责现有流式聊天输出封装(SSE/OpenAI 兼容 chunk 转换)。 +- 已有线上链路依赖,当前不改业务逻辑。 + +2. `prompt.go` +- 通用 Agent 提示词。 + +3. `quick_note_prompt.go` +- AI 随口记专用提示词(意图识别、优先级评估)。 + +4. `state.go` +- 随口记链路状态结构(意图标记、抽取结果、重试计数、持久化结果)。 + +5. `tool.go` +- 随口记工具打包入口: + - `BuildQuickNoteToolBundle` + - 工具输入输出 schema + - deadline 解析与优先级校验 + +6. `quick_note_graph.go` +- 随口记 graph 编排实现: + - 节点1:意图识别 + - 节点2:优先级评估 + - 节点3:调用写库工具 + - 分支:失败自动重试(最多 3 次) + +> 说明:服务层通过 `RunQuickNoteGraph` 调用该图;若判定为非随口记意图,会自动回落到原有普通流式聊天逻辑。 diff --git a/backend/agent/quick_note_graph.go b/backend/agent/quick_note_graph.go new file mode 100644 index 0000000..521f3b2 --- /dev/null +++ b/backend/agent/quick_note_graph.go @@ -0,0 +1,466 @@ +package agent + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/cloudwego/eino-ext/components/model/ark" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +const ( + quickNoteGraphNodeIntent = "quick_note_intent" + quickNoteGraphNodeRank = "quick_note_priority" + quickNoteGraphNodePersist = "quick_note_persist" + quickNoteGraphNodeExit = "quick_note_exit" +) + +type quickNoteIntentModelOutput struct { + IsQuickNote bool `json:"is_quick_note"` + Title string `json:"title"` + DeadlineAt string `json:"deadline_at"` + Reason string `json:"reason"` +} + +type quickNotePriorityModelOutput struct { + PriorityGroup int `json:"priority_group"` + Reason string `json:"reason"` +} + +// QuickNoteGraphRunInput 是运行“随口记 graph”所需的输入依赖。 +// 说明: +// - EmitStage 可选,用于把节点进度推送给外层(例如 SSE 状态块); +// - 不传 EmitStage 时,图逻辑保持静默执行。 +type QuickNoteGraphRunInput struct { + Model *ark.ChatModel + State *QuickNoteState + Deps QuickNoteToolDeps + + EmitStage func(stage, detail string) +} + +// RunQuickNoteGraph 执行“随口记”图编排。 +// 设计目标: +// 1) 意图识别和信息抽取与写库解耦; +// 2) 发生模型抖动或工具失败时,具备可控降级和重试; +// 3) 时间解析严格可控,避免把非法日期静默写成 NULL。 +func RunQuickNoteGraph(ctx context.Context, input QuickNoteGraphRunInput) (*QuickNoteState, error) { + if input.Model == nil { + return nil, errors.New("quick note graph: model is nil") + } + if input.State == nil { + return nil, errors.New("quick note graph: state is nil") + } + if err := input.Deps.validate(); err != nil { + return nil, err + } + + emitStage := func(stage, detail string) { + if input.EmitStage != nil { + input.EmitStage(stage, detail) + } + } + + // 统一初始化“当前时间基准”: + // - RequestNow 用于相对时间解析; + // - RequestNowText 用于拼接到提示词,让模型知道“现在是几点”。 + if input.State.RequestNow.IsZero() { + input.State.RequestNow = quickNoteNowToMinute() + } + if strings.TrimSpace(input.State.RequestNowText) == "" { + input.State.RequestNowText = formatQuickNoteTimeToMinute(input.State.RequestNow) + } + + toolBundle, err := BuildQuickNoteToolBundle(ctx, input.Deps) + if err != nil { + return nil, err + } + createTaskTool, err := getInvokableToolByName(toolBundle, ToolNameQuickNoteCreateTask) + if err != nil { + return nil, err + } + + graph := compose.NewGraph[*QuickNoteState, *QuickNoteState]() + + // 节点1:意图识别与信息抽取。 + if err = graph.AddLambdaNode(quickNoteGraphNodeIntent, compose.InvokableLambda( + func(ctx context.Context, st *QuickNoteState) (*QuickNoteState, error) { + if st == nil { + return nil, errors.New("quick note graph: nil state in intent node") + } + + emitStage("quick_note.intent.analyzing", "正在分析用户输入是否属于任务安排请求。") + + prompt := fmt.Sprintf(`当前时间(北京时间,精确到分钟):%s +用户输入:%s +请仅输出 JSON(不要 markdown,不要解释),字段如下: +{ + "is_quick_note": boolean, + "title": string, + "deadline_at": string, + "reason": string +} +字段约束: +1) deadline_at 只允许输出绝对时间,格式必须为 "yyyy-MM-dd HH:mm"。 +2) 如果用户说了“明天/后天/下周一/今晚”等相对时间,必须基于上面的当前时间换算成绝对时间。 +3) 如果用户没有提及时间,deadline_at 输出空字符串。`, + st.RequestNowText, + st.UserInput, + ) + raw, callErr := callModelForJSON(ctx, input.Model, QuickNoteIntentPrompt, prompt) + if callErr != nil { + st.IsQuickNoteIntent = false + st.IntentJudgeReason = "意图识别失败,回退普通聊天" + return st, nil + } + + parsed, parseErr := parseJSONPayload[quickNoteIntentModelOutput](raw) + if parseErr != nil { + st.IsQuickNoteIntent = false + st.IntentJudgeReason = "意图识别结果不可解析,回退普通聊天" + return st, nil + } + + st.IsQuickNoteIntent = parsed.IsQuickNote + st.IntentJudgeReason = strings.TrimSpace(parsed.Reason) + if !st.IsQuickNoteIntent { + return st, nil + } + + title := strings.TrimSpace(parsed.Title) + if title == "" { + title = strings.TrimSpace(st.UserInput) + } + st.ExtractedTitle = title + + emitStage("quick_note.deadline.validating", "正在校验并归一化任务时间。") + + // Step A:优先尝试解析模型抽取出来的 deadline。 + st.ExtractedDeadlineText = strings.TrimSpace(parsed.DeadlineAt) + if st.ExtractedDeadlineText != "" { + if deadline, deadlineErr := parseOptionalDeadlineWithNow(st.ExtractedDeadlineText, st.RequestNow); deadlineErr == nil { + st.ExtractedDeadline = deadline + } + } + + // Step B:基于用户原句执行“本地时间解析 + 合法性校验”。 + userDeadline, userHasTimeHint, userDeadlineErr := parseOptionalDeadlineFromUserInput(st.UserInput, st.RequestNow) + if userHasTimeHint && userDeadlineErr != nil { + st.DeadlineValidationError = userDeadlineErr.Error() + st.AssistantReply = "我识别到你给了时间信息,但这个时间格式我没法准确解析,请改成例如:2026-03-20 18:30、明天下午3点、下周一上午9点。" + emitStage("quick_note.failed", "时间校验失败,未执行写入。") + return st, nil + } + + if st.ExtractedDeadline == nil && userDeadline != nil { + st.ExtractedDeadline = userDeadline + if st.ExtractedDeadlineText == "" { + st.ExtractedDeadlineText = strings.TrimSpace(st.UserInput) + } + } + return st, nil + })); err != nil { + return nil, err + } + + // 节点2:优先级评估。 + if err = graph.AddLambdaNode(quickNoteGraphNodeRank, compose.InvokableLambda( + func(ctx context.Context, st *QuickNoteState) (*QuickNoteState, error) { + if st == nil { + return nil, errors.New("quick note graph: nil state in priority node") + } + if !st.IsQuickNoteIntent || strings.TrimSpace(st.DeadlineValidationError) != "" { + return st, nil + } + + emitStage("quick_note.priority.evaluating", "正在评估任务优先级。") + + deadlineText := "无" + if st.ExtractedDeadline != nil { + deadlineText = formatQuickNoteTimeToMinute(*st.ExtractedDeadline) + } + deadlineClue := strings.TrimSpace(st.ExtractedDeadlineText) + if deadlineClue == "" { + deadlineClue = "无" + } + + prompt := fmt.Sprintf(`当前时间(北京时间,精确到分钟):%s +请对以下任务评估优先级: +- 任务标题:%s +- 用户原始输入:%s +- 时间线索原文:%s +- 归一化截止时间:%s + +请仅输出 JSON(不要 markdown,不要解释): +{ + "priority_group": 1|2|3|4, + "reason": "简短理由" +}`, + st.RequestNowText, + st.ExtractedTitle, + st.UserInput, + deadlineClue, + deadlineText, + ) + + raw, callErr := callModelForJSON(ctx, input.Model, QuickNotePriorityPrompt, prompt) + if callErr != nil { + fallback := fallbackPriority(st) + st.ExtractedPriority = fallback + st.ExtractedPriorityReason = "优先级评估失败,使用兜底策略" + return st, nil + } + + parsed, parseErr := parseJSONPayload[quickNotePriorityModelOutput](raw) + if parseErr != nil || !IsValidTaskPriority(parsed.PriorityGroup) { + fallback := fallbackPriority(st) + st.ExtractedPriority = fallback + st.ExtractedPriorityReason = "优先级结果异常,使用兜底策略" + return st, nil + } + + st.ExtractedPriority = parsed.PriorityGroup + st.ExtractedPriorityReason = strings.TrimSpace(parsed.Reason) + return st, nil + })); err != nil { + return nil, err + } + + // 节点3:调用“写库工具”执行持久化。 + if err = graph.AddLambdaNode(quickNoteGraphNodePersist, compose.InvokableLambda( + func(ctx context.Context, st *QuickNoteState) (*QuickNoteState, error) { + if st == nil { + return nil, errors.New("quick note graph: nil state in persist node") + } + if !st.IsQuickNoteIntent || strings.TrimSpace(st.DeadlineValidationError) != "" { + return st, nil + } + + emitStage("quick_note.persisting", "正在写入任务数据。") + + priority := st.ExtractedPriority + if !IsValidTaskPriority(priority) { + priority = fallbackPriority(st) + st.ExtractedPriority = priority + } + + deadlineText := "" + if st.ExtractedDeadline != nil { + deadlineText = st.ExtractedDeadline.In(quickNoteLocation()).Format(time.RFC3339) + } + + toolInput := QuickNoteCreateTaskToolInput{ + Title: st.ExtractedTitle, + PriorityGroup: priority, + DeadlineAt: deadlineText, + } + rawInput, marshalErr := json.Marshal(toolInput) + if marshalErr != nil { + st.RecordToolError("构造工具参数失败: " + marshalErr.Error()) + if !st.CanRetryTool() { + st.AssistantReply = "抱歉,记录任务时参数处理失败,请稍后重试。" + emitStage("quick_note.failed", "参数构造失败,未完成写入。") + } + return st, nil + } + + rawOutput, invokeErr := createTaskTool.InvokableRun(ctx, string(rawInput)) + if invokeErr != nil { + st.RecordToolError(invokeErr.Error()) + if !st.CanRetryTool() { + st.AssistantReply = "抱歉,我尝试了多次仍未能成功记录这条任务,请稍后再试。" + emitStage("quick_note.failed", "多次重试后仍未完成写入。") + } + return st, nil + } + + toolOutput, parseErr := parseJSONPayload[QuickNoteCreateTaskToolOutput](rawOutput) + if parseErr != nil { + st.RecordToolError("解析工具返回失败: " + parseErr.Error()) + if !st.CanRetryTool() { + st.AssistantReply = "抱歉,我拿到了异常结果,没能确认任务是否记录成功,请稍后再试。" + emitStage("quick_note.failed", "结果解析异常,无法确认写入结果。") + } + return st, nil + } + + st.RecordToolSuccess(toolOutput.TaskID) + if strings.TrimSpace(toolOutput.Title) != "" { + st.ExtractedTitle = strings.TrimSpace(toolOutput.Title) + } + if IsValidTaskPriority(toolOutput.PriorityGroup) { + st.ExtractedPriority = toolOutput.PriorityGroup + } + reply := strings.TrimSpace(toolOutput.Message) + if reply == "" { + reply = fmt.Sprintf("已为你记录:%s(%s)", st.ExtractedTitle, PriorityLabelCN(st.ExtractedPriority)) + } + st.AssistantReply = reply + emitStage("quick_note.persisted", "任务写入成功,正在组织回复内容。") + return st, nil + })); err != nil { + return nil, err + } + + if err = graph.AddLambdaNode(quickNoteGraphNodeExit, compose.InvokableLambda( + func(ctx context.Context, st *QuickNoteState) (*QuickNoteState, error) { + return st, nil + })); err != nil { + return nil, err + } + + if err = graph.AddEdge(compose.START, quickNoteGraphNodeIntent); err != nil { + return nil, err + } + if err = graph.AddBranch(quickNoteGraphNodeIntent, compose.NewGraphBranch( + func(ctx context.Context, st *QuickNoteState) (string, error) { + if st == nil || !st.IsQuickNoteIntent { + return quickNoteGraphNodeExit, nil + } + if strings.TrimSpace(st.DeadlineValidationError) != "" { + return quickNoteGraphNodeExit, nil + } + return quickNoteGraphNodeRank, nil + }, + map[string]bool{quickNoteGraphNodeRank: true, quickNoteGraphNodeExit: true}, + )); err != nil { + return nil, err + } + if err = graph.AddEdge(quickNoteGraphNodeExit, compose.END); err != nil { + return nil, err + } + if err = graph.AddEdge(quickNoteGraphNodeRank, quickNoteGraphNodePersist); err != nil { + return nil, err + } + if err = graph.AddBranch(quickNoteGraphNodePersist, compose.NewGraphBranch( + func(ctx context.Context, st *QuickNoteState) (string, error) { + if st == nil { + return compose.END, nil + } + if st.Persisted { + return compose.END, nil + } + if st.CanRetryTool() { + return quickNoteGraphNodePersist, nil + } + if strings.TrimSpace(st.AssistantReply) == "" { + st.AssistantReply = "抱歉,我尝试了多次仍未能成功记录这条任务,请稍后再试。" + } + return compose.END, nil + }, + map[string]bool{quickNoteGraphNodePersist: true, compose.END: true}, + )); err != nil { + return nil, err + } + + maxSteps := input.State.MaxToolRetry + 10 + if maxSteps < 12 { + maxSteps = 12 + } + + runnable, err := graph.Compile(ctx, + compose.WithGraphName("QuickNoteGraph"), + compose.WithMaxRunSteps(maxSteps), + compose.WithNodeTriggerMode(compose.AnyPredecessor), + ) + if err != nil { + return nil, err + } + + return runnable.Invoke(ctx, input.State) +} + +func getInvokableToolByName(bundle *QuickNoteToolBundle, name string) (tool.InvokableTool, error) { + if bundle == nil { + return nil, errors.New("tool bundle is nil") + } + if len(bundle.Tools) == 0 || len(bundle.ToolInfos) == 0 { + return nil, errors.New("tool bundle is empty") + } + for idx, info := range bundle.ToolInfos { + if info == nil || info.Name != name { + continue + } + invokable, ok := bundle.Tools[idx].(tool.InvokableTool) + if !ok { + return nil, fmt.Errorf("tool %s is not invokable", name) + } + return invokable, nil + } + return nil, fmt.Errorf("tool %s not found", name) +} + +func callModelForJSON(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string) (string, error) { + messages := []*schema.Message{ + schema.SystemMessage(systemPrompt), + schema.UserMessage(userPrompt), + } + resp, err := chatModel.Generate(ctx, messages) + if err != nil { + return "", err + } + if resp == nil { + return "", errors.New("模型返回为空") + } + content := strings.TrimSpace(resp.Content) + if content == "" { + return "", errors.New("模型返回内容为空") + } + return content, nil +} + +func parseJSONPayload[T any](raw string) (*T, error) { + clean := strings.TrimSpace(raw) + if clean == "" { + return nil, errors.New("empty response") + } + + if strings.HasPrefix(clean, "```") { + clean = strings.TrimPrefix(clean, "```json") + clean = strings.TrimPrefix(clean, "```") + clean = strings.TrimSuffix(clean, "```") + clean = strings.TrimSpace(clean) + } + + var out T + if err := json.Unmarshal([]byte(clean), &out); err == nil { + return &out, nil + } + + obj := extractJSONObject(clean) + if obj == "" { + return nil, fmt.Errorf("no json object found in: %s", clean) + } + if err := json.Unmarshal([]byte(obj), &out); err != nil { + return nil, err + } + return &out, nil +} + +func extractJSONObject(text string) string { + start := strings.Index(text, "{") + end := strings.LastIndex(text, "}") + if start == -1 || end == -1 || end <= start { + return "" + } + return text[start : end+1] +} + +func fallbackPriority(st *QuickNoteState) int { + if st == nil { + return QuickNotePrioritySimpleNotImportant + } + if st.ExtractedDeadline != nil { + if time.Until(*st.ExtractedDeadline) <= 48*time.Hour { + return QuickNotePriorityImportantUrgent + } + return QuickNotePriorityImportantNotUrgent + } + return QuickNotePrioritySimpleNotImportant +} diff --git a/backend/agent/quick_note_prompt.go b/backend/agent/quick_note_prompt.go new file mode 100644 index 0000000..f09a958 --- /dev/null +++ b/backend/agent/quick_note_prompt.go @@ -0,0 +1,37 @@ +package agent + +const ( + // QuickNoteIntentPrompt 用于第一阶段:判断用户输入是否属于“随口记”。 + // 设计约束: + // 1) 只做识别与抽取,不允许模型宣称“已写库”; + // 2) 遇到相对时间必须先换算成绝对时间,减少后续工具层歧义; + // 3) 若无时间信息必须返回空字符串,避免幻觉时间污染数据库。 + QuickNoteIntentPrompt = `你是 SmartFlow 的“随口记分诊器”。 +请判断用户输入是否表达了“帮我记一个任务/日程”的需求。 +- 若是,请提取任务标题与时间线索。 +- 时间处理必须严谨:若出现相对时间(如明天/后天/下周一/今晚),必须基于上文给出的“当前时间”换算为绝对时间。 +- 若不是,请明确返回“非随口记意图”。 +- 不要声称已经写入数据库。` + + // QuickNotePriorityPrompt 用于第二阶段:将任务归类到四象限优先级。 + // 输出会直接映射到 tasks.priority(1~4),因此要求结果必须可解释。 + QuickNotePriorityPrompt = `你是 SmartFlow 的任务优先级评估器。 +根据任务内容、时间约束和执行成本,输出优先级 priority_group: +1=重要且紧急,2=重要不紧急,3=简单不重要,4=不简单不重要。 +请给出简短理由,理由必须可解释。` + + // QuickNoteReplyBanterPrompt 用于随口记成功后的“轻松跟进句”生成。 + // 约束重点: + // 1) 只输出一句自然中文; + // 2) 贴合用户原话题(例如吃早餐、开会、写报告); + // 3) 禁止新增事实(尤其不能改时间、优先级、任务内容); + // 4) 不要 markdown,不要列表,不要引号包裹。 + QuickNoteReplyBanterPrompt = `你是 SmartFlow 的中文口语化回复润色助手。 +请根据用户原话生成一句轻松自然的跟进话术,让回复更有温度。 +要求: +- 只输出一句中文,不超过30字。 +- 顺着用户创建提醒的主题延伸,就像聊天时友好的问候一样,记得动用你知道的对应领域的知识。例如(注意,只是例子):用户说提醒他明天早上吃麦当劳,你润色回复应该类似这样:"薯饼记得趁热吃哦~"。 +- 可以轻微调侃,但语气友好,不刻薄。 +- 不得新增或修改任务事实(任务名、时间、优先级)。 +- 不要输出 markdown、编号、引号。` +) diff --git a/backend/agent/state.go b/backend/agent/state.go index 4883155..d67dc88 100644 --- a/backend/agent/state.go +++ b/backend/agent/state.go @@ -1 +1,149 @@ package agent + +import "time" + +const ( + // QuickNoteDatetimeMinuteLayout 是“随口记”链路内部统一的分钟级时间格式。 + // 说明: + // 1) 用于把“当前时间基准”传给模型,避免模型在相对时间推断时出现秒级抖动。 + // 2) 用于日志和调试,读起来比 RFC3339 更直观。 + QuickNoteDatetimeMinuteLayout = "2006-01-02 15:04" + + // quickNoteTimezoneName 是随口记链路默认业务时区。 + // 这里固定为东八区,避免容器运行在 UTC 时把“明天/今晚”解释偏移到错误日期。 + quickNoteTimezoneName = "Asia/Shanghai" + + // QuickNotePriorityImportantUrgent 对应四象限里的“重要且紧急”。 + // 在当前 tasks 表中映射为 priority=1(数值越小优先级越高)。 + QuickNotePriorityImportantUrgent = 1 + // QuickNotePriorityImportantNotUrgent 对应“重要不紧急”。 + QuickNotePriorityImportantNotUrgent = 2 + // QuickNotePrioritySimpleNotImportant 对应“简单不重要”。 + QuickNotePrioritySimpleNotImportant = 3 + // QuickNotePriorityComplexNotImportant 对应“不简单不重要”。 + QuickNotePriorityComplexNotImportant = 4 +) + +// IsValidTaskPriority 判断优先级是否合法。 +// 目前后端任务模型限定为 1~4。 +func IsValidTaskPriority(priority int) bool { + return priority >= QuickNotePriorityImportantUrgent && priority <= QuickNotePriorityComplexNotImportant +} + +// PriorityLabelCN 把优先级数值转换为中文标签,便于拼接给用户的自然语言回复。 +func PriorityLabelCN(priority int) string { + switch priority { + case QuickNotePriorityImportantUrgent: + return "重要且紧急" + case QuickNotePriorityImportantNotUrgent: + return "重要不紧急" + case QuickNotePrioritySimpleNotImportant: + return "简单不重要" + case QuickNotePriorityComplexNotImportant: + return "不简单不重要" + default: + return "未知优先级" + } +} + +// QuickNoteState 是“AI随口记”链路在 graph 节点间传递的统一状态容器。 +// 设计目标: +// 1) 把本次请求的上下文收拢到一个结构里,降低节点函数参数散落; +// 2) 让“识别、评估、写库、重试、回复”每一步都可追踪; +// 3) 便于后续扩展打点和可观测字段(例如时间解析失败原因)。 +type QuickNoteState struct { + // 基础上下文:用于日志关联与用户隔离。 + TraceID string + UserID int + ConversationID string + + // RequestNow 记录“请求进入随口记链路时”的时间基准(分钟级)。 + // 所有相对时间(明天/后天/下周一)都必须基于这个时间计算, + // 这样同一次请求内不会因为时间流逝产生口径漂移。 + RequestNow time.Time + // RequestNowText 是 RequestNow 的字符串形式,主要用于 prompt 注入。 + RequestNowText string + + // 用户原始输入(例如:提醒我下周日之前完成大作业)。 + UserInput string + + // 意图判定结果。 + IsQuickNoteIntent bool + IntentJudgeReason string + + // 结构化抽取结果:由“意图识别/信息抽取”节点写入。 + ExtractedTitle string + ExtractedDeadline *time.Time + ExtractedDeadlineText string + ExtractedPriority int + + // ExtractedPriorityReason 记录优先级评估理由,便于后续排查模型判断是否符合预期。 + ExtractedPriorityReason string + // DeadlineValidationError 记录时间校验失败原因。 + // 只要该字段非空,就说明用户提供了无法解析的时间表达,本次请求不应落库。 + DeadlineValidationError string + + // 工具调用过程状态:用于重试与故障回溯。 + ToolAttemptCount int + MaxToolRetry int + LastToolError string + + // 最终持久化结果:由“写库工具”节点回填。 + PersistedTaskID int + Persisted bool + + // AssistantReply 是 graph 最终给用户的回复文案。 + AssistantReply string +} + +// NewQuickNoteState 创建随口记状态对象并初始化默认重试次数。 +func NewQuickNoteState(traceID string, userID int, conversationID, userInput string) *QuickNoteState { + requestNow := quickNoteNowToMinute() + return &QuickNoteState{ + TraceID: traceID, + UserID: userID, + ConversationID: conversationID, + RequestNow: requestNow, + RequestNowText: formatQuickNoteTimeToMinute(requestNow), + UserInput: userInput, + MaxToolRetry: 3, + } +} + +// CanRetryTool 判断当前是否还能继续重试工具调用。 +func (s *QuickNoteState) CanRetryTool() bool { + return s.ToolAttemptCount < s.MaxToolRetry +} + +// RecordToolError 记录一次工具调用失败。 +func (s *QuickNoteState) RecordToolError(errMsg string) { + s.ToolAttemptCount++ + s.LastToolError = errMsg +} + +// RecordToolSuccess 记录一次工具调用成功。 +func (s *QuickNoteState) RecordToolSuccess(taskID int) { + s.ToolAttemptCount++ + s.PersistedTaskID = taskID + s.Persisted = true + s.LastToolError = "" +} + +// quickNoteLocation 返回随口记链路使用的业务时区。 +func quickNoteLocation() *time.Location { + loc, err := time.LoadLocation(quickNoteTimezoneName) + if err != nil { + return time.Local + } + return loc +} + +// quickNoteNowToMinute 返回当前时间并截断到分钟级。 +func quickNoteNowToMinute() time.Time { + return time.Now().In(quickNoteLocation()).Truncate(time.Minute) +} + +// formatQuickNoteTimeToMinute 将时间格式化为分钟级字符串。 +func formatQuickNoteTimeToMinute(t time.Time) string { + return t.In(quickNoteLocation()).Format(QuickNoteDatetimeMinuteLayout) +} diff --git a/backend/agent/tool.go b/backend/agent/tool.go index 4883155..be7d75e 100644 --- a/backend/agent/tool.go +++ b/backend/agent/tool.go @@ -1 +1,618 @@ package agent + +import ( + "context" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + "time" + + "github.com/cloudwego/eino/components/tool" + toolutils "github.com/cloudwego/eino/components/tool/utils" + "github.com/cloudwego/eino/schema" +) + +const ( + // ToolNameQuickNoteCreateTask 是“AI随口记”写库工具的标准名称。 + // 该名称会直接暴露给大模型,因此建议保持稳定,避免后续提示词和历史上下文失配。 + ToolNameQuickNoteCreateTask = "quick_note_create_task" + // ToolDescQuickNoteCreateTask 是工具的简要职责说明。 + ToolDescQuickNoteCreateTask = "把用户随口提到的事项落库为任务,支持可选截止时间与优先级" +) + +var ( + // quickNoteDeadlineLayouts 是“绝对时间”白名单格式。 + // 只要命中任意一个 layout,就会被归一化为分钟级时间并进入写库流程。 + quickNoteDeadlineLayouts = []string{ + time.RFC3339, + "2006-01-02T15:04", + "2006-01-02 15:04:05", + "2006-01-02 15:04", + "2006/01/02 15:04:05", + "2006/01/02 15:04", + "2006.01.02 15:04:05", + "2006.01.02 15:04", + "2006-01-02", + "2006/01/02", + "2006.01.02", + } + quickNoteDateOnlyLayouts = map[string]struct{}{ + "2006-01-02": {}, + "2006/01/02": {}, + "2006.01.02": {}, + } + + // 正则区: + // 1) 用于解析明确时间表达; + // 2) 用于“是否存在时间线索”的判定(即使格式错误,也会触发校验失败而非静默忽略)。 + quickNoteClockHMRegex = regexp.MustCompile(`(\d{1,2})\s*[::]\s*(\d{1,2})`) + quickNoteClockCNRegex = regexp.MustCompile(`(\d{1,2})\s*点\s*(半|(\d{1,2})\s*分?)?`) + quickNoteYMDRegex = regexp.MustCompile(`(\d{4})\s*年\s*(\d{1,2})\s*月\s*(\d{1,2})\s*[日号]?`) + quickNoteMDRegex = regexp.MustCompile(`(\d{1,2})\s*月\s*(\d{1,2})\s*[日号]?`) + quickNoteDateSepRegex = regexp.MustCompile(`\d{1,4}\s*[-/.]\s*\d{1,2}(\s*[-/.]\s*\d{1,2})?`) + quickNoteWeekdayRegex = regexp.MustCompile(`(下周|下星期|下礼拜|本周|这周|本星期|这星期|周|星期|礼拜)([一二三四五六日天])`) + quickNoteRelativeTokens = []string{ + "今天", "今日", "今晚", "今早", "今晨", "明天", "明日", "后天", "大后天", "昨天", "昨日", + "早上", "早晨", "上午", "中午", "下午", "晚上", "傍晚", "夜里", "凌晨", + } +) + +// QuickNoteToolDeps 描述“随口记工具包”需要的外部依赖。 +// 这里采用函数注入的方式,避免 agent 包和 service/dao 强耦合,后续更容易演进为 mock 测试或多实现切换。 +type QuickNoteToolDeps struct { + // ResolveUserID 从上下文中解析当前登录用户 ID。 + ResolveUserID func(ctx context.Context) (int, error) + // CreateTask 执行真实写库动作。 + CreateTask func(ctx context.Context, req QuickNoteCreateTaskRequest) (*QuickNoteCreateTaskResult, error) +} + +func (d QuickNoteToolDeps) validate() error { + if d.ResolveUserID == nil { + return errors.New("quick note tool deps: ResolveUserID is nil") + } + if d.CreateTask == nil { + return errors.New("quick note tool deps: CreateTask is nil") + } + return nil +} + +// QuickNoteToolBundle 是随口记工具集合的打包结果。 +// - Tools: 给 ToolsNode 使用 +// - ToolInfos: 给 ChatModel 绑定工具 schema 使用 +// 两者分开返回,可以适配你后面用 chain、graph、react 的不同挂载姿势。 +type QuickNoteToolBundle struct { + Tools []tool.BaseTool + ToolInfos []*schema.ToolInfo +} + +// QuickNoteCreateTaskRequest 是工具层到业务层的内部请求结构。 +// 与模型输入解耦,避免模型字段变化直接影响业务签名。 +type QuickNoteCreateTaskRequest struct { + UserID int + Title string + PriorityGroup int + DeadlineAt *time.Time +} + +// QuickNoteCreateTaskResult 是业务层返回给工具层的结构化结果。 +type QuickNoteCreateTaskResult struct { + TaskID int + Title string + PriorityGroup int + DeadlineAt *time.Time +} + +// QuickNoteCreateTaskToolInput 是提供给大模型的工具参数定义。 +// 注意:user_id 不对模型暴露,统一从鉴权上下文提取,避免越权写入。 +type QuickNoteCreateTaskToolInput struct { + Title string `json:"title" jsonschema:"required,description=任务标题,简洁明确"` + // PriorityGroup 使用 1~4,和后端 tasks.priority 保持一致。 + PriorityGroup int `json:"priority_group" jsonschema:"required,enum=1,enum=2,enum=3,enum=4,description=优先级分组(1重要且紧急,2重要不紧急,3简单不重要,4不简单不重要)"` + // DeadlineAt 支持绝对时间与常见相对时间(如明天/后天/下周一/今晚),内部会归一化为绝对时间。 + DeadlineAt string `json:"deadline_at,omitempty" jsonschema:"description=可选截止时间,支持RFC3339、yyyy-MM-dd HH:mm:ss、yyyy-MM-dd HH:mm 以及常见中文相对时间"` +} + +// QuickNoteCreateTaskToolOutput 是返回给大模型的工具结果。 +// 该结构可直接给模型用于“向用户解释已记录到哪个优先级”。 +type QuickNoteCreateTaskToolOutput struct { + TaskID int `json:"task_id"` + Title string `json:"title"` + PriorityGroup int `json:"priority_group"` + PriorityLabel string `json:"priority_label"` + DeadlineAt string `json:"deadline_at,omitempty"` + Message string `json:"message"` +} + +// BuildQuickNoteToolBundle 构建“AI随口记”工具包。 +// 这是 agent 目录给上层编排层(chain/graph/react)提供的统一入口。 +func BuildQuickNoteToolBundle(ctx context.Context, deps QuickNoteToolDeps) (*QuickNoteToolBundle, error) { + if err := deps.validate(); err != nil { + return nil, err + } + + createTaskTool, err := toolutils.InferTool( + ToolNameQuickNoteCreateTask, + ToolDescQuickNoteCreateTask, + func(ctx context.Context, input *QuickNoteCreateTaskToolInput) (*QuickNoteCreateTaskToolOutput, error) { + if input == nil { + return nil, errors.New("工具参数不能为空") + } + + title := strings.TrimSpace(input.Title) + if title == "" { + return nil, errors.New("title 不能为空") + } + if !IsValidTaskPriority(input.PriorityGroup) { + return nil, fmt.Errorf("priority_group=%d 非法,必须在 1~4", input.PriorityGroup) + } + + // 这里对 deadline_at 做“强校验”: + // - 空值允许(代表没有截止时间); + // - 非空但无法解析直接报错,避免把有问题的时间静默写成 NULL。 + deadline, err := parseOptionalDeadline(input.DeadlineAt) + if err != nil { + return nil, err + } + + userID, err := deps.ResolveUserID(ctx) + if err != nil { + return nil, fmt.Errorf("解析用户身份失败: %w", err) + } + if userID <= 0 { + return nil, fmt.Errorf("非法 user_id=%d", userID) + } + + result, err := deps.CreateTask(ctx, QuickNoteCreateTaskRequest{ + UserID: userID, + Title: title, + PriorityGroup: input.PriorityGroup, + DeadlineAt: deadline, + }) + if err != nil { + return nil, err + } + if result == nil || result.TaskID <= 0 { + return nil, errors.New("写入任务后返回结果异常") + } + + finalTitle := title + if strings.TrimSpace(result.Title) != "" { + finalTitle = strings.TrimSpace(result.Title) + } + + finalPriority := input.PriorityGroup + if IsValidTaskPriority(result.PriorityGroup) { + finalPriority = result.PriorityGroup + } + + deadlineStr := "" + if result.DeadlineAt != nil { + deadlineStr = result.DeadlineAt.In(quickNoteLocation()).Format(time.RFC3339) + } else if deadline != nil { + deadlineStr = deadline.In(quickNoteLocation()).Format(time.RFC3339) + } + + return &QuickNoteCreateTaskToolOutput{ + TaskID: result.TaskID, + Title: finalTitle, + PriorityGroup: finalPriority, + PriorityLabel: PriorityLabelCN(finalPriority), + DeadlineAt: deadlineStr, + Message: fmt.Sprintf("已记录:%s(%s)", finalTitle, PriorityLabelCN(finalPriority)), + }, nil + }, + ) + if err != nil { + return nil, fmt.Errorf("构建随口记工具失败: %w", err) + } + + tools := []tool.BaseTool{createTaskTool} + infos, err := collectToolInfos(ctx, tools) + if err != nil { + return nil, err + } + + return &QuickNoteToolBundle{ + Tools: tools, + ToolInfos: infos, + }, nil +} + +func collectToolInfos(ctx context.Context, tools []tool.BaseTool) ([]*schema.ToolInfo, error) { + infos := make([]*schema.ToolInfo, 0, len(tools)) + for _, t := range tools { + info, err := t.Info(ctx) + if err != nil { + return nil, fmt.Errorf("读取工具信息失败: %w", err) + } + infos = append(infos, info) + } + return infos, nil +} + +// parseOptionalDeadline 解析工具输入中的可选截止时间。 +// 该入口用于“工具参数强校验”:只要调用方给了非空 deadline_at,就必须能被解析。 +func parseOptionalDeadline(raw string) (*time.Time, error) { + value := normalizeDeadlineInput(raw) + if value == "" { + return nil, nil + } + + deadline, hasHint, err := parseOptionalDeadlineFromText(value, quickNoteNowToMinute()) + if err != nil { + return nil, err + } + if deadline == nil { + if !hasHint { + return nil, fmt.Errorf("deadline_at 格式不支持: %s", value) + } + return nil, fmt.Errorf("deadline_at 无法解析: %s", value) + } + return deadline, nil +} + +// parseOptionalDeadlineWithNow 在给定时间基准下解析 deadline。 +// 该函数保持“严格模式”:非空字符串无法解析时会直接返回 error。 +func parseOptionalDeadlineWithNow(raw string, now time.Time) (*time.Time, error) { + value := normalizeDeadlineInput(raw) + if value == "" { + return nil, nil + } + + deadline, _, err := parseOptionalDeadlineFromText(value, now) + if err != nil { + return nil, err + } + if deadline == nil { + return nil, fmt.Errorf("deadline_at 格式不支持: %s", value) + } + return deadline, nil +} + +// parseOptionalDeadlineFromUserInput 是“用户原句解析”的宽松入口。 +// 返回值说明: +// - deadline != nil:成功解析出时间; +// - hasHint=false 且 err=nil:文本里没有明显时间线索,应视为“用户没给时间”; +// - hasHint=true 且 err!=nil:用户给了时间但格式非法,应提示用户修正,不应落库。 +func parseOptionalDeadlineFromUserInput(raw string, now time.Time) (*time.Time, bool, error) { + value := normalizeDeadlineInput(raw) + if value == "" { + return nil, false, nil + } + + deadline, hasHint, err := parseOptionalDeadlineFromText(value, now) + if err != nil { + if hasHint { + return nil, true, err + } + return nil, false, nil + } + if deadline == nil { + if hasHint { + return nil, true, fmt.Errorf("deadline_at 无法解析: %s", value) + } + return nil, false, nil + } + return deadline, true, nil +} + +// parseOptionalDeadlineFromText 是内部通用解析器。 +// 解析顺序: +// 1) 绝对时间(明确年月日时分); +// 2) 相对时间(明天/下周一/今晚); +// 3) 若识别到时间线索但仍失败,返回 hasHint=true + error,交给上层决定是否拦截。 +func parseOptionalDeadlineFromText(value string, now time.Time) (*time.Time, bool, error) { + if strings.TrimSpace(value) == "" { + return nil, false, nil + } + + loc := quickNoteLocation() + now = now.In(loc) + hasHint := hasDeadlineHint(value) + + if abs, ok := tryParseAbsoluteDeadline(value, loc); ok { + return abs, true, nil + } + + if rel, recognized, err := tryParseRelativeDeadline(value, now, loc); recognized { + if err != nil { + return nil, true, err + } + return rel, true, nil + } + + if hasHint { + return nil, true, fmt.Errorf("deadline_at 格式不支持: %s", value) + } + return nil, false, nil +} + +// normalizeDeadlineInput 把中文标点和空白先归一化,降低格式解析的噪声。 +func normalizeDeadlineInput(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + replacer := strings.NewReplacer( + ":", ":", + ",", ",", + "。", ".", + " ", " ", + ) + return strings.TrimSpace(replacer.Replace(trimmed)) +} + +// hasDeadlineHint 判断文本里是否存在“时间相关线索”。 +// 该函数的意义是区分两种情况: +// 1) 用户根本没给时间(允许 deadline 为空); +// 2) 用户给了时间但写错(必须提示修正,不能静默写 NULL)。 +func hasDeadlineHint(value string) bool { + if quickNoteClockHMRegex.MatchString(value) || + quickNoteClockCNRegex.MatchString(value) || + quickNoteYMDRegex.MatchString(value) || + quickNoteMDRegex.MatchString(value) || + quickNoteDateSepRegex.MatchString(value) || + quickNoteWeekdayRegex.MatchString(value) { + return true + } + for _, token := range quickNoteRelativeTokens { + if strings.Contains(value, token) { + return true + } + } + return false +} + +// tryParseAbsoluteDeadline 尝试按绝对时间格式解析。 +// 若只提供日期(无时分),默认归一到当天 23:59,表示“当日截止”。 +func tryParseAbsoluteDeadline(value string, loc *time.Location) (*time.Time, bool) { + for _, layout := range quickNoteDeadlineLayouts { + var ( + t time.Time + err error + ) + if layout == time.RFC3339 { + t, err = time.Parse(layout, value) + if err == nil { + t = t.In(loc) + } + } else { + t, err = time.ParseInLocation(layout, value, loc) + } + if err != nil { + continue + } + + if _, dateOnly := quickNoteDateOnlyLayouts[layout]; dateOnly { + t = time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 0, 0, loc) + } else { + t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), 0, 0, loc) + } + return &t, true + } + return nil, false +} + +// tryParseRelativeDeadline 尝试解析“相对时间 + 可选时刻”。 +// 例子: +// - 明天交报告(默认 23:59) +// - 下周一上午9点开会(解析为下周一 09:00) +func tryParseRelativeDeadline(value string, now time.Time, loc *time.Location) (*time.Time, bool, error) { + baseDate, recognized := inferBaseDate(value, now, loc) + if !recognized { + return nil, false, nil + } + + hour, minute, hasExplicitClock, err := extractClock(value) + if err != nil { + return nil, true, err + } + if !hasExplicitClock { + hour, minute = defaultClockByHint(value) + } + + deadline := time.Date(baseDate.Year(), baseDate.Month(), baseDate.Day(), hour, minute, 0, 0, loc) + return &deadline, true, nil +} + +// inferBaseDate 负责先确定“哪一天”。 +// 解析优先级: +// 1) 明确年月日; +// 2) 月日(自动推断年份); +// 3) 周几表达(本周/下周); +// 4) 明天/后天/今晚等相对词。 +func inferBaseDate(value string, now time.Time, loc *time.Location) (time.Time, bool) { + if matched := quickNoteYMDRegex.FindStringSubmatch(value); len(matched) == 4 { + year, _ := strconv.Atoi(matched[1]) + month, _ := strconv.Atoi(matched[2]) + day, _ := strconv.Atoi(matched[3]) + if isValidDate(year, month, day) { + return time.Date(year, time.Month(month), day, 0, 0, 0, 0, loc), true + } + } + + if matched := quickNoteMDRegex.FindStringSubmatch(value); len(matched) == 3 { + month, _ := strconv.Atoi(matched[1]) + day, _ := strconv.Atoi(matched[2]) + year := now.Year() + if !isValidDate(year, month, day) { + return time.Time{}, false + } + candidate := time.Date(year, time.Month(month), day, 0, 0, 0, 0, loc) + if candidate.Before(startOfDay(now)) { + year++ + if !isValidDate(year, month, day) { + return time.Time{}, false + } + candidate = time.Date(year, time.Month(month), day, 0, 0, 0, 0, loc) + } + return candidate, true + } + + if matched := quickNoteWeekdayRegex.FindStringSubmatch(value); len(matched) == 3 { + prefix := matched[1] + target, ok := toWeekday(matched[2]) + if ok { + return resolveWeekdayDate(now, prefix, target), true + } + } + + today := startOfDay(now) + switch { + case strings.Contains(value, "大后天"): + return today.AddDate(0, 0, 3), true + case strings.Contains(value, "后天"): + return today.AddDate(0, 0, 2), true + case strings.Contains(value, "明天") || strings.Contains(value, "明日"): + return today.AddDate(0, 0, 1), true + case strings.Contains(value, "今天") || strings.Contains(value, "今日") || strings.Contains(value, "今晚") || strings.Contains(value, "今早") || strings.Contains(value, "今晨"): + return today, true + case strings.Contains(value, "昨天") || strings.Contains(value, "昨日"): + return today.AddDate(0, 0, -1), true + default: + return time.Time{}, false + } +} + +// extractClock 从文本提取时刻(时/分)。 +// 支持: +// - 24h 表达:18:30 +// - 中文表达:3点、3点半、3点20分 +func extractClock(value string) (int, int, bool, error) { + hour := 0 + minute := 0 + hasClock := false + + if matched := quickNoteClockHMRegex.FindStringSubmatch(value); len(matched) == 3 { + h, errH := strconv.Atoi(matched[1]) + m, errM := strconv.Atoi(matched[2]) + if errH != nil || errM != nil { + return 0, 0, true, fmt.Errorf("deadline_at 时间解析失败: %s", value) + } + hour = h + minute = m + hasClock = true + } else if matched := quickNoteClockCNRegex.FindStringSubmatch(value); len(matched) >= 2 { + h, errH := strconv.Atoi(matched[1]) + if errH != nil { + return 0, 0, true, fmt.Errorf("deadline_at 时间解析失败: %s", value) + } + hour = h + minute = 0 + hasClock = true + if len(matched) >= 3 { + if matched[2] == "半" { + minute = 30 + } else if len(matched) >= 4 && strings.TrimSpace(matched[3]) != "" { + m, errM := strconv.Atoi(strings.TrimSpace(matched[3])) + if errM != nil { + return 0, 0, true, fmt.Errorf("deadline_at 时间解析失败: %s", value) + } + minute = m + } + } + } + + if !hasClock { + return 0, 0, false, nil + } + + if isPMHint(value) && hour < 12 { + hour += 12 + } + if isNoonHint(value) && hour >= 1 && hour <= 10 { + hour += 12 + } + if strings.Contains(value, "凌晨") && hour == 12 { + hour = 0 + } + + if hour < 0 || hour > 23 || minute < 0 || minute > 59 { + return 0, 0, true, fmt.Errorf("deadline_at 时间超出范围: %s", value) + } + return hour, minute, true, nil +} + +// defaultClockByHint 当文本只给了“日期/相对日”但没给具体时刻时,按语义兜底。 +func defaultClockByHint(value string) (int, int) { + switch { + case strings.Contains(value, "凌晨"): + return 1, 0 + case strings.Contains(value, "早上") || strings.Contains(value, "早晨") || strings.Contains(value, "上午") || strings.Contains(value, "今早") || strings.Contains(value, "明早"): + return 9, 0 + case strings.Contains(value, "中午"): + return 12, 0 + case strings.Contains(value, "下午"): + return 15, 0 + case strings.Contains(value, "晚上") || strings.Contains(value, "今晚") || strings.Contains(value, "傍晚") || strings.Contains(value, "夜里"): + return 20, 0 + default: + // 只给了日期没有具体时刻时,默认当天结束前。 + return 23, 59 + } +} + +func isPMHint(value string) bool { + return strings.Contains(value, "下午") || strings.Contains(value, "晚上") || strings.Contains(value, "今晚") || strings.Contains(value, "傍晚") +} + +func isNoonHint(value string) bool { + return strings.Contains(value, "中午") +} + +func startOfDay(t time.Time) time.Time { + loc := t.Location() + return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc) +} + +func isValidDate(year, month, day int) bool { + if month < 1 || month > 12 || day < 1 || day > 31 { + return false + } + candidate := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC) + return candidate.Year() == year && int(candidate.Month()) == month && candidate.Day() == day +} + +func toWeekday(chinese string) (time.Weekday, bool) { + switch chinese { + case "一": + return time.Monday, true + case "二": + return time.Tuesday, true + case "三": + return time.Wednesday, true + case "四": + return time.Thursday, true + case "五": + return time.Friday, true + case "六": + return time.Saturday, true + case "日", "天": + return time.Sunday, true + default: + return time.Sunday, false + } +} + +// resolveWeekdayDate 根据“本周/下周 + 周几”换算目标日期。 +func resolveWeekdayDate(now time.Time, prefix string, target time.Weekday) time.Time { + today := startOfDay(now) + weekdayOffset := (int(today.Weekday()) + 6) % 7 + weekStart := today.AddDate(0, 0, -weekdayOffset) + targetOffset := (int(target) + 6) % 7 + candidateThisWeek := weekStart.AddDate(0, 0, targetOffset) + + switch { + case strings.HasPrefix(prefix, "下"): + return candidateThisWeek.AddDate(0, 0, 7) + case strings.HasPrefix(prefix, "本"), strings.HasPrefix(prefix, "这"): + return candidateThisWeek + default: + if candidateThisWeek.Before(today) { + return candidateThisWeek.AddDate(0, 0, 7) + } + return candidateThisWeek + } +} diff --git a/backend/agent/tool_deadline_test.go b/backend/agent/tool_deadline_test.go new file mode 100644 index 0000000..df9f37c --- /dev/null +++ b/backend/agent/tool_deadline_test.go @@ -0,0 +1,123 @@ +package agent + +import ( + "testing" + "time" +) + +func TestParseOptionalDeadlineWithNow_Absolute(t *testing.T) { + loc := quickNoteLocation() + now := time.Date(2026, 3, 12, 10, 15, 0, 0, loc) + + deadline, err := parseOptionalDeadlineWithNow("2026-03-20 18:30", now) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if deadline == nil { + t.Fatalf("deadline should not be nil") + } + + want := time.Date(2026, 3, 20, 18, 30, 0, 0, loc) + if !deadline.Equal(want) { + t.Fatalf("unexpected deadline, got=%s want=%s", deadline.Format(time.RFC3339), want.Format(time.RFC3339)) + } +} + +func TestParseOptionalDeadlineWithNow_RelativeTomorrowWithoutClock(t *testing.T) { + loc := quickNoteLocation() + now := time.Date(2026, 3, 12, 10, 15, 0, 0, loc) + + deadline, err := parseOptionalDeadlineWithNow("明天交计网实验报告", now) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if deadline == nil { + t.Fatalf("deadline should not be nil") + } + + want := time.Date(2026, 3, 13, 23, 59, 0, 0, loc) + if !deadline.Equal(want) { + t.Fatalf("unexpected deadline, got=%s want=%s", deadline.Format(time.RFC3339), want.Format(time.RFC3339)) + } +} + +func TestParseOptionalDeadlineWithNow_RelativeTomorrowWithClock(t *testing.T) { + loc := quickNoteLocation() + now := time.Date(2026, 3, 12, 10, 15, 0, 0, loc) + + deadline, err := parseOptionalDeadlineWithNow("明天下午3点交计网实验报告", now) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if deadline == nil { + t.Fatalf("deadline should not be nil") + } + + want := time.Date(2026, 3, 13, 15, 0, 0, 0, loc) + if !deadline.Equal(want) { + t.Fatalf("unexpected deadline, got=%s want=%s", deadline.Format(time.RFC3339), want.Format(time.RFC3339)) + } +} + +func TestParseOptionalDeadlineWithNow_RelativeWeekday(t *testing.T) { + loc := quickNoteLocation() + now := time.Date(2026, 3, 12, 10, 15, 0, 0, loc) // 周四 + + deadline, err := parseOptionalDeadlineWithNow("下周一上午9点开组会", now) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if deadline == nil { + t.Fatalf("deadline should not be nil") + } + + want := time.Date(2026, 3, 16, 9, 0, 0, 0, loc) + if !deadline.Equal(want) { + t.Fatalf("unexpected deadline, got=%s want=%s", deadline.Format(time.RFC3339), want.Format(time.RFC3339)) + } +} + +func TestParseOptionalDeadlineFromUserInput_NoHint(t *testing.T) { + loc := quickNoteLocation() + now := time.Date(2026, 3, 12, 10, 15, 0, 0, loc) + + deadline, hasHint, err := parseOptionalDeadlineFromUserInput("帮我记一下要复习计网", now) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if hasHint { + t.Fatalf("expected no time hint") + } + if deadline != nil { + t.Fatalf("deadline should be nil when no time hint") + } +} + +func TestParseOptionalDeadlineFromUserInput_InvalidDate(t *testing.T) { + loc := quickNoteLocation() + now := time.Date(2026, 3, 12, 10, 15, 0, 0, loc) + + deadline, hasHint, err := parseOptionalDeadlineFromUserInput("2026-13-45 25:99 交实验", now) + if err == nil { + t.Fatalf("expected error but got nil") + } + if !hasHint { + t.Fatalf("expected hasHint=true") + } + if deadline != nil { + t.Fatalf("deadline should be nil for invalid date") + } +} + +func TestParseOptionalDeadlineWithNow_Invalid(t *testing.T) { + loc := quickNoteLocation() + now := time.Date(2026, 3, 12, 10, 15, 0, 0, loc) + + deadline, err := parseOptionalDeadlineWithNow("记得尽快处理", now) + if err == nil { + t.Fatalf("expected error but got nil") + } + if deadline != nil { + t.Fatalf("deadline should be nil for invalid input") + } +} diff --git a/backend/cmd/start.go b/backend/cmd/start.go index 534c893..7ef211c 100644 --- a/backend/cmd/start.go +++ b/backend/cmd/start.go @@ -83,7 +83,7 @@ func Start() { courseService := service.NewCourseService(courseRepo, scheduleRepo) taskClassService := service.NewTaskClassService(taskClassRepo, cacheRepo, scheduleRepo, manager) scheduleService := service.NewScheduleService(scheduleRepo, userRepo, taskClassRepo, manager, cacheRepo) - agentService := service.NewAgentService(aiHub, agentRepo, agentCacheRepo, asyncPipeline) + agentService := service.NewAgentService(aiHub, agentRepo, taskRepo, agentCacheRepo, asyncPipeline) // API 层初始化。 userApi := api.NewUserHandler(userService) diff --git a/backend/conv/task.go b/backend/conv/task.go index a7cbbae..3f43057 100644 --- a/backend/conv/task.go +++ b/backend/conv/task.go @@ -37,7 +37,12 @@ func ModelToGetUserTasksResp(tasks []model.Task) []model.GetUserTaskResp { if task.IsCompleted { status = "completed" } - deadline := task.DeadlineAt.Format("2006-01-02 15:04:05") + + deadline := "" + if task.DeadlineAt != nil { + deadline = task.DeadlineAt.Format("2006-01-02 15:04:05") + } + resp = append(resp, model.GetUserTaskResp{ ID: task.ID, UserID: task.UserID, diff --git a/backend/model/task.go b/backend/model/task.go index 0540f90..a208059 100644 --- a/backend/model/task.go +++ b/backend/model/task.go @@ -3,28 +3,29 @@ package model import "time" type Task struct { - ID int `gorm:"primaryKey;autoIncrement"` - UserID int `gorm:"column:user_id;index"` - Title string `gorm:"type:varchar(255)"` - Priority int `gorm:"not null"` - IsCompleted bool `gorm:"column:is_completed;default:false"` - DeadlineAt time.Time `gorm:"column:deadline_at"` + ID int `gorm:"primaryKey;autoIncrement"` + UserID int `gorm:"column:user_id;index"` + Title string `gorm:"type:varchar(255)"` + Priority int `gorm:"not null"` + IsCompleted bool `gorm:"column:is_completed;default:false"` + DeadlineAt *time.Time `gorm:"column:deadline_at"` } type UserAddTaskResponse struct { - ID int `json:"id"` - Title string `json:"title"` - PriorityGroup int `json:"priority_group"` - DeadlineAt time.Time `json:"deadline_at"` - Status string `json:"status"` - CreatedAt time.Time `json:"created_at"` + ID int `json:"id"` + Title string `json:"title"` + PriorityGroup int `json:"priority_group"` + DeadlineAt *time.Time `json:"deadline_at"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` } type UserAddTaskRequest struct { - Title string `json:"title"` - PriorityGroup int `json:"priority_group"` - DeadlineAt time.Time `json:"deadline_at"` + Title string `json:"title"` + PriorityGroup int `json:"priority_group"` + DeadlineAt *time.Time `json:"deadline_at"` } + type GetUserTaskResp struct { ID int `json:"id"` UserID int `json:"user_id"` diff --git a/backend/service/agent.go b/backend/service/agent.go index 3f0892e..35148dd 100644 --- a/backend/service/agent.go +++ b/backend/service/agent.go @@ -20,14 +20,16 @@ import ( type AgentService struct { AIHub *inits.AIHub repo *dao.AgentDAO + taskRepo *dao.TaskDAO agentCache *dao.AgentCache asyncPipeline *AgentAsyncPipeline } -func NewAgentService(aiHub *inits.AIHub, repo *dao.AgentDAO, agentRedis *dao.AgentCache, asyncPipeline *AgentAsyncPipeline) *AgentService { +func NewAgentService(aiHub *inits.AIHub, repo *dao.AgentDAO, taskRepo *dao.TaskDAO, agentRedis *dao.AgentCache, asyncPipeline *AgentAsyncPipeline) *AgentService { return &AgentService{ AIHub: aiHub, repo: repo, + taskRepo: taskRepo, agentCache: agentRedis, asyncPipeline: asyncPipeline, } @@ -67,18 +69,104 @@ func pushErrNonBlocking(errChan chan error, err error) { } } +// runNormalChatFlow 执行普通流式聊天链路(非随口记)。 +// 该函数被两处复用: +// 1) 用户输入本就不是随口记; +// 2) 开启随口记进度推送后,最终判定“非随口记”时回落到普通聊天。 +func (s *AgentService) runNormalChatFlow( + ctx context.Context, + selectedModel *ark.ChatModel, + resolvedModelName string, + userMessage string, + ifThinking bool, + userID int, + chatID string, + traceID string, + requestStart time.Time, + outChan chan<- string, + errChan chan error, +) { + chatHistory, err := s.agentCache.GetHistory(ctx, chatID) + if err != nil { + pushErrNonBlocking(errChan, err) + return + } + + cacheMiss := false + if chatHistory == nil { + cacheMiss = true + histories, hisErr := s.repo.GetUserChatHistories(ctx, userID, pkg.HistoryFetchLimitByModel(resolvedModelName), chatID) + if hisErr != nil { + pushErrNonBlocking(errChan, hisErr) + return + } + chatHistory = conv.ToEinoMessages(histories) + } + + historyBudget := pkg.HistoryTokenBudgetByModel(resolvedModelName, agent.SystemPrompt, userMessage) + trimmedHistory, totalHistoryTokens, keptHistoryTokens, droppedCount := pkg.TrimHistoryByTokenBudget(chatHistory, historyBudget) + chatHistory = trimmedHistory + + targetWindow := pkg.CalcSessionWindowSize(len(chatHistory)) + if err = s.agentCache.SetSessionWindowSize(ctx, chatID, targetWindow); err != nil { + log.Printf("设置历史窗口失败 chat=%s: %v", chatID, err) + } + if err = s.agentCache.EnforceHistoryWindow(ctx, chatID); err != nil { + log.Printf("执行历史窗口裁剪失败 chat=%s: %v", chatID, err) + } + + if droppedCount > 0 { + log.Printf("历史裁剪: chat=%s total_tokens=%d kept_tokens=%d dropped=%d budget=%d target_window=%d", + chatID, totalHistoryTokens, keptHistoryTokens, droppedCount, historyBudget, targetWindow) + } + + if cacheMiss { + if err = s.agentCache.BackfillHistory(ctx, chatID, chatHistory); err != nil { + pushErrNonBlocking(errChan, err) + return + } + } + + fullText, streamErr := agent.StreamChat(ctx, selectedModel, resolvedModelName, userMessage, ifThinking, chatHistory, outChan, traceID, chatID, requestStart) + if streamErr != nil { + pushErrNonBlocking(errChan, streamErr) + return + } + + if err = s.agentCache.PushMessage(ctx, chatID, &schema.Message{Role: schema.User, Content: userMessage}); err != nil { + log.Printf("写入用户消息到 Redis 失败: %v", err) + } + + if err = s.saveChatHistoryReliable(ctx, model.ChatHistoryPersistPayload{ + UserID: userID, + ConversationID: chatID, + Role: "user", + Message: userMessage, + }); err != nil { + pushErrNonBlocking(errChan, err) + return + } + + if saveErr := s.saveChatHistoryReliable(context.Background(), model.ChatHistoryPersistPayload{ + UserID: userID, + ConversationID: chatID, + Role: "assistant", + Message: fullText, + }); saveErr != nil { + pushErrNonBlocking(errChan, saveErr) + } +} + func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThinking bool, modelName string, userID int, chatID string) (<-chan string, <-chan error) { requestStart := time.Now() traceID := uuid.NewString() - outChan := make(chan string, 5) + outChan := make(chan string, 8) errChan := make(chan error, 1) // 1) 规范会话 ID,选择模型。 chatID = normalizeConversationID(chatID) selectedModel, resolvedModelName := s.pickChatModel(modelName) - /*log.Printf("打点|请求开始|trace_id=%s|chat_id=%s|user_id=%d|model=%s|请求累计_ms=%d", - traceID, chatID, userID, resolvedModelName, time.Since(requestStart).Milliseconds())*/ // 2) 确保会话存在(优先缓存,必要时回源 DB 并创建)。 result, err := s.agentCache.GetConversationStatus(ctx, chatID) @@ -109,121 +197,77 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin } } - // 3) 拉取并裁剪历史上下文。 - chatHistory, err := s.agentCache.GetHistory(ctx, chatID) - if err != nil { - errChan <- err - close(outChan) - close(errChan) + // 3) 如果命中“任务安排关键词”,开启随口记阶段推送(伪装成 reasoning chunk)。 + if shouldEmitQuickNoteProgress(userMessage) { + go func() { + defer close(outChan) + + progress := newQuickNoteProgressEmitter(outChan, resolvedModelName, true) + progress.Emit("request.accepted", "检测到任务安排请求,开始执行随口记流程。") + + quickHandled, quickState, quickErr := s.tryHandleQuickNoteWithGraph( + ctx, + selectedModel, + userMessage, + userID, + chatID, + traceID, + progress.Emit, + ) + if quickErr != nil { + log.Printf("随口记 graph 执行失败,回退普通聊天 trace_id=%s chat_id=%s err=%v", traceID, chatID, quickErr) + } + + if quickHandled { + progress.Emit("quick_note.reply.polishing", "正在结合你的话题润色回复。") + quickReply := buildQuickNoteFinalReply(ctx, selectedModel, userMessage, quickState) + if emitErr := emitSingleAssistantCompletion(outChan, resolvedModelName, quickReply); emitErr != nil { + pushErrNonBlocking(errChan, emitErr) + return + } + + s.persistChatAfterReply(ctx, userID, chatID, userMessage, quickReply, errChan) + return + } + + progress.Emit("quick_note.fallback", "当前输入不是随口记请求,切换到普通对话。") + s.runNormalChatFlow(ctx, selectedModel, resolvedModelName, userMessage, ifThinking, userID, chatID, traceID, requestStart, outChan, errChan) + }() return outChan, errChan } - cacheMiss := false - if chatHistory == nil { - cacheMiss = true - histories, hisErr := s.repo.GetUserChatHistories(ctx, userID, pkg.HistoryFetchLimitByModel(resolvedModelName), chatID) - if hisErr != nil { - errChan <- hisErr - close(outChan) - close(errChan) - return outChan, errChan - } - chatHistory = conv.ToEinoMessages(histories) - } - - historyBudget := pkg.HistoryTokenBudgetByModel(resolvedModelName, agent.SystemPrompt, userMessage) - trimmedHistory, totalHistoryTokens, keptHistoryTokens, droppedCount := pkg.TrimHistoryByTokenBudget(chatHistory, historyBudget) - chatHistory = trimmedHistory - - targetWindow := pkg.CalcSessionWindowSize(len(chatHistory)) - if err = s.agentCache.SetSessionWindowSize(ctx, chatID, targetWindow); err != nil { - log.Printf("设置历史窗口失败 chat=%s: %v", chatID, err) - } - if err = s.agentCache.EnforceHistoryWindow(ctx, chatID); err != nil { - log.Printf("执行历史窗口裁剪失败 chat=%s: %v", chatID, err) - } - - if droppedCount > 0 { - log.Printf("历史裁剪: chat=%s total_tokens=%d kept_tokens=%d dropped=%d budget=%d target_window=%d", - chatID, totalHistoryTokens, keptHistoryTokens, droppedCount, historyBudget, targetWindow) - } - - if cacheMiss { - if err = s.agentCache.BackfillHistory(ctx, chatID, chatHistory); err != nil { - errChan <- err - close(outChan) - close(errChan) - return outChan, errChan - } - } - - // 单请求主链路打点:开流前准备完成。 - /*log.Printf("打点|开流前准备完成|trace_id=%s|chat_id=%s|本步耗时_ms=%d|请求累计_ms=%d|history_len=%d|cache_miss=%t", - traceID, + // 4) 无阶段推送模式:保持原逻辑,先尝试随口记,不命中再走普通聊天。 + quickHandled, quickState, quickErr := s.tryHandleQuickNoteWithGraph( + ctx, + selectedModel, + userMessage, + userID, chatID, - time.Since(requestStart).Milliseconds(), - time.Since(requestStart).Milliseconds(), - len(chatHistory), - cacheMiss, - )*/ + traceID, + nil, + ) + if quickErr != nil { + log.Printf("随口记 graph 执行失败,回退普通聊天 trace_id=%s chat_id=%s err=%v", traceID, chatID, quickErr) + } + if quickHandled { + go func() { + defer close(outChan) - // 4) 启动流式输出,回答完成后执行后置持久化。 + quickReply := buildQuickNoteFinalReply(ctx, selectedModel, userMessage, quickState) + if emitErr := emitSingleAssistantCompletion(outChan, resolvedModelName, quickReply); emitErr != nil { + pushErrNonBlocking(errChan, emitErr) + return + } + + s.persistChatAfterReply(ctx, userID, chatID, userMessage, quickReply, errChan) + }() + return outChan, errChan + } + + // 5) 普通流式聊天。 go func() { defer close(outChan) - - /*streamStart := time.Now()*/ - fullText, streamErr := agent.StreamChat(ctx, selectedModel, resolvedModelName, userMessage, ifThinking, chatHistory, outChan, traceID, chatID, requestStart) - if streamErr != nil { - pushErrNonBlocking(errChan, streamErr) - return - } - /*log.Printf("打点|流式输出完成|trace_id=%s|chat_id=%s|本步耗时_ms=%d|请求累计_ms=%d|reply_chars=%d", - traceID, chatID, time.Since(streamStart).Milliseconds(), time.Since(requestStart).Milliseconds(), len(fullText)) - - postPersistStart := time.Now() - - stepStart := time.Now()*/ - if err = s.agentCache.PushMessage(ctx, chatID, &schema.Message{Role: schema.User, Content: userMessage}); err != nil { - log.Printf("写入用户消息到 Redis 失败: %v", err) - } - /*log.Printf("打点|后置持久化_用户_写Redis|trace_id=%s|chat_id=%s|本步耗时_ms=%d|请求累计_ms=%d", - traceID, chatID, time.Since(stepStart).Milliseconds(), time.Since(requestStart).Milliseconds()) - - stepStart = time.Now()*/ - if err = s.saveChatHistoryReliable(ctx, model.ChatHistoryPersistPayload{ - UserID: userID, - ConversationID: chatID, - Role: "user", - Message: userMessage, - }); err != nil { - errChan <- err - close(outChan) - close(errChan) - } - /*log.Printf("打点|后置持久化_用户_写持久化请求|trace_id=%s|chat_id=%s|本步耗时_ms=%d|请求累计_ms=%d", - traceID, chatID, time.Since(stepStart).Milliseconds(), time.Since(requestStart).Milliseconds()) - - stepStart = time.Now() - if cacheErr := s.agentCache.PushMessage(context.Background(), chatID, &schema.Message{Role: schema.Assistant, Content: fullText}); cacheErr != nil { - log.Printf("写入助手消息到 Redis 失败: %v", cacheErr) - } - log.Printf("打点|后置持久化_助手_写Redis|trace_id=%s|chat_id=%s|本步耗时_ms=%d|请求累计_ms=%d", - traceID, chatID, time.Since(stepStart).Milliseconds(), time.Since(requestStart).Milliseconds()) - - stepStart = time.Now()*/ - if saveErr := s.saveChatHistoryReliable(context.Background(), model.ChatHistoryPersistPayload{ - UserID: userID, - ConversationID: chatID, - Role: "assistant", - Message: fullText, - }); saveErr != nil { - pushErrNonBlocking(errChan, saveErr) - } - /*log.Printf("打点|后置持久化_助手_写持久化请求|trace_id=%s|chat_id=%s|本步耗时_ms=%d|请求累计_ms=%d", - traceID, chatID, time.Since(stepStart).Milliseconds(), time.Since(requestStart).Milliseconds()) - - log.Printf("打点|后置持久化完成|trace_id=%s|chat_id=%s|本步耗时_ms=%d|请求累计_ms=%d", - traceID, chatID, time.Since(postPersistStart).Milliseconds(), time.Since(requestStart).Milliseconds())*/ + s.runNormalChatFlow(ctx, selectedModel, resolvedModelName, userMessage, ifThinking, userID, chatID, traceID, requestStart, outChan, errChan) }() return outChan, errChan diff --git a/backend/service/agent_quick_note.go b/backend/service/agent_quick_note.go new file mode 100644 index 0000000..921cf8a --- /dev/null +++ b/backend/service/agent_quick_note.go @@ -0,0 +1,316 @@ +package service + +import ( + "context" + "fmt" + "log" + "strings" + "time" + + "github.com/LoveLosita/smartflow/backend/agent" + "github.com/LoveLosita/smartflow/backend/model" + "github.com/cloudwego/eino-ext/components/model/ark" + "github.com/cloudwego/eino/schema" + "github.com/google/uuid" +) + +// quickNoteProgressEmitter 负责把“链路阶段状态”伪装成 OpenAI 兼容的 reasoning_content chunk。 +// 设计目标: +// 1) 不改现有 OpenAI 兼容协议外壳; +// 2) 让 Apifox 在等待期间也能看到“思考块”,避免用户空等; +// 3) 该 emitter 只负责状态,不负责最终正文回复和 [DONE] 结束块。 +type quickNoteProgressEmitter struct { + outChan chan<- string + modelName string + requestID string + created int64 + enablePush bool +} + +func newQuickNoteProgressEmitter(outChan chan<- string, modelName string, enable bool) *quickNoteProgressEmitter { + resolvedModel := strings.TrimSpace(modelName) + if resolvedModel == "" { + resolvedModel = "worker" + } + return &quickNoteProgressEmitter{ + outChan: outChan, + modelName: resolvedModel, + requestID: "chatcmpl-" + uuid.NewString(), + created: time.Now().Unix(), + enablePush: enable, + } +} + +// Emit 按“阶段 + 说明”输出 reasoning_content。 +// 注意: +// - 这里不输出 role,避免和后续正文的 role 块冲突; +// - 即使发送失败,也只记录日志,不影响主流程继续执行。 +func (e *quickNoteProgressEmitter) Emit(stage, detail string) { + if e == nil || !e.enablePush || e.outChan == nil { + return + } + stage = strings.TrimSpace(stage) + detail = strings.TrimSpace(detail) + if stage == "" && detail == "" { + return + } + + reasoning := fmt.Sprintf("阶段:%s", stage) + if detail != "" { + reasoning += "\n" + detail + } + + chunk, err := agent.ToOpenAIStream(&schema.Message{ReasoningContent: reasoning}, e.requestID, e.modelName, e.created, false) + if err != nil { + log.Printf("输出随口记阶段状态失败 stage=%s err=%v", stage, err) + return + } + if chunk != "" { + e.outChan <- chunk + } +} + +// tryHandleQuickNoteWithGraph 尝试用“随口记 graph”处理本次用户输入。 +// 返回值语义: +// - handled=true:本次请求已在随口记链路处理完成(成功/失败都会返回文案); +// - handled=false:不是随口记意图,调用方应回落普通聊天链路; +// - state:用于拼接最终“一次性正文回复”。 +func (s *AgentService) tryHandleQuickNoteWithGraph( + ctx context.Context, + selectedModel *ark.ChatModel, + userMessage string, + userID int, + chatID string, + traceID string, + emitStage func(stage, detail string), +) (handled bool, state *agent.QuickNoteState, err error) { + if s.taskRepo == nil || selectedModel == nil { + return false, nil, nil + } + + state = agent.NewQuickNoteState(traceID, userID, chatID, userMessage) + finalState, runErr := agent.RunQuickNoteGraph(ctx, agent.QuickNoteGraphRunInput{ + Model: selectedModel, + State: state, + Deps: agent.QuickNoteToolDeps{ + ResolveUserID: func(ctx context.Context) (int, error) { + return userID, nil + }, + CreateTask: func(ctx context.Context, req agent.QuickNoteCreateTaskRequest) (*agent.QuickNoteCreateTaskResult, error) { + taskModel := &model.Task{ + UserID: req.UserID, + Title: req.Title, + Priority: req.PriorityGroup, + IsCompleted: false, + DeadlineAt: req.DeadlineAt, + } + created, createErr := s.taskRepo.AddTask(taskModel) + if createErr != nil { + return nil, createErr + } + return &agent.QuickNoteCreateTaskResult{ + TaskID: created.ID, + Title: created.Title, + PriorityGroup: created.Priority, + DeadlineAt: created.DeadlineAt, + }, nil + }, + }, + EmitStage: emitStage, + }) + if runErr != nil { + return false, nil, runErr + } + if finalState == nil || !finalState.IsQuickNoteIntent { + return false, nil, nil + } + + return true, finalState, nil +} + +// emitSingleAssistantCompletion 将单条完整回复包装成 OpenAI 兼容 chunk 流并写入 outChan。 +// 说明: +// - 保持现有 OpenAI 兼容格式不变; +// - 正文只发一次,不做伪分段。 +func emitSingleAssistantCompletion(outChan chan<- string, modelName, reply string) error { + if strings.TrimSpace(modelName) == "" { + modelName = "worker" + } + requestID := "chatcmpl-" + uuid.NewString() + created := time.Now().Unix() + + chunk, err := agent.ToOpenAIStream(&schema.Message{Role: schema.Assistant, Content: reply}, requestID, modelName, created, true) + if err != nil { + return err + } + if chunk != "" { + outChan <- chunk + } + + finishChunk, err := agent.ToOpenAIFinishStream(requestID, modelName, created) + if err != nil { + return err + } + outChan <- finishChunk + outChan <- "[DONE]" + return nil +} + +// buildQuickNoteFinalReply 生成最终的一次性正文回复。 +// 组合策略: +// 1) 任务事实(标题/优先级/截止时间)由后端拼接,确保准确; +// 2) 轻松跟进句交给 AI 生成,贴合用户话题(避免硬编码“薯饼”这类场景分支); +// 3) AI 生成失败时自动降级为固定友好文案,保证稳定可用。 +func buildQuickNoteFinalReply(ctx context.Context, selectedModel *ark.ChatModel, userMessage string, state *agent.QuickNoteState) string { + if state == nil { + return "我这次没成功记上,别急,再发我一次我马上补上。" + } + + if state.Persisted { + title := strings.TrimSpace(state.ExtractedTitle) + if title == "" { + title = "这条任务" + } + + priorityText := "已安排优先级" + if agent.IsValidTaskPriority(state.ExtractedPriority) { + priorityText = fmt.Sprintf("优先级:%s", agent.PriorityLabelCN(state.ExtractedPriority)) + } + + deadlineText := "" + if state.ExtractedDeadline != nil { + deadlineText = fmt.Sprintf(";截止时间 %s", state.ExtractedDeadline.In(time.Local).Format("2006-01-02 15:04")) + } + + factLine := fmt.Sprintf("好,给你安排上了:%s(%s%s)。", title, priorityText, deadlineText) + + banter, err := generateQuickNoteBanter(ctx, selectedModel, userMessage, title, priorityText, deadlineText) + if err != nil { + return factLine + " 这下可以先安心推进,不用等 ddl 来敲门了。" + } + if strings.TrimSpace(banter) == "" { + return factLine + " 这下可以先安心推进,不用等 ddl 来敲门了。" + } + return factLine + " " + banter + } + + if strings.TrimSpace(state.DeadlineValidationError) != "" { + return "我识别到你给了时间,但格式不够明确,暂时不敢乱记。你可以改成比如:2026-03-20 18:30、明天下午3点、下周一上午9点,我立刻帮你安排。" + } + + if strings.TrimSpace(state.AssistantReply) != "" { + return strings.TrimSpace(state.AssistantReply) + } + return "这次没成功写入任务,我没跑路,再给我一次我就把它稳稳记上。" +} + +// generateQuickNoteBanter 让模型根据用户原话生成一条“贴题轻松句”。 +// 约束: +// - 只生成跟进语气,不承担事实表达; +// - 不得改动任务事实; +// - 输出控制在一句,方便直接拼接在事实句后。 +func generateQuickNoteBanter( + ctx context.Context, + selectedModel *ark.ChatModel, + userMessage string, + title string, + priorityText string, + deadlineText string, +) (string, error) { + if selectedModel == nil { + return "", fmt.Errorf("model is nil") + } + + prompt := fmt.Sprintf(`用户原话:%s +已确认事实: +- 任务标题:%s +- %s +- %s + +请输出一句轻松自然的跟进话术(仅一句)。`, + strings.TrimSpace(userMessage), + strings.TrimSpace(title), + strings.TrimSpace(priorityText), + strings.TrimSpace(deadlineText), + ) + + messages := []*schema.Message{ + schema.SystemMessage(agent.QuickNoteReplyBanterPrompt), + schema.UserMessage(prompt), + } + + resp, err := selectedModel.Generate(ctx, messages) + if err != nil { + return "", err + } + if resp == nil { + return "", fmt.Errorf("empty response") + } + + text := strings.TrimSpace(resp.Content) + text = strings.Trim(text, "\"'“”‘’") + if text == "" { + return "", fmt.Errorf("empty content") + } + + // 简单兜底:只保留首行,避免模型输出多段。 + if idx := strings.Index(text, "\n"); idx >= 0 { + text = strings.TrimSpace(text[:idx]) + } + return text, nil +} + +// shouldEmitQuickNoteProgress 用于判断是否应在“等待阶段”推送状态块。 +// 规则偏保守:只要出现明显“记任务/提醒”语义,就开启阶段推送。 +func shouldEmitQuickNoteProgress(userMessage string) bool { + text := strings.TrimSpace(userMessage) + if text == "" { + return false + } + keywords := []string{"记一下", "帮我记", "提醒", "任务", "待办", "日程", "安排", "截止", "ddl"} + for _, kw := range keywords { + if strings.Contains(text, kw) { + return true + } + } + return false +} + +// persistChatAfterReply 在“随口记 graph”返回后,复用当前项目的后置持久化策略: +// 1) 用户消息写 Redis + outbox/DB; +// 2) 助手消息写 Redis + outbox/DB。 +func (s *AgentService) persistChatAfterReply( + ctx context.Context, + userID int, + chatID string, + userMessage string, + assistantReply string, + errChan chan error, +) { + if err := s.agentCache.PushMessage(ctx, chatID, &schema.Message{Role: schema.User, Content: userMessage}); err != nil { + log.Printf("写入用户消息到 Redis 失败: %v", err) + } + + if err := s.saveChatHistoryReliable(ctx, model.ChatHistoryPersistPayload{ + UserID: userID, + ConversationID: chatID, + Role: "user", + Message: userMessage, + }); err != nil { + pushErrNonBlocking(errChan, err) + return + } + + if err := s.agentCache.PushMessage(context.Background(), chatID, &schema.Message{Role: schema.Assistant, Content: assistantReply}); err != nil { + log.Printf("写入助手消息到 Redis 失败: %v", err) + } + + if err := s.saveChatHistoryReliable(context.Background(), model.ChatHistoryPersistPayload{ + UserID: userID, + ConversationID: chatID, + Role: "assistant", + Message: assistantReply, + }); err != nil { + pushErrNonBlocking(errChan, err) + } +}