diff --git a/.gitignore b/.gitignore index 72b6cbe..75f51a7 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,5 @@ backend/config.yaml .idea/ .vscode/ .DS_Store # Mac 用户必加 -.gocache/ \ No newline at end of file +.gocache/ +.gomodcache/ \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md index d052ecf..91c7c92 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -4,6 +4,7 @@ 1. 默认语言规则:所有注释、接口文案、说明、评审反馈均使用中文。 2. 请勤加注释,尤其是复杂逻辑部分,确保代码易于理解和维护。 +3. 每次在本地执行测试命令(如 `go test`)后,必须清理项目根目录下的 `.gocache` 目录,避免缓存文件长期堆积。 ## 注释规范(强制) diff --git a/README.md b/README.md index 7dc3ac9..8852ee4 100644 --- a/README.md +++ b/README.md @@ -388,7 +388,7 @@ flowchart TD B --> C[确保会话存在
Redis会话状态检查
必要时回源DB创建] C --> D[模型控制码路由
action=quick_note/chat] D --> E{route是否命中quick_note} - E -- 否 --> X[普通聊天链路
StreamChat流式输出] + E -- 否 --> X[普通聊天链路
StreamChat流式输出
或者其它分支] E -- 是 --> F[quick_note.request.accepted
推送reasoning状态块] F --> G[跳过二次意图判定
直接进入聚合规划] G --> H[单请求聚合规划
生成title/deadline/priority/banter] diff --git a/backend/agent/chat/stream.go b/backend/agent/chat/stream.go index 4c2093d..aa52117 100644 --- a/backend/agent/chat/stream.go +++ b/backend/agent/chat/stream.go @@ -103,7 +103,7 @@ func StreamChat( traceID string, chatID string, requestStart time.Time, -) (string, error) { +) (string, *schema.TokenUsage, error) { /*callStart := time.Now()*/ messages := make([]*schema.Message, 0) @@ -123,7 +123,7 @@ func StreamChat( /*connectStart := time.Now()*/ reader, err := llm.Stream(ctx, messages, ark.WithThinking(thinking)) if err != nil { - return "", err + return "", nil, err } defer reader.Close() @@ -134,6 +134,7 @@ func StreamChat( created := time.Now().Unix() firstChunk := true chunkCount := 0 + var tokenUsage *schema.TokenUsage /*streamRecvStart := time.Now() log.Printf("打点|流连接建立|trace_id=%s|chat_id=%s|request_id=%s|本步耗时_ms=%d|请求累计_ms=%d|history_len=%d", @@ -152,14 +153,19 @@ func StreamChat( break } if err != nil { - return "", err + return "", nil, err + } + + // 优先记录模型真实 usage(通常在尾块返回,部分模型也可能中途返回)。 + if chunk != nil && chunk.ResponseMeta != nil && chunk.ResponseMeta.Usage != nil { + tokenUsage = mergeTokenUsage(tokenUsage, chunk.ResponseMeta.Usage) } fullText.WriteString(chunk.Content) payload, err := ToOpenAIStream(chunk, requestID, modelName, created, firstChunk) if err != nil { - return "", err + return "", nil, err } if payload != "" { outChan <- payload @@ -179,7 +185,7 @@ func StreamChat( finishChunk, err := ToOpenAIFinishStream(requestID, modelName, created) if err != nil { - return "", err + return "", nil, err } outChan <- finishChunk outChan <- "[DONE]" @@ -194,5 +200,39 @@ func StreamChat( time.Since(requestStart).Milliseconds(), )*/ - return fullText.String(), nil + return fullText.String(), tokenUsage, nil +} + +// mergeTokenUsage 合并流式分片中的 usage。 +// +// 设计说明: +// 1. 不同模型的 usage 回传时机不同(中间块/尾块); +// 2. 这里按“更大值覆盖”合并,确保最终拿到完整统计; +// 3. 只用于统计,不影响流式正文输出。 +func mergeTokenUsage(base *schema.TokenUsage, incoming *schema.TokenUsage) *schema.TokenUsage { + if incoming == nil { + return base + } + if base == nil { + copied := *incoming + return &copied + } + + merged := *base + if incoming.PromptTokens > merged.PromptTokens { + merged.PromptTokens = incoming.PromptTokens + } + if incoming.CompletionTokens > merged.CompletionTokens { + merged.CompletionTokens = incoming.CompletionTokens + } + if incoming.TotalTokens > merged.TotalTokens { + merged.TotalTokens = incoming.TotalTokens + } + if incoming.PromptTokenDetails.CachedTokens > merged.PromptTokenDetails.CachedTokens { + merged.PromptTokenDetails.CachedTokens = incoming.PromptTokenDetails.CachedTokens + } + if incoming.CompletionTokensDetails.ReasoningTokens > merged.CompletionTokensDetails.ReasoningTokens { + merged.CompletionTokensDetails.ReasoningTokens = incoming.CompletionTokensDetails.ReasoningTokens + } + return &merged } diff --git a/backend/api/task.go b/backend/api/task.go index dd80e8d..91b805c 100644 --- a/backend/api/task.go +++ b/backend/api/task.go @@ -62,3 +62,36 @@ func (th *TaskHandler) GetUserTasks(c *gin.Context) { //3. 返回响应 c.JSON(http.StatusOK, respond.RespWithData(respond.Ok, resp)) } + +// CompleteTask 标记任务为已完成。 +// +// 职责边界: +// 1. 负责解析请求与读取 user_id; +// 2. 负责调用 Service 执行业务; +// 3. 不负责幂等校验(幂等由路由中间件处理)。 +func (th *TaskHandler) CompleteTask(c *gin.Context) { + // 1. 绑定请求参数。 + var req model.UserCompleteTaskRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, respond.WrongParamType) + fmt.Println(err) + return + } + + // 2. 从鉴权上下文获取 user_id,保证只能操作自己的任务。 + userID := c.GetInt("user_id") + + // 3. 设置短超时,避免该写接口长期占用连接。 + ctx, cancel := context.WithTimeout(c.Request.Context(), 1*time.Second) + defer cancel() + + // 4. 调用 Service 执行“标记完成”逻辑。 + resp, err := th.svc.CompleteTask(ctx, &req, userID) + if err != nil { + respond.DealWithError(c, err) + return + } + + // 5. 返回统一响应结构。 + c.JSON(http.StatusOK, respond.RespWithData(respond.Ok, resp)) +} diff --git a/backend/cmd/start.go b/backend/cmd/start.go index 3e5940e..87b9ee2 100644 --- a/backend/cmd/start.go +++ b/backend/cmd/start.go @@ -82,6 +82,9 @@ func Start() { if err = eventsvc.RegisterTaskUrgencyPromoteHandler(eventBus, outboxRepo, manager); err != nil { log.Fatalf("Failed to register task urgency promote event handler: %v", err) } + if err = eventsvc.RegisterChatTokenUsageAdjustHandler(eventBus, outboxRepo, manager); err != nil { + log.Fatalf("Failed to register chat token usage adjust event handler: %v", err) + } eventBus.Start(context.Background()) defer eventBus.Close() log.Println("Outbox event bus started") diff --git a/backend/dao/agent.go b/backend/dao/agent.go index 3a26268..d497092 100644 --- a/backend/dao/agent.go +++ b/backend/dao/agent.go @@ -33,22 +33,32 @@ func (r *AgentDAO) WithTx(tx *gorm.DB) *AgentDAO { // 失败处理: // 1. 任一步骤失败都返回 error; // 2. 若调用方处于事务中,返回 error 会触发事务回滚。 -func (a *AgentDAO) saveChatHistoryCore(ctx context.Context, userID int, conversationID string, role, message string) error { +func (a *AgentDAO) saveChatHistoryCore(ctx context.Context, userID int, conversationID string, role, message string, tokensConsumed int) error { + // 0. token 入库前兜底:负数统一归零,避免异常值污染累计统计。 + if tokensConsumed < 0 { + tokensConsumed = 0 + } + // 1. 先写 chat_histories 原始消息。 userChat := model.ChatHistory{ UserID: userID, MessageContent: &message, Role: &role, ChatID: conversationID, + TokensConsumed: tokensConsumed, } if err := a.db.WithContext(ctx).Create(&userChat).Error; err != nil { return err } - // 2. 再更新会话统计(message_count +1, last_message_at=now)。 + // 2. 再更新会话统计: + // 2.1 message_count +1,保持和 chat_histories 行数口径一致; + // 2.2 tokens_total 累加本条消息 token; + // 2.3 last_message_at 刷新为当前时间,供会话排序使用。 now := time.Now() updates := map[string]interface{}{ "message_count": gorm.Expr("message_count + ?", 1), + "tokens_total": gorm.Expr("tokens_total + ?", tokensConsumed), "last_message_at": &now, } result := a.db.WithContext(ctx).Model(&model.AgentChat{}). @@ -61,6 +71,23 @@ func (a *AgentDAO) saveChatHistoryCore(ctx context.Context, userID int, conversa // 会话不存在时直接失败,避免出现“孤儿历史消息”。 return fmt.Errorf("conversation not found when updating stats: user_id=%d chat_id=%s", userID, conversationID) } + + // 3. 最后更新 users.token_usage(同一事务内): + // 3.1 只在 tokensConsumed>0 时执行,避免无意义写入; + // 3.2 和 chat_histories/agent_chats 放在同一事务里,保证统计口径原子一致; + // 3.3 若用户行不存在则返回错误,触发事务回滚,防止出现“会话统计成功但用户统计丢失”。 + if tokensConsumed > 0 { + userUpdate := a.db.WithContext(ctx). + Model(&model.User{}). + Where("id = ?", userID). + Update("token_usage", gorm.Expr("token_usage + ?", tokensConsumed)) + if userUpdate.Error != nil { + return userUpdate.Error + } + if userUpdate.RowsAffected == 0 { + return fmt.Errorf("user not found when updating token usage: user_id=%d", userID) + } + } return nil } @@ -69,8 +96,8 @@ func (a *AgentDAO) saveChatHistoryCore(ctx context.Context, userID int, conversa // 设计目的: // 1. 给服务层组合多个 DAO 操作时复用,避免嵌套事务; // 2. 让 outbox 消费处理器可以和业务写入共享同一个 tx。 -func (a *AgentDAO) SaveChatHistoryInTx(ctx context.Context, userID int, conversationID string, role, message string) error { - return a.saveChatHistoryCore(ctx, userID, conversationID, role, message) +func (a *AgentDAO) SaveChatHistoryInTx(ctx context.Context, userID int, conversationID string, role, message string, tokensConsumed int) error { + return a.saveChatHistoryCore(ctx, userID, conversationID, role, message, tokensConsumed) } // SaveChatHistory 在同步直写路径下写入聊天历史。 @@ -78,9 +105,58 @@ func (a *AgentDAO) SaveChatHistoryInTx(ctx context.Context, userID int, conversa // 说明: // 1. 该方法会自行开启事务; // 2. 内部复用 saveChatHistoryCore,确保和 SaveChatHistoryInTx 的业务口径完全一致。 -func (a *AgentDAO) SaveChatHistory(ctx context.Context, userID int, conversationID string, role, message string) error { +func (a *AgentDAO) SaveChatHistory(ctx context.Context, userID int, conversationID string, role, message string, tokensConsumed int) error { return a.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - return a.WithTx(tx).saveChatHistoryCore(ctx, userID, conversationID, role, message) + return a.WithTx(tx).saveChatHistoryCore(ctx, userID, conversationID, role, message, tokensConsumed) + }) +} + +// adjustTokenUsageCore 在同一事务语义下做“会话+用户”token 账本增量调整。 +// +// 职责边界: +// 1. 只更新 agent_chats.tokens_total 与 users.token_usage; +// 2. 不写 chat_histories(消息落库由 SaveChatHistory* 路径负责); +// 3. deltaTokens<=0 时视为无操作,直接返回。 +func (a *AgentDAO) adjustTokenUsageCore(ctx context.Context, userID int, conversationID string, deltaTokens int) error { + if deltaTokens <= 0 { + return nil + } + + // 1. 先更新会话累计 token。 + chatUpdate := a.db.WithContext(ctx). + Model(&model.AgentChat{}). + Where("user_id = ? AND chat_id = ?", userID, conversationID). + Update("tokens_total", gorm.Expr("tokens_total + ?", deltaTokens)) + if chatUpdate.Error != nil { + return chatUpdate.Error + } + if chatUpdate.RowsAffected == 0 { + return fmt.Errorf("conversation not found when adjusting tokens: user_id=%d chat_id=%s", userID, conversationID) + } + + // 2. 再更新用户累计 token。 + userUpdate := a.db.WithContext(ctx). + Model(&model.User{}). + Where("id = ?", userID). + Update("token_usage", gorm.Expr("token_usage + ?", deltaTokens)) + if userUpdate.Error != nil { + return userUpdate.Error + } + if userUpdate.RowsAffected == 0 { + return fmt.Errorf("user not found when adjusting token usage: user_id=%d", userID) + } + return nil +} + +// AdjustTokenUsageInTx 在调用方已开启事务时执行 token 账本增量调整。 +func (a *AgentDAO) AdjustTokenUsageInTx(ctx context.Context, userID int, conversationID string, deltaTokens int) error { + return a.adjustTokenUsageCore(ctx, userID, conversationID, deltaTokens) +} + +// AdjustTokenUsage 在同步路径下执行 token 账本增量调整(内部自带事务)。 +func (a *AgentDAO) AdjustTokenUsage(ctx context.Context, userID int, conversationID string, deltaTokens int) error { + return a.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return a.WithTx(tx).adjustTokenUsageCore(ctx, userID, conversationID, deltaTokens) }) } diff --git a/backend/dao/task.go b/backend/dao/task.go index 5ed4f75..56fc6d2 100644 --- a/backend/dao/task.go +++ b/backend/dao/task.go @@ -2,6 +2,7 @@ package dao import ( "context" + "errors" "time" "github.com/LoveLosita/smartflow/backend/model" @@ -45,6 +46,76 @@ func (dao *TaskDAO) GetTasksByUserID(userID int) ([]model.Task, error) { return tasks, nil } +// CompleteTaskByID 将指定任务标记为“已完成”。 +// +// 职责边界: +// 1. 只负责“当前用户 + 指定 task_id”的完成状态更新; +// 2. 不负责幂等中间件(由路由层统一挂载); +// 3. 不负责业务层响应包装(由 Service 层处理)。 +// +// 返回语义: +// 1. 第一个返回值 *model.Task:返回更新后的任务快照(至少含 ID/UserID/IsCompleted); +// 2. 第二个返回值 bool: +// 2.1 true:任务原本就已完成,本次属于幂等命中; +// 2.2 false:本次从未完成成功更新为已完成; +// 3. error: +// 3.1 gorm.ErrRecordNotFound:任务不存在或不属于当前用户; +// 3.2 其他 error:数据库异常。 +func (dao *TaskDAO) CompleteTaskByID(ctx context.Context, userID int, taskID int) (*model.Task, bool, error) { + // 1. 基础兜底:非法参数直接返回“记录不存在”语义,避免下游误写。 + if userID <= 0 || taskID <= 0 { + return nil, false, gorm.ErrRecordNotFound + } + + // 2. 先查询目标任务,明确区分“已完成”与“不存在”。 + var target model.Task + findErr := dao.db.WithContext(ctx). + Where("id = ? AND user_id = ?", taskID, userID). + First(&target).Error + if findErr != nil { + return nil, false, findErr + } + + // 3. 若任务已完成,直接按幂等成功返回,不再写库。 + if target.IsCompleted { + return &target, true, nil + } + + // 4. 若任务未完成,执行状态更新。 + // + // 4.1 使用 Model(&model.Task{UserID:userID}) 的目的: + // 让 cache_deleter 在 GORM Update 回调里拿到 user_id,从而正确删除任务缓存。 + // 4.2 更新条件继续限定 user_id + id,避免误更新其他用户数据。 + updateResult := dao.db.WithContext(ctx). + Model(&model.Task{UserID: userID}). + Where("id = ? AND user_id = ?", taskID, userID). + Update("is_completed", true) + if updateResult.Error != nil { + return nil, false, updateResult.Error + } + + // 5. 极端并发兜底: + // 5.1 若 RowsAffected=0,可能是并发请求已先一步更新; + // 5.2 此时二次读取任务状态,若已完成则按幂等成功返回,否则视为不存在/异常。 + if updateResult.RowsAffected == 0 { + var check model.Task + checkErr := dao.db.WithContext(ctx). + Where("id = ? AND user_id = ?", taskID, userID). + First(&check).Error + if checkErr != nil { + return nil, false, checkErr + } + if check.IsCompleted { + return &check, true, nil + } + return nil, false, errors.New("任务状态更新失败") + } + + // 6. 返回更新后的快照给 Service 层组装响应。 + target.IsCompleted = true + return &target, false, nil +} + // PromoteTaskUrgencyByIDs 批量执行“任务紧急性平移”。 // // 职责边界: diff --git a/backend/model/agent.go b/backend/model/agent.go index 61bc69b..523cada 100644 --- a/backend/model/agent.go +++ b/backend/model/agent.go @@ -20,6 +20,21 @@ type ChatHistoryPersistPayload struct { ConversationID string `json:"conversation_id"` Role string `json:"role"` Message string `json:"message"` + TokensConsumed int `json:"tokens_consumed"` +} + +// ChatTokenUsageAdjustPayload 是“会话 token 账本增量调整”事件载荷。 +// +// 职责边界: +// 1. 只表达“对哪个用户/会话增加多少 token”; +// 2. 不承载 chat_histories 落库语义(消息正文由聊天持久化事件负责); +// 3. 不包含 outbox/kafka 协议字段(由基础设施层统一封装)。 +type ChatTokenUsageAdjustPayload struct { + UserID int `json:"user_id"` + ConversationID string `json:"conversation_id"` + TokensDelta int `json:"tokens_delta"` + Reason string `json:"reason"` + TriggeredAt time.Time `json:"triggered_at"` } // GetConversationMetaResponse 是会话元信息查询接口的返回结构。 diff --git a/backend/model/task.go b/backend/model/task.go index 18a6af3..7b60b25 100644 --- a/backend/model/task.go +++ b/backend/model/task.go @@ -56,6 +56,31 @@ type UserAddTaskRequest struct { DeadlineAt *time.Time `json:"deadline_at"` } +// UserCompleteTaskRequest 是“标记任务完成”接口的请求体。 +// +// 职责边界: +// 1. 只承载目标任务 ID; +// 2. 不承载 user_id(user_id 一律由鉴权中间件注入,避免越权)。 +type UserCompleteTaskRequest struct { + TaskID int `json:"task_id"` +} + +// UserCompleteTaskResponse 是“标记任务完成”接口的响应体。 +// +// 字段语义: +// 1. TaskID:本次操作的目标任务; +// 2. IsCompleted:操作后的完成状态(成功时恒为 true); +// 3. AlreadyCompleted: +// 3.1 true:任务原本就已完成,本次请求命中幂等语义; +// 3.2 false:任务由未完成切换为完成; +// 4. Status:给前端的简短状态文案。 +type UserCompleteTaskResponse struct { + TaskID int `json:"task_id"` + IsCompleted bool `json:"is_completed"` + AlreadyCompleted bool `json:"already_completed"` + Status string `json:"status"` +} + type GetUserTaskResp struct { ID int `json:"id"` UserID int `json:"user_id"` diff --git a/backend/respond/respond.go b/backend/respond/respond.go index a4b1f7e..da7d34e 100644 --- a/backend/respond/respond.go +++ b/backend/respond/respond.go @@ -313,4 +313,9 @@ var ( //请求相关的响应 Status: "40049", Info: "task class item trying to insert out of time range", } + + WrongTaskID = Response{ //任务ID错误 + Status: "40050", + Info: "wrong task id", + } ) diff --git a/backend/routers/routers.go b/backend/routers/routers.go index 505d0ae..3ac667c 100644 --- a/backend/routers/routers.go +++ b/backend/routers/routers.go @@ -53,6 +53,7 @@ func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO, limiter *pk { taskGroup.Use(middleware.JWTTokenAuth(cache), middleware.RateLimitMiddleware(limiter, 20, 1)) taskGroup.POST("/create", middleware.IdempotencyMiddleware(cache), handlers.TaskHandler.AddTask) + taskGroup.PUT("/complete", middleware.IdempotencyMiddleware(cache), handlers.TaskHandler.CompleteTask) taskGroup.GET("/get", handlers.TaskHandler.GetUserTasks) } courseGroup := apiGroup.Group("/course") diff --git a/backend/service/agentsvc/agent.go b/backend/service/agentsvc/agent.go index d34bc31..5e9de2f 100644 --- a/backend/service/agentsvc/agent.go +++ b/backend/service/agentsvc/agent.go @@ -32,6 +32,11 @@ type AgentService struct { // 这里通过依赖注入把“模型、仓储、缓存、异步持久化通道”统一交给服务层管理, // 便于后续在单测中替换实现,或在启动流程中按环境切换配置。 func NewAgentService(aiHub *inits.AIHub, repo *dao.AgentDAO, taskRepo *dao.TaskDAO, agentRedis *dao.AgentCache, eventPublisher outboxinfra.EventPublisher) *AgentService { + // 全局注册一次 token 采集 callback: + // 1. 只注册一次,避免重复处理; + // 2. 只有带 RequestTokenMeter 的请求上下文才会真正累加。 + ensureTokenMeterCallbackRegistered() + return &AgentService{ AIHub: aiHub, repo: repo, @@ -76,7 +81,7 @@ func (s *AgentService) PersistChatHistory(ctx context.Context, payload model.Cha // 1. 未注入事件发布器时(例如本地极简环境),直接同步写 DB。 // 这样可以保证功能不依赖 Kafka 也能跑通。 if s.eventPublisher == nil { - return s.repo.SaveChatHistory(ctx, payload.UserID, payload.ConversationID, payload.Role, payload.Message) + return s.repo.SaveChatHistory(ctx, payload.UserID, payload.ConversationID, payload.Role, payload.Message, payload.TokensConsumed) } // 2. 已启用异步总线时,只发布“持久化请求事件”,不在请求路径阻塞 Kafka。 // 2.1 发布成功仅代表“事件安全入队”,实际落库由消费者异步完成。 @@ -167,12 +172,23 @@ func (s *AgentService) runNormalChatFlow( // 6. 执行真正的流式聊天。 // fullText 用于后续写 Redis/持久化,outChan 用于把流片段实时推给前端。 - fullText, streamErr := chat.StreamChat(ctx, selectedModel, resolvedModelName, userMessage, ifThinking, chatHistory, outChan, traceID, chatID, requestStart) + fullText, streamUsage, streamErr := chat.StreamChat(ctx, selectedModel, resolvedModelName, userMessage, ifThinking, chatHistory, outChan, traceID, chatID, requestStart) if streamErr != nil { pushErrNonBlocking(errChan, streamErr) return } + // 6.1 流式 usage 并入请求级 token 统计器: + // 6.1.1 route/quicknote/taskquery 等 Generate 调用由 callback 自动累加; + // 6.1.2 主对话 Stream usage 在这里手动补齐。 + addSchemaUsageIntoRequest(ctx, streamUsage) + requestTokenSnapshot := snapshotRequestTokenMeter(ctx) + requestTotalTokens := requestTokenSnapshot.TotalTokens + if requestTotalTokens <= 0 && streamUsage != nil { + // 兜底:若 callback/meter 未生效,至少使用流式 usage 保底记账。 + requestTotalTokens = normalizeUsageTotal(streamUsage.TotalTokens, streamUsage.PromptTokens, streamUsage.CompletionTokens) + } + // 7. 后置持久化(用户消息): // 7.1 先写 Redis,保证“最新会话上下文”可立即用于下一轮推理; // 7.2 再走可靠持久化入口(outbox 或同步 DB)。 @@ -185,6 +201,8 @@ func (s *AgentService) runNormalChatFlow( ConversationID: chatID, Role: "user", Message: userMessage, + // 口径B:用户消息固定记 0;本轮总 token 统一记在助手消息。 + TokensConsumed: 0, }); err != nil { pushErrNonBlocking(errChan, err) return @@ -204,6 +222,8 @@ func (s *AgentService) runNormalChatFlow( ConversationID: chatID, Role: "assistant", Message: fullText, + // 口径B:助手消息记录“本轮请求总 token”。 + TokensConsumed: requestTotalTokens, }); saveErr != nil { pushErrNonBlocking(errChan, saveErr) } @@ -223,13 +243,16 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin outChan := make(chan string, 8) errChan := make(chan error, 1) + // 0. 初始化“请求级 token 统计器”,用于聚合本次请求所有模型开销。 + requestCtx, _ := withRequestTokenMeter(ctx) + // 1) 规范会话 ID,选择模型。 chatID = normalizeConversationID(chatID) selectedModel, resolvedModelName := s.pickChatModel(modelName) // 2) 确保会话存在(优先缓存,必要时回源 DB 并创建)。 // 2.1 先查 Redis 会话标记,命中则可跳过 DB 存在性校验。 - result, err := s.agentCache.GetConversationStatus(ctx, chatID) + result, err := s.agentCache.GetConversationStatus(requestCtx, chatID) if err != nil { errChan <- err close(outChan) @@ -238,7 +261,7 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin } if !result { // 2.2 缓存未命中时回源 DB:确认会话是否存在。 - innerResult, ifErr := s.repo.IfChatExists(ctx, userID, chatID) + innerResult, ifErr := s.repo.IfChatExists(requestCtx, userID, chatID) if ifErr != nil { errChan <- ifErr close(outChan) @@ -255,7 +278,7 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin } } // 2.4 补写 Redis 会话标记,优化下次访问。 - if err = s.agentCache.SetConversationStatus(ctx, chatID); err != nil { + if err = s.agentCache.SetConversationStatus(requestCtx, chatID); err != nil { log.Printf("设置会话状态缓存失败 chat=%s: %v", chatID, err) } } @@ -269,11 +292,11 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin defer close(outChan) // 3.1 先走轻量路由,拿到统一 action。 - routing := s.decideActionRouting(ctx, selectedModel, userMessage) + routing := s.decideActionRouting(requestCtx, selectedModel, userMessage) // 3.2 chat:直接走普通聊天主链路。 if routing.Action == route.ActionChat { - s.runNormalChatFlow(ctx, selectedModel, resolvedModelName, userMessage, ifThinking, userID, chatID, traceID, requestStart, outChan, errChan) + s.runNormalChatFlow(requestCtx, selectedModel, resolvedModelName, userMessage, ifThinking, userID, chatID, traceID, requestStart, outChan, errChan) return } @@ -284,7 +307,7 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin // 3.4 quick_note_create:执行随口记 graph。 if routing.Action == route.ActionQuickNoteCreate { quickHandled, quickState, quickErr := s.tryHandleQuickNoteWithGraph( - ctx, + requestCtx, selectedModel, userMessage, userID, @@ -301,14 +324,15 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin if quickHandled { // 3.4.1 随口记处理成功:组织最终回复并按 OpenAI 兼容格式输出。 progress.Emit("quick_note.reply.polishing", "正在结合你的话题润色回复。") - quickReply := buildQuickNoteFinalReply(ctx, selectedModel, userMessage, quickState) + quickReply := buildQuickNoteFinalReply(requestCtx, selectedModel, userMessage, quickState) if emitErr := emitSingleAssistantCompletion(outChan, resolvedModelName, quickReply); emitErr != nil { pushErrNonBlocking(errChan, emitErr) return } // 3.4.2 对随口记回复执行统一后置持久化(Redis + outbox/DB)。 - s.persistChatAfterReply(ctx, userID, chatID, userMessage, quickReply, errChan) + requestTotalTokens := snapshotRequestTokenMeter(requestCtx).TotalTokens + s.persistChatAfterReply(requestCtx, userID, chatID, userMessage, quickReply, 0, requestTotalTokens, errChan) // 3.4.3 随口记链路同样异步生成会话标题(仅首次写入)。 s.ensureConversationTitleAsync(userID, chatID) return @@ -316,18 +340,18 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin // 3.4.4 路由误判或 graph 判定非随口记时,回落普通聊天,保证“能聊”。 progress.Emit("quick_note.fallback", "当前输入不是随口记请求,切换到普通对话。") - s.runNormalChatFlow(ctx, selectedModel, resolvedModelName, userMessage, ifThinking, userID, chatID, traceID, requestStart, outChan, errChan) + s.runNormalChatFlow(requestCtx, selectedModel, resolvedModelName, userMessage, ifThinking, userID, chatID, traceID, requestStart, outChan, errChan) return } // 3.5 task_query:执行任务查询 tool-calling。 if routing.Action == route.ActionTaskQuery { - reply, queryErr := s.runTaskQueryFlow(ctx, selectedModel, userMessage, userID, progress.Emit) + reply, queryErr := s.runTaskQueryFlow(requestCtx, selectedModel, userMessage, userID, progress.Emit) if queryErr != nil { // 3.5.1 任务查询失败时回退普通聊天,避免请求直接中断。 log.Printf("任务查询 tool-calling 执行失败,回退普通聊天 trace_id=%s chat_id=%s err=%v", traceID, chatID, queryErr) progress.Emit("task_query.fallback", "任务查询暂不可用,先切回普通对话。") - s.runNormalChatFlow(ctx, selectedModel, resolvedModelName, userMessage, ifThinking, userID, chatID, traceID, requestStart, outChan, errChan) + s.runNormalChatFlow(requestCtx, selectedModel, resolvedModelName, userMessage, ifThinking, userID, chatID, traceID, requestStart, outChan, errChan) return } @@ -336,13 +360,14 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin pushErrNonBlocking(errChan, emitErr) return } - s.persistChatAfterReply(ctx, userID, chatID, userMessage, reply, errChan) + requestTotalTokens := snapshotRequestTokenMeter(requestCtx).TotalTokens + s.persistChatAfterReply(requestCtx, userID, chatID, userMessage, reply, 0, requestTotalTokens, errChan) s.ensureConversationTitleAsync(userID, chatID) return } // 3.6 未知 action 兜底:走普通聊天,保证可用性。 - s.runNormalChatFlow(ctx, selectedModel, resolvedModelName, userMessage, ifThinking, userID, chatID, traceID, requestStart, outChan, errChan) + s.runNormalChatFlow(requestCtx, selectedModel, resolvedModelName, userMessage, ifThinking, userID, chatID, traceID, requestStart, outChan, errChan) }() return outChan, errChan diff --git a/backend/service/agentsvc/agent_meta.go b/backend/service/agentsvc/agent_meta.go index 95f2760..6ba60b6 100644 --- a/backend/service/agentsvc/agent_meta.go +++ b/backend/service/agentsvc/agent_meta.go @@ -10,6 +10,7 @@ import ( "github.com/LoveLosita/smartflow/backend/model" "github.com/LoveLosita/smartflow/backend/respond" + eventsvc "github.com/LoveLosita/smartflow/backend/service/events" "github.com/cloudwego/eino-ext/components/model/ark" einoModel "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/schema" @@ -32,6 +33,9 @@ const ( conversationListDefaultPageSize = 20 // conversationListMaxPageSize 是会话列表单页上限,避免超大分页压垮数据库。 conversationListMaxPageSize = 100 + // conversationTitleTokenAdjustReason 是“标题异步生成 token 账本调整”原因码。 + // 用于日志和后续审计归因。 + conversationTitleTokenAdjustReason = "conversation_title_async" ) const conversationTitlePrompt = `你是 SmartFlow 的会话标题生成器。 @@ -190,7 +194,7 @@ func (s *AgentService) ensureConversationTitleAsync(userID int, chatID string) { } // 4. 调用模型生成标题,并做格式清洗。 - generated, err := s.generateConversationTitle(ctx, history) + generated, titleTokens, err := s.generateConversationTitle(ctx, history) if err != nil { log.Printf("异步生成会话标题失败(模型生成失败) chat=%s err=%v", chatID, err) return @@ -199,6 +203,28 @@ func (s *AgentService) ensureConversationTitleAsync(userID int, chatID string) { return } + // 4.1 标题生成成功后,把本次异步模型 token 记账: + // 4.1.1 启用 outbox 时走 adjust 事件,异步可靠入账; + // 4.1.2 未启用 outbox 时走同步兜底,直接更新账本。 + if titleTokens > 0 { + if s.eventPublisher != nil { + publishErr := eventsvc.PublishChatTokenUsageAdjustRequested(ctx, s.eventPublisher, model.ChatTokenUsageAdjustPayload{ + UserID: userID, + ConversationID: chatID, + TokensDelta: titleTokens, + Reason: conversationTitleTokenAdjustReason, + TriggeredAt: time.Now(), + }) + if publishErr != nil { + log.Printf("异步标题 token 记账事件发布失败 chat=%s tokens=%d err=%v", chatID, titleTokens, publishErr) + } + } else { + if adjustErr := s.repo.AdjustTokenUsage(ctx, userID, chatID, titleTokens); adjustErr != nil { + log.Printf("异步标题 token 同步记账失败 chat=%s tokens=%d err=%v", chatID, titleTokens, adjustErr) + } + } + } + // 5. 只在标题仍为空时写入,保证并发幂等。 if err = s.repo.UpdateConversationTitleIfEmpty(ctx, userID, chatID, generated); err != nil { log.Printf("异步生成会话标题失败(写库失败) chat=%s err=%v", chatID, err) @@ -207,17 +233,17 @@ func (s *AgentService) ensureConversationTitleAsync(userID int, chatID string) { } // generateConversationTitle 使用聊天模型从近期历史生成标题。 -func (s *AgentService) generateConversationTitle(ctx context.Context, history []*schema.Message) (string, error) { +func (s *AgentService) generateConversationTitle(ctx context.Context, history []*schema.Message) (string, int, error) { modelInst := s.pickTitleModel() if modelInst == nil { - return "", fmt.Errorf("标题生成模型未初始化") + return "", 0, fmt.Errorf("标题生成模型未初始化") } // 1. 只取最近 N 条,降低 token 并聚焦当前会话主题。 trimmed := tailMessages(history, conversationTitleHistoryLimit) prompt := buildConversationTitleUserPrompt(trimmed) if strings.TrimSpace(prompt) == "" { - return "", fmt.Errorf("缺少可用历史内容") + return "", 0, fmt.Errorf("缺少可用历史内容") } messages := []*schema.Message{ @@ -232,12 +258,22 @@ func (s *AgentService) generateConversationTitle(ctx context.Context, history [] einoModel.WithMaxTokens(40), ) if err != nil { - return "", err + return "", 0, err } if resp == nil { - return "", fmt.Errorf("标题生成模型返回为空") + return "", 0, fmt.Errorf("标题生成模型返回为空") } - return normalizeConversationTitle(resp.Content), nil + + // 2.1 标题链路的 token 从模型响应 usage 中提取;缺失则按 0 处理,不影响主流程。 + titleTokens := 0 + if resp.ResponseMeta != nil && resp.ResponseMeta.Usage != nil { + titleTokens = normalizeUsageTotal( + resp.ResponseMeta.Usage.TotalTokens, + resp.ResponseMeta.Usage.PromptTokens, + resp.ResponseMeta.Usage.CompletionTokens, + ) + } + return normalizeConversationTitle(resp.Content), titleTokens, nil } // pickTitleModel 选择用于标题生成的模型。 diff --git a/backend/service/agentsvc/agent_quick_note.go b/backend/service/agentsvc/agent_quick_note.go index ad41b16..ddcc08f 100644 --- a/backend/service/agentsvc/agent_quick_note.go +++ b/backend/service/agentsvc/agent_quick_note.go @@ -345,6 +345,8 @@ func (s *AgentService) persistChatAfterReply( chatID string, userMessage string, assistantReply string, + userTokens int, + assistantTokens int, errChan chan error, ) { // 1. 先把用户消息写入 Redis,保证会话上下文“马上可见”。 @@ -358,6 +360,7 @@ func (s *AgentService) persistChatAfterReply( ConversationID: chatID, Role: "user", Message: userMessage, + TokensConsumed: userTokens, }); err != nil { pushErrNonBlocking(errChan, err) return @@ -374,6 +377,7 @@ func (s *AgentService) persistChatAfterReply( ConversationID: chatID, Role: "assistant", Message: assistantReply, + TokensConsumed: assistantTokens, }); err != nil { pushErrNonBlocking(errChan, err) } diff --git a/backend/service/agentsvc/token_meter.go b/backend/service/agentsvc/token_meter.go new file mode 100644 index 0000000..bb3add1 --- /dev/null +++ b/backend/service/agentsvc/token_meter.go @@ -0,0 +1,145 @@ +package agentsvc + +import ( + "context" + "sync" + + einoCallbacks "github.com/cloudwego/eino/callbacks" + einoModel "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" + templatecb "github.com/cloudwego/eino/utils/callbacks" +) + +type requestTokenMeterCtxKey struct{} + +// RequestTokenMeter 是“单次请求级”的 token 统计容器。 +// +// 设计目标: +// 1. 聚合本次请求内所有模型调用 token(路由/图节点/流式主对话); +// 2. 线程安全,允许在同一请求内被多个链路节点并发累加; +// 3. 最终由服务层一次性读取快照并写入持久化。 +type RequestTokenMeter struct { + mu sync.Mutex + + promptTokens int + completionTokens int + totalTokens int +} + +// RequestTokenMeterSnapshot 是 RequestTokenMeter 的只读快照。 +type RequestTokenMeterSnapshot struct { + PromptTokens int + CompletionTokens int + TotalTokens int +} + +var registerTokenMeterCallbackOnce sync.Once + +// ensureTokenMeterCallbackRegistered 注册一次全局 ChatModel callback。 +// +// 说明: +// 1. callback 只负责“采集并累加 token”,不做业务决策; +// 2. 仅当 ctx 里存在 RequestTokenMeter 时才会生效; +// 3. 采用 once,避免在测试/多次构造服务时重复注册。 +func ensureTokenMeterCallbackRegistered() { + registerTokenMeterCallbackOnce.Do(func() { + handler := templatecb.NewHandlerHelper(). + ChatModel(&templatecb.ModelCallbackHandler{ + OnEnd: func(ctx context.Context, _ *einoCallbacks.RunInfo, output *einoModel.CallbackOutput) context.Context { + if output == nil || output.TokenUsage == nil { + return ctx + } + addModelUsageIntoRequest(ctx, output.TokenUsage) + return ctx + }, + }). + Handler() + einoCallbacks.AppendGlobalHandlers(handler) + }) +} + +// withRequestTokenMeter 创建并挂载“请求级 token 统计器”。 +func withRequestTokenMeter(ctx context.Context) (context.Context, *RequestTokenMeter) { + meter := &RequestTokenMeter{} + return context.WithValue(ctx, requestTokenMeterCtxKey{}, meter), meter +} + +// getRequestTokenMeter 读取请求上下文中的 token 统计器。 +func getRequestTokenMeter(ctx context.Context) *RequestTokenMeter { + if ctx == nil { + return nil + } + meter, _ := ctx.Value(requestTokenMeterCtxKey{}).(*RequestTokenMeter) + return meter +} + +// addSchemaUsageIntoRequest 把 schema usage 累加到请求级统计器。 +func addSchemaUsageIntoRequest(ctx context.Context, usage *schema.TokenUsage) { + if usage == nil { + return + } + addTokenUsageValues(ctx, usage.PromptTokens, usage.CompletionTokens, normalizeUsageTotal(usage.TotalTokens, usage.PromptTokens, usage.CompletionTokens)) +} + +// addModelUsageIntoRequest 把 Eino model callback usage 累加到请求级统计器。 +func addModelUsageIntoRequest(ctx context.Context, usage *einoModel.TokenUsage) { + if usage == nil { + return + } + addTokenUsageValues(ctx, usage.PromptTokens, usage.CompletionTokens, normalizeUsageTotal(usage.TotalTokens, usage.PromptTokens, usage.CompletionTokens)) +} + +// addTokenUsageValues 统一累加 token 数值。 +func addTokenUsageValues(ctx context.Context, promptTokens, completionTokens, totalTokens int) { + meter := getRequestTokenMeter(ctx) + if meter == nil { + return + } + + if promptTokens < 0 { + promptTokens = 0 + } + if completionTokens < 0 { + completionTokens = 0 + } + if totalTokens < 0 { + totalTokens = 0 + } + + meter.mu.Lock() + defer meter.mu.Unlock() + meter.promptTokens += promptTokens + meter.completionTokens += completionTokens + meter.totalTokens += totalTokens +} + +// snapshotRequestTokenMeter 获取请求级 token 统计快照。 +func snapshotRequestTokenMeter(ctx context.Context) RequestTokenMeterSnapshot { + meter := getRequestTokenMeter(ctx) + if meter == nil { + return RequestTokenMeterSnapshot{} + } + meter.mu.Lock() + defer meter.mu.Unlock() + return RequestTokenMeterSnapshot{ + PromptTokens: meter.promptTokens, + CompletionTokens: meter.completionTokens, + TotalTokens: meter.totalTokens, + } +} + +// normalizeUsageTotal 统一 total token 口径。 +// +// 规则: +// 1. 模型返回 total>0 时优先使用 total; +// 2. total 缺失时使用 prompt+completion 回退。 +func normalizeUsageTotal(totalTokens, promptTokens, completionTokens int) int { + if totalTokens > 0 { + return totalTokens + } + sum := promptTokens + completionTokens + if sum < 0 { + return 0 + } + return sum +} diff --git a/backend/service/events/chat_history_persist.go b/backend/service/events/chat_history_persist.go index 60a0ac2..2964923 100644 --- a/backend/service/events/chat_history_persist.go +++ b/backend/service/events/chat_history_persist.go @@ -68,6 +68,7 @@ func RegisterChatHistoryPersistHandler( payload.ConversationID, payload.Role, payload.Message, + payload.TokensConsumed, ) }) } diff --git a/backend/service/events/chat_token_usage_adjust.go b/backend/service/events/chat_token_usage_adjust.go new file mode 100644 index 0000000..673ec17 --- /dev/null +++ b/backend/service/events/chat_token_usage_adjust.go @@ -0,0 +1,101 @@ +package events + +import ( + "context" + "encoding/json" + "errors" + "strconv" + "time" + + "github.com/LoveLosita/smartflow/backend/dao" + kafkabus "github.com/LoveLosita/smartflow/backend/infra/kafka" + outboxinfra "github.com/LoveLosita/smartflow/backend/infra/outbox" + "github.com/LoveLosita/smartflow/backend/model" + "gorm.io/gorm" +) + +const ( + // EventTypeChatTokenUsageAdjustRequested 是“会话 token 账本增量调整”事件类型。 + // + // 命名约束: + // 1. 仅表达业务语义,不泄露 outbox/kafka 实现细节; + // 2. 作为稳定路由键长期保留,后续演进优先通过 event_version。 + EventTypeChatTokenUsageAdjustRequested = "chat.token.usage.adjust.requested" +) + +// RegisterChatTokenUsageAdjustHandler 注册“会话 token 账本增量调整”消费者。 +// +// 职责边界: +// 1. 只处理 token 调整事件,不处理聊天正文落库; +// 2. 通过 outbox 统一消费事务入口,保证“业务成功 + consumed 推进”原子一致; +// 3. 非法载荷直接标记 dead,避免无意义重试。 +func RegisterChatTokenUsageAdjustHandler( + bus *outboxinfra.EventBus, + outboxRepo *outboxinfra.Repository, + repoManager *dao.RepoManager, +) error { + if bus == nil { + return errors.New("event bus is nil") + } + if outboxRepo == nil { + return errors.New("outbox repository is nil") + } + if repoManager == nil { + return errors.New("repo manager is nil") + } + + handler := func(ctx context.Context, envelope kafkabus.Envelope) error { + var payload model.ChatTokenUsageAdjustPayload + if unmarshalErr := json.Unmarshal(envelope.Payload, &payload); unmarshalErr != nil { + _ = outboxRepo.MarkDead(ctx, envelope.OutboxID, "解析会话 token 调整载荷失败: "+unmarshalErr.Error()) + return nil + } + + if payload.UserID <= 0 || payload.TokensDelta <= 0 || payload.ConversationID == "" { + _ = outboxRepo.MarkDead(ctx, envelope.OutboxID, "会话 token 调整载荷无效: user_id/conversation_id/tokens_delta 非法") + return nil + } + + return outboxRepo.ConsumeAndMarkConsumed(ctx, envelope.OutboxID, func(tx *gorm.DB) error { + txM := repoManager.WithTx(tx) + return txM.Agent.AdjustTokenUsageInTx(ctx, payload.UserID, payload.ConversationID, payload.TokensDelta) + }) + } + + return bus.RegisterEventHandler(EventTypeChatTokenUsageAdjustRequested, handler) +} + +// PublishChatTokenUsageAdjustRequested 发布“会话 token 账本增量调整”事件。 +// +// 说明: +// 1. 只保证“写入 outbox 成功”,不等待消费完成; +// 2. 业务层只传 DTO,不关心 outbox/kafka 协议细节。 +func PublishChatTokenUsageAdjustRequested( + ctx context.Context, + publisher outboxinfra.EventPublisher, + payload model.ChatTokenUsageAdjustPayload, +) error { + if publisher == nil { + return errors.New("event publisher is nil") + } + if payload.UserID <= 0 { + return errors.New("invalid user_id") + } + if payload.TokensDelta <= 0 { + return errors.New("invalid tokens_delta") + } + if payload.ConversationID == "" { + return errors.New("invalid conversation_id") + } + if payload.TriggeredAt.IsZero() { + payload.TriggeredAt = time.Now() + } + + return publisher.Publish(ctx, outboxinfra.PublishRequest{ + EventType: EventTypeChatTokenUsageAdjustRequested, + EventVersion: outboxinfra.DefaultEventVersion, + MessageKey: payload.ConversationID, + AggregateID: strconv.Itoa(payload.UserID) + ":" + payload.ConversationID, + Payload: payload, + }) +} diff --git a/backend/service/task.go b/backend/service/task.go index e50e18b..5986e9b 100644 --- a/backend/service/task.go +++ b/backend/service/task.go @@ -14,6 +14,7 @@ import ( "github.com/LoveLosita/smartflow/backend/respond" eventsvc "github.com/LoveLosita/smartflow/backend/service/events" "github.com/go-redis/redis/v8" + "gorm.io/gorm" ) const ( @@ -72,6 +73,46 @@ func (ts *TaskService) AddTask(ctx context.Context, req *model.UserAddTaskReques return response, nil } +// CompleteTask 将用户指定任务标记为“已完成”。 +// +// 职责边界: +// 1. 负责入参校验与业务错误映射; +// 2. 负责调用 DAO 执行状态更新; +// 3. 不负责幂等键校验(幂等由中间件处理); +// 4. 不负责缓存删除细节(缓存删除由 GORM cache_deleter 回调触发)。 +func (ts *TaskService) CompleteTask(ctx context.Context, req *model.UserCompleteTaskRequest, userID int) (*model.UserCompleteTaskResponse, error) { + // 1. 参数兜底:请求体为空、非法 user 或非法 task_id 直接返回业务错误。 + if req == nil || userID <= 0 || req.TaskID <= 0 { + return nil, respond.WrongTaskID + } + + // 2. 调用 DAO 执行“查询 + 必要时更新”。 + updatedTask, alreadyCompleted, err := ts.dao.CompleteTaskByID(ctx, userID, req.TaskID) + if err != nil { + // 2.1 任务不存在或不属于当前用户时,统一映射为 WrongTaskID。 + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, respond.WrongTaskID + } + // 2.2 其余数据库异常向上透传,交由统一错误处理器返回 500。 + return nil, err + } + if updatedTask == nil { + // 3. 极端防御:DAO 不应返回 nil,若发生则视为内部异常。 + return nil, errors.New("complete task succeeded but task is nil") + } + + // 4. 构造响应: + // 4.1 already_completed=true 表示本次命中幂等,不影响最终成功状态; + // 4.2 is_completed 始终为 true,便于前端直接刷新状态。 + resp := &model.UserCompleteTaskResponse{ + TaskID: updatedTask.ID, + IsCompleted: true, + AlreadyCompleted: alreadyCompleted, + Status: "completed", + } + return resp, nil +} + // GetUserTasks 获取用户任务列表(含“读时紧急性派生”与“异步平移触发”)。 // // 核心流程(步骤化):