package dao import ( "context" "encoding/json" "fmt" "strconv" "time" "github.com/cloudwego/eino/schema" "github.com/go-redis/redis/v8" ) type AgentCache struct { client *redis.Client // 默认窗口大小(会被会话级动态窗口覆盖) windowSize int // 缓存过期时间 expiration time.Duration } const ( minHistoryWindowSize = 16 maxHistoryWindowSize = 4096 ) func NewAgentCache(client *redis.Client) *AgentCache { return &AgentCache{ client: client, windowSize: 128, expiration: 1 * time.Hour, } } func (m *AgentCache) historyKey(sessionID string) string { return fmt.Sprintf("smartflow:history:%s", sessionID) } func (m *AgentCache) historyWindowKey(sessionID string) string { return fmt.Sprintf("smartflow:history_window:%s", sessionID) } func (m *AgentCache) normalizeWindowSize(size int) int { if size < minHistoryWindowSize { return minHistoryWindowSize } if size > maxHistoryWindowSize { return maxHistoryWindowSize } return size } func (m *AgentCache) getSessionWindowSize(ctx context.Context, sessionID string) (int, error) { windowKey := m.historyWindowKey(sessionID) val, err := m.client.Get(ctx, windowKey).Result() if err == redis.Nil { return m.windowSize, nil } if err != nil { return 0, err } size, convErr := strconv.Atoi(val) if convErr != nil { return m.windowSize, nil } return m.normalizeWindowSize(size), nil } // SetSessionWindowSize 设置会话级窗口上限。 func (m *AgentCache) SetSessionWindowSize(ctx context.Context, sessionID string, size int) error { normalized := m.normalizeWindowSize(size) windowKey := m.historyWindowKey(sessionID) return m.client.Set(ctx, windowKey, normalized, m.expiration).Err() } // EnforceHistoryWindow 按当前会话窗口强制修剪历史队列。 func (m *AgentCache) EnforceHistoryWindow(ctx context.Context, sessionID string) error { size, err := m.getSessionWindowSize(ctx, sessionID) if err != nil { return err } key := m.historyKey(sessionID) pipe := m.client.Pipeline() pipe.LTrim(ctx, key, 0, int64(size-1)) pipe.Expire(ctx, key, m.expiration) _, err = pipe.Exec(ctx) return err } func (m *AgentCache) PushMessage(ctx context.Context, sessionID string, msg *schema.Message) error { key := m.historyKey(sessionID) size, err := m.getSessionWindowSize(ctx, sessionID) if err != nil { return err } // 1. 序列化 Eino 消息。 data, err := json.Marshal(msg) if err != nil { return fmt.Errorf("marshal message failed: %w", err) } // 2. 使用 Pipeline 保证“写入+裁剪+续期”原子执行。 pipe := m.client.Pipeline() pipe.LPush(ctx, key, data) pipe.LTrim(ctx, key, 0, int64(size-1)) pipe.Expire(ctx, key, m.expiration) _, err = pipe.Exec(ctx) return err } func (m *AgentCache) GetHistory(ctx context.Context, sessionID string) ([]*schema.Message, error) { key := m.historyKey(sessionID) vals, err := m.client.LRange(ctx, key, 0, -1).Result() if err != nil { return nil, err } if len(vals) == 0 { return nil, nil } messages := make([]*schema.Message, len(vals)) for i, val := range vals { var msg schema.Message if err := json.Unmarshal([]byte(val), &msg); err != nil { return nil, err } // LRANGE 返回 [最新...最旧],这里反转成 [最旧...最新] messages[len(vals)-1-i] = &msg } return messages, nil } // BackfillHistory 在缓存失效时,把历史消息一次性回填到 Redis。 func (m *AgentCache) BackfillHistory(ctx context.Context, sessionID string, messages []*schema.Message) error { key := m.historyKey(sessionID) size, err := m.getSessionWindowSize(ctx, sessionID) if err != nil { return err } if len(messages) == 0 { return m.client.Del(ctx, key).Err() } values := make([]interface{}, len(messages)) for i, msg := range messages { data, err := json.Marshal(msg) if err != nil { return fmt.Errorf("marshal failed at index %d: %w", i, err) } values[i] = data } pipe := m.client.Pipeline() pipe.Del(ctx, key) pipe.LPush(ctx, key, values...) pipe.LTrim(ctx, key, 0, int64(size-1)) pipe.Expire(ctx, key, m.expiration) _, err = pipe.Exec(ctx) return err } func (m *AgentCache) ClearHistory(ctx context.Context, sessionID string) error { historyKey := m.historyKey(sessionID) windowKey := m.historyWindowKey(sessionID) return m.client.Del(ctx, historyKey, windowKey).Err() } func (m *AgentCache) GetConversationStatus(ctx context.Context, sessionID string) (bool, error) { key := fmt.Sprintf("smartflow:conversation_status:%s", sessionID) n, err := m.client.Exists(ctx, key).Result() if err != nil { return false, err } return n == 1, nil } func (m *AgentCache) SetConversationStatus(ctx context.Context, sessionID string) error { key := fmt.Sprintf("smartflow:conversation_status:%s", sessionID) // 仅用于“存在性”标记:只有不存在时才写入,避免重复写。 return m.client.SetNX(ctx, key, 1, m.expiration).Err() } func (m *AgentCache) DeleteConversationStatus(ctx context.Context, sessionID string) error { key := fmt.Sprintf("smartflow:conversation_status:%s", sessionID) return m.client.Del(ctx, key).Err() } // ---- Compaction 缓存 ---- func (m *AgentCache) compactionKey(chatID string) string { return fmt.Sprintf("smartflow:compaction:%s", chatID) } // SaveCompactionCache 将压缩摘要缓存到 Redis。 func (m *AgentCache) SaveCompactionCache(ctx context.Context, chatID string, summary string, watermark int) error { key := m.compactionKey(chatID) data, _ := json.Marshal(map[string]any{ "summary": summary, "watermark": watermark, }) return m.client.Set(ctx, key, data, m.expiration).Err() } // LoadCompactionCache 从 Redis 读取压缩摘要缓存。 func (m *AgentCache) LoadCompactionCache(ctx context.Context, chatID string) (summary string, watermark int, ok bool, err error) { key := m.compactionKey(chatID) val, err := m.client.Get(ctx, key).Result() if err != nil { if err == redis.Nil { return "", 0, false, nil } return "", 0, false, err } var data struct { Summary string `json:"summary"` Watermark int `json:"watermark"` } if jsonErr := json.Unmarshal([]byte(val), &data); jsonErr != nil { return "", 0, false, nil } return data.Summary, data.Watermark, true, nil }