Files
smartmate/backend/api/agent.go
Losita e06284d0b0 Version: 0.7.6.dev.260325
后端:
- ♻️ 将 `taskquery` 模块迁移至 `agent2`,并完成与 `agent2` 业务链路及整体结构的正式接入

前端:
- 🧱 已完成基础框架搭建,并完成了登录、注册、主页等页面并对接了对应接口;但整体功能实现仍在完善中
2026-03-25 00:49:16 +08:00

262 lines
8.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/respond"
"github.com/LoveLosita/smartflow/backend/service"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"gorm.io/gorm"
)
type AgentHandler struct {
svc *service.AgentService
}
// NewAgentHandler 组装 AgentHandler。
func NewAgentHandler(svc *service.AgentService) *AgentHandler {
return &AgentHandler{
svc: svc,
}
}
func writeSSEData(w io.Writer, payload string) error {
_, err := io.WriteString(w, "data: "+payload+"\n\n")
return err
}
func (api *AgentHandler) ChatAgent(c *gin.Context) {
// 1) 设置 SSE 响应头
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
// 2) 解析请求体
var req model.UserSendMessageRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, respond.WrongParamType)
return
}
// 3) 规范化会话 ID
conversationID := strings.TrimSpace(req.ConversationID)
if conversationID == "" {
conversationID = uuid.NewString()
}
c.Writer.Header().Set("X-Conversation-ID", conversationID)
userID := c.GetInt("user_id")
outChan, errChan := api.svc.AgentChat(c.Request.Context(), req.Message, req.Thinking, req.Model, userID, conversationID, req.Extra)
// 4) 转发 SSE 流
c.Stream(func(w io.Writer) bool {
select {
case err, ok := <-errChan:
if ok && err != nil {
// 4.1 统一 SSE 错误体:
// 4.1.1 默认按内部错误输出 message/type
// 4.1.2 若是 respond.Response含业务码额外透传 code便于前端识别 5xxxx 等自定义错误。
errorBody := map[string]any{
"message": err.Error(),
"type": "server_error",
}
var respErr respond.Response
if errors.As(err, &respErr) {
errorBody["code"] = respErr.Status
}
errPayload, _ := json.Marshal(map[string]any{
"error": errorBody,
})
_ = writeSSEData(w, string(errPayload))
_ = writeSSEData(w, "[DONE]")
}
return false
case msg, ok := <-outChan:
if !ok {
return false
}
if err := writeSSEData(w, msg); err != nil {
return false
}
return true
case <-c.Request.Context().Done():
return false
}
})
}
// GetConversationMeta 返回单个会话的元信息(标题、消息数、最近消息时间等)。
// 设计说明:
// 1) 该接口用于配合 SSE 聊天链路:标题异步生成后,前端可通过 conversation_id 拉取;
// 2) 不依赖 SSE header 动态更新避免“header 必须首包前写入”的协议限制;
// 3) 会话不存在时返回 400避免前端把无效会话当成系统错误。
func (api *AgentHandler) GetConversationMeta(c *gin.Context) {
// 1. 读取 query 参数并做基础校验。
conversationID := strings.TrimSpace(c.Query("conversation_id"))
if conversationID == "" {
c.JSON(http.StatusBadRequest, respond.MissingParam)
return
}
// 2. 统一透传 user_id避免越权读取他人会话。
userID := c.GetInt("user_id")
// 3. 设置短超时,避免该查询接口被慢查询长时间占用。
ctx, cancel := context.WithTimeout(c.Request.Context(), 1*time.Second)
defer cancel()
// 4. 调 service 查询会话元信息。
meta, err := api.svc.GetConversationMeta(ctx, userID, conversationID)
if err != nil {
// 会话不存在按参数错误处理,返回 400 给前端更直观。
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusBadRequest, respond.WrongParamType)
return
}
respond.DealWithError(c, err)
return
}
// 5. 返回统一响应结构。
c.JSON(http.StatusOK, respond.RespWithData(respond.Ok, meta))
}
// GetConversationList 返回当前登录用户的会话列表(分页)。
//
// 设计说明:
// 1) 接口只返回“列表元信息”,不返回消息正文,避免列表接口过重;
// 2) page/page_size 为可选参数,缺省值由 service 层统一兜底;
// 3) status 可选,支持 active/archived非法值直接返回 400。
func (api *AgentHandler) GetConversationList(c *gin.Context) {
// 1. 从 JWT 上下文读取 user_id保证只查“当前用户自己的会话”。
userID := c.GetInt("user_id")
// 2. 解析分页参数(可选):
// 2.1 参数不存在时保持 0让 service 使用默认值;
// 2.2 参数存在但格式非法时直接返回 400避免脏参数下沉。
page := 0
if rawPage := strings.TrimSpace(c.Query("page")); rawPage != "" {
parsedPage, err := strconv.Atoi(rawPage)
if err != nil {
c.JSON(http.StatusBadRequest, respond.WrongParamType)
return
}
page = parsedPage
}
pageSize := 0
if rawPageSize := strings.TrimSpace(c.Query("page_size")); rawPageSize != "" {
parsedPageSize, err := strconv.Atoi(rawPageSize)
if err != nil {
c.JSON(http.StatusBadRequest, respond.WrongParamType)
return
}
pageSize = parsedPageSize
}
// 2.3 limit 是 page_size 的懒加载别名:
// 2.3.1 前端若显式传 limit则以 limit 为准,避免前端再做字段转换;
// 2.3.2 若 limit 非法同样直接返回 400避免把脏参数下沉到 service
// 2.3.3 若未传 limit则继续沿用历史 page_size 行为,保持老前端兼容。
if rawLimit := strings.TrimSpace(c.Query("limit")); rawLimit != "" {
parsedLimit, err := strconv.Atoi(rawLimit)
if err != nil {
c.JSON(http.StatusBadRequest, respond.WrongParamType)
return
}
pageSize = parsedLimit
}
// 3. status 过滤器可选,最终合法性由 service 层统一校验。
status := strings.TrimSpace(c.Query("status"))
// 4. 读接口设置短超时,避免慢查询占用连接。
ctx, cancel := context.WithTimeout(c.Request.Context(), 1*time.Second)
defer cancel()
// 5. 调 service 查询并返回统一响应结构。
resp, err := api.svc.GetConversationList(ctx, userID, page, pageSize, status)
if err != nil {
respond.DealWithError(c, err)
return
}
c.JSON(http.StatusOK, respond.RespWithData(respond.Ok, resp))
}
// GetConversationHistory 返回指定会话的聊天历史记录。
//
// 设计说明:
// 1) 该接口只读历史,不负责改写 Redis/DB 中的会话状态;
// 2) 读取顺序复用现有服务层能力:先校验归属,再查 Redis未命中再回源 DB
// 3) 会话不存在时统一返回 400避免前端把无效会话误判成系统故障。
func (api *AgentHandler) GetConversationHistory(c *gin.Context) {
// 1. 参数校验conversation_id 必填。
conversationID := strings.TrimSpace(c.Query("conversation_id"))
if conversationID == "" {
c.JSON(http.StatusBadRequest, respond.MissingParam)
return
}
// 2. 从鉴权上下文取当前用户 ID确保查询范围只落在“本人会话”内。
userID := c.GetInt("user_id")
// 3. 设置短超时,避免缓存抖动或慢查询长期占用连接。
ctx, cancel := context.WithTimeout(c.Request.Context(), 2*time.Second)
defer cancel()
// 4. 调 service 查询聊天历史。
history, err := api.svc.GetConversationHistory(ctx, userID, conversationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusBadRequest, respond.WrongParamType)
return
}
respond.DealWithError(c, err)
return
}
// 5. 返回统一响应结构。
c.JSON(http.StatusOK, respond.RespWithData(respond.Ok, history))
}
// GetSchedulePlanPreview 返回“指定会话”的排程结构化预览。
//
// 设计说明:
// 1) 该接口只读 Redis 预览快照,不修改聊天主链路协议;
// 2) 按 conversation_id + user_id 读取,避免跨用户越权访问;
// 3) 预览受 TTL 影响,若不存在会返回业务错误码。
func (api *AgentHandler) GetSchedulePlanPreview(c *gin.Context) {
// 1. 参数校验conversation_id 必填。
conversationID := strings.TrimSpace(c.Query("conversation_id"))
if conversationID == "" {
c.JSON(http.StatusBadRequest, respond.MissingParam)
return
}
// 2. 从鉴权上下文取当前用户 ID保证查询范围只在“本人会话”内。
userID := c.GetInt("user_id")
// 3. 设置短超时,防止缓存抖动时占用连接过久。
ctx, cancel := context.WithTimeout(c.Request.Context(), 1*time.Second)
defer cancel()
// 4. 调 service 查询并返回统一响应结构。
preview, err := api.svc.GetSchedulePlanPreview(ctx, userID, conversationID)
if err != nil {
respond.DealWithError(c, err)
return
}
c.JSON(http.StatusOK, respond.RespWithData(respond.Ok, preview))
}