diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..c1c1abf --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,5 @@ +# AGENTS.md + +## 协作偏好(逐条追加) + +1. 默认语言规则:所有注释、接口文案、说明、评审反馈均使用中文。 diff --git a/backend/agent/graph.go b/backend/agent/graph.go index 4181e7b..62ec996 100644 --- a/backend/agent/graph.go +++ b/backend/agent/graph.go @@ -13,7 +13,7 @@ import ( arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" ) -// StreamResponse 为 OpenAI/DeepSeek 兼容的流式 chunk 结构 +// StreamResponse 为 OpenAI/DeepSeek 兼容的流式 chunk 结构。 type StreamResponse struct { ID string `json:"id"` Object string `json:"object"` @@ -34,7 +34,7 @@ type StreamDelta struct { ReasoningContent string `json:"reasoning_content,omitempty"` } -// ToOpenAIStream 将单个 Eino chunk 转为 OpenAI 兼容 JSON +// ToOpenAIStream 将单个 Eino chunk 转为 OpenAI 兼容 JSON。 func ToOpenAIStream(chunk *schema.Message, requestID, modelName string, created int64, includeRole bool) (string, error) { delta := StreamDelta{} if includeRole { @@ -67,7 +67,7 @@ func ToOpenAIStream(chunk *schema.Message, requestID, modelName string, created return string(jsonBytes), nil } -// ToOpenAIFinishStream 生成结束 chunk(finish_reason=stop) +// ToOpenAIFinishStream 生成结束 chunk(finish_reason=stop)。 func ToOpenAIFinishStream(requestID, modelName string, created int64) (string, error) { stop := "stop" dto := StreamResponse{ diff --git a/backend/api/agent.go b/backend/api/agent.go index a6e8dab..917d404 100644 --- a/backend/api/agent.go +++ b/backend/api/agent.go @@ -17,7 +17,7 @@ type AgentHandler struct { svc *service.AgentService } -// NewAgentHandler 组装 AgentHandler +// NewAgentHandler 组装 AgentHandler。 func NewAgentHandler(svc *service.AgentService) *AgentHandler { return &AgentHandler{ svc: svc, diff --git a/backend/dao/agent-cache.go b/backend/dao/agent-cache.go index c30e723..f37f6f9 100644 --- a/backend/dao/agent-cache.go +++ b/backend/dao/agent-cache.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strconv" "time" "github.com/cloudwego/eino/schema" @@ -12,22 +13,86 @@ import ( type AgentCache struct { client *redis.Client - // 默认滑动窗口大小,比如 20 条消息 + // 默认窗口大小(会被会话级动态窗口覆盖) windowSize int // 缓存过期时间 expiration time.Duration } +const ( + minHistoryWindowSize = 16 + maxHistoryWindowSize = 4096 +) + func NewAgentCache(client *redis.Client) *AgentCache { return &AgentCache{ client: client, - windowSize: 20, // 后续更新:根据 Token 消耗灵活调整 - expiration: 1 * time.Hour, // 保持一小时的热记忆 + windowSize: 128, + expiration: 1 * time.Hour, } } +func (m *AgentCache) historyKey(sessionID string) string { + return fmt.Sprintf("smartflow:history:%s", sessionID) +} + +func (m *AgentCache) historyWindowKey(sessionID string) string { + return fmt.Sprintf("smartflow:history_window:%s", sessionID) +} + +func (m *AgentCache) normalizeWindowSize(size int) int { + if size < minHistoryWindowSize { + return minHistoryWindowSize + } + if size > maxHistoryWindowSize { + return maxHistoryWindowSize + } + return size +} + +func (m *AgentCache) getSessionWindowSize(ctx context.Context, sessionID string) (int, error) { + windowKey := m.historyWindowKey(sessionID) + val, err := m.client.Get(ctx, windowKey).Result() + if err == redis.Nil { + return m.windowSize, nil + } + if err != nil { + return 0, err + } + size, convErr := strconv.Atoi(val) + if convErr != nil { + return m.windowSize, nil + } + return m.normalizeWindowSize(size), nil +} + +// SetSessionWindowSize 设置会话级窗口上限。 +func (m *AgentCache) SetSessionWindowSize(ctx context.Context, sessionID string, size int) error { + normalized := m.normalizeWindowSize(size) + windowKey := m.historyWindowKey(sessionID) + return m.client.Set(ctx, windowKey, normalized, m.expiration).Err() +} + +// EnforceHistoryWindow 按当前会话窗口强制修剪历史队列。 +func (m *AgentCache) EnforceHistoryWindow(ctx context.Context, sessionID string) error { + size, err := m.getSessionWindowSize(ctx, sessionID) + if err != nil { + return err + } + key := m.historyKey(sessionID) + pipe := m.client.Pipeline() + pipe.LTrim(ctx, key, 0, int64(size-1)) + pipe.Expire(ctx, key, m.expiration) + _, err = pipe.Exec(ctx) + return err +} + func (m *AgentCache) PushMessage(ctx context.Context, sessionID string, msg *schema.Message) error { - key := fmt.Sprintf("smartflow:history:%s", sessionID) + key := m.historyKey(sessionID) + size, err := m.getSessionWindowSize(ctx, sessionID) + if err != nil { + return err + } // 1. 序列化 Eino 消息 data, err := json.Marshal(msg) @@ -37,15 +102,10 @@ func (m *AgentCache) PushMessage(ctx context.Context, sessionID string, msg *sch // 2. 利用 Pipeline 保证原子操作 pipe := m.client.Pipeline() - - // 往左侧推入最新消息 (LIFO 逻辑) + // 往左侧推入最新消息(LIFO) pipe.LPush(ctx, key, data) - - // 核心:强制修剪,只保留最新的 windowSize 条 - // 0 是最新的一条,windowSize-1 是最后一条 - pipe.LTrim(ctx, key, 0, int64(m.windowSize-1)) - - // 刷新过期时间 + // 只保留最新 size 条 + pipe.LTrim(ctx, key, 0, int64(size-1)) pipe.Expire(ctx, key, m.expiration) _, err = pipe.Exec(ctx) @@ -53,15 +113,12 @@ func (m *AgentCache) PushMessage(ctx context.Context, sessionID string, msg *sch } func (m *AgentCache) GetHistory(ctx context.Context, sessionID string) ([]*schema.Message, error) { - key := fmt.Sprintf("smartflow:history:%s", sessionID) + key := m.historyKey(sessionID) - // 获取所有缓存的消息 vals, err := m.client.LRange(ctx, key, 0, -1).Result() if err != nil { return nil, err } - - // 如果 Redis 为空,这里返回 nil 触发后续的 MySQL 捞取逻辑 if len(vals) == 0 { return nil, nil } @@ -72,25 +129,25 @@ func (m *AgentCache) GetHistory(ctx context.Context, sessionID string) ([]*schem if err := json.Unmarshal([]byte(val), &msg); err != nil { return nil, err } - - // 关键逻辑:反转顺序 - // LRANGE 返回顺序:[MsgN, MsgN-1, ... Msg1] - // 我们需要的顺序:[Msg1, ... MsgN-1, MsgN] + // LRANGE 返回 [最新..最旧],这里反转成 [最旧..最新] messages[len(vals)-1-i] = &msg } return messages, nil } -// BackfillHistory 用于缓存失效时,从数据库加载完数据后一次性回填 Redis +// BackfillHistory 在缓存失效时,把历史消息一次性回填到 Redis。 func (m *AgentCache) BackfillHistory(ctx context.Context, sessionID string, messages []*schema.Message) error { - if len(messages) == 0 { - return nil + key := m.historyKey(sessionID) + size, err := m.getSessionWindowSize(ctx, sessionID) + if err != nil { + return err } - key := fmt.Sprintf("smartflow:history:%s", sessionID) + if len(messages) == 0 { + return m.client.Del(ctx, key).Err() + } - // 1. 将所有 Eino 消息序列化为 []interface{} 供 redis 批量写入 values := make([]interface{}, len(messages)) for i, msg := range messages { data, err := json.Marshal(msg) @@ -100,29 +157,21 @@ func (m *AgentCache) BackfillHistory(ctx context.Context, sessionID string, mess values[i] = data } - // 2. 执行原子回填 pipe := m.client.Pipeline() - - // 先清理旧 Key(防止数据重复或残留) pipe.Del(ctx, key) - - // 批量写入:按照 [最旧 -> 最新] 的顺序 LPUSH - // 结果在 Redis 中:[最新, ..., 最旧] (符合我们 GetHistory 的反转逻辑) + // 输入是 [最旧..最新],LPUSH 后变成 [最新..最旧] pipe.LPush(ctx, key, values...) - - // 依然要进行修剪,确保不超过窗口大小 - pipe.LTrim(ctx, key, 0, int64(m.windowSize-1)) - - // 设置过期时间 + pipe.LTrim(ctx, key, 0, int64(size-1)) pipe.Expire(ctx, key, m.expiration) - _, err := pipe.Exec(ctx) + _, err = pipe.Exec(ctx) return err } func (m *AgentCache) ClearHistory(ctx context.Context, sessionID string) error { - key := fmt.Sprintf("smartflow:history:%s", sessionID) - return m.client.Del(ctx, key).Err() + historyKey := m.historyKey(sessionID) + windowKey := m.historyWindowKey(sessionID) + return m.client.Del(ctx, historyKey, windowKey).Err() } func (m *AgentCache) GetConversationStatus(ctx context.Context, sessionID string) (bool, error) { diff --git a/backend/pkg/token_budget.go b/backend/pkg/token_budget.go new file mode 100644 index 0000000..cb28fb2 --- /dev/null +++ b/backend/pkg/token_budget.go @@ -0,0 +1,146 @@ +package pkg + +import ( + "math" + "strings" + "unicode" + + "github.com/cloudwego/eino/schema" +) + +const ( + // Worker 模型最大输入上下文(用户提供) + WorkerMaxInputTokens = 224000 + // 给模型输出和协议开销预留的冗余 token + ContextReserveTokens = 28000 + + // 缓存未命中时,从数据库拉取的历史消息上限 + DefaultHistoryFetchLimit = 1200 + + // Redis 会话窗口上下限与缓冲 + SessionWindowMin = 32 + SessionWindowMax = 4096 + SessionWindowBuffer = 2 +) + +// MaxContextTokensByModel 返回指定模型的最大上下文 token。 +func MaxContextTokensByModel(modelName string) int { + switch strings.ToLower(strings.TrimSpace(modelName)) { + case "worker", "strategist": + return WorkerMaxInputTokens + default: + return WorkerMaxInputTokens + } +} + +// HistoryFetchLimitByModel 返回缓存未命中时的历史拉取条数。 +func HistoryFetchLimitByModel(_ string) int { + return DefaultHistoryFetchLimit +} + +// HistoryTokenBudgetByModel 计算“历史上下文”可使用的 token 预算。 +func HistoryTokenBudgetByModel(modelName, systemPrompt, userInput string) int { + maxTokens := MaxContextTokensByModel(modelName) + baseTokens := EstimateTextTokens(systemPrompt) + EstimateTextTokens(userInput) + 64 + budget := maxTokens - ContextReserveTokens - baseTokens + if budget < 0 { + return 0 + } + return budget +} + +// EstimateTextTokens 粗略估算文本 token: +// - CJK 字符约 1:1 +// - ASCII 字符约 4:1 +// - 其他字符约 2:1 +func EstimateTextTokens(text string) int { + if strings.TrimSpace(text) == "" { + return 0 + } + + var cjkCount, asciiCount, otherCount int + for _, r := range text { + switch { + case unicode.IsSpace(r): + continue + case r <= unicode.MaxASCII: + asciiCount++ + case isCJK(r): + cjkCount++ + default: + otherCount++ + } + } + + tokens := cjkCount + int(math.Ceil(float64(asciiCount)/4.0)) + int(math.Ceil(float64(otherCount)/2.0)) + if tokens <= 0 { + return 1 + } + return tokens +} + +// EstimateMessageTokens 估算单条消息 token(包含固定协议开销)。 +func EstimateMessageTokens(msg *schema.Message) int { + if msg == nil { + return 0 + } + const messageOverhead = 6 + return messageOverhead + EstimateTextTokens(msg.Content) + EstimateTextTokens(msg.ReasoningContent) +} + +// EstimateHistoryTokens 估算历史消息总 token。 +func EstimateHistoryTokens(history []*schema.Message) int { + total := 0 + for _, msg := range history { + total += EstimateMessageTokens(msg) + } + return total +} + +// TrimHistoryByTokenBudget 从最旧消息开始裁剪,直到历史 token 不超过预算。 +// 返回值:裁剪后历史、裁剪前 token、裁剪后 token、裁掉条数。 +func TrimHistoryByTokenBudget(history []*schema.Message, historyBudget int) ([]*schema.Message, int, int, int) { + if len(history) == 0 { + return history, 0, 0, 0 + } + + totalBefore := EstimateHistoryTokens(history) + if historyBudget <= 0 { + return []*schema.Message{}, totalBefore, 0, len(history) + } + if totalBefore <= historyBudget { + return history, totalBefore, totalBefore, 0 + } + + tokenPerMsg := make([]int, len(history)) + total := 0 + for i, msg := range history { + t := EstimateMessageTokens(msg) + tokenPerMsg[i] = t + total += t + } + + drop := 0 + for total > historyBudget && drop < len(history) { + total -= tokenPerMsg[drop] + drop++ + } + + return history[drop:], totalBefore, total, drop +} + +// CalcSessionWindowSize 根据裁剪后消息条数计算 Redis 队列窗口大小。 +func CalcSessionWindowSize(trimmedHistoryLen int) int { + size := trimmedHistoryLen + SessionWindowBuffer + if size < SessionWindowMin { + size = SessionWindowMin + } + if size > SessionWindowMax { + size = SessionWindowMax + } + return size +} + +func isCJK(r rune) bool { + return unicode.Is(unicode.Han, r) || unicode.Is(unicode.Hiragana, r) || unicode.Is(unicode.Katakana, r) || unicode.Is(unicode.Hangul, r) +} diff --git a/backend/service/agent.go b/backend/service/agent.go index 34a9673..78b6317 100644 --- a/backend/service/agent.go +++ b/backend/service/agent.go @@ -9,6 +9,7 @@ import ( "github.com/LoveLosita/smartflow/backend/conv" "github.com/LoveLosita/smartflow/backend/dao" "github.com/LoveLosita/smartflow/backend/inits" + "github.com/LoveLosita/smartflow/backend/pkg" "github.com/cloudwego/eino-ext/components/model/ark" "github.com/cloudwego/eino/schema" "github.com/google/uuid" @@ -82,7 +83,7 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin } } - // 4) 构建历史上下文 + // 4) 组装历史上下文(先读缓存,缓存未命中再读数据库) chatHistory, err := s.agentCache.GetHistory(ctx, chatID) if err != nil { errChan <- err @@ -90,8 +91,11 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin close(errChan) return outChan, errChan } + + cacheMiss := false if chatHistory == nil { - histories, err := s.repo.GetUserChatHistories(ctx, userID, 20, chatID) + cacheMiss = true + histories, err := s.repo.GetUserChatHistories(ctx, userID, pkg.HistoryFetchLimitByModel(resolvedModelName), chatID) if err != nil { errChan <- err close(outChan) @@ -99,7 +103,30 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin return outChan, errChan } chatHistory = conv.ToEinoMessages(histories) - if err = s.agentCache.BackfillHistory(ctx, chatID, chatHistory); err != nil { + } + + // 5) 按 token 预算裁剪历史:从最旧消息开始持续弹出,直到满足预算 + historyBudget := pkg.HistoryTokenBudgetByModel(resolvedModelName, agent.SystemPrompt, userMessage) + trimmedHistory, totalHistoryTokens, keptHistoryTokens, droppedCount := pkg.TrimHistoryByTokenBudget(chatHistory, historyBudget) + chatHistory = trimmedHistory + + // 6) 根据最新裁剪结果动态调整 Redis 会话窗口 + targetWindow := pkg.CalcSessionWindowSize(len(chatHistory)) + if err := s.agentCache.SetSessionWindowSize(ctx, chatID, targetWindow); err != nil { + log.Printf("failed to set history window for %s: %v", chatID, err) + } + if err := s.agentCache.EnforceHistoryWindow(ctx, chatID); err != nil { + log.Printf("failed to enforce history window for %s: %v", chatID, err) + } + + if droppedCount > 0 { + log.Printf("agent history trimmed: chat=%s total_tokens=%d kept_tokens=%d dropped=%d budget=%d target_window=%d", + chatID, totalHistoryTokens, keptHistoryTokens, droppedCount, historyBudget, targetWindow) + } + + // 缓存未命中时,把“裁剪后的历史”回填进缓存 + if cacheMiss { + if err := s.agentCache.BackfillHistory(ctx, chatID, chatHistory); err != nil { errChan <- err close(outChan) close(errChan) @@ -107,7 +134,7 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin } } - // 5) 异步落用户消息 + // 7) 异步落用户消息(先写缓存再写库) go func() { bg := context.Background() _ = s.agentCache.PushMessage(bg, chatID, &schema.Message{ @@ -117,7 +144,7 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin _ = s.repo.SaveChatHistory(bg, userID, chatID, "user", userMessage) }() - // 6) 流式输出模型回复 + // 8) 启动流式聊天 go func() { defer close(outChan) @@ -127,7 +154,7 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin return } - // 7) 异步落助手消息 + // 9) 异步落助手回复 go func() { bg := context.Background() _ = s.agentCache.PushMessage(bg, chatID, &schema.Message{