diff --git a/backend/cmd/start.go b/backend/cmd/start.go index 64e4c57..8830479 100644 --- a/backend/cmd/start.go +++ b/backend/cmd/start.go @@ -44,7 +44,7 @@ func Start() { rdb := inits.InitRedis() //工具包 limiter := pkg.NewRateLimiter(rdb) - //初始化agent + //初始化eino aiHub, err := inits.InitEino() if err != nil { log.Fatalf("Failed to initialize Eino: %v", err) @@ -53,6 +53,7 @@ func Start() { //dao 层 cacheRepo := dao.NewCacheDAO(rdb) + agentCacheRepo := dao.NewAgentCache(rdb) _ = db.Use(middleware.NewGormCachePlugin(cacheRepo)) // 注册 GORM 插件 userRepo := dao.NewUserDAO(db) taskRepo := dao.NewTaskDAO(db) @@ -67,7 +68,7 @@ func Start() { courseService := service.NewCourseService(courseRepo, scheduleRepo) taskClassService := service.NewTaskClassService(taskClassRepo, cacheRepo, scheduleRepo, manager) scheduleService := service.NewScheduleService(scheduleRepo, userRepo, taskClassRepo, manager, cacheRepo) - agentService := service.NewAgentService(aiHub, agentRepo) + agentService := service.NewAgentService(aiHub, agentRepo, agentCacheRepo) //api 层 userApi := api.NewUserHandler(userService) taskApi := api.NewTaskHandler(taskSv) diff --git a/backend/dao/agent-cache.go b/backend/dao/agent-cache.go new file mode 100644 index 0000000..7c4a630 --- /dev/null +++ b/backend/dao/agent-cache.go @@ -0,0 +1,146 @@ +package dao + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/cloudwego/eino/schema" + "github.com/go-redis/redis/v8" +) + +type AgentCache struct { + client *redis.Client + // 默认滑动窗口大小,比如 20 条消息 + windowSize int + // 缓存过期时间 + expiration time.Duration +} + +func NewAgentCache(client *redis.Client) *AgentCache { + return &AgentCache{ + client: client, + windowSize: 20, // 后续更新:根据 Token 消耗灵活调整 + expiration: 1 * time.Hour, // 保持一小时的热记忆 + } +} + +func (m *AgentCache) PushMessage(ctx context.Context, sessionID string, msg *schema.Message) error { + key := fmt.Sprintf("smartflow:history:%s", sessionID) + + // 1. 序列化 Eino 消息 + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("marshal message failed: %w", err) + } + + // 2. 利用 Pipeline 保证原子操作 + pipe := m.client.Pipeline() + + // 往左侧推入最新消息 (LIFO 逻辑) + pipe.LPush(ctx, key, data) + + // 核心:强制修剪,只保留最新的 windowSize 条 + // 0 是最新的一条,windowSize-1 是最后一条 + pipe.LTrim(ctx, key, 0, int64(m.windowSize-1)) + + // 刷新过期时间 + pipe.Expire(ctx, key, m.expiration) + + _, err = pipe.Exec(ctx) + return err +} + +func (m *AgentCache) GetHistory(ctx context.Context, sessionID string) ([]*schema.Message, error) { + key := fmt.Sprintf("smartflow:history:%s", sessionID) + + // 获取所有缓存的消息 + vals, err := m.client.LRange(ctx, key, 0, -1).Result() + if err != nil { + return nil, err + } + + // 如果 Redis 为空,这里返回 nil 触发后续的 MySQL 捞取逻辑 + if len(vals) == 0 { + return nil, nil + } + + messages := make([]*schema.Message, len(vals)) + for i, val := range vals { + var msg schema.Message + if err := json.Unmarshal([]byte(val), &msg); err != nil { + return nil, err + } + + // 关键逻辑:反转顺序 + // LRANGE 返回顺序:[MsgN, MsgN-1, ... Msg1] + // 我们需要的顺序:[Msg1, ... MsgN-1, MsgN] + messages[len(vals)-1-i] = &msg + } + + return messages, nil +} + +// BackfillHistory 用于缓存失效时,从数据库加载完数据后一次性回填 Redis +func (m *AgentCache) BackfillHistory(ctx context.Context, sessionID string, messages []*schema.Message) error { + if len(messages) == 0 { + return nil + } + + key := fmt.Sprintf("smartflow:history:%s", sessionID) + + // 1. 将所有 Eino 消息序列化为 []interface{} 供 redis 批量写入 + values := make([]interface{}, len(messages)) + for i, msg := range messages { + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("marshal failed at index %d: %w", err) + } + values[i] = data + } + + // 2. 执行原子回填 + pipe := m.client.Pipeline() + + // 先清理旧 Key(防止数据重复或残留) + pipe.Del(ctx, key) + + // 批量写入:按照 [最旧 -> 最新] 的顺序 LPUSH + // 结果在 Redis 中:[最新, ..., 最旧] (符合我们 GetHistory 的反转逻辑) + pipe.LPush(ctx, key, values...) + + // 依然要进行修剪,确保不超过窗口大小 + pipe.LTrim(ctx, key, 0, int64(m.windowSize-1)) + + // 设置过期时间 + pipe.Expire(ctx, key, m.expiration) + + _, err := pipe.Exec(ctx) + return err +} + +func (m *AgentCache) ClearHistory(ctx context.Context, sessionID string) error { + key := fmt.Sprintf("smartflow:history:%s", sessionID) + return m.client.Del(ctx, key).Err() +} + +func (m *AgentCache) GetConversationStatus(ctx context.Context, sessionID string) (bool, error) { + key := fmt.Sprintf("smartflow:conversation_status:%s", sessionID) + n, err := m.client.Exists(ctx, key).Result() + if err != nil { + return false, err + } + return n == 1, nil +} + +func (m *AgentCache) SetConversationStatus(ctx context.Context, sessionID string) error { + key := fmt.Sprintf("smartflow:conversation_status:%s", sessionID) + // 仅用于“存在性”标记:只有不存在时才写入,避免重复写 + return m.client.SetNX(ctx, key, 1, m.expiration).Err() +} + +func (m *AgentCache) DeleteConversationStatus(ctx context.Context, sessionID string) error { + key := fmt.Sprintf("smartflow:conversation_status:%s", sessionID) + return m.client.Del(ctx, key).Err() +} diff --git a/backend/service/agent.go b/backend/service/agent.go index 2b23d53..b54d527 100644 --- a/backend/service/agent.go +++ b/backend/service/agent.go @@ -2,6 +2,7 @@ package service import ( "context" + "log" "github.com/LoveLosita/smartflow/backend/agent" "github.com/LoveLosita/smartflow/backend/conv" @@ -11,14 +12,16 @@ import ( ) type AgentService struct { - AIHub *inits.AIHub - repo *dao.AgentDAO + AIHub *inits.AIHub + repo *dao.AgentDAO + agentCache *dao.AgentCache } -func NewAgentService(aiHub *inits.AIHub, repo *dao.AgentDAO) *AgentService { +func NewAgentService(aiHub *inits.AIHub, repo *dao.AgentDAO, agentRedis *dao.AgentCache) *AgentService { return &AgentService{ - AIHub: aiHub, - repo: repo, + AIHub: aiHub, + repo: repo, + agentCache: agentRedis, } } @@ -27,16 +30,47 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin outChan := make(chan string, 5) errChan := make(chan error, 1) //2. 先确保这个会话存在(如果不存在就创建一个新的) - result, err := s.repo.IfChatExists(ctx, userID, chatID) + //先看看缓存里面有没有这个会话 + result, err := s.agentCache.GetConversationStatus(ctx, chatID) if err != nil { errChan <- err close(outChan) close(errChan) return outChan, errChan } + //如果缓存里面没有,就去查库 + if !result { + innerResult, err := s.repo.IfChatExists(ctx, userID, chatID) + if err != nil { + errChan <- err + close(outChan) + close(errChan) + return outChan, errChan + } + if !innerResult { + //如果会话不存在,先创建一个新的会话 + _, err := s.repo.CreateNewChat(userID, chatID) + if err != nil { + errChan <- err + close(outChan) + close(errChan) + return outChan, errChan + } + } + } + //能走到这里,要么缓存里有这个会话,要么数据库里有这个会话了 + //4. 提取出历史消息,构建上下文 + //先尝试从缓存里拿历史消息 var chatHistory []*schema.Message - if result { - //4. 提取出历史消息,构建上下文 + chatHistory, err = s.agentCache.GetHistory(ctx, chatID) + if err != nil { + errChan <- err + close(outChan) + close(errChan) + return outChan, errChan + } + //如果缓存里没有历史消息,就从数据库里拿 + if chatHistory == nil { //先从数据库拿到历史消息 histories, err := s.repo.GetUserChatHistories(ctx, userID, 20, chatID) if err != nil { @@ -47,9 +81,8 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin } //再转换成 Eino 的消息格式 chatHistory = conv.ToEinoMessages(histories) - } else { - //如果会话不存在,先创建一个新的会话 - _, err := s.repo.CreateNewChat(userID, chatID) + //把历史消息放到缓存里,方便下次直接拿 + err = s.agentCache.BackfillHistory(ctx, chatID, chatHistory) if err != nil { errChan <- err close(outChan) @@ -57,14 +90,16 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin return outChan, errChan } } - //3. 将用户消息落库 - err = s.repo.SaveChatHistory(ctx, userID, chatID, "user", userMessage) - if err != nil { - errChan <- err - close(outChan) - close(errChan) - return outChan, errChan - } + //3. 将用户消息异步落缓存和库 + go func() { + //这里先不管落库成功与否了,毕竟不想因为落库失败而影响用户的聊天体验 + _ = s.agentCache.PushMessage(ctx, chatID, &schema.Message{ + Role: "user", + Content: userMessage, + }) + _ = s.repo.SaveChatHistory(ctx, userID, chatID, "user", userMessage) + }() + //5. 启动一个 goroutine 来处理聊天逻辑 go func() { defer close(outChan) // 确保在函数结束时关闭通道 @@ -74,15 +109,18 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin errChan <- err return } - err = s.repo.SaveChatHistory(ctx, userID, chatID, "assistant", fullText) - if err != nil { - errChan <- err - return - } + //4. 将 AI 的回复异步落缓存和库 + go func() { + _ = s.agentCache.PushMessage(ctx, chatID, &schema.Message{ + Role: "assistant", + Content: fullText, + }) + err = s.repo.SaveChatHistory(context.Background(), userID, chatID, "assistant", fullText) + if err != nil { + log.Printf("Failed to save chat history to database: %v", err) + return + } + }() }() return outChan, errChan } - -func (s *AgentService) CreateNewChat(userID int, chatID string) (int64, error) { - return s.repo.CreateNewChat(userID, chatID) -}