package api import ( "context" "encoding/json" "errors" "io" "net/http" "strconv" "strings" "time" "github.com/LoveLosita/smartflow/backend/model" "github.com/LoveLosita/smartflow/backend/respond" "github.com/LoveLosita/smartflow/backend/service" "github.com/gin-gonic/gin" "github.com/google/uuid" "gorm.io/gorm" ) type AgentHandler struct { svc *service.AgentService } // NewAgentHandler 组装 AgentHandler。 func NewAgentHandler(svc *service.AgentService) *AgentHandler { return &AgentHandler{ svc: svc, } } func writeSSEData(w io.Writer, payload string) error { _, err := io.WriteString(w, "data: "+payload+"\n\n") return err } func (api *AgentHandler) ChatAgent(c *gin.Context) { // 1) 设置 SSE 响应头 c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") c.Writer.Header().Set("Transfer-Encoding", "chunked") c.Writer.Header().Set("X-Accel-Buffering", "no") // 2) 解析请求体 var req model.UserSendMessageRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, respond.WrongParamType) return } // 3) 规范化会话 ID conversationID := strings.TrimSpace(req.ConversationID) if conversationID == "" { conversationID = uuid.NewString() } c.Writer.Header().Set("X-Conversation-ID", conversationID) userID := c.GetInt("user_id") outChan, errChan := api.svc.AgentChat(c.Request.Context(), req.Message, req.Thinking, req.Model, userID, conversationID, req.Extra) // 4) 转发 SSE 流 c.Stream(func(w io.Writer) bool { select { case err, ok := <-errChan: if ok && err != nil { errPayload, _ := json.Marshal(map[string]any{ "error": map[string]any{ "message": err.Error(), "type": "server_error", }, }) _ = writeSSEData(w, string(errPayload)) _ = writeSSEData(w, "[DONE]") } return false case msg, ok := <-outChan: if !ok { return false } if err := writeSSEData(w, msg); err != nil { return false } return true case <-c.Request.Context().Done(): return false } }) } // GetConversationMeta 返回单个会话的元信息(标题、消息数、最近消息时间等)。 // 设计说明: // 1) 该接口用于配合 SSE 聊天链路:标题异步生成后,前端可通过 conversation_id 拉取; // 2) 不依赖 SSE header 动态更新,避免“header 必须首包前写入”的协议限制; // 3) 会话不存在时返回 400,避免前端把无效会话当成系统错误。 func (api *AgentHandler) GetConversationMeta(c *gin.Context) { // 1. 读取 query 参数并做基础校验。 conversationID := strings.TrimSpace(c.Query("conversation_id")) if conversationID == "" { c.JSON(http.StatusBadRequest, respond.MissingParam) return } // 2. 统一透传 user_id,避免越权读取他人会话。 userID := c.GetInt("user_id") // 3. 设置短超时,避免该查询接口被慢查询长时间占用。 ctx, cancel := context.WithTimeout(c.Request.Context(), 1*time.Second) defer cancel() // 4. 调 service 查询会话元信息。 meta, err := api.svc.GetConversationMeta(ctx, userID, conversationID) if err != nil { // 会话不存在按参数错误处理,返回 400 给前端更直观。 if errors.Is(err, gorm.ErrRecordNotFound) { c.JSON(http.StatusBadRequest, respond.WrongParamType) return } respond.DealWithError(c, err) return } // 5. 返回统一响应结构。 c.JSON(http.StatusOK, respond.RespWithData(respond.Ok, meta)) } // GetConversationList 返回当前登录用户的会话列表(分页)。 // // 设计说明: // 1) 接口只返回“列表元信息”,不返回消息正文,避免列表接口过重; // 2) page/page_size 为可选参数,缺省值由 service 层统一兜底; // 3) status 可选,支持 active/archived,非法值直接返回 400。 func (api *AgentHandler) GetConversationList(c *gin.Context) { // 1. 从 JWT 上下文读取 user_id,保证只查“当前用户自己的会话”。 userID := c.GetInt("user_id") // 2. 解析分页参数(可选): // 2.1 参数不存在时保持 0,让 service 使用默认值; // 2.2 参数存在但格式非法时直接返回 400,避免脏参数下沉。 page := 0 if rawPage := strings.TrimSpace(c.Query("page")); rawPage != "" { parsedPage, err := strconv.Atoi(rawPage) if err != nil { c.JSON(http.StatusBadRequest, respond.WrongParamType) return } page = parsedPage } pageSize := 0 if rawPageSize := strings.TrimSpace(c.Query("page_size")); rawPageSize != "" { parsedPageSize, err := strconv.Atoi(rawPageSize) if err != nil { c.JSON(http.StatusBadRequest, respond.WrongParamType) return } pageSize = parsedPageSize } // 3. status 过滤器可选,最终合法性由 service 层统一校验。 status := strings.TrimSpace(c.Query("status")) // 4. 读接口设置短超时,避免慢查询占用连接。 ctx, cancel := context.WithTimeout(c.Request.Context(), 1*time.Second) defer cancel() // 5. 调 service 查询并返回统一响应结构。 resp, err := api.svc.GetConversationList(ctx, userID, page, pageSize, status) if err != nil { respond.DealWithError(c, err) return } c.JSON(http.StatusOK, respond.RespWithData(respond.Ok, resp)) }