Version: 0.9.77.dev.260505

后端:
1.阶段 6 CP4/CP5 目录收口与共享边界纯化
- 将 backend 根目录收口为 services、client、gateway、cmd、shared 五个一级目录
- 收拢 bootstrap、inits、infra/kafka、infra/outbox、conv、respond、pkg、middleware,移除根目录旧实现与空目录
- 将 utils 下沉到 services/userauth/internal/auth,将 logic 下沉到 services/schedule/core/planning
- 将迁移期 runtime 桥接实现统一收拢到 services/runtime/{conv,dao,eventsvc,model},删除 shared/legacy 与未再被 import 的旧 service 实现
- 将 gateway/shared/respond 收口为 HTTP/Gin 错误写回适配,shared/respond 仅保留共享错误语义与状态映射
- 将 HTTP IdempotencyMiddleware 与 RateLimitMiddleware 收口到 gateway/middleware
- 将 GormCachePlugin 下沉到 shared/infra/gormcache,将共享 RateLimiter 下沉到 shared/infra/ratelimit,将 agent token budget 下沉到 services/agent/shared
- 删除 InitEino 兼容壳,收缩 cmd/internal/coreinit 仅保留旧组合壳残留域初始化语义
- 更新微服务迁移计划与桌面 checklist,补齐 CP4/CP5 当前切流点、目录终态与验证结果
- 完成 go test ./...、git diff --check 与最终真实 smoke;health、register/login、task/create+get、schedule/today、task-class/list、memory/items、agent chat/meta/timeline/context-stats 全部 200,SSE 合并结果为 CP5_OK 且 [DONE] 只有 1 个
This commit is contained in:
Losita
2026-05-05 23:25:07 +08:00
parent 2a96f4c6f9
commit 3b6fca44a6
226 changed files with 731 additions and 3497 deletions

View File

@@ -6,10 +6,10 @@ import (
"fmt"
"time"
"github.com/LoveLosita/smartflow/backend/conv"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/ports"
"github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/trigger"
"github.com/LoveLosita/smartflow/backend/services/runtime/conv"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
)

View File

@@ -6,7 +6,7 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
const (

View File

@@ -4,7 +4,7 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
// IsPreviewExpired 判断 preview 是否已经超过确认有效期。

View File

@@ -8,8 +8,8 @@ import (
"strconv"
"strings"
"github.com/LoveLosita/smartflow/backend/conv"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/conv"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)

View File

@@ -8,11 +8,11 @@ import (
"log"
"time"
"github.com/LoveLosita/smartflow/backend/dao"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/ports"
activesvc "github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/service"
"github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/trigger"
"github.com/LoveLosita/smartflow/backend/services/runtime/dao"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
const (

View File

@@ -9,11 +9,11 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/candidate"
schedulercontext "github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/context"
"github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/observe"
"github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/ports"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
func candidateDTO(item candidate.Candidate) CandidateDTO {

View File

@@ -7,10 +7,10 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/candidate"
schedulercontext "github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/context"
"github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/observe"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/google/uuid"
"gorm.io/gorm"
)

View File

@@ -6,11 +6,11 @@ import (
"errors"
"time"
"github.com/LoveLosita/smartflow/backend/dao"
"github.com/LoveLosita/smartflow/backend/model"
activeapply "github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/apply"
"github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/applyadapter"
activepreview "github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/preview"
"github.com/LoveLosita/smartflow/backend/services/runtime/dao"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
)

View File

@@ -8,10 +8,10 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/dao"
"github.com/LoveLosita/smartflow/backend/model"
activepreview "github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/preview"
"github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/selection"
"github.com/LoveLosita/smartflow/backend/services/runtime/dao"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/google/uuid"
"gorm.io/gorm"
)

View File

@@ -8,11 +8,11 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/dao"
outboxinfra "github.com/LoveLosita/smartflow/backend/infra/outbox"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/trigger"
"github.com/LoveLosita/smartflow/backend/services/runtime/dao"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
sharedevents "github.com/LoveLosita/smartflow/backend/shared/events"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
"github.com/google/uuid"
"gorm.io/gorm"
)

View File

@@ -8,9 +8,9 @@ import (
"strings"
"time"
outboxinfra "github.com/LoveLosita/smartflow/backend/infra/outbox"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
sharedevents "github.com/LoveLosita/smartflow/backend/shared/events"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
)
// EnqueueActiveScheduleTriggeredInTx 在事务内写入 active_schedule.triggered outbox 消息。

View File

@@ -7,14 +7,14 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/dao"
kafkabus "github.com/LoveLosita/smartflow/backend/infra/kafka"
outboxinfra "github.com/LoveLosita/smartflow/backend/infra/outbox"
"github.com/LoveLosita/smartflow/backend/model"
activegraph "github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/graph"
activepreview "github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/preview"
"github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/trigger"
"github.com/LoveLosita/smartflow/backend/services/runtime/dao"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
sharedevents "github.com/LoveLosita/smartflow/backend/shared/events"
kafkabus "github.com/LoveLosita/smartflow/backend/shared/infra/kafka"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
"github.com/google/uuid"
"gorm.io/gorm"
"gorm.io/gorm/clause"

View File

@@ -3,10 +3,9 @@ package dao
import (
"fmt"
outboxinfra "github.com/LoveLosita/smartflow/backend/infra/outbox"
coremodel "github.com/LoveLosita/smartflow/backend/model"
"github.com/spf13/viper"
"gorm.io/driver/mysql"
coremodel "github.com/LoveLosita/smartflow/backend/services/runtime/model"
mysqlinfra "github.com/LoveLosita/smartflow/backend/shared/infra/mysql"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
"gorm.io/gorm"
)
@@ -17,18 +16,7 @@ import (
// 2. 不迁移 task、schedule、agent、notification 或 user/auth 表,避免独立进程越权管理其它服务模型;
// 3. 返回的 *gorm.DB 供服务内主链路、due job scanner 和 outbox consumer 复用。
func OpenDBFromConfig() (*gorm.DB, error) {
host := viper.GetString("database.host")
port := viper.GetString("database.port")
user := viper.GetString("database.user")
password := viper.GetString("database.password")
dbname := viper.GetString("database.dbname")
dsn := fmt.Sprintf(
"%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
user, password, host, port, dbname,
)
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
db, err := mysqlinfra.OpenDBFromConfig()
if err != nil {
return nil, err
}

View File

@@ -5,9 +5,9 @@ import (
"log"
"strings"
"github.com/LoveLosita/smartflow/backend/respond"
activeapply "github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/apply"
contracts "github.com/LoveLosita/smartflow/backend/shared/contracts/activescheduler"
"github.com/LoveLosita/smartflow/backend/shared/respond"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

View File

@@ -6,10 +6,10 @@ import (
"errors"
"time"
"github.com/LoveLosita/smartflow/backend/respond"
"github.com/LoveLosita/smartflow/backend/services/active_scheduler/rpc/pb"
activeschedulersv "github.com/LoveLosita/smartflow/backend/services/active_scheduler/sv"
contracts "github.com/LoveLosita/smartflow/backend/shared/contracts/activescheduler"
"github.com/LoveLosita/smartflow/backend/shared/respond"
)
type Handler struct {

View File

@@ -8,10 +8,6 @@ import (
"strings"
"time"
rootdao "github.com/LoveLosita/smartflow/backend/dao"
kafkabus "github.com/LoveLosita/smartflow/backend/infra/kafka"
outboxinfra "github.com/LoveLosita/smartflow/backend/infra/outbox"
eventsvc "github.com/LoveLosita/smartflow/backend/service/events"
activeadapters "github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/adapters"
activeapply "github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/apply"
activeapplyadapter "github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/applyadapter"
@@ -22,8 +18,12 @@ import (
activesvc "github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/service"
"github.com/LoveLosita/smartflow/backend/services/active_scheduler/core/trigger"
llmservice "github.com/LoveLosita/smartflow/backend/services/llm"
rootdao "github.com/LoveLosita/smartflow/backend/services/runtime/dao"
eventsvc "github.com/LoveLosita/smartflow/backend/services/runtime/eventsvc"
contracts "github.com/LoveLosita/smartflow/backend/shared/contracts/activescheduler"
sharedevents "github.com/LoveLosita/smartflow/backend/shared/events"
kafkabus "github.com/LoveLosita/smartflow/backend/shared/infra/kafka"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
"gorm.io/gorm"
)

View File

@@ -4,8 +4,8 @@ import (
"fmt"
"time"
"github.com/LoveLosita/smartflow/backend/model"
schedule "github.com/LoveLosita/smartflow/backend/services/agent/tools/schedule"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
// ScheduleStateToPreview 将 agent 的 ScheduleState 转换为前端预览缓存格式。

View File

@@ -6,10 +6,10 @@ import (
"sort"
"time"
baseconv "github.com/LoveLosita/smartflow/backend/conv"
"github.com/LoveLosita/smartflow/backend/dao"
"github.com/LoveLosita/smartflow/backend/model"
schedule "github.com/LoveLosita/smartflow/backend/services/agent/tools/schedule"
baseconv "github.com/LoveLosita/smartflow/backend/services/runtime/conv"
"github.com/LoveLosita/smartflow/backend/services/runtime/dao"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
// ScheduleProvider 实现 model.ScheduleStateProvider 接口。

View File

@@ -3,8 +3,8 @@ package agentconv
import (
"sort"
"github.com/LoveLosita/smartflow/backend/model"
schedule "github.com/LoveLosita/smartflow/backend/services/agent/tools/schedule"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
// WindowDay 表示排课窗口中的一天(相对周 + 周几)。

View File

@@ -1,9 +1,9 @@
package agentconv
import (
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/respond"
schedule "github.com/LoveLosita/smartflow/backend/services/agent/tools/schedule"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/LoveLosita/smartflow/backend/shared/respond"
)
// ApplyPlacedItems 将前端提交的绝对时间放置项应用到 ScheduleState。

View File

@@ -8,13 +8,13 @@ import (
"strings"
"time"
taskmodel "github.com/LoveLosita/smartflow/backend/model"
agentmodel "github.com/LoveLosita/smartflow/backend/services/agent/model"
agentprompt "github.com/LoveLosita/smartflow/backend/services/agent/prompt"
agentrouter "github.com/LoveLosita/smartflow/backend/services/agent/router"
agentshared "github.com/LoveLosita/smartflow/backend/services/agent/shared"
agentstream "github.com/LoveLosita/smartflow/backend/services/agent/stream"
llmservice "github.com/LoveLosita/smartflow/backend/services/llm"
taskmodel "github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/cloudwego/eino/schema"
)

View File

@@ -6,9 +6,9 @@ import (
"fmt"
"log"
"github.com/LoveLosita/smartflow/backend/pkg"
agentmodel "github.com/LoveLosita/smartflow/backend/services/agent/model"
agentprompt "github.com/LoveLosita/smartflow/backend/services/agent/prompt"
agentshared "github.com/LoveLosita/smartflow/backend/services/agent/shared"
agentstream "github.com/LoveLosita/smartflow/backend/services/agent/stream"
llmservice "github.com/LoveLosita/smartflow/backend/services/llm"
"github.com/cloudwego/eino/schema"
@@ -78,7 +78,7 @@ func compactUnifiedMessagesIfNeeded(
msg3 := messages[3].Content
// 3. Token 预算检查。
breakdown, overBudget, needCompactMsg1, needCompactMsg2 := pkg.CheckStageTokenBudget(msg0, msg1, msg2, msg3)
breakdown, overBudget, needCompactMsg1, needCompactMsg2 := agentshared.CheckStageTokenBudget(msg0, msg1, msg2, msg3)
log.Printf(
"[COMPACT:%s] token budget check: total=%d budget=%d over=%v compactMsg1=%v compactMsg2=%v (msg0=%d msg1=%d msg2=%d msg3=%d)",
@@ -97,14 +97,14 @@ func compactUnifiedMessagesIfNeeded(
msg1 = compactUnifiedMsg1(ctx, input, msg1)
messages[1].Content = msg1
// 压缩 msg1 后重算预算。
breakdown = pkg.EstimateStageMessagesTokens(msg0, msg1, msg2, msg3)
breakdown = agentshared.EstimateStageMessagesTokens(msg0, msg1, msg2, msg3)
}
// 6. msg2 压缩(阶段工作区 → LLM 摘要)。
if needCompactMsg2 || breakdown.Total > pkg.StageTokenBudget {
if needCompactMsg2 || breakdown.Total > agentshared.StageTokenBudget {
msg2 = compactUnifiedMsg2(ctx, input, msg2)
messages[2].Content = msg2
breakdown = pkg.EstimateStageMessagesTokens(msg0, msg1, msg2, msg3)
breakdown = agentshared.EstimateStageMessagesTokens(msg0, msg1, msg2, msg3)
}
// 7. 记录最终 token 分布。
@@ -124,8 +124,8 @@ func compactUnifiedMessagesIfNeeded(
// 1. 先按消息类型汇总 token保证总量准确
// 2. 再把最后一个 user 消息尽量视作 msg3保留阶段指令语义
// 3. 其他历史内容归入 msg1 / msg2确保上下文统计不会因为结构不标准而断更。
func estimateFallbackStageTokenBreakdown(messages []*schema.Message) pkg.StageTokenBreakdown {
breakdown := pkg.StageTokenBreakdown{Budget: pkg.StageTokenBudget}
func estimateFallbackStageTokenBreakdown(messages []*schema.Message) agentshared.StageTokenBreakdown {
breakdown := agentshared.StageTokenBreakdown{Budget: agentshared.StageTokenBudget}
if len(messages) == 0 {
return breakdown
}
@@ -146,7 +146,7 @@ func estimateFallbackStageTokenBreakdown(messages []*schema.Message) pkg.StageTo
if msg == nil {
continue
}
tokens := pkg.EstimateMessageTokens(msg)
tokens := agentshared.EstimateMessageTokens(msg)
breakdown.Total += tokens
switch msg.Role {
@@ -199,7 +199,7 @@ func compactUnifiedMsg1(
}
// 3. SSE: 压缩开始。
tokenBefore := pkg.EstimateTextTokens(msg1)
tokenBefore := agentshared.EstimateTextTokens(msg1)
_ = input.Emitter.EmitStatus(
input.StatusBlockID, input.StageName, "context_compact_start",
fmt.Sprintf("正在压缩对话历史(%d tokens...", tokenBefore),
@@ -219,7 +219,7 @@ func compactUnifiedMsg1(
}
// 5. SSE: 压缩完成。
tokenAfter := pkg.EstimateTextTokens(newSummary)
tokenAfter := agentshared.EstimateTextTokens(newSummary)
_ = input.Emitter.EmitStatus(
input.StatusBlockID, input.StageName, "context_compact_done",
fmt.Sprintf("对话历史已压缩:%d → %d tokens", tokenBefore, tokenAfter),
@@ -246,7 +246,7 @@ func compactUnifiedMsg2(
msg2 string,
) string {
// 1. SSE: 压缩开始。
tokenBefore := pkg.EstimateTextTokens(msg2)
tokenBefore := agentshared.EstimateTextTokens(msg2)
_ = input.Emitter.EmitStatus(
input.StatusBlockID, input.StageName, "context_compact_start",
fmt.Sprintf("正在压缩执行记录(%d tokens...", tokenBefore),
@@ -266,7 +266,7 @@ func compactUnifiedMsg2(
}
// 3. SSE: 压缩完成。
tokenAfter := pkg.EstimateTextTokens(compressed)
tokenAfter := agentshared.EstimateTextTokens(compressed)
_ = input.Emitter.EmitStatus(
input.StatusBlockID, input.StageName, "context_compact_done",
fmt.Sprintf("执行记录已压缩:%d → %d tokens", tokenBefore, tokenAfter),
@@ -285,7 +285,7 @@ func compactUnifiedMsg2(
func saveUnifiedTokenStats(
ctx context.Context,
input UnifiedCompactInput,
breakdown pkg.StageTokenBreakdown,
breakdown agentshared.StageTokenBreakdown,
) {
if input.CompactionStore == nil || input.FlowState == nil {
return

View File

@@ -5,7 +5,7 @@ import (
"log"
"strings"
"github.com/LoveLosita/smartflow/backend/respond"
"github.com/LoveLosita/smartflow/backend/shared/respond"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

View File

@@ -6,11 +6,11 @@ import (
"errors"
"strings"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/respond"
"github.com/LoveLosita/smartflow/backend/services/agent/rpc/pb"
agentsv "github.com/LoveLosita/smartflow/backend/services/agent/sv"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
agentcontracts "github.com/LoveLosita/smartflow/backend/shared/contracts/agent"
"github.com/LoveLosita/smartflow/backend/shared/respond"
)
type Handler struct {

View File

@@ -1,6 +1,6 @@
package agentshared
import "github.com/LoveLosita/smartflow/backend/model"
import "github.com/LoveLosita/smartflow/backend/services/runtime/model"
func CloneWeekSchedules(src []model.UserWeekSchedule) []model.UserWeekSchedule {
if len(src) == 0 {

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"log"
"github.com/LoveLosita/smartflow/backend/pkg"
agentmodel "github.com/LoveLosita/smartflow/backend/services/agent/model"
agentprompt "github.com/LoveLosita/smartflow/backend/services/agent/prompt"
agentstream "github.com/LoveLosita/smartflow/backend/services/agent/stream"
@@ -77,7 +76,7 @@ func CompactUnifiedMessagesIfNeeded(
msg3 := messages[3].Content
// 3. 执行 token 预算检查,判断是否需要压缩历史对话或阶段工作区。
breakdown, overBudget, needCompactMsg1, needCompactMsg2 := pkg.CheckStageTokenBudget(msg0, msg1, msg2, msg3)
breakdown, overBudget, needCompactMsg1, needCompactMsg2 := CheckStageTokenBudget(msg0, msg1, msg2, msg3)
log.Printf(
"[COMPACT:%s] token budget check: total=%d budget=%d over=%v compactMsg1=%v compactMsg2=%v (msg0=%d msg1=%d msg2=%d msg3=%d)",
@@ -95,14 +94,14 @@ func CompactUnifiedMessagesIfNeeded(
if needCompactMsg1 {
msg1 = compactUnifiedMsg1(ctx, input, msg1)
messages[1].Content = msg1
breakdown = pkg.EstimateStageMessagesTokens(msg0, msg1, msg2, msg3)
breakdown = EstimateStageMessagesTokens(msg0, msg1, msg2, msg3)
}
// 6. 若 msg1 压缩后仍超限,再压缩 msg2阶段工作区 / ReAct 记录)。
if needCompactMsg2 || breakdown.Total > pkg.StageTokenBudget {
if needCompactMsg2 || breakdown.Total > StageTokenBudget {
msg2 = compactUnifiedMsg2(ctx, input, msg2)
messages[2].Content = msg2
breakdown = pkg.EstimateStageMessagesTokens(msg0, msg1, msg2, msg3)
breakdown = EstimateStageMessagesTokens(msg0, msg1, msg2, msg3)
}
// 7. 记录最终 token 分布,供后续调试与监控使用。
@@ -122,8 +121,8 @@ func CompactUnifiedMessagesIfNeeded(
// 1. 先按消息类型汇总 token保证总量准确
// 2. 再把最后一个 user 消息尽量视作 msg3保留阶段指令语义
// 3. 其他历史内容归入 msg1 / msg2确保上下文统计不会因为结构不标准而断更。
func estimateFallbackStageTokenBreakdown(messages []*schema.Message) pkg.StageTokenBreakdown {
breakdown := pkg.StageTokenBreakdown{Budget: pkg.StageTokenBudget}
func estimateFallbackStageTokenBreakdown(messages []*schema.Message) StageTokenBreakdown {
breakdown := StageTokenBreakdown{Budget: StageTokenBudget}
if len(messages) == 0 {
return breakdown
}
@@ -144,7 +143,7 @@ func estimateFallbackStageTokenBreakdown(messages []*schema.Message) pkg.StageTo
if msg == nil {
continue
}
tokens := pkg.EstimateMessageTokens(msg)
tokens := EstimateMessageTokens(msg)
breakdown.Total += tokens
switch msg.Role {
@@ -194,7 +193,7 @@ func compactUnifiedMsg1(
log.Printf("[COMPACT:%s] load existing compaction failed: %v, proceed without cache", input.StageName, err)
}
tokenBefore := pkg.EstimateTextTokens(msg1)
tokenBefore := EstimateTextTokens(msg1)
_ = input.Emitter.EmitStatus(
input.StatusBlockID, input.StageName, "context_compact_start",
fmt.Sprintf("正在压缩对话历史(%d tokens...", tokenBefore),
@@ -212,7 +211,7 @@ func compactUnifiedMsg1(
return msg1
}
tokenAfter := pkg.EstimateTextTokens(newSummary)
tokenAfter := EstimateTextTokens(newSummary)
_ = input.Emitter.EmitStatus(
input.StatusBlockID, input.StageName, "context_compact_done",
fmt.Sprintf("对话历史已压缩:%d → %d tokens", tokenBefore, tokenAfter),
@@ -237,7 +236,7 @@ func compactUnifiedMsg2(
input UnifiedCompactInput,
msg2 string,
) string {
tokenBefore := pkg.EstimateTextTokens(msg2)
tokenBefore := EstimateTextTokens(msg2)
_ = input.Emitter.EmitStatus(
input.StatusBlockID, input.StageName, "context_compact_start",
fmt.Sprintf("正在压缩执行记录(%d tokens...", tokenBefore),
@@ -255,7 +254,7 @@ func compactUnifiedMsg2(
return msg2
}
tokenAfter := pkg.EstimateTextTokens(compressed)
tokenAfter := EstimateTextTokens(compressed)
_ = input.Emitter.EmitStatus(
input.StatusBlockID, input.StageName, "context_compact_done",
fmt.Sprintf("执行记录已压缩:%d → %d tokens", tokenBefore, tokenAfter),
@@ -274,7 +273,7 @@ func compactUnifiedMsg2(
func saveUnifiedTokenStats(
ctx context.Context,
input UnifiedCompactInput,
breakdown pkg.StageTokenBreakdown,
breakdown StageTokenBreakdown,
) {
if input.CompactionStore == nil || input.FlowState == nil {
return

View File

@@ -0,0 +1,209 @@
package agentshared
import (
"math"
"strings"
"unicode"
"github.com/cloudwego/eino/schema"
)
const (
// Worker 模型最大输入上下文(用户提供)
WorkerMaxInputTokens = 224000
// 给模型输出和协议开销预留的冗余 token
ContextReserveTokens = 28000
// 缓存未命中时,从数据库拉取的历史消息上限
DefaultHistoryFetchLimit = 1200
// Redis 会话窗口上下限与缓冲
SessionWindowMin = 32
SessionWindowMax = 4096
SessionWindowBuffer = 2
// ---- Execute Context Compaction 预算 ----
// Execute 阶段 prompt 总 token 上限
ExecuteTokenBudget = 80000
// msg0 + msg3 固定开销 + 安全余量
ExecuteReserveTokens = 8000
StageTokenBudget = ExecuteTokenBudget
StageReserveTokens = ExecuteReserveTokens
)
// MaxContextTokensByModel 返回指定模型的最大上下文 token。
func MaxContextTokensByModel(modelName string) int {
switch strings.ToLower(strings.TrimSpace(modelName)) {
case "worker", "strategist":
return WorkerMaxInputTokens
default:
return WorkerMaxInputTokens
}
}
// HistoryFetchLimitByModel 返回缓存未命中时的历史拉取条数。
func HistoryFetchLimitByModel(_ string) int {
return DefaultHistoryFetchLimit
}
// HistoryTokenBudgetByModel 计算“历史上下文”可使用的 token 预算。
func HistoryTokenBudgetByModel(modelName, systemPrompt, userInput string) int {
maxTokens := MaxContextTokensByModel(modelName)
baseTokens := EstimateTextTokens(systemPrompt) + EstimateTextTokens(userInput) + 64
budget := maxTokens - ContextReserveTokens - baseTokens
if budget < 0 {
return 0
}
return budget
}
// EstimateTextTokens 粗略估算文本 token
// - CJK 字符约 1:1
// - ASCII 字符约 4:1
// - 其他字符约 2:1
func EstimateTextTokens(text string) int {
if strings.TrimSpace(text) == "" {
return 0
}
var cjkCount, asciiCount, otherCount int
for _, r := range text {
switch {
case unicode.IsSpace(r):
continue
case r <= unicode.MaxASCII:
asciiCount++
case isCJK(r):
cjkCount++
default:
otherCount++
}
}
tokens := cjkCount + int(math.Ceil(float64(asciiCount)/4.0)) + int(math.Ceil(float64(otherCount)/2.0))
if tokens <= 0 {
return 1
}
return tokens
}
// EstimateMessageTokens 估算单条消息 token包含固定协议开销
func EstimateMessageTokens(msg *schema.Message) int {
if msg == nil {
return 0
}
const messageOverhead = 6
return messageOverhead + EstimateTextTokens(msg.Content) + EstimateTextTokens(msg.ReasoningContent)
}
// EstimateHistoryTokens 估算历史消息总 token。
func EstimateHistoryTokens(history []*schema.Message) int {
total := 0
for _, msg := range history {
total += EstimateMessageTokens(msg)
}
return total
}
// TrimHistoryByTokenBudget 从最旧消息开始裁剪,直到历史 token 不超过预算。
// 返回值:裁剪后历史、裁剪前 token、裁剪后 token、裁掉条数。
func TrimHistoryByTokenBudget(history []*schema.Message, historyBudget int) ([]*schema.Message, int, int, int) {
if len(history) == 0 {
return history, 0, 0, 0
}
totalBefore := EstimateHistoryTokens(history)
if historyBudget <= 0 {
return []*schema.Message{}, totalBefore, 0, len(history)
}
if totalBefore <= historyBudget {
return history, totalBefore, totalBefore, 0
}
tokenPerMsg := make([]int, len(history))
total := 0
for i, msg := range history {
t := EstimateMessageTokens(msg)
tokenPerMsg[i] = t
total += t
}
drop := 0
for total > historyBudget && drop < len(history) {
total -= tokenPerMsg[drop]
drop++
}
return history[drop:], totalBefore, total, drop
}
// CalcSessionWindowSize 根据裁剪后消息条数计算 Redis 队列窗口大小。
func CalcSessionWindowSize(trimmedHistoryLen int) int {
size := trimmedHistoryLen + SessionWindowBuffer
if size < SessionWindowMin {
size = SessionWindowMin
}
if size > SessionWindowMax {
size = SessionWindowMax
}
return size
}
func isCJK(r rune) bool {
return unicode.Is(unicode.Han, r) || unicode.Is(unicode.Hiragana, r) || unicode.Is(unicode.Katakana, r) || unicode.Is(unicode.Hangul, r)
}
// StageTokenBreakdown 记录四条阶段消息的 token 分布。
type StageTokenBreakdown struct {
Msg0 int `json:"msg0"`
Msg1 int `json:"msg1"`
Msg2 int `json:"msg2"`
Msg3 int `json:"msg3"`
Total int `json:"total"`
Budget int `json:"budget"`
}
// ExecuteTokenBreakdown 保留为历史兼容别名,避免旧调用点改动。
type ExecuteTokenBreakdown = StageTokenBreakdown
// EstimateStageMessagesTokens 估算四条阶段消息的 token 分布。
func EstimateStageMessagesTokens(msg0, msg1, msg2, msg3 string) StageTokenBreakdown {
b := StageTokenBreakdown{
Msg0: EstimateTextTokens(msg0),
Msg1: EstimateTextTokens(msg1),
Msg2: EstimateTextTokens(msg2),
Msg3: EstimateTextTokens(msg3),
Budget: StageTokenBudget,
}
b.Total = b.Msg0 + b.Msg1 + b.Msg2 + b.Msg3
return b
}
// CheckStageTokenBudget 检查是否超出阶段预算,并给出需要压缩的消息标记。
//
// 1. 先计算四条消息的 token 分布,便于后续日志和统计。
// 2. 如果总量没有超预算,直接返回。
// 3. 如果超预算,则按 msg1 / msg2 的相对占比判断是否需要分别压缩。
func CheckStageTokenBudget(msg0, msg1, msg2, msg3 string) (breakdown StageTokenBreakdown, overBudget bool, needCompactMsg1 bool, needCompactMsg2 bool) {
breakdown = EstimateStageMessagesTokens(msg0, msg1, msg2, msg3)
overBudget = breakdown.Total > StageTokenBudget
if !overBudget {
return
}
// msg1 过大时,优先压缩历史对话。
available := StageTokenBudget - StageReserveTokens
needCompactMsg1 = breakdown.Msg1 > available/2
// 若压缩 msg1 后仍然超限,再压缩执行记录区。
needCompactMsg2 = (breakdown.Total - breakdown.Msg1 + available/4) > StageTokenBudget
return
}
// EstimateExecuteMessagesTokens 保留旧名称,内部复用阶段预算实现。
func EstimateExecuteMessagesTokens(msg0, msg1, msg2, msg3 string) StageTokenBreakdown {
return EstimateStageMessagesTokens(msg0, msg1, msg2, msg3)
}
// CheckExecuteTokenBudget 保留旧名称,内部复用阶段预算实现。
func CheckExecuteTokenBudget(msg0, msg1, msg2, msg3 string) (breakdown StageTokenBreakdown, overBudget bool, needCompactMsg1 bool, needCompactMsg2 bool) {
return CheckStageTokenBudget(msg0, msg1, msg2, msg3)
}

View File

@@ -9,18 +9,18 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/conv"
"github.com/LoveLosita/smartflow/backend/dao"
outboxinfra "github.com/LoveLosita/smartflow/backend/infra/outbox"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/pkg"
eventsvc "github.com/LoveLosita/smartflow/backend/service/events"
agentmodel "github.com/LoveLosita/smartflow/backend/services/agent/model"
agentprompt "github.com/LoveLosita/smartflow/backend/services/agent/prompt"
agentshared "github.com/LoveLosita/smartflow/backend/services/agent/shared"
agenttools "github.com/LoveLosita/smartflow/backend/services/agent/tools"
llmservice "github.com/LoveLosita/smartflow/backend/services/llm"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
memoryobserve "github.com/LoveLosita/smartflow/backend/services/memory/observe"
"github.com/LoveLosita/smartflow/backend/services/runtime/conv"
"github.com/LoveLosita/smartflow/backend/services/runtime/dao"
eventsvc "github.com/LoveLosita/smartflow/backend/services/runtime/eventsvc"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
"github.com/cloudwego/eino/schema"
"github.com/google/uuid"
)
@@ -333,7 +333,7 @@ func (s *AgentService) runNormalChatFlow(
if chatHistory == nil {
// 2. 缓存未命中时回源 DB并转换为 Eino message 格式。
cacheMiss = true
histories, hisErr := s.repo.GetUserChatHistories(ctx, userID, pkg.HistoryFetchLimitByModel(resolvedModelName), chatID)
histories, hisErr := s.repo.GetUserChatHistories(ctx, userID, agentshared.HistoryFetchLimitByModel(resolvedModelName), chatID)
if hisErr != nil {
pushErrNonBlocking(errChan, hisErr)
return
@@ -343,12 +343,12 @@ func (s *AgentService) runNormalChatFlow(
// 3. 计算本次请求可用的历史 token 预算,并执行历史裁剪。
// 这样可以在上下文增长时稳定控制模型窗口,避免超长上下文引发报错或高延迟。
historyBudget := pkg.HistoryTokenBudgetByModel(resolvedModelName, agentprompt.SystemPrompt, userMessage)
trimmedHistory, totalHistoryTokens, keptHistoryTokens, droppedCount := pkg.TrimHistoryByTokenBudget(chatHistory, historyBudget)
historyBudget := agentshared.HistoryTokenBudgetByModel(resolvedModelName, agentprompt.SystemPrompt, userMessage)
trimmedHistory, totalHistoryTokens, keptHistoryTokens, droppedCount := agentshared.TrimHistoryByTokenBudget(chatHistory, historyBudget)
chatHistory = trimmedHistory
// 4. 根据裁剪后历史长度更新 Redis 会话窗口配置,并主动执行窗口收敛。
targetWindow := pkg.CalcSessionWindowSize(len(chatHistory))
targetWindow := agentshared.CalcSessionWindowSize(len(chatHistory))
if err = s.agentCache.SetSessionWindowSize(ctx, chatID, targetWindow); err != nil {
log.Printf("设置历史窗口失败 chat=%s: %v", chatID, err)
}

View File

@@ -7,8 +7,8 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
agentstream "github.com/LoveLosita/smartflow/backend/services/agent/stream"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/cloudwego/eino/schema"
)

View File

@@ -17,12 +17,12 @@ import (
"github.com/cloudwego/eino/schema"
"github.com/spf13/viper"
"github.com/LoveLosita/smartflow/backend/conv"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/pkg"
"github.com/LoveLosita/smartflow/backend/respond"
eventsvc "github.com/LoveLosita/smartflow/backend/service/events"
agentprompt "github.com/LoveLosita/smartflow/backend/services/agent/prompt"
agentshared "github.com/LoveLosita/smartflow/backend/services/agent/shared"
"github.com/LoveLosita/smartflow/backend/services/runtime/conv"
eventsvc "github.com/LoveLosita/smartflow/backend/services/runtime/eventsvc"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/LoveLosita/smartflow/backend/shared/respond"
)
const (
@@ -410,7 +410,7 @@ func (s *AgentService) loadConversationContext(ctx context.Context, chatID, user
// 缓存未命中时回源 DB。
if history == nil {
histories, hisErr := s.repo.GetUserChatHistories(ctx, 0, pkg.HistoryFetchLimitByModel("worker"), chatID)
histories, hisErr := s.repo.GetUserChatHistories(ctx, 0, agentshared.HistoryFetchLimitByModel("worker"), chatID)
if hisErr != nil {
log.Printf("从 DB 加载历史失败 chat=%s: %v", chatID, hisErr)
} else {

View File

@@ -8,10 +8,10 @@ import (
"time"
"unicode/utf8"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/respond"
eventsvc "github.com/LoveLosita/smartflow/backend/service/events"
llmservice "github.com/LoveLosita/smartflow/backend/services/llm"
eventsvc "github.com/LoveLosita/smartflow/backend/services/runtime/eventsvc"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/LoveLosita/smartflow/backend/shared/respond"
"github.com/cloudwego/eino/schema"
)

View File

@@ -7,9 +7,9 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/respond"
agentshared "github.com/LoveLosita/smartflow/backend/services/agent/shared"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/LoveLosita/smartflow/backend/shared/respond"
)
// GetSchedulePlanPreview 按 conversation_id 读取结构化排程预览。

View File

@@ -7,11 +7,11 @@ import (
"log"
"strings"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/respond"
agentconv "github.com/LoveLosita/smartflow/backend/services/agent/conv"
agentmodel "github.com/LoveLosita/smartflow/backend/services/agent/model"
agentshared "github.com/LoveLosita/smartflow/backend/services/agent/shared"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/LoveLosita/smartflow/backend/shared/respond"
)
// SaveScheduleState 处理前端拖拽后的“暂存排程状态”请求。

View File

@@ -7,9 +7,9 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/respond"
agentmodel "github.com/LoveLosita/smartflow/backend/services/agent/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/LoveLosita/smartflow/backend/shared/respond"
)
func (s *AgentService) QueryTasksForTool(ctx context.Context, req agentmodel.TaskQueryRequest) ([]agentmodel.TaskQueryTaskRecord, error) {

View File

@@ -8,9 +8,9 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
eventsvc "github.com/LoveLosita/smartflow/backend/service/events"
agentstream "github.com/LoveLosita/smartflow/backend/services/agent/stream"
eventsvc "github.com/LoveLosita/smartflow/backend/services/runtime/eventsvc"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
)

View File

@@ -13,7 +13,7 @@ import (
//
// 职责边界:
// 1. 只读取候选记忆,不暴露管理写接口;
// 2. 不要求调用方知道 gateway/client/memory 的具体实现;
// 2. 不要求调用方知道 backend/client/memory 的具体实现;
// 3. 错误原样返回给预取链路,由 agent 侧负责软降级和观测记录。
type MemoryRPCReaderClient interface {
Retrieve(ctx context.Context, req memorycontracts.RetrieveRequest) ([]memorycontracts.ItemDTO, error)

View File

@@ -8,9 +8,9 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
agentconv "github.com/LoveLosita/smartflow/backend/services/agent/conv"
scheduletool "github.com/LoveLosita/smartflow/backend/services/agent/tools/schedule"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
schedulecontracts "github.com/LoveLosita/smartflow/backend/shared/contracts/schedule"
taskclasscontracts "github.com/LoveLosita/smartflow/backend/shared/contracts/taskclass"
)

View File

@@ -7,8 +7,8 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
agenttools "github.com/LoveLosita/smartflow/backend/services/agent/tools"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
taskclasscontracts "github.com/LoveLosita/smartflow/backend/shared/contracts/taskclass"
)

View File

@@ -7,10 +7,10 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/respond"
agentmodel "github.com/LoveLosita/smartflow/backend/services/agent/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
taskcontracts "github.com/LoveLosita/smartflow/backend/shared/contracts/task"
"github.com/LoveLosita/smartflow/backend/shared/respond"
)
const quickTaskCreateRPCTimeout = 3 * time.Second

View File

@@ -6,7 +6,7 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
// TaskClassUpsertInput 描述任务类写库工具的标准化入参。

View File

@@ -3,9 +3,9 @@ package agenttools
import (
"strings"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/agent/tools/schedule"
taskclassresult "github.com/LoveLosita/smartflow/backend/services/agent/tools/taskclass_result"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
type taskClassUpsertExecutionInput struct {

View File

@@ -3,8 +3,7 @@ package dao
import (
"fmt"
"github.com/spf13/viper"
"gorm.io/driver/mysql"
mysqlinfra "github.com/LoveLosita/smartflow/backend/shared/infra/mysql"
"gorm.io/gorm"
)
@@ -15,18 +14,7 @@ import (
// 2. 本函数不 AutoMigrate schedule 表,避免 course 进程越权管理 schedule schema
// 3. 启动期只检查运行时依赖表是否存在,缺表时尽早失败。
func OpenDBFromConfig() (*gorm.DB, error) {
host := viper.GetString("database.host")
port := viper.GetString("database.port")
user := viper.GetString("database.user")
password := viper.GetString("database.password")
dbname := viper.GetString("database.dbname")
dsn := fmt.Sprintf(
"%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
user, password, host, port, dbname,
)
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
db, err := mysqlinfra.OpenDBFromConfig()
if err != nil {
return nil, err
}

View File

@@ -3,7 +3,7 @@ package dao
import (
"context"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
)

View File

@@ -5,8 +5,8 @@ import (
"log"
"strings"
"github.com/LoveLosita/smartflow/backend/respond"
coursesv "github.com/LoveLosita/smartflow/backend/services/course/sv"
"github.com/LoveLosita/smartflow/backend/shared/respond"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

View File

@@ -5,11 +5,11 @@ import (
"encoding/json"
"errors"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/respond"
"github.com/LoveLosita/smartflow/backend/services/course/rpc/pb"
coursesv "github.com/LoveLosita/smartflow/backend/services/course/sv"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
coursecontracts "github.com/LoveLosita/smartflow/backend/shared/contracts/course"
"github.com/LoveLosita/smartflow/backend/shared/respond"
)
type Handler struct {

View File

@@ -4,12 +4,12 @@ import (
"context"
"strings"
"github.com/LoveLosita/smartflow/backend/conv"
rootdao "github.com/LoveLosita/smartflow/backend/dao"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/respond"
coursedao "github.com/LoveLosita/smartflow/backend/services/course/dao"
llmservice "github.com/LoveLosita/smartflow/backend/services/llm"
"github.com/LoveLosita/smartflow/backend/services/runtime/conv"
rootdao "github.com/LoveLosita/smartflow/backend/services/runtime/dao"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/LoveLosita/smartflow/backend/shared/respond"
)
type CourseService struct {

View File

@@ -6,7 +6,7 @@ import (
"net/http"
"strings"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
const (

View File

@@ -8,8 +8,8 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
llmservice "github.com/LoveLosita/smartflow/backend/services/llm"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
// ParseCourseTableImage 使用 Ark SDK Responses 解析课程表图片。

View File

@@ -3,7 +3,7 @@ package llm
import (
"strings"
"github.com/LoveLosita/smartflow/backend/inits"
einoinfra "github.com/LoveLosita/smartflow/backend/shared/infra/eino"
)
// Service 只负责统一暴露已经构造好的模型客户端,不负责 prompt 和业务编排。
@@ -19,7 +19,7 @@ type Service struct {
// 2. CourseImageResponsesClient 允许外部预先注入,便于测试或特殊启动路径复用。
// 3. 某个字段为空时不报错,直接保留 nil交给上层继续走兼容降级。
type Options struct {
AIHub *inits.AIHub
AIHub *einoinfra.AIHub
APIKey string
BaseURL string
CourseVisionModel string

View File

@@ -1,14 +1,13 @@
package dao
import (
"context"
"fmt"
outboxinfra "github.com/LoveLosita/smartflow/backend/infra/outbox"
coremodel "github.com/LoveLosita/smartflow/backend/model"
coremodel "github.com/LoveLosita/smartflow/backend/services/runtime/model"
mysqlinfra "github.com/LoveLosita/smartflow/backend/shared/infra/mysql"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
redisinfra "github.com/LoveLosita/smartflow/backend/shared/infra/redis"
"github.com/go-redis/redis/v8"
"github.com/spf13/viper"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
@@ -19,19 +18,8 @@ import (
// 2. 不迁移 agent、task、schedule、active-scheduler、notification 等跨域表,避免独立进程越权管理别的领域;
// 3. 返回的 *gorm.DB 供 memory 服务内部 repo、worker 和 outbox consumer 复用。
func OpenDBFromConfig() (*gorm.DB, error) {
host := viper.GetString("database.host")
port := viper.GetString("database.port")
user := viper.GetString("database.user")
password := viper.GetString("database.password")
dbname := viper.GetString("database.dbname")
dsn := fmt.Sprintf(
"%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
user, password, host, port, dbname,
)
// 1. 先按统一配置建立 MySQL 连接;若连接失败,独立 memory 进程直接 fail fast。
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
db, err := mysqlinfra.OpenDBFromConfig()
if err != nil {
return nil, err
}
@@ -60,15 +48,7 @@ func OpenDBFromConfig() (*gorm.DB, error) {
// 2. 不创建、不预热、不清理任何 memory 业务 key
// 3. Ping 失败直接返回 error让入口在缓存、锁或幂等依赖异常时尽早暴露问题。
func OpenRedisFromConfig() (*redis.Client, error) {
client := redis.NewClient(&redis.Options{
Addr: viper.GetString("redis.host") + ":" + viper.GetString("redis.port"),
Password: viper.GetString("redis.password"),
DB: 0,
})
if _, err := client.Ping(context.Background()).Result(); err != nil {
return nil, err
}
return client, nil
return redisinfra.OpenRedisFromConfig()
}
// autoMigrateMemoryOutboxTable 只迁移 memory 服务自己的 outbox 物理表。

View File

@@ -4,7 +4,7 @@ import (
"sort"
"time"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
const dedupRecentTieWindow = 24 * time.Hour

View File

@@ -7,11 +7,11 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
memoryrepo "github.com/LoveLosita/smartflow/backend/services/memory/internal/repo"
memoryutils "github.com/LoveLosita/smartflow/backend/services/memory/internal/utils"
memoryvectorsync "github.com/LoveLosita/smartflow/backend/services/memory/internal/vectorsync"
memoryobserve "github.com/LoveLosita/smartflow/backend/services/memory/observe"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
)

View File

@@ -4,7 +4,7 @@ import (
"context"
"errors"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
)

View File

@@ -6,8 +6,8 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
)

View File

@@ -6,8 +6,8 @@ import (
"errors"
"time"
"github.com/LoveLosita/smartflow/backend/model"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)

View File

@@ -4,7 +4,7 @@ import (
"context"
"errors"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)

View File

@@ -3,9 +3,9 @@ package service
import (
"strings"
"github.com/LoveLosita/smartflow/backend/model"
memoryutils "github.com/LoveLosita/smartflow/backend/services/memory/internal/utils"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
func toItemDTO(item model.MemoryItem) memorymodel.ItemDTO {

View File

@@ -6,13 +6,13 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/respond"
memoryrepo "github.com/LoveLosita/smartflow/backend/services/memory/internal/repo"
memoryutils "github.com/LoveLosita/smartflow/backend/services/memory/internal/utils"
memoryvectorsync "github.com/LoveLosita/smartflow/backend/services/memory/internal/vectorsync"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
memoryobserve "github.com/LoveLosita/smartflow/backend/services/memory/observe"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/LoveLosita/smartflow/backend/shared/respond"
"gorm.io/gorm"
)

View File

@@ -8,12 +8,12 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
memoryrepo "github.com/LoveLosita/smartflow/backend/services/memory/internal/repo"
memoryutils "github.com/LoveLosita/smartflow/backend/services/memory/internal/utils"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
memoryobserve "github.com/LoveLosita/smartflow/backend/services/memory/observe"
ragservice "github.com/LoveLosita/smartflow/backend/services/rag"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
const (

View File

@@ -5,9 +5,9 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
memoryutils "github.com/LoveLosita/smartflow/backend/services/memory/internal/utils"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
// HybridRetrieve 统一承接读取侧 RAG-first 召回链路。

View File

@@ -4,7 +4,7 @@ import (
"encoding/json"
"strings"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
const (

View File

@@ -1,8 +1,8 @@
package utils
import (
"github.com/LoveLosita/smartflow/backend/model"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
// EffectiveUserSetting 返回用户记忆设置的生效值。

View File

@@ -6,10 +6,10 @@ import (
"log"
"strings"
"github.com/LoveLosita/smartflow/backend/model"
memoryrepo "github.com/LoveLosita/smartflow/backend/services/memory/internal/repo"
memoryobserve "github.com/LoveLosita/smartflow/backend/services/memory/observe"
ragservice "github.com/LoveLosita/smartflow/backend/services/rag"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
// Syncer 负责 memory_items 与向量库之间的最小桥接。

View File

@@ -5,10 +5,10 @@ import (
"fmt"
"strings"
"github.com/LoveLosita/smartflow/backend/model"
memoryrepo "github.com/LoveLosita/smartflow/backend/services/memory/internal/repo"
memoryutils "github.com/LoveLosita/smartflow/backend/services/memory/internal/utils"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
// ApplyActionOutcome 是单个决策动作的执行结果。

View File

@@ -4,11 +4,11 @@ import (
"context"
"fmt"
"github.com/LoveLosita/smartflow/backend/model"
memoryrepo "github.com/LoveLosita/smartflow/backend/services/memory/internal/repo"
memoryutils "github.com/LoveLosita/smartflow/backend/services/memory/internal/utils"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
ragservice "github.com/LoveLosita/smartflow/backend/services/rag"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
)

View File

@@ -9,7 +9,6 @@ import (
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
memoryorchestrator "github.com/LoveLosita/smartflow/backend/services/memory/internal/orchestrator"
memoryrepo "github.com/LoveLosita/smartflow/backend/services/memory/internal/repo"
memoryutils "github.com/LoveLosita/smartflow/backend/services/memory/internal/utils"
@@ -17,6 +16,7 @@ import (
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
memoryobserve "github.com/LoveLosita/smartflow/backend/services/memory/observe"
ragservice "github.com/LoveLosita/smartflow/backend/services/rag"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
)

View File

@@ -5,7 +5,6 @@ import (
"errors"
"log"
"github.com/LoveLosita/smartflow/backend/model"
llmservice "github.com/LoveLosita/smartflow/backend/services/llm"
memorycleanup "github.com/LoveLosita/smartflow/backend/services/memory/internal/cleanup"
memoryorchestrator "github.com/LoveLosita/smartflow/backend/services/memory/internal/orchestrator"
@@ -16,6 +15,7 @@ import (
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
memoryobserve "github.com/LoveLosita/smartflow/backend/services/memory/observe"
ragservice "github.com/LoveLosita/smartflow/backend/services/rag"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
)

View File

@@ -5,7 +5,7 @@ import (
"log"
"strings"
"github.com/LoveLosita/smartflow/backend/respond"
"github.com/LoveLosita/smartflow/backend/shared/respond"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

View File

@@ -4,10 +4,10 @@ import (
"context"
"encoding/json"
"github.com/LoveLosita/smartflow/backend/respond"
"github.com/LoveLosita/smartflow/backend/services/memory/rpc/pb"
memorysv "github.com/LoveLosita/smartflow/backend/services/memory/sv"
memorycontracts "github.com/LoveLosita/smartflow/backend/shared/contracts/memory"
"github.com/LoveLosita/smartflow/backend/shared/respond"
)
type Handler struct {

View File

@@ -5,13 +5,13 @@ import (
"errors"
"log"
kafkabus "github.com/LoveLosita/smartflow/backend/infra/kafka"
outboxinfra "github.com/LoveLosita/smartflow/backend/infra/outbox"
coremodel "github.com/LoveLosita/smartflow/backend/model"
eventsvc "github.com/LoveLosita/smartflow/backend/service/events"
memorymodule "github.com/LoveLosita/smartflow/backend/services/memory"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
eventsvc "github.com/LoveLosita/smartflow/backend/services/runtime/eventsvc"
coremodel "github.com/LoveLosita/smartflow/backend/services/runtime/model"
memorycontracts "github.com/LoveLosita/smartflow/backend/shared/contracts/memory"
kafkabus "github.com/LoveLosita/smartflow/backend/shared/infra/kafka"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
)
// Service 是 memory 独立进程的服务门面。

View File

@@ -3,11 +3,10 @@ package dao
import (
"fmt"
outboxinfra "github.com/LoveLosita/smartflow/backend/infra/outbox"
coremodel "github.com/LoveLosita/smartflow/backend/model"
notificationmodel "github.com/LoveLosita/smartflow/backend/services/notification/model"
"github.com/spf13/viper"
"gorm.io/driver/mysql"
coremodel "github.com/LoveLosita/smartflow/backend/services/runtime/model"
mysqlinfra "github.com/LoveLosita/smartflow/backend/shared/infra/mysql"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
"gorm.io/gorm"
)
@@ -18,18 +17,7 @@ import (
// 2. 不迁移主动调度、agent、userauth 或其它服务表;
// 3. 返回的 *gorm.DB 供 notification 服务内 DAO 和 outbox consumer 复用。
func OpenDBFromConfig() (*gorm.DB, error) {
host := viper.GetString("database.host")
port := viper.GetString("database.port")
user := viper.GetString("database.user")
password := viper.GetString("database.password")
dbname := viper.GetString("database.dbname")
dsn := fmt.Sprintf(
"%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
user, password, host, port, dbname,
)
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
db, err := mysqlinfra.OpenDBFromConfig()
if err != nil {
return nil, err
}

View File

@@ -5,7 +5,7 @@ import (
"log"
"strings"
"github.com/LoveLosita/smartflow/backend/respond"
"github.com/LoveLosita/smartflow/backend/shared/respond"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

View File

@@ -5,10 +5,10 @@ import (
"errors"
"time"
"github.com/LoveLosita/smartflow/backend/respond"
"github.com/LoveLosita/smartflow/backend/services/notification/rpc/pb"
notificationsv "github.com/LoveLosita/smartflow/backend/services/notification/sv"
contracts "github.com/LoveLosita/smartflow/backend/shared/contracts/notification"
"github.com/LoveLosita/smartflow/backend/shared/respond"
)
type Handler struct {

View File

@@ -5,10 +5,10 @@ import (
"errors"
"strings"
"github.com/LoveLosita/smartflow/backend/respond"
notificationfeishu "github.com/LoveLosita/smartflow/backend/services/notification/internal/feishu"
notificationmodel "github.com/LoveLosita/smartflow/backend/services/notification/model"
contracts "github.com/LoveLosita/smartflow/backend/shared/contracts/notification"
"github.com/LoveLosita/smartflow/backend/shared/respond"
"gorm.io/gorm"
)

View File

@@ -7,9 +7,9 @@ import (
"log"
"strings"
kafkabus "github.com/LoveLosita/smartflow/backend/infra/kafka"
outboxinfra "github.com/LoveLosita/smartflow/backend/infra/outbox"
sharedevents "github.com/LoveLosita/smartflow/backend/shared/events"
kafkabus "github.com/LoveLosita/smartflow/backend/shared/infra/kafka"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
)
// OutboxBus 是 notification 服务注册消费 handler 需要的最小总线接口。

View File

@@ -0,0 +1,52 @@
package conv
import (
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/cloudwego/eino/schema"
)
// ToEinoMessages 将数据库模型转换为 Eino 模型
func ToEinoMessages(dbMsgs []model.ChatHistory) []*schema.Message {
res := make([]*schema.Message, 0)
for _, m := range dbMsgs {
var role schema.RoleType
switch safeChatHistoryRole(m.Role) {
case "user":
role = schema.User
case "assistant":
role = schema.Assistant
default:
role = schema.System
}
msg := &schema.Message{
Role: role,
Content: safeChatHistoryText(m.MessageContent),
ReasoningContent: safeChatHistoryText(m.ReasoningContent),
}
// retry 机制已整体下线:历史数据里的 retry_* 列不再回灌到运行期上下文。
extra := make(map[string]any)
extra["history_id"] = m.ID
if m.ReasoningDurationSeconds > 0 {
extra["reasoning_duration_seconds"] = m.ReasoningDurationSeconds
}
if len(extra) > 0 {
msg.Extra = extra
}
res = append(res, msg)
}
return res
}
func safeChatHistoryRole(role *string) string {
if role == nil {
return ""
}
return *role
}
func safeChatHistoryText(text *string) string {
if text == nil {
return ""
}
return *text
}

View File

@@ -0,0 +1,481 @@
package conv
import (
"fmt"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
import "sort"
func SchedulesToScheduleConflictDetail(schedules []model.Schedule) []model.ScheduleConflictDetail {
if len(schedules) == 0 {
return []model.ScheduleConflictDetail{}
}
// 1. 使用 Map 进行逻辑分组
// Key 格式: EventID-Week-Day (防止同一事件在不同天出现时被混为一谈)
groups := make(map[string]*model.ScheduleConflictDetail)
for _, s := range schedules {
key := fmt.Sprintf("%d-%d-%d", s.EventID, s.Week, s.DayOfWeek)
if _, ok := groups[key]; !ok {
// 初始化该分组
groups[key] = &model.ScheduleConflictDetail{
EventID: s.EventID,
Name: s.Event.Name,
Location: *s.Event.Location, // 假设字段是 *string
Type: s.Event.Type,
Week: s.Week,
DayOfWeek: s.DayOfWeek,
}
}
// 将当前节次加入数组
groups[key].Sections = append(groups[key].Sections, s.Section)
}
// 2. 处理每个分组的区间逻辑
res := make([]model.ScheduleConflictDetail, 0, len(groups))
for _, detail := range groups {
// 排序节次,例如把 [3, 1, 2] 变成 [1, 2, 3]
sort.Ints(detail.Sections)
// 最小值即起始,最大值即结束
detail.StartSection = detail.Sections[0]
detail.EndSection = detail.Sections[len(detail.Sections)-1]
res = append(res, *detail)
}
// 3. 可选:对结果集按时间排序,让前端收到的 DTO 也是有序的
sort.Slice(res, func(i, j int) bool {
if res[i].Week != res[j].Week {
return res[i].Week < res[j].Week
}
if res[i].DayOfWeek != res[j].DayOfWeek {
return res[i].DayOfWeek < res[j].DayOfWeek
}
return res[i].StartSection < res[j].StartSection
})
return res
}
// SectionToTime 映射表:将原子节次转为起始/结束时间点
// 此处以重邮为例
var sectionTimeMap = map[int][2]string{
1: {"08:00", "08:45"}, 2: {"08:55", "09:40"},
3: {"10:15", "11:00"}, 4: {"11:10", "11:55"},
5: {"14:00", "14:45"}, 6: {"14:55", "15:40"},
7: {"16:15", "17:00"}, 8: {"17:10", "17:55"},
9: {"19:00", "19:45"}, 10: {"19:55", "20:40"},
11: {"20:50", "21:35"}, 12: {"21:45", "22:30"},
}
func SchedulesToUserTodaySchedule(schedules []model.Schedule) []model.UserTodaySchedule {
if len(schedules) == 0 {
return []model.UserTodaySchedule{}
}
// 1. 数据预处理:按 Week-Day 分组
dayGroups := make(map[string][]model.Schedule)
for _, s := range schedules {
dayKey := fmt.Sprintf("%d-%d", s.Week, s.DayOfWeek)
dayGroups[dayKey] = append(dayGroups[dayKey], s)
}
var result []model.UserTodaySchedule
for _, daySchedules := range dayGroups {
todayDTO := model.UserTodaySchedule{
Week: daySchedules[0].Week,
DayOfWeek: daySchedules[0].DayOfWeek,
Events: []model.EventBrief{},
}
// 💡 关键点:建立一个 Section 查找表,方便 O(1) 确定某节课是什么
sectionMap := make(map[int]model.Schedule)
for _, s := range daySchedules {
sectionMap[s.Section] = s
}
order := 1
// 💡 线性扫描:从第 1 节巡检到第 12 节
for curr := 1; curr <= 12; {
if slot, ok := sectionMap[curr]; ok {
// === A 场景:当前节次有课 ===
// 1. 寻找该事件的连续范围(比如 9-12 节连上)
// 我们向后探测,直到 EventID 变化或节次断开
end := curr
for next := curr + 1; next <= 12; next++ {
if nextSlot, exist := sectionMap[next]; exist && nextSlot.EventID == slot.EventID {
end = next
} else {
break
}
}
// 2. 封装 EventBrief
brief := model.EventBrief{
ID: slot.EventID,
Order: order,
Name: slot.Event.Name,
Location: *slot.Event.Location,
Type: slot.Event.Type,
StartTime: sectionTimeMap[curr][0],
EndTime: sectionTimeMap[end][1],
Span: end - curr + 1,
}
// 3. 处理嵌入任务
// 只要这几个连续节次里有一个有任务,就带上
for i := curr; i <= end; i++ {
if s, exist := sectionMap[i]; exist && s.EmbeddedTask != nil {
brief.EmbeddedTaskInfo = model.TaskBrief{
ID: s.EmbeddedTask.ID,
Name: *s.EmbeddedTask.Content,
Type: "task",
}
break
}
}
todayDTO.Events = append(todayDTO.Events, brief)
// 💡 指针跳跃:直接跳过已处理的节次
curr = end + 1
order++
} else {
// === B 场景当前节次没课Type = "empty" ===
// 逻辑按照学校标准大节1-2, 3-4...)进行空位合并
// 如果当前是奇数节1, 3, 5...)且下一节也没课,就合并成一个空块
emptyEnd := curr
if curr%2 != 0 && curr < 12 {
if _, nextHasClass := sectionMap[curr+1]; !nextHasClass {
emptyEnd = curr + 1
}
}
todayDTO.Events = append(todayDTO.Events, model.EventBrief{
ID: 0, // 空课 ID 为 0
Order: order,
Name: "无课",
Type: "empty",
StartTime: sectionTimeMap[curr][0],
EndTime: sectionTimeMap[emptyEnd][1],
Location: "休息时间",
})
curr = emptyEnd + 1
order++
}
}
result = append(result, todayDTO)
}
return result
}
func SchedulesToUserWeeklySchedule(schedules []model.Schedule) *model.UserWeekSchedule {
if len(schedules) == 0 {
return &model.UserWeekSchedule{
Week: 0,
Events: []model.WeeklyEventBrief{},
}
}
// 1. 初始化返回结构 (默认取第一条数据的周次)
weekDTO := &model.UserWeekSchedule{
Week: schedules[0].Week,
Events: []model.WeeklyEventBrief{},
}
// 2. 建立 [天][节次] 的快速索引地图
// indexMap[day][section] -> model.Schedule
indexMap := make(map[int]map[int]model.Schedule)
for d := 1; d <= 7; d++ {
indexMap[d] = make(map[int]model.Schedule)
}
for _, s := range schedules {
indexMap[s.DayOfWeek][s.Section] = s
}
// 3. 线性扫描 1-7 天
for day := 1; day <= 7; day++ {
order := 1 // 每一天开始时,内部显示顺序重置
// 4. 线性扫描 1-12 节
for curr := 1; curr <= 12; {
// 场景 A当前槽位有课/有任务
if slot, hasClass := indexMap[day][curr]; hasClass {
end := curr
// 探测逻辑:合并相同 EventID 的连续节次 (Span 计算)
for next := curr + 1; next <= 12; next++ {
if nextSlot, exist := indexMap[day][next]; exist && nextSlot.EventID == slot.EventID {
end = next
} else {
break
}
}
span := end - curr + 1
brief := model.WeeklyEventBrief{
ID: slot.EventID,
Order: order,
DayOfWeek: day,
Name: slot.Event.Name,
Location: *slot.Event.Location,
Type: slot.Event.Type,
StartTime: sectionTimeMap[curr][0], // 使用你定义的映射表
EndTime: sectionTimeMap[end][1],
Span: span,
}
// 提取嵌入任务信息 (逻辑同前,探测整个 Span)
for i := curr; i <= end; i++ {
if s, exist := indexMap[day][i]; exist && s.EmbeddedTask != nil {
brief.EmbeddedTaskInfo = model.TaskBrief{
ID: s.EmbeddedTask.ID,
Name: *s.EmbeddedTask.Content,
Type: "task",
}
break
}
}
weekDTO.Events = append(weekDTO.Events, brief)
curr = end + 1 // 指针跳跃到下一块
order++
} else {
// 场景 B无课 (Type="empty"),进行逻辑合并
emptyEnd := curr
// 奇数节起步且下一节也空,则合并为大节 (1-2, 3-4...)
if curr%2 != 0 && curr < 12 {
if _, nextHasClass := indexMap[day][curr+1]; !nextHasClass {
emptyEnd = curr + 1
}
}
weekDTO.Events = append(weekDTO.Events, model.WeeklyEventBrief{
ID: 0,
Order: order,
DayOfWeek: day,
Name: "无课",
Type: "empty",
StartTime: sectionTimeMap[curr][0],
EndTime: sectionTimeMap[emptyEnd][1],
Span: emptyEnd - curr + 1,
Location: "",
})
curr = emptyEnd + 1
order++
}
}
}
return weekDTO
}
func SchedulesToRecentCompletedSchedules(schedules []model.Schedule) *model.UserRecentCompletedScheduleResponse {
// 1. 初始化结果集,确保即使为空也返回空数组而非 nil
result := &model.UserRecentCompletedScheduleResponse{
Events: make([]model.RecentCompletedEventBrief, 0),
}
if len(schedules) == 0 {
return result
}
// 💡 核心去重地图key 是 EventID
seen := make(map[int]bool)
for _, s := range schedules {
// 2. 检查这个逻辑事件(课程或任务块)是否已经处理过
if seen[s.EventID] {
continue
}
// 3. 确定显示的“名分”和“类型”
displayName := s.Event.Name
displayType := s.Event.Type
// 🚀 关键逻辑:如果存在嵌入任务,则“鸠占鹊巢”
// 即使载体是 course只要里面塞了任务我们就对外宣称这是一个 task
if s.EmbeddedTask != nil && s.EmbeddedTask.Content != nil {
displayName = *s.EmbeddedTask.Content
displayType = "embedded_task"
}
// 4. 格式化结束时间 (即完成时间)
strTime := s.Event.EndTime.Format("2006-01-02 15:04:05")
// 5. 构造 Brief
temp := model.RecentCompletedEventBrief{
// ID 统一使用 EventID确保唯一性且方便前端追踪逻辑块
ID: s.EventID,
Name: displayName,
Type: displayType,
CompletedTime: strTime,
}
result.Events = append(result.Events, temp)
// 6. 标记该事件已处理
seen[s.EventID] = true
}
return result
}
func SchedulesToUserOngoingSchedule(schedules []model.Schedule) *model.OngoingSchedule {
if len(schedules) == 0 {
return nil
}
//取第一个 Schedule 的 Event 作为正在进行的事件
ongoing := schedules[0]
return &model.OngoingSchedule{
ID: ongoing.EventID,
Name: ongoing.Event.Name,
Type: ongoing.Event.Type,
Location: *ongoing.Event.Location,
StartTime: ongoing.Event.StartTime,
EndTime: ongoing.Event.EndTime,
}
}
// 这里我们使用一个临时的内部结构来兼容“实日程”和“虚计划”
type slotInfo struct {
schedule *model.Schedule
plan *model.TaskClassItem
}
func PlanningResultToUserWeekSchedules(userSchedule []model.Schedule, plans []model.TaskClassItem) []model.UserWeekSchedule {
// 1. 周次范围探测与数据分桶 (保持高效的 O(N) 复杂度)
minW, maxW := 25, 1
weekMap := make(map[int][]model.Schedule)
for _, s := range userSchedule {
if s.Week < minW {
minW = s.Week
}
if s.Week > maxW {
maxW = s.Week
}
weekMap[s.Week] = append(weekMap[s.Week], s)
}
planMap := make(map[int][]model.TaskClassItem)
for _, p := range plans {
if p.EmbeddedTime == nil {
continue
}
w := p.EmbeddedTime.Week
if w < minW {
minW = w
}
if w > maxW {
maxW = w
}
planMap[w] = append(planMap[w], p)
}
var results []model.UserWeekSchedule
for w := minW; w <= maxW; w++ {
// 构建当前周的逻辑网格
indexMap := make(map[int]map[int]slotInfo)
for d := 1; d <= 7; d++ {
indexMap[d] = make(map[int]slotInfo)
}
for _, s := range weekMap[w] {
indexMap[s.DayOfWeek][s.Section] = slotInfo{schedule: &s}
}
for _, p := range planMap[w] {
for sec := p.EmbeddedTime.SectionFrom; sec <= p.EmbeddedTime.SectionTo; sec++ {
info := indexMap[p.EmbeddedTime.DayOfWeek][sec]
info.plan = &p
indexMap[p.EmbeddedTime.DayOfWeek][sec] = info
}
}
weekDTO := &model.UserWeekSchedule{Week: w, Events: []model.WeeklyEventBrief{}}
for day := 1; day <= 7; day++ {
order := 1
for curr := 1; curr <= 12; {
slot := indexMap[day][curr]
if slot.schedule != nil || slot.plan != nil {
end := curr
// 🚀 修复逻辑 A精准探测合并边界
for next := curr + 1; next <= 12; next++ {
nextSlot := indexMap[day][next]
isSame := false
if slot.schedule != nil && nextSlot.schedule != nil {
// 场景:都是课,且是同一门课
isSame = slot.schedule.EventID == nextSlot.schedule.EventID
} else if slot.schedule == nil && nextSlot.schedule == nil && slot.plan != nil && nextSlot.plan != nil {
// 场景:都是新排任务,且是同一个 TaskItem (修复了之前会合并不同任务的 Bug)
isSame = slot.plan.ID == nextSlot.plan.ID
}
if isSame {
end = next
} else {
break
}
}
// 🚀 修复逻辑 B直接计算 span 并传值,消除重复计算
span := end - curr + 1
brief := buildBrief(slot, day, curr, end, span, order)
weekDTO.Events = append(weekDTO.Events, brief)
curr = end + 1
order++
} else {
// 场景 B留空处理 (逻辑保持原子化)
emptyEnd := curr
if curr%2 != 0 && curr < 12 {
if next := indexMap[day][curr+1]; next.schedule == nil && next.plan == nil {
emptyEnd = curr + 1
}
}
weekDTO.Events = append(weekDTO.Events, model.WeeklyEventBrief{
Name: "无课", Type: "empty", DayOfWeek: day, Order: order,
StartTime: sectionTimeMap[curr][0], EndTime: sectionTimeMap[emptyEnd][1],
Span: emptyEnd - curr + 1,
})
curr = emptyEnd + 1
order++
}
}
}
results = append(results, *weekDTO)
}
return results
}
func buildBrief(slot slotInfo, day, start, end, span, order int) model.WeeklyEventBrief {
brief := model.WeeklyEventBrief{
DayOfWeek: day,
Order: order,
StartTime: sectionTimeMap[start][0],
EndTime: sectionTimeMap[end][1],
Span: span,
Status: "normal", // 默认设为正常状态
}
if slot.schedule != nil {
// 场景 A它是数据库里原有的课 (实日程)
brief.ID = slot.schedule.EventID
brief.Name = slot.schedule.Event.Name
brief.Location = *slot.schedule.Event.Location
brief.Type = slot.schedule.Event.Type
// 如果这节课里被算法“塞”进了一个计划任务
if slot.plan != nil {
brief.Status = "suggested" // 标记为建议状态,前端据此高亮整块
brief.EmbeddedTaskInfo = model.TaskBrief{
ID: slot.plan.ID,
Name: *slot.plan.Content,
Type: "task",
}
}
} else if slot.plan != nil {
// 场景 B它是算法在空地新建的任务块 (虚日程)
brief.Name = *slot.plan.Content
brief.Type = "task"
brief.Status = "suggested" // 标记为建议状态
brief.ID = slot.plan.ID // 虚日程的 ID 直接使用 TaskClassItem 的 ID方便前端追踪和操作
}
return brief
}

View File

@@ -0,0 +1,218 @@
package conv
import (
"errors"
"time"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/LoveLosita/smartflow/backend/shared/respond"
)
const dateLayout = "2006-01-02"
func parseDatePtr(s string) (*time.Time, error) {
if s == "" {
return nil, nil
}
t, err := time.ParseInLocation(dateLayout, s, time.Local)
if err != nil {
return nil, err
}
return &t, nil
}
func ProcessUserAddTaskClassRequest(req *model.UserAddTaskClassRequest, userID int) (*model.TaskClass, []model.TaskClassItem, error) {
startDate, err := parseDatePtr(req.StartDate)
if err != nil {
return nil, nil, respond.WrongParamType
}
endDate, err := parseDatePtr(req.EndDate)
if err != nil {
return nil, nil, respond.WrongParamType
}
//1.填充section1,2
taskClass := model.TaskClass{
Name: &req.Name,
Mode: &req.Mode,
StartDate: startDate,
EndDate: endDate,
SubjectType: stringPtrOrNil(req.SubjectType),
DifficultyLevel: stringPtrOrNil(req.DifficultyLevel),
CognitiveIntensity: stringPtrOrNil(req.CognitiveIntensity),
UserID: &userID,
}
//2.填充section3
taskClass.TotalSlots = &req.Config.TotalSlots
taskClass.AllowFillerCourse = &req.Config.AllowFillerCourse
taskClass.Strategy = &req.Config.Strategy
/*//处理 ExcludedSlots 切片为 JSON 字符串
if len(req.Config.ExcludedSlots) > 0 {
//转换为 JSON 字符串
excludedSlotsJSON := "["
for i, slot := range req.Config.ExcludedSlots {
excludedSlotsJSON += string(rune(slot + '0')) //简单转换为字符
if i != len(req.Config.ExcludedSlots)-1 {
excludedSlotsJSON += ","
}
}
excludedSlotsJSON += "]"
taskClass.ExcludedSlots = &excludedSlotsJSON
} else {
emptyJSON := "[]"
taskClass.ExcludedSlots = &emptyJSON
}*/
taskClass.ExcludedSlots = req.Config.ExcludedSlots // 直接复用 IntSlice 类型,前端也能正确解析为 []int
taskClass.ExcludedDaysOfWeek = req.Config.ExcludedDaysOfWeek
//3.开始构建 items
var items []model.TaskClassItem
for _, itemReq := range req.Items {
item := model.TaskClassItem{ //填充section 2
Order: &itemReq.Order,
Content: &itemReq.Content,
EmbeddedTime: itemReq.EmbeddedTime,
Status: nil,
}
items = append(items, item)
}
return &taskClass, items, nil
}
func timeOrZero(t *time.Time) time.Time {
if t == nil {
return time.Time{}
}
return *t
}
func TaskClassModelToResponse(taskClasses []model.TaskClass) *model.UserGetTaskClassesResponse {
var resp model.UserGetTaskClassesResponse
for _, tc := range taskClasses {
tcResp := model.TaskClassSummary{
ID: tc.ID,
Name: *tc.Name,
Mode: *tc.Mode,
StartDate: timeOrZero(tc.StartDate),
EndDate: timeOrZero(tc.EndDate),
TotalSlots: *tc.TotalSlots,
Strategy: *tc.Strategy,
SubjectType: safeStr(tc.SubjectType),
DifficultyLevel: safeStr(tc.DifficultyLevel),
CognitiveIntensity: safeStr(tc.CognitiveIntensity),
}
resp.TaskClasses = append(resp.TaskClasses, tcResp)
}
return &resp
}
func ProcessUserGetCompleteTaskClassRequest(taskClass *model.TaskClass) (*model.UserAddTaskClassRequest, error) {
if taskClass == nil {
return nil, errors.New("源数据对象不可为空")
}
// 1. 映射基础信息 (处理指针解引用)
req := &model.UserAddTaskClassRequest{
Name: safeStr(taskClass.Name),
Mode: safeStr(taskClass.Mode),
StartDate: formatTime(taskClass.StartDate),
EndDate: formatTime(taskClass.EndDate),
SubjectType: safeStr(taskClass.SubjectType),
DifficultyLevel: safeStr(taskClass.DifficultyLevel),
CognitiveIntensity: safeStr(taskClass.CognitiveIntensity),
}
// 2. 映射配置信息 (Config Section)
req.Config = model.UserAddTaskClassConfig{
TotalSlots: safeInt(taskClass.TotalSlots),
AllowFillerCourse: safeBool(taskClass.AllowFillerCourse),
Strategy: safeStr(taskClass.Strategy),
}
/*// 3. 处理 ExcludedSlots JSON 字符串 -> []int
if taskClass.ExcludedSlots != nil && *taskClass.ExcludedSlots != "" {
var excluded []int
// 直接使用标准反序列化,比手动处理 rune 字符要健壮得多
if err := json.Unmarshal([]byte(*taskClass.ExcludedSlots), &excluded); err == nil {
req.Config.ExcludedSlots = excluded
}
}*/
req.Config.ExcludedSlots = taskClass.ExcludedSlots // 直接复用 IntSlice 类型,前端也能正确解析为 []int
req.Config.ExcludedDaysOfWeek = taskClass.ExcludedDaysOfWeek
// 4. 映射子项信息 (Items Section)
// 此时 items 已经通过 Preload 加载到了 taskClass.Items 中
req.Items = make([]model.UserAddTaskClassItemRequest, 0, len(taskClass.Items))
for _, item := range taskClass.Items {
itemReq := model.UserAddTaskClassItemRequest{
ID: item.ID, // 填充数据库主键 ID前端拖拽编排依赖此字段
Order: safeInt(item.Order),
Content: safeStr(item.Content),
EmbeddedTime: item.EmbeddedTime, // 结构体指针直接复用
}
req.Items = append(req.Items, itemReq)
}
return req, nil
}
// UserInsertTaskItemRequestToModel 用于将填入空闲时段日程的请求转换为 Schedule 模型
func UserInsertTaskItemRequestToModel(req *model.UserInsertTaskClassItemToScheduleRequest, item *model.TaskClassItem, taskID *int, userID, startSection, endSection int) ([]model.Schedule, *model.ScheduleEvent, error) {
var schedules []model.Schedule
for section := startSection; section <= endSection; section++ {
req1 := &model.Schedule{
UserID: userID,
EmbeddedTaskID: taskID,
Week: req.Week,
DayOfWeek: req.DayOfWeek,
Section: section,
Status: "normal",
}
schedules = append(schedules, *req1)
}
startTime, endTime, err := RelativeTimeToRealTime(req.Week, req.DayOfWeek, startSection, endSection)
if err != nil {
return nil, nil, err
}
req2 := &model.ScheduleEvent{
UserID: userID, // 由调用方填充
Name: safeStr(item.Content), // 任务内容作为事件名称
Type: "task",
RelID: &item.ID, // 关联到 TaskClassItem 的 ID
CanBeEmbedded: false, // 任务事件允许嵌入其他任务(如果需要的话)
StartTime: startTime,
EndTime: endTime,
}
return schedules, req2, nil
}
// --- 🛡️ 辅助工具函数:保持代码清爽并防止 Panic ---
func safeStr(s *string) string {
if s == nil {
return ""
}
return *s
}
func safeInt(i *int) int {
if i == nil {
return 0
}
return *i
}
func stringPtrOrNil(value string) *string {
if value == "" {
return nil
}
return &value
}
func safeBool(b *bool) bool {
if b == nil {
return true
}
return *b
}
func formatTime(t *time.Time) string {
if t == nil {
return ""
}
// 务必使用 2006-01-02 格式以匹配前端校验
return t.Format("2006-01-02")
}

View File

@@ -0,0 +1,94 @@
package conv
import (
"time"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
func UserAddTaskRequestToModel(request *model.UserAddTaskRequest, userID int) *model.Task {
return &model.Task{
Title: request.Title,
Priority: request.PriorityGroup,
EstimatedSections: model.NormalizeEstimatedSections(&request.EstimatedSections),
DeadlineAt: request.DeadlineAt,
UrgencyThresholdAt: request.UrgencyThresholdAt,
UserID: userID,
}
}
func ModelToUserAddTaskResponse(task *model.Task) *model.UserAddTaskResponse {
status := "incomplete"
if task.IsCompleted {
status = "completed"
}
return &model.UserAddTaskResponse{
ID: task.ID,
Title: task.Title,
PriorityGroup: task.Priority,
EstimatedSections: model.NormalizeEstimatedSections(&task.EstimatedSections),
DeadlineAt: task.DeadlineAt,
Status: status,
CreatedAt: time.Now(), // 创建时间使用当前服务时间,保持既有响应语义。
}
}
func ModelToGetUserTasksResp(tasks []model.Task) []model.GetUserTaskResp {
var resp []model.GetUserTaskResp
for _, task := range tasks {
status := "incomplete"
if task.IsCompleted {
status = "completed"
}
deadline := ""
if task.DeadlineAt != nil {
deadline = task.DeadlineAt.Format("2006-01-02 15:04:05")
}
urgencyThreshold := ""
if task.UrgencyThresholdAt != nil {
urgencyThreshold = task.UrgencyThresholdAt.Format("2006-01-02 15:04:05")
}
resp = append(resp, model.GetUserTaskResp{
ID: task.ID,
UserID: task.UserID,
Title: task.Title,
PriorityGroup: task.Priority,
EstimatedSections: model.NormalizeEstimatedSections(&task.EstimatedSections),
Status: status,
Deadline: deadline,
IsCompleted: task.IsCompleted,
UrgencyThresholdAt: urgencyThreshold,
})
}
return resp
}
// ModelToGetUserTaskResp 将单个 Task 模型转换为 GetUserTaskResp。
func ModelToGetUserTaskResp(task *model.Task) model.GetUserTaskResp {
status := "incomplete"
if task.IsCompleted {
status = "completed"
}
deadline := ""
if task.DeadlineAt != nil {
deadline = task.DeadlineAt.Format("2006-01-02 15:04:05")
}
urgencyThreshold := ""
if task.UrgencyThresholdAt != nil {
urgencyThreshold = task.UrgencyThresholdAt.Format("2006-01-02 15:04:05")
}
return model.GetUserTaskResp{
ID: task.ID,
UserID: task.UserID,
Title: task.Title,
PriorityGroup: task.Priority,
EstimatedSections: model.NormalizeEstimatedSections(&task.EstimatedSections),
Status: status,
Deadline: deadline,
IsCompleted: task.IsCompleted,
UrgencyThresholdAt: urgencyThreshold,
}
}

View File

@@ -0,0 +1,149 @@
package conv
import (
"errors"
"fmt"
"time"
"github.com/LoveLosita/smartflow/backend/shared/respond"
"github.com/spf13/viper"
)
// DateFormat 此处定义一个全局常量,确保在整个代码中使用统一的日期格式解析和格式化
const DateFormat = "2006-01-02"
// RealDateToRelativeDate 将绝对日期转换为相对日期(格式: "week-day"
func RealDateToRelativeDate(realDate string) (int, int, error) {
SemesterStartDate := viper.GetString("time.semesterStartDate") // 从配置文件中读取学期开始日期
SemesterEndDate := viper.GetString("time.semesterEndDate") // 从配置文件中读取学期结束日期
t, err := time.Parse(DateFormat, realDate)
if err != nil {
return 0, 0, err
}
start, err := time.Parse(DateFormat, SemesterStartDate)
if err != nil {
return 0, 0, err
}
end, err := time.Parse(DateFormat, SemesterEndDate)
if err != nil {
return 0, 0, err
}
// 边界校验:日期必须在学期范围内
if t.Before(start) || t.After(end) {
return 0, 0, errors.New("日期超出学期范围")
}
// 计算天数差值注意24小时为一个基准天
days := int(t.Sub(start).Hours() / 24)
// 计算周数和星期
// 假设 SemesterStartDate 对应第 1 周,周 1
week := (days / 7) + 1
dayOfWeek := (days % 7) + 1
return week, dayOfWeek, nil
}
// RelativeDateToRealDate 将相对日期转换为绝对日期(输入格式: "week-day"
func RelativeDateToRealDate(week, dayOfWeek int) (string, error) {
SemesterStartDate := viper.GetString("time.semesterStartDate") // 从配置文件中读取学期开始日期
SemesterEndDate := viper.GetString("time.semesterEndDate") // 从配置文件中读取学期结束日期
start, _ := time.Parse(DateFormat, SemesterStartDate)
// 核心转换逻辑:(周-1)*7 + (天-1)
offsetDays := (week-1)*7 + (dayOfWeek - 1)
targetDate := start.AddDate(0, 0, offsetDays)
// 校验计算出的日期是否超出学期结束日期
end, _ := time.Parse(DateFormat, SemesterEndDate)
if targetDate.After(end) {
return "", respond.TimeOutOfRangeOfThisSemester
}
return targetDate.Format(DateFormat), nil
}
type SectionTime struct {
Start string // 第一个开始
End string // 第一个结束
}
var SectionTimeMap2 = map[int]SectionTime{
1: {Start: "08:00", End: "08:45"},
2: {Start: "08:55", End: "09:40"},
3: {Start: "10:15", End: "11:00"},
4: {Start: "11:10", End: "11:55"},
5: {Start: "14:00", End: "14:45"},
6: {Start: "14:55", End: "15:40"},
7: {Start: "16:15", End: "17:00"},
8: {Start: "17:10", End: "17:55"},
9: {Start: "19:00", End: "19:45"},
10: {Start: "19:55", End: "20:40"},
11: {Start: "20:50", End: "21:35"},
12: {Start: "21:45", End: "22:30"},
}
func RelativeTimeToRealTime(week, dayOfWeek, startSection, endSection int) (time.Time, time.Time, error) {
// 1. 安全校验
if startSection > endSection {
return time.Time{}, time.Time{}, respond.InvalidSectionRange
}
startTimeInfo, okStart := SectionTimeMap2[startSection]
endTimeInfo, okEnd := SectionTimeMap2[endSection]
if !okStart || !okEnd {
return time.Time{}, time.Time{}, respond.InvalidSectionNumber
}
if week < 1 || dayOfWeek < 1 || dayOfWeek > 7 {
return time.Time{}, time.Time{}, respond.InvalidWeekOrDayOfWeek
}
// 2. 计算目标日期
// 偏移天数 = (周数-1)*7 + (周几-1)
daysOffset := (week-1)*7 + (dayOfWeek - 1)
TermStartDate := viper.GetString("time.semesterStartDate") // 从配置文件中读取学期开始日期
baseDate, _ := time.Parse("2006-01-02", TermStartDate)
targetDate := baseDate.AddDate(0, 0, daysOffset)
dateStr := targetDate.Format("2006-01-02")
// 3. 锁定时区 (Asia/Shanghai)
timeZone := viper.GetString("time.zone") // 从配置文件中读取时区
loc, _ := time.LoadLocation(timeZone)
// 拼接:起始节次的 Start 和 结束节次的 End
startFullStr := fmt.Sprintf("%s %s", dateStr, startTimeInfo.Start)
endFullStr := fmt.Sprintf("%s %s", dateStr, endTimeInfo.End)
startTime, err := time.ParseInLocation("2006-01-02 15:04", startFullStr, loc)
if err != nil {
return time.Time{}, time.Time{}, err
}
endTime, err := time.ParseInLocation("2006-01-02 15:04", endFullStr, loc)
if err != nil {
return time.Time{}, time.Time{}, err
}
return startTime, endTime, nil
}
func CalculateFirstDayOfWeek(date time.Time) time.Time {
// 计算当前日期是周几0-60表示周日
weekday := int(date.Weekday())
if weekday == 0 {
weekday = 7 // 将周日调整为7方便计算
}
// 计算距离周一的天数偏移
offset := weekday - 1
// 计算本周一的日期
firstDayOfWeek := date.AddDate(0, 0, -offset)
return firstDayOfWeek
}
func CalculateLastDayOfWeek(date time.Time) time.Time {
// 计算当前日期是周几0-60表示周日
weekday := int(date.Weekday())
if weekday == 0 {
weekday = 7 // 将周日调整为7方便计算
}
// 计算距离周日的天数偏移
offset := 7 - weekday
// 计算本周日的日期
lastDayOfWeek := date.AddDate(0, 0, offset)
return lastDayOfWeek
}

View File

@@ -0,0 +1,310 @@
package dao
import (
"context"
"errors"
"time"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// ActiveScheduleDAO 管理主动调度阶段 1 的自有表。
//
// 职责边界:
// 1. 只负责 active_schedule_jobs / triggers / previews 的基础读写;
// 2. 不负责构造候选、调用 LLM、投递 provider 或写正式日程;
// 3. 幂等查询只按持久化键读取事实,是否复用结果由上层状态机判断。
type ActiveScheduleDAO struct {
db *gorm.DB
}
func NewActiveScheduleDAO(db *gorm.DB) *ActiveScheduleDAO {
return &ActiveScheduleDAO{db: db}
}
func (d *ActiveScheduleDAO) WithTx(tx *gorm.DB) *ActiveScheduleDAO {
return &ActiveScheduleDAO{db: tx}
}
func (d *ActiveScheduleDAO) ensureDB() error {
if d == nil || d.db == nil {
return errors.New("active schedule dao 未初始化")
}
return nil
}
// CreateOrUpdateJob 按 job.id 幂等创建或覆盖主动调度 job。
//
// 职责边界:
// 1. 只按主键 upsert 当前传入的 job 快照;
// 2. 不判断 task 是否仍满足主动调度条件,该判断由 job scanner 读取 task 真值后完成;
// 3. 调用方需要保证 ID 稳定,例如按 task_id 当前有效 job 或生成 asj_*。
func (d *ActiveScheduleDAO) CreateOrUpdateJob(ctx context.Context, job *model.ActiveScheduleJob) error {
if err := d.ensureDB(); err != nil {
return err
}
if job == nil || job.ID == "" {
return errors.New("active schedule job 不能为空且必须包含 id")
}
return d.db.WithContext(ctx).
Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "id"}},
UpdateAll: true,
}).
Create(job).Error
}
// UpdateJobFields 按 job_id 更新指定字段。
//
// 职责边界:
// 1. 只执行局部字段更新,不隐式改变其它状态;
// 2. updates 为空时直接返回 nil方便上层按条件拼装更新
// 3. 不做状态机合法性校验,状态流转由 active_scheduler/job 负责。
func (d *ActiveScheduleDAO) UpdateJobFields(ctx context.Context, jobID string, updates map[string]any) error {
if err := d.ensureDB(); err != nil {
return err
}
if jobID == "" {
return errors.New("active schedule job id 不能为空")
}
if len(updates) == 0 {
return nil
}
return d.db.WithContext(ctx).
Model(&model.ActiveScheduleJob{}).
Where("id = ?", jobID).
Updates(updates).Error
}
func (d *ActiveScheduleDAO) GetJobByID(ctx context.Context, jobID string) (*model.ActiveScheduleJob, error) {
if err := d.ensureDB(); err != nil {
return nil, err
}
if jobID == "" {
return nil, gorm.ErrRecordNotFound
}
var job model.ActiveScheduleJob
err := d.db.WithContext(ctx).Where("id = ?", jobID).First(&job).Error
if err != nil {
return nil, err
}
return &job, nil
}
// FindPendingJobByTask 查询某个 task 当前待触发 job。
//
// 说明:
// 1. 用于 task 创建/更新时决定复用还是覆盖当前有效 job
// 2. 只查 pending已 triggered/canceled/skipped 的历史 job 保留审计,不再被覆盖。
func (d *ActiveScheduleDAO) FindPendingJobByTask(ctx context.Context, userID int, taskID int) (*model.ActiveScheduleJob, error) {
if err := d.ensureDB(); err != nil {
return nil, err
}
if userID <= 0 || taskID <= 0 {
return nil, gorm.ErrRecordNotFound
}
var job model.ActiveScheduleJob
err := d.db.WithContext(ctx).
Where("user_id = ? AND task_id = ? AND status = ?", userID, taskID, model.ActiveScheduleJobStatusPending).
Order("trigger_at ASC, created_at ASC").
First(&job).Error
if err != nil {
return nil, err
}
return &job, nil
}
// ListDueJobs 读取到期且仍待触发的 job。
//
// 失败处理:
// 1. 参数非法时返回空列表,避免 worker 因配置抖动误扫全表;
// 2. 数据库错误直接返回,让上层按扫描器策略记录并重试。
func (d *ActiveScheduleDAO) ListDueJobs(ctx context.Context, now time.Time, limit int) ([]model.ActiveScheduleJob, error) {
if err := d.ensureDB(); err != nil {
return nil, err
}
if limit <= 0 || now.IsZero() {
return []model.ActiveScheduleJob{}, nil
}
var jobs []model.ActiveScheduleJob
err := d.db.WithContext(ctx).
Where("status = ? AND trigger_at <= ?", model.ActiveScheduleJobStatusPending, now).
Order("trigger_at ASC, id ASC").
Limit(limit).
Find(&jobs).Error
if err != nil {
return nil, err
}
return jobs, nil
}
func (d *ActiveScheduleDAO) CreateTrigger(ctx context.Context, trigger *model.ActiveScheduleTrigger) error {
if err := d.ensureDB(); err != nil {
return err
}
if trigger == nil || trigger.ID == "" {
return errors.New("active schedule trigger 不能为空且必须包含 id")
}
return d.db.WithContext(ctx).Create(trigger).Error
}
// UpdateTriggerFields 按 trigger_id 局部更新触发状态。
//
// 职责边界:
// 1. 只提供字段更新能力,不判断 pending -> processing -> preview_generated 是否合规;
// 2. 上层若需要 CAS 状态流转,应在 updates 外自行加 where 条件或后续扩展专用方法;
// 3. updates 为空时直接返回 nil。
func (d *ActiveScheduleDAO) UpdateTriggerFields(ctx context.Context, triggerID string, updates map[string]any) error {
if err := d.ensureDB(); err != nil {
return err
}
if triggerID == "" {
return errors.New("active schedule trigger id 不能为空")
}
if len(updates) == 0 {
return nil
}
return d.db.WithContext(ctx).
Model(&model.ActiveScheduleTrigger{}).
Where("id = ?", triggerID).
Updates(updates).Error
}
func (d *ActiveScheduleDAO) GetTriggerByID(ctx context.Context, triggerID string) (*model.ActiveScheduleTrigger, error) {
if err := d.ensureDB(); err != nil {
return nil, err
}
if triggerID == "" {
return nil, gorm.ErrRecordNotFound
}
var trigger model.ActiveScheduleTrigger
err := d.db.WithContext(ctx).Where("id = ?", triggerID).First(&trigger).Error
if err != nil {
return nil, err
}
return &trigger, nil
}
// FindTriggerByDedupeKey 查询触发去重键对应的最近 trigger。
//
// 说明:
// 1. important_urgent_task 使用 user_id + trigger_type + target + 30 分钟窗口构造 dedupe_key
// 2. unfinished_feedback 可把反馈幂等键放入 dedupe_key
// 3. statuses 为空时读取所有状态,方便调用方按场景选择是否复用 failed 记录。
func (d *ActiveScheduleDAO) FindTriggerByDedupeKey(ctx context.Context, dedupeKey string, statuses []string) (*model.ActiveScheduleTrigger, error) {
if err := d.ensureDB(); err != nil {
return nil, err
}
if dedupeKey == "" {
return nil, gorm.ErrRecordNotFound
}
query := d.db.WithContext(ctx).
Where("dedupe_key = ?", dedupeKey)
if len(statuses) > 0 {
query = query.Where("status IN ?", statuses)
}
var trigger model.ActiveScheduleTrigger
err := query.Order("created_at DESC, id DESC").First(&trigger).Error
if err != nil {
return nil, err
}
return &trigger, nil
}
// FindTriggerByIdempotencyKey 查询 API/用户反馈幂等键对应的 trigger。
func (d *ActiveScheduleDAO) FindTriggerByIdempotencyKey(ctx context.Context, userID int, triggerType string, idempotencyKey string) (*model.ActiveScheduleTrigger, error) {
if err := d.ensureDB(); err != nil {
return nil, err
}
if userID <= 0 || triggerType == "" || idempotencyKey == "" {
return nil, gorm.ErrRecordNotFound
}
var trigger model.ActiveScheduleTrigger
err := d.db.WithContext(ctx).
Where("user_id = ? AND trigger_type = ? AND idempotency_key = ?", userID, triggerType, idempotencyKey).
Order("created_at DESC, id DESC").
First(&trigger).Error
if err != nil {
return nil, err
}
return &trigger, nil
}
func (d *ActiveScheduleDAO) CreatePreview(ctx context.Context, preview *model.ActiveSchedulePreview) error {
if err := d.ensureDB(); err != nil {
return err
}
if preview == nil || preview.ID == "" {
return errors.New("active schedule preview 不能为空且必须包含 preview_id")
}
return d.db.WithContext(ctx).Create(preview).Error
}
func (d *ActiveScheduleDAO) UpdatePreviewFields(ctx context.Context, previewID string, updates map[string]any) error {
if err := d.ensureDB(); err != nil {
return err
}
if previewID == "" {
return errors.New("active schedule preview id 不能为空")
}
if len(updates) == 0 {
return nil
}
return d.db.WithContext(ctx).
Model(&model.ActiveSchedulePreview{}).
Where("preview_id = ?", previewID).
Updates(updates).Error
}
func (d *ActiveScheduleDAO) GetPreviewByID(ctx context.Context, previewID string) (*model.ActiveSchedulePreview, error) {
if err := d.ensureDB(); err != nil {
return nil, err
}
if previewID == "" {
return nil, gorm.ErrRecordNotFound
}
var preview model.ActiveSchedulePreview
err := d.db.WithContext(ctx).Where("preview_id = ?", previewID).First(&preview).Error
if err != nil {
return nil, err
}
return &preview, nil
}
func (d *ActiveScheduleDAO) GetPreviewByTriggerID(ctx context.Context, triggerID string) (*model.ActiveSchedulePreview, error) {
if err := d.ensureDB(); err != nil {
return nil, err
}
if triggerID == "" {
return nil, gorm.ErrRecordNotFound
}
var preview model.ActiveSchedulePreview
err := d.db.WithContext(ctx).
Where("trigger_id = ?", triggerID).
Order("created_at DESC").
First(&preview).Error
if err != nil {
return nil, err
}
return &preview, nil
}
// FindPreviewByApplyIdempotencyKey 查询 confirm 重试时的预览应用状态。
func (d *ActiveScheduleDAO) FindPreviewByApplyIdempotencyKey(ctx context.Context, previewID string, idempotencyKey string) (*model.ActiveSchedulePreview, error) {
if err := d.ensureDB(); err != nil {
return nil, err
}
if previewID == "" || idempotencyKey == "" {
return nil, gorm.ErrRecordNotFound
}
var preview model.ActiveSchedulePreview
err := d.db.WithContext(ctx).
Where("preview_id = ? AND apply_idempotency_key = ?", previewID, idempotencyKey).
First(&preview).Error
if err != nil {
return nil, err
}
return &preview, nil
}

View File

@@ -0,0 +1,438 @@
package dao
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
var activeScheduleSessionLiveStatuses = []string{
model.ActiveScheduleSessionStatusWaitingUserReply,
model.ActiveScheduleSessionStatusRerunning,
}
// ActiveScheduleSessionDAO 负责主动调度会话的数据库读写。
//
// 职责边界:
// 1. 只管 session 表本身,不管聊天入口拦截策略;
// 2. 只提供按 session_id / conversation_id 的读写能力,不编排 graph
// 3. cache 命中策略由上层决定,这里始终把 MySQL 当作最终真相。
type ActiveScheduleSessionDAO struct {
db *gorm.DB
}
// NewActiveScheduleSessionDAO 创建主动调度会话 DAO。
func NewActiveScheduleSessionDAO(db *gorm.DB) *ActiveScheduleSessionDAO {
return &ActiveScheduleSessionDAO{db: db}
}
// WithTx 基于外部事务句柄构造同事务 DAO。
func (d *ActiveScheduleSessionDAO) WithTx(tx *gorm.DB) *ActiveScheduleSessionDAO {
return &ActiveScheduleSessionDAO{db: tx}
}
func (d *ActiveScheduleSessionDAO) ensureDB() error {
if d == nil || d.db == nil {
return errors.New("active schedule session dao 未初始化")
}
return nil
}
// UpsertActiveScheduleSession 按 session_id 幂等写入或覆盖主动调度会话。
//
// 步骤化说明:
// 1. 先校验主键、归属用户和状态,避免把脏会话写进数据表;
// 2. 再把轻量 state 统一序列化为 state_json保证数据库侧格式稳定
// 3. 最后走 OnConflict upsert保留 created_at仅刷新业务字段和 updated_at。
func (d *ActiveScheduleSessionDAO) UpsertActiveScheduleSession(ctx context.Context, snapshot *model.ActiveScheduleSessionSnapshot) error {
if err := d.ensureDB(); err != nil {
return err
}
normalized, err := normalizeActiveScheduleSessionSnapshot(snapshot)
if err != nil {
return err
}
stateJSON, err := marshalActiveScheduleSessionState(normalized.State)
if err != nil {
return fmt.Errorf("marshal active schedule session state failed: %w", err)
}
now := time.Now()
row := model.ActiveScheduleSession{
SessionID: normalized.SessionID,
UserID: normalized.UserID,
ConversationID: nullableStringPtr(normalized.ConversationID),
TriggerID: normalized.TriggerID,
CurrentPreviewID: nullableStringPtr(normalized.CurrentPreviewID),
Status: normalized.Status,
StateJSON: stateJSON,
CreatedAt: normalized.CreatedAt,
UpdatedAt: now,
}
if row.CreatedAt.IsZero() {
row.CreatedAt = now
}
return d.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "session_id"},
},
DoUpdates: clause.Assignments(map[string]any{
"user_id": row.UserID,
"conversation_id": row.ConversationID,
"trigger_id": row.TriggerID,
"current_preview_id": row.CurrentPreviewID,
"status": row.Status,
"state_json": row.StateJSON,
"updated_at": row.UpdatedAt,
}),
}).Create(&row).Error
}
// GetActiveScheduleSessionBySessionID 按 session_id 读取任意状态的会话记录。
//
// 返回语义:
// 1. 命中:返回 snapshot, nil
// 2. 未命中:返回 nil, nil交给上层判断是否需要走回源或新建
// 3. 数据损坏:返回 error避免把坏状态继续传给拦截逻辑。
func (d *ActiveScheduleSessionDAO) GetActiveScheduleSessionBySessionID(ctx context.Context, sessionID string) (*model.ActiveScheduleSessionSnapshot, error) {
if err := d.ensureDB(); err != nil {
return nil, err
}
normalizedSessionID := strings.TrimSpace(sessionID)
if normalizedSessionID == "" {
return nil, errors.New("session_id is empty")
}
var row model.ActiveScheduleSession
err := d.db.WithContext(ctx).
Where("session_id = ?", normalizedSessionID).
First(&row).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return activeScheduleSessionSnapshotFromRow(&row)
}
// GetActiveScheduleSessionByConversationID 按 user_id + conversation_id 读取最新的会话记录。
//
// 职责边界:
// 1. 始终返回同一 conversation 最新的一条记录,方便上层直接判断当前 status
// 2. 不在 DAO 内部做“是否拦截”的业务裁决,避免把路由规则写死在存储层;
// 3. 若同一 conversation 误写出多条记录,按最近更新时间优先返回。
func (d *ActiveScheduleSessionDAO) GetActiveScheduleSessionByConversationID(ctx context.Context, userID int, conversationID string) (*model.ActiveScheduleSessionSnapshot, error) {
if err := d.ensureDB(); err != nil {
return nil, err
}
if userID <= 0 {
return nil, fmt.Errorf("invalid user_id: %d", userID)
}
normalizedConversationID := strings.TrimSpace(conversationID)
if normalizedConversationID == "" {
return nil, errors.New("conversation_id is empty")
}
var row model.ActiveScheduleSession
err := d.db.WithContext(ctx).
Where("user_id = ? AND conversation_id = ?", userID, normalizedConversationID).
Order("updated_at DESC, created_at DESC, session_id DESC").
First(&row).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return activeScheduleSessionSnapshotFromRow(&row)
}
// UpdateActiveScheduleSessionFieldsBySessionID 按 session_id 更新局部字段。
//
// 说明:
// 1. 这里不负责 state_json 的序列化,调用方需要自己准备好最终字段值;
// 2. 若 updates 为空,直接返回 nil避免多余的数据库写入
// 3. updated_at 会在这里自动刷新,保证时间线可追踪。
func (d *ActiveScheduleSessionDAO) UpdateActiveScheduleSessionFieldsBySessionID(ctx context.Context, sessionID string, updates map[string]any) error {
if err := d.ensureDB(); err != nil {
return err
}
normalizedSessionID := strings.TrimSpace(sessionID)
if normalizedSessionID == "" {
return errors.New("session_id is empty")
}
if len(updates) == 0 {
return nil
}
normalizedUpdates := cloneUpdateMap(updates)
if _, ok := normalizedUpdates["updated_at"]; !ok {
normalizedUpdates["updated_at"] = time.Now()
}
return d.db.WithContext(ctx).
Model(&model.ActiveScheduleSession{}).
Where("session_id = ?", normalizedSessionID).
Updates(normalizedUpdates).Error
}
// TryTransitionActiveScheduleSessionStatusBySessionID 按 session_id 原子切换主动调度会话状态。
//
// 职责边界:
// 1. 只负责“当前状态仍为 fromStatus 时才切到 toStatus”的轻量 CAS不写 state_json 和 preview_id
// 2. 返回 true 表示本次调用抢到了状态推进权,可以继续执行后续 rerun
// 3. 返回 false 表示已有其他请求先推进了状态,调用方应降级为占管提示,避免重复生成 preview。
func (d *ActiveScheduleSessionDAO) TryTransitionActiveScheduleSessionStatusBySessionID(ctx context.Context, sessionID string, fromStatus string, toStatus string) (bool, error) {
if err := d.ensureDB(); err != nil {
return false, err
}
normalizedSessionID := strings.TrimSpace(sessionID)
if normalizedSessionID == "" {
return false, errors.New("session_id is empty")
}
normalizedFrom, err := normalizeActiveScheduleSessionStatus(fromStatus)
if err != nil {
return false, fmt.Errorf("invalid active schedule session from status: %w", err)
}
normalizedTo, err := normalizeActiveScheduleSessionStatus(toStatus)
if err != nil {
return false, fmt.Errorf("invalid active schedule session to status: %w", err)
}
result := d.db.WithContext(ctx).
Model(&model.ActiveScheduleSession{}).
Where("session_id = ? AND status = ?", normalizedSessionID, normalizedFrom).
Updates(map[string]any{
"status": normalizedTo,
"updated_at": time.Now(),
})
if result.Error != nil {
return false, result.Error
}
return result.RowsAffected > 0, nil
}
// UpdateActiveScheduleSessionFieldsByConversationID 按 user_id + conversation_id 更新最新记录的局部字段。
//
// 步骤化说明:
// 1. 先定位同一 conversation 最新的 session再按 session_id 回写,避免一次 update 覆盖多条历史;
// 2. 再写入局部字段和 updated_at保证状态变化可以按会话维度回写
// 3. 找不到任何会话时直接返回,交给上层决定是否要新建 session 或释放普通聊天。
func (d *ActiveScheduleSessionDAO) UpdateActiveScheduleSessionFieldsByConversationID(ctx context.Context, userID int, conversationID string, updates map[string]any) error {
if err := d.ensureDB(); err != nil {
return err
}
if userID <= 0 {
return fmt.Errorf("invalid user_id: %d", userID)
}
normalizedConversationID := strings.TrimSpace(conversationID)
if normalizedConversationID == "" {
return errors.New("conversation_id is empty")
}
if len(updates) == 0 {
return nil
}
row, err := d.GetActiveScheduleSessionByConversationID(ctx, userID, normalizedConversationID)
if err != nil {
return err
}
if row == nil {
return gorm.ErrRecordNotFound
}
normalizedUpdates := cloneUpdateMap(updates)
if _, ok := normalizedUpdates["updated_at"]; !ok {
normalizedUpdates["updated_at"] = time.Now()
}
return d.db.WithContext(ctx).
Model(&model.ActiveScheduleSession{}).
Where("session_id = ?", row.SessionID).
Updates(normalizedUpdates).Error
}
func normalizeActiveScheduleSessionSnapshot(snapshot *model.ActiveScheduleSessionSnapshot) (*model.ActiveScheduleSessionSnapshot, error) {
if snapshot == nil {
return nil, errors.New("active schedule session snapshot is nil")
}
normalizedSessionID := strings.TrimSpace(snapshot.SessionID)
if normalizedSessionID == "" {
return nil, errors.New("session_id is empty")
}
if snapshot.UserID <= 0 {
return nil, fmt.Errorf("invalid user_id: %d", snapshot.UserID)
}
normalizedStatus, err := normalizeActiveScheduleSessionStatus(snapshot.Status)
if err != nil {
return nil, err
}
normalizedTriggerID := strings.TrimSpace(snapshot.TriggerID)
if normalizedTriggerID == "" {
return nil, errors.New("trigger_id is empty")
}
normalized := *snapshot
normalized.SessionID = normalizedSessionID
normalized.UserID = snapshot.UserID
normalized.ConversationID = strings.TrimSpace(snapshot.ConversationID)
normalized.TriggerID = normalizedTriggerID
normalized.CurrentPreviewID = strings.TrimSpace(snapshot.CurrentPreviewID)
normalized.Status = normalizedStatus
normalized.State = normalizeActiveScheduleSessionState(snapshot.State)
return &normalized, nil
}
func normalizeActiveScheduleSessionStatus(raw string) (string, error) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case model.ActiveScheduleSessionStatusWaitingUserReply:
return model.ActiveScheduleSessionStatusWaitingUserReply, nil
case model.ActiveScheduleSessionStatusRerunning:
return model.ActiveScheduleSessionStatusRerunning, nil
case model.ActiveScheduleSessionStatusReadyPreview:
return model.ActiveScheduleSessionStatusReadyPreview, nil
case model.ActiveScheduleSessionStatusApplied:
return model.ActiveScheduleSessionStatusApplied, nil
case model.ActiveScheduleSessionStatusIgnored:
return model.ActiveScheduleSessionStatusIgnored, nil
case model.ActiveScheduleSessionStatusExpired:
return model.ActiveScheduleSessionStatusExpired, nil
case model.ActiveScheduleSessionStatusFailed:
return model.ActiveScheduleSessionStatusFailed, nil
default:
return "", fmt.Errorf("invalid active schedule session status: %s", raw)
}
}
func normalizeActiveScheduleSessionState(state model.ActiveScheduleSessionState) model.ActiveScheduleSessionState {
state.PendingQuestion = strings.TrimSpace(state.PendingQuestion)
state.LastCandidateID = strings.TrimSpace(state.LastCandidateID)
state.LastNotificationID = strings.TrimSpace(state.LastNotificationID)
state.FailedReason = strings.TrimSpace(state.FailedReason)
if state.ExpiresAt != nil && state.ExpiresAt.IsZero() {
state.ExpiresAt = nil
}
if len(state.MissingInfo) > 0 {
state.MissingInfo = dedupeAndTrimStrings(state.MissingInfo)
}
return state
}
func marshalActiveScheduleSessionState(state model.ActiveScheduleSessionState) (string, error) {
normalized := normalizeActiveScheduleSessionState(state)
raw, err := json.Marshal(normalized)
if err != nil {
return "", err
}
text := strings.TrimSpace(string(raw))
if text == "" {
return "{}", nil
}
return text, nil
}
func unmarshalActiveScheduleSessionState(raw string) (model.ActiveScheduleSessionState, error) {
clean := strings.TrimSpace(raw)
if clean == "" || clean == "null" {
return model.ActiveScheduleSessionState{}, nil
}
var state model.ActiveScheduleSessionState
if err := json.Unmarshal([]byte(clean), &state); err != nil {
return model.ActiveScheduleSessionState{}, err
}
state = normalizeActiveScheduleSessionState(state)
return state, nil
}
func activeScheduleSessionSnapshotFromRow(row *model.ActiveScheduleSession) (*model.ActiveScheduleSessionSnapshot, error) {
if row == nil {
return nil, errors.New("active schedule session row is nil")
}
state, err := unmarshalActiveScheduleSessionState(row.StateJSON)
if err != nil {
return nil, fmt.Errorf("unmarshal active schedule session state failed: %w", err)
}
return &model.ActiveScheduleSessionSnapshot{
SessionID: row.SessionID,
UserID: row.UserID,
ConversationID: nullableStringValue(row.ConversationID),
TriggerID: row.TriggerID,
CurrentPreviewID: nullableStringValue(row.CurrentPreviewID),
Status: row.Status,
State: state,
CreatedAt: row.CreatedAt,
UpdatedAt: row.UpdatedAt,
}, nil
}
func nullableStringPtr(raw string) *string {
normalized := strings.TrimSpace(raw)
if normalized == "" {
return nil
}
return &normalized
}
func nullableStringValue(raw *string) string {
if raw == nil {
return ""
}
return strings.TrimSpace(*raw)
}
func cloneUpdateMap(updates map[string]any) map[string]any {
cloned := make(map[string]any, len(updates)+1)
for key, value := range updates {
cloned[key] = value
}
return cloned
}
func dedupeAndTrimStrings(values []string) []string {
if len(values) == 0 {
return nil
}
result := make([]string, 0, len(values))
seen := make(map[string]struct{}, len(values))
for _, item := range values {
normalized := strings.TrimSpace(item)
if normalized == "" {
continue
}
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
result = append(result, normalized)
}
if len(result) == 0 {
return nil
}
return result
}

View File

@@ -0,0 +1,226 @@
package dao
import (
"context"
"encoding/json"
"fmt"
"strconv"
"time"
"github.com/cloudwego/eino/schema"
"github.com/go-redis/redis/v8"
)
type AgentCache struct {
client *redis.Client
// 默认窗口大小(会被会话级动态窗口覆盖)
windowSize int
// 缓存过期时间
expiration time.Duration
}
const (
minHistoryWindowSize = 16
maxHistoryWindowSize = 4096
)
func NewAgentCache(client *redis.Client) *AgentCache {
return &AgentCache{
client: client,
windowSize: 128,
expiration: 1 * time.Hour,
}
}
func (m *AgentCache) historyKey(sessionID string) string {
return fmt.Sprintf("smartflow:history:%s", sessionID)
}
func (m *AgentCache) historyWindowKey(sessionID string) string {
return fmt.Sprintf("smartflow:history_window:%s", sessionID)
}
func (m *AgentCache) normalizeWindowSize(size int) int {
if size < minHistoryWindowSize {
return minHistoryWindowSize
}
if size > maxHistoryWindowSize {
return maxHistoryWindowSize
}
return size
}
func (m *AgentCache) getSessionWindowSize(ctx context.Context, sessionID string) (int, error) {
windowKey := m.historyWindowKey(sessionID)
val, err := m.client.Get(ctx, windowKey).Result()
if err == redis.Nil {
return m.windowSize, nil
}
if err != nil {
return 0, err
}
size, convErr := strconv.Atoi(val)
if convErr != nil {
return m.windowSize, nil
}
return m.normalizeWindowSize(size), nil
}
// SetSessionWindowSize 设置会话级窗口上限。
func (m *AgentCache) SetSessionWindowSize(ctx context.Context, sessionID string, size int) error {
normalized := m.normalizeWindowSize(size)
windowKey := m.historyWindowKey(sessionID)
return m.client.Set(ctx, windowKey, normalized, m.expiration).Err()
}
// EnforceHistoryWindow 按当前会话窗口强制修剪历史队列。
func (m *AgentCache) EnforceHistoryWindow(ctx context.Context, sessionID string) error {
size, err := m.getSessionWindowSize(ctx, sessionID)
if err != nil {
return err
}
key := m.historyKey(sessionID)
pipe := m.client.Pipeline()
pipe.LTrim(ctx, key, 0, int64(size-1))
pipe.Expire(ctx, key, m.expiration)
_, err = pipe.Exec(ctx)
return err
}
func (m *AgentCache) PushMessage(ctx context.Context, sessionID string, msg *schema.Message) error {
key := m.historyKey(sessionID)
size, err := m.getSessionWindowSize(ctx, sessionID)
if err != nil {
return err
}
// 1. 序列化 Eino 消息。
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshal message failed: %w", err)
}
// 2. 使用 Pipeline 保证“写入+裁剪+续期”原子执行。
pipe := m.client.Pipeline()
pipe.LPush(ctx, key, data)
pipe.LTrim(ctx, key, 0, int64(size-1))
pipe.Expire(ctx, key, m.expiration)
_, err = pipe.Exec(ctx)
return err
}
func (m *AgentCache) GetHistory(ctx context.Context, sessionID string) ([]*schema.Message, error) {
key := m.historyKey(sessionID)
vals, err := m.client.LRange(ctx, key, 0, -1).Result()
if err != nil {
return nil, err
}
if len(vals) == 0 {
return nil, nil
}
messages := make([]*schema.Message, len(vals))
for i, val := range vals {
var msg schema.Message
if err := json.Unmarshal([]byte(val), &msg); err != nil {
return nil, err
}
// LRANGE 返回 [最新...最旧],这里反转成 [最旧...最新]
messages[len(vals)-1-i] = &msg
}
return messages, nil
}
// BackfillHistory 在缓存失效时,把历史消息一次性回填到 Redis。
func (m *AgentCache) BackfillHistory(ctx context.Context, sessionID string, messages []*schema.Message) error {
key := m.historyKey(sessionID)
size, err := m.getSessionWindowSize(ctx, sessionID)
if err != nil {
return err
}
if len(messages) == 0 {
return m.client.Del(ctx, key).Err()
}
values := make([]interface{}, len(messages))
for i, msg := range messages {
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshal failed at index %d: %w", i, err)
}
values[i] = data
}
pipe := m.client.Pipeline()
pipe.Del(ctx, key)
pipe.LPush(ctx, key, values...)
pipe.LTrim(ctx, key, 0, int64(size-1))
pipe.Expire(ctx, key, m.expiration)
_, err = pipe.Exec(ctx)
return err
}
func (m *AgentCache) ClearHistory(ctx context.Context, sessionID string) error {
historyKey := m.historyKey(sessionID)
windowKey := m.historyWindowKey(sessionID)
return m.client.Del(ctx, historyKey, windowKey).Err()
}
func (m *AgentCache) GetConversationStatus(ctx context.Context, sessionID string) (bool, error) {
key := fmt.Sprintf("smartflow:conversation_status:%s", sessionID)
n, err := m.client.Exists(ctx, key).Result()
if err != nil {
return false, err
}
return n == 1, nil
}
func (m *AgentCache) SetConversationStatus(ctx context.Context, sessionID string) error {
key := fmt.Sprintf("smartflow:conversation_status:%s", sessionID)
// 仅用于“存在性”标记:只有不存在时才写入,避免重复写。
return m.client.SetNX(ctx, key, 1, m.expiration).Err()
}
func (m *AgentCache) DeleteConversationStatus(ctx context.Context, sessionID string) error {
key := fmt.Sprintf("smartflow:conversation_status:%s", sessionID)
return m.client.Del(ctx, key).Err()
}
// ---- Compaction 缓存 ----
func (m *AgentCache) compactionKey(chatID string) string {
return fmt.Sprintf("smartflow:compaction:%s", chatID)
}
// SaveCompactionCache 将压缩摘要缓存到 Redis。
func (m *AgentCache) SaveCompactionCache(ctx context.Context, chatID string, summary string, watermark int) error {
key := m.compactionKey(chatID)
data, _ := json.Marshal(map[string]any{
"summary": summary,
"watermark": watermark,
})
return m.client.Set(ctx, key, data, m.expiration).Err()
}
// LoadCompactionCache 从 Redis 读取压缩摘要缓存。
func (m *AgentCache) LoadCompactionCache(ctx context.Context, chatID string) (summary string, watermark int, ok bool, err error) {
key := m.compactionKey(chatID)
val, err := m.client.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return "", 0, false, nil
}
return "", 0, false, err
}
var data struct {
Summary string `json:"summary"`
Watermark int `json:"watermark"`
}
if jsonErr := json.Unmarshal([]byte(val), &data); jsonErr != nil {
return "", 0, false, nil
}
return data.Summary, data.Watermark, true, nil
}

View File

@@ -0,0 +1,483 @@
package dao
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type AgentDAO struct {
db *gorm.DB
}
func NewAgentDAO(db *gorm.DB) *AgentDAO {
return &AgentDAO{db: db}
}
func (r *AgentDAO) WithTx(tx *gorm.DB) *AgentDAO {
return &AgentDAO{db: tx}
}
// saveChatHistoryCore 是"聊天消息落库 + 会话统计更新"的核心实现。
//
// 职责边界:
// 1. 只执行当前 DAO 句柄上的数据库写入动作;
// 2. 不主动开启事务(事务由调用方决定);
// 3. 保证 chat_histories 与 agent_chats.message_count 的一致性口径。
//
// 失败处理:
// 1. 任一步骤失败都返回 error
// 2. 若调用方处于事务中,返回 error 会触发事务回滚。
//
// 关于 retry 字段:
// 1. retry 机制已整体下线,本函数不再写入 retry_group_id / retry_index / retry_from_* 四列;
// 2. 这些列在 GORM ChatHistory 模型上暂时保留,列本身可空,历史数据不受影响;
// 3. Step B 会做 DROP COLUMN 的 migration。
func (a *AgentDAO) saveChatHistoryCore(ctx context.Context, userID int, conversationID string, role, message, reasoningContent string, reasoningDurationSeconds int, tokensConsumed int, sourceEventID string) error {
// 0. token 入库前兜底:负数统一归零,避免异常值污染累计统计。
if tokensConsumed < 0 {
tokensConsumed = 0
}
reasoningContent = strings.TrimSpace(reasoningContent)
if reasoningDurationSeconds < 0 {
reasoningDurationSeconds = 0
}
normalizedEventID := strings.TrimSpace(sourceEventID)
var normalizedEventIDPtr *string
if normalizedEventID != "" {
normalizedEventIDPtr = &normalizedEventID
var chat model.AgentChat
err := a.db.WithContext(ctx).
Clauses(clause.Locking{Strength: "UPDATE"}).
Select("last_history_event_id").
Where("user_id = ? AND chat_id = ?", userID, conversationID).
First(&chat).Error
if err != nil {
return err
}
if chat.LastHistoryEventID != nil && strings.TrimSpace(*chat.LastHistoryEventID) == normalizedEventID {
return nil
}
}
// 1. 先写 chat_histories 原始消息。
var reasoningContentPtr *string
if reasoningContent != "" {
reasoningContentPtr = &reasoningContent
}
userChat := model.ChatHistory{
SourceEventID: normalizedEventIDPtr,
UserID: userID,
MessageContent: &message,
ReasoningContent: reasoningContentPtr,
ReasoningDurationSeconds: reasoningDurationSeconds,
Role: &role,
ChatID: conversationID,
TokensConsumed: tokensConsumed,
}
if err := a.db.WithContext(ctx).Create(&userChat).Error; err != nil {
return err
}
// 2. 再更新会话统计,保证 message_count / tokens_total / last_message_at 同步推进。
now := time.Now()
updates := map[string]interface{}{
"message_count": gorm.Expr("message_count + ?", 1),
"tokens_total": gorm.Expr("tokens_total + ?", tokensConsumed),
"last_message_at": &now,
}
if normalizedEventIDPtr != nil {
updates["last_history_event_id"] = normalizedEventIDPtr
}
result := a.db.WithContext(ctx).Model(&model.AgentChat{}).
Where("user_id = ? AND chat_id = ?", userID, conversationID).
Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("conversation not found when updating stats: user_id=%d chat_id=%s", userID, conversationID)
}
return nil
}
// SaveChatHistoryInTx 在调用方"已开启事务"的场景下写入聊天历史。
//
// 设计目的:
// 1. 给服务层组合多个 DAO 操作时复用,避免嵌套事务;
// 2. 让 outbox 消费处理器可以和业务写入共享同一个 tx。
func (a *AgentDAO) SaveChatHistoryInTx(ctx context.Context, userID int, conversationID string, role, message, reasoningContent string, reasoningDurationSeconds int, tokensConsumed int, sourceEventID string) error {
return a.saveChatHistoryCore(ctx, userID, conversationID, role, message, reasoningContent, reasoningDurationSeconds, tokensConsumed, sourceEventID)
}
// SaveChatHistory 在同步直写路径下写入聊天历史。
//
// 说明:
// 1. 该方法会自行开启事务;
// 2. 内部复用 saveChatHistoryCore确保和 SaveChatHistoryInTx 的业务口径完全一致。
func (a *AgentDAO) SaveChatHistory(ctx context.Context, userID int, conversationID string, role, message, reasoningContent string, reasoningDurationSeconds int, tokensConsumed int, sourceEventID string) error {
return a.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return a.WithTx(tx).saveChatHistoryCore(ctx, userID, conversationID, role, message, reasoningContent, reasoningDurationSeconds, tokensConsumed, sourceEventID)
})
}
// adjustTokenUsageCore 在同一事务语义下做"会话"token 账本增量调整。
//
// 职责边界:
// 1. 只更新 agent_chats.tokens_total
// 2. 不写 chat_histories消息落库由 SaveChatHistory* 路径负责);
// 3. deltaTokens<=0 时视为无操作,直接返回。
func (a *AgentDAO) adjustTokenUsageCore(ctx context.Context, userID int, conversationID string, deltaTokens int, eventID string) error {
if deltaTokens <= 0 {
return nil
}
normalizedEventID := strings.TrimSpace(eventID)
var normalizedEventIDPtr *string
if normalizedEventID != "" {
normalizedEventIDPtr = &normalizedEventID
var chat model.AgentChat
err := a.db.WithContext(ctx).
Clauses(clause.Locking{Strength: "UPDATE"}).
Select("last_token_adjust_event_id").
Where("user_id = ? AND chat_id = ?", userID, conversationID).
First(&chat).Error
if err != nil {
return err
}
if chat.LastTokenAdjustEventID != nil && strings.TrimSpace(*chat.LastTokenAdjustEventID) == normalizedEventID {
return nil
}
}
chatUpdate := a.db.WithContext(ctx).
Model(&model.AgentChat{}).
Where("user_id = ? AND chat_id = ?", userID, conversationID).
Updates(map[string]interface{}{
"tokens_total": gorm.Expr("tokens_total + ?", deltaTokens),
"last_token_adjust_event_id": normalizedEventIDPtr,
})
if chatUpdate.Error != nil {
return chatUpdate.Error
}
if chatUpdate.RowsAffected == 0 {
return fmt.Errorf("conversation not found when adjusting tokens: user_id=%d chat_id=%s", userID, conversationID)
}
return nil
}
// AdjustTokenUsageInTx 在调用方已开启事务时执行 token 账本增量调整。
func (a *AgentDAO) AdjustTokenUsageInTx(ctx context.Context, userID int, conversationID string, deltaTokens int, eventID string) error {
return a.adjustTokenUsageCore(ctx, userID, conversationID, deltaTokens, eventID)
}
// AdjustTokenUsage 在同步路径下执行 token 账本增量调整(内部自带事务)。
func (a *AgentDAO) AdjustTokenUsage(ctx context.Context, userID int, conversationID string, deltaTokens int, eventID string) error {
return a.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return a.WithTx(tx).adjustTokenUsageCore(ctx, userID, conversationID, deltaTokens, eventID)
})
}
func (a *AgentDAO) CreateNewChat(userID int, chatID string) (int64, error) {
chat := model.AgentChat{
ChatID: chatID,
UserID: userID,
MessageCount: 0,
LastMessageAt: nil,
}
if err := a.db.Create(&chat).Error; err != nil {
return 0, err
}
return chat.ID, nil
}
func (a *AgentDAO) GetUserChatHistories(ctx context.Context, userID, limit int, chatID string) ([]model.ChatHistory, error) {
var histories []model.ChatHistory
err := a.db.WithContext(ctx).
Where("user_id = ? AND chat_id = ?", userID, chatID).
Order("created_at desc").
Limit(limit).
Find(&histories).Error
if err != nil {
return nil, err
}
// 保留"最近 N 条"后,反转成时间正序,方便模型消费。
for i, j := 0, len(histories)-1; i < j; i, j = i+1, j-1 {
histories[i], histories[j] = histories[j], histories[i]
}
return histories, nil
}
func (a *AgentDAO) IfChatExists(ctx context.Context, userID int, chatID string) (bool, error) {
var chat model.AgentChat
err := a.db.WithContext(ctx).Where("user_id = ? AND chat_id = ?", userID, chatID).First(&chat).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
return false, err
}
return true, nil
}
// GetConversationMeta 查询单个会话元信息。
func (a *AgentDAO) GetConversationMeta(ctx context.Context, userID int, chatID string) (*model.AgentChat, error) {
var chat model.AgentChat
err := a.db.WithContext(ctx).
Select("chat_id", "title", "message_count", "last_message_at", "status").
Where("user_id = ? AND chat_id = ?", userID, chatID).
First(&chat).Error
if err != nil {
return nil, err
}
return &chat, nil
}
// GetConversationTitle 读取当前会话标题。
func (a *AgentDAO) GetConversationTitle(ctx context.Context, userID int, chatID string) (title string, exists bool, err error) {
var chat model.AgentChat
queryErr := a.db.WithContext(ctx).
Select("title").
Where("user_id = ? AND chat_id = ?", userID, chatID).
First(&chat).Error
if queryErr != nil {
if errors.Is(queryErr, gorm.ErrRecordNotFound) {
return "", false, nil
}
return "", false, queryErr
}
if chat.Title == nil {
return "", true, nil
}
return strings.TrimSpace(*chat.Title), true, nil
}
// UpdateConversationTitleIfEmpty 仅在标题为空时更新会话标题。
func (a *AgentDAO) UpdateConversationTitleIfEmpty(ctx context.Context, userID int, chatID, title string) error {
normalized := strings.TrimSpace(title)
if normalized == "" {
return nil
}
return a.db.WithContext(ctx).
Model(&model.AgentChat{}).
Where("user_id = ? AND chat_id = ? AND (title IS NULL OR title = '')", userID, chatID).
Update("title", normalized).Error
}
// GetConversationList 按分页查询指定用户的会话列表。
//
// 职责边界:
// 1. 只负责读库,不负责缓存;
// 2. 只负责 user_id 数据隔离,不负责参数合法性兜底(由 service 负责);
// 3. 返回总数 total 供上层计算 has_more。
func (a *AgentDAO) GetConversationList(ctx context.Context, userID, page, pageSize int, status string) ([]model.AgentChat, int64, error) {
// 1. 先构造统一过滤条件,保证 total 与 list 的统计口径一致。
baseQuery := a.db.WithContext(ctx).Model(&model.AgentChat{}).Where("user_id = ?", userID)
if strings.TrimSpace(status) != "" {
baseQuery = baseQuery.Where("status = ?", status)
}
// 2. 先查总条数,给前端分页器提供完整元信息。
var total int64
if err := baseQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
if total == 0 {
return make([]model.AgentChat, 0), 0, nil
}
// 3. 再查当前页数据:
// 3.1 按最近消息时间倒序,保证"最近活跃"优先展示;
// 3.2 同时间戳下按 id 倒序,避免翻页时顺序抖动。
offset := (page - 1) * pageSize
var chats []model.AgentChat
query := a.db.WithContext(ctx).
Model(&model.AgentChat{}).
Select("id", "chat_id", "title", "message_count", "last_message_at", "status", "created_at").
Where("user_id = ?", userID)
if strings.TrimSpace(status) != "" {
query = query.Where("status = ?", status)
}
if err := query.Order("last_message_at DESC").
Order("id DESC").
Offset(offset).
Limit(pageSize).
Find(&chats).Error; err != nil {
return nil, 0, err
}
return chats, total, nil
}
// ---- 压缩摘要持久化 ----
//
// 1. 旧接口 SaveCompaction / LoadCompaction 继续保留,默认只读写 execute 阶段。
// 2. 新接口按 stageKey 分桶读写,数据仍然落在 agent_chats.compaction_summary。
// 3. 为兼容历史数据,若 compaction_summary 仍是旧字符串格式,则自动回退读取。
func (a *AgentDAO) SaveCompaction(ctx context.Context, userID int, chatID string, summary string, watermark int) error {
return a.SaveStageCompaction(ctx, userID, chatID, "execute", summary, watermark)
}
func (a *AgentDAO) LoadCompaction(ctx context.Context, userID int, chatID string) (summary string, watermark int, err error) {
return a.LoadStageCompaction(ctx, userID, chatID, "execute")
}
// SaveContextTokenStats 保存上下文窗口 token 分布统计。
func (a *AgentDAO) SaveContextTokenStats(ctx context.Context, userID int, chatID string, statsJSON string) error {
return a.db.WithContext(ctx).
Model(&model.AgentChat{}).
Where("user_id = ? AND chat_id = ?", userID, chatID).
Update("context_token_stats", statsJSON).Error
}
// LoadContextTokenStats 读取上下文窗口 token 分布统计。
func (a *AgentDAO) LoadContextTokenStats(ctx context.Context, userID int, chatID string) (string, error) {
var chat model.AgentChat
err := a.db.WithContext(ctx).
Select("context_token_stats").
Where("user_id = ? AND chat_id = ?", userID, chatID).
First(&chat).Error
if err != nil {
return "", err
}
if chat.ContextTokenStats != nil {
return *chat.ContextTokenStats, nil
}
return "", nil
}
type stageCompactionRecord struct {
Summary string `json:"summary"`
Watermark int `json:"watermark"`
}
type stageCompactionEnvelope struct {
Version int `json:"version"`
Stages map[string]stageCompactionRecord `json:"stages"`
}
// normalizeCompactionStageKey 统一 stageKey 的写法,避免 "Execute" 和 "execute" 被当成两个键。
func normalizeCompactionStageKey(stageKey string) string {
key := strings.ToLower(strings.TrimSpace(stageKey))
if key == "" {
return "execute"
}
return key
}
// loadStageCompactionStages 负责把数据库里的压缩摘要统一解包成 stage -> record。
//
// 1. 先处理空值,避免后续逻辑误判。
// 2. 如果已经是 JSON envelope就按 stage 逐项读取。
// 3. 如果还是旧版纯字符串,就把它当作 execute 阶段的兼容数据。
func loadStageCompactionStages(summary *string, watermark int) map[string]stageCompactionRecord {
stages := map[string]stageCompactionRecord{}
if summary == nil {
return stages
}
raw := strings.TrimSpace(*summary)
if raw == "" {
return stages
}
var env stageCompactionEnvelope
if err := json.Unmarshal([]byte(raw), &env); err == nil && len(env.Stages) > 0 {
for key, record := range env.Stages {
stages[normalizeCompactionStageKey(key)] = stageCompactionRecord{
Summary: strings.TrimSpace(record.Summary),
Watermark: record.Watermark,
}
}
return stages
}
stages["execute"] = stageCompactionRecord{
Summary: raw,
Watermark: watermark,
}
return stages
}
// marshalStageCompactionStages 负责把按阶段分桶后的摘要重新编码为 JSON envelope。
func marshalStageCompactionStages(stages map[string]stageCompactionRecord) (string, error) {
env := stageCompactionEnvelope{
Version: 1,
Stages: stages,
}
data, err := json.Marshal(env)
if err != nil {
return "", err
}
return string(data), nil
}
// LoadStageCompaction 按 stageKey 读取压缩摘要和水位线。
func (a *AgentDAO) LoadStageCompaction(ctx context.Context, userID int, chatID string, stageKey string) (summary string, watermark int, err error) {
stageKey = normalizeCompactionStageKey(stageKey)
var chat model.AgentChat
err = a.db.WithContext(ctx).
Select("compaction_summary", "compaction_watermark").
Where("user_id = ? AND chat_id = ?", userID, chatID).
First(&chat).Error
if err != nil {
return "", 0, err
}
stages := loadStageCompactionStages(chat.CompactionSummary, chat.CompactionWatermark)
if record, ok := stages[stageKey]; ok {
return record.Summary, record.Watermark, nil
}
return "", 0, nil
}
// SaveStageCompaction 按 stageKey 保存压缩摘要和水位线。
//
// 1. 先读取现有摘要,避免覆盖其他阶段已经写入的数据。
// 2. 再更新当前阶段对应的分桶内容。
// 3. 最后整体回写 JSON envelope并保留 execute 阶段的 legacy watermark 兼容字段。
func (a *AgentDAO) SaveStageCompaction(ctx context.Context, userID int, chatID string, stageKey string, summary string, watermark int) error {
stageKey = normalizeCompactionStageKey(stageKey)
var chat model.AgentChat
err := a.db.WithContext(ctx).
Select("compaction_summary", "compaction_watermark").
Where("user_id = ? AND chat_id = ?", userID, chatID).
First(&chat).Error
if err != nil {
return err
}
stages := loadStageCompactionStages(chat.CompactionSummary, chat.CompactionWatermark)
stages[stageKey] = stageCompactionRecord{
Summary: strings.TrimSpace(summary),
Watermark: watermark,
}
payload, err := marshalStageCompactionStages(stages)
if err != nil {
return err
}
legacyWatermark := watermark
if executeRecord, ok := stages["execute"]; ok {
legacyWatermark = executeRecord.Watermark
}
return a.db.WithContext(ctx).
Model(&model.AgentChat{}).
Where("user_id = ? AND chat_id = ?", userID, chatID).
Updates(map[string]any{
"compaction_summary": payload,
"compaction_watermark": legacyWatermark,
}).Error
}

View File

@@ -0,0 +1,252 @@
package dao
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// UpsertScheduleStateSnapshot 以“user_id + conversation_id”维度写入/覆盖排程状态快照。
//
// 职责边界:
// 1. 负责把强类型快照序列化并持久化到 agent_schedule_states
// 2. 负责 upsert 冲突更新(同会话覆盖),并自动 revision+1
// 3. 不负责 Redis 缓存读写,不负责业务分流,不负责正式日程落库。
//
// 步骤化说明:
// 1. 先做参数与主键语义校验,避免把脏快照写入数据库;
// 2. 再把切片字段统一序列化为 JSON保证表内口径稳定
// 3. 最后执行 OnConflict upsert
// 3.1 新记录直接插入;
// 3.2 已存在记录则覆盖业务字段,并把 revision 自增;
// 3.3 任一阶段失败都返回 error由上层决定是否降级。
func (a *AgentDAO) UpsertScheduleStateSnapshot(ctx context.Context, snapshot *model.SchedulePlanStateSnapshot) error {
if a == nil || a.db == nil {
return errors.New("agent dao is not initialized")
}
if snapshot == nil {
return errors.New("schedule state snapshot is nil")
}
if snapshot.UserID <= 0 {
return fmt.Errorf("invalid snapshot user_id: %d", snapshot.UserID)
}
conversationID := strings.TrimSpace(snapshot.ConversationID)
if conversationID == "" {
return errors.New("schedule state snapshot conversation_id is empty")
}
taskClassIDsJSON, err := marshalJSONOrDefault(snapshot.TaskClassIDs, "[]")
if err != nil {
return fmt.Errorf("marshal task_class_ids failed: %w", err)
}
constraintsJSON, err := marshalJSONOrDefault(snapshot.Constraints, "[]")
if err != nil {
return fmt.Errorf("marshal constraints failed: %w", err)
}
hybridEntriesJSON, err := marshalJSONOrDefault(snapshot.HybridEntries, "[]")
if err != nil {
return fmt.Errorf("marshal hybrid_entries failed: %w", err)
}
allocatedItemsJSON, err := marshalJSONOrDefault(snapshot.AllocatedItems, "[]")
if err != nil {
return fmt.Errorf("marshal allocated_items failed: %w", err)
}
candidatePlansJSON, err := marshalJSONOrDefault(snapshot.CandidatePlans, "[]")
if err != nil {
return fmt.Errorf("marshal candidate_plans failed: %w", err)
}
stateVersion := snapshot.StateVersion
if stateVersion <= 0 {
stateVersion = model.SchedulePlanStateVersionV1
}
revision := snapshot.Revision
if revision <= 0 {
revision = 1
}
row := model.AgentScheduleState{
UserID: snapshot.UserID,
ConversationID: conversationID,
Revision: revision,
StateVersion: stateVersion,
TaskClassIDsJSON: taskClassIDsJSON,
ConstraintsJSON: constraintsJSON,
HybridEntriesJSON: hybridEntriesJSON,
AllocatedItemsJSON: allocatedItemsJSON,
CandidatePlansJSON: candidatePlansJSON,
UserIntent: strings.TrimSpace(snapshot.UserIntent),
Strategy: normalizeStrategy(snapshot.Strategy),
AdjustmentScope: normalizeAdjustmentScope(snapshot.AdjustmentScope),
RestartRequested: snapshot.RestartRequested,
FinalSummary: strings.TrimSpace(snapshot.FinalSummary),
Completed: snapshot.Completed,
TraceID: strings.TrimSpace(snapshot.TraceID),
}
now := time.Now()
return a.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "user_id"},
{Name: "conversation_id"},
},
DoUpdates: clause.Assignments(map[string]any{
"revision": gorm.Expr("revision + 1"),
"state_version": row.StateVersion,
"task_class_ids": row.TaskClassIDsJSON,
"constraints": row.ConstraintsJSON,
"hybrid_entries": row.HybridEntriesJSON,
"allocated_items": row.AllocatedItemsJSON,
"candidate_plans": row.CandidatePlansJSON,
"user_intent": row.UserIntent,
"strategy": row.Strategy,
"adjustment_scope": row.AdjustmentScope,
"restart_requested": row.RestartRequested,
"final_summary": row.FinalSummary,
"completed": row.Completed,
"trace_id": row.TraceID,
"updated_at": now,
}),
}).Create(&row).Error
}
// GetScheduleStateSnapshot 读取指定会话的排程状态快照。
//
// 职责边界:
// 1. 负责按 user_id + conversation_id 查询快照;
// 2. 负责把数据库 JSON 字段反序列化回强类型结构;
// 3. 不负责回填 Redis不负责业务分流判定。
//
// 返回语义:
// 1. 命中:返回 snapshot, nil
// 2. 未命中:返回 nil, nil上层可继续走其他兜底
// 3. 反序列化失败:返回 error说明库内数据不合法需要排障
func (a *AgentDAO) GetScheduleStateSnapshot(ctx context.Context, userID int, conversationID string) (*model.SchedulePlanStateSnapshot, error) {
if a == nil || a.db == nil {
return nil, errors.New("agent dao is not initialized")
}
if userID <= 0 {
return nil, fmt.Errorf("invalid user_id: %d", userID)
}
normalizedConversationID := strings.TrimSpace(conversationID)
if normalizedConversationID == "" {
return nil, errors.New("conversation_id is empty")
}
var row model.AgentScheduleState
err := a.db.WithContext(ctx).
Where("user_id = ? AND conversation_id = ?", userID, normalizedConversationID).
First(&row).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
taskClassIDs := make([]int, 0)
if err = unmarshalJSONOrDefault(row.TaskClassIDsJSON, &taskClassIDs, []int{}); err != nil {
return nil, fmt.Errorf("unmarshal task_class_ids failed: %w", err)
}
constraints := make([]string, 0)
if err = unmarshalJSONOrDefault(row.ConstraintsJSON, &constraints, []string{}); err != nil {
return nil, fmt.Errorf("unmarshal constraints failed: %w", err)
}
hybridEntries := make([]model.HybridScheduleEntry, 0)
if err = unmarshalJSONOrDefault(row.HybridEntriesJSON, &hybridEntries, []model.HybridScheduleEntry{}); err != nil {
return nil, fmt.Errorf("unmarshal hybrid_entries failed: %w", err)
}
allocatedItems := make([]model.TaskClassItem, 0)
if err = unmarshalJSONOrDefault(row.AllocatedItemsJSON, &allocatedItems, []model.TaskClassItem{}); err != nil {
return nil, fmt.Errorf("unmarshal allocated_items failed: %w", err)
}
candidatePlans := make([]model.UserWeekSchedule, 0)
if err = unmarshalJSONOrDefault(row.CandidatePlansJSON, &candidatePlans, []model.UserWeekSchedule{}); err != nil {
return nil, fmt.Errorf("unmarshal candidate_plans failed: %w", err)
}
return &model.SchedulePlanStateSnapshot{
UserID: row.UserID,
ConversationID: row.ConversationID,
Revision: row.Revision,
StateVersion: row.StateVersion,
TaskClassIDs: taskClassIDs,
Constraints: constraints,
HybridEntries: hybridEntries,
AllocatedItems: allocatedItems,
CandidatePlans: candidatePlans,
UserIntent: row.UserIntent,
Strategy: normalizeStrategy(row.Strategy),
AdjustmentScope: normalizeAdjustmentScope(row.AdjustmentScope),
RestartRequested: row.RestartRequested,
FinalSummary: row.FinalSummary,
Completed: row.Completed,
TraceID: row.TraceID,
UpdatedAt: row.UpdatedAt,
}, nil
}
// marshalJSONOrDefault 统一处理“结构体 -> JSON 字符串”序列化。
//
// 设计目的:
// 1. 避免每个字段手写重复的 marshal 判空逻辑;
// 2. nil 场景统一写成默认 JSON例如 [])以保持数据库口径稳定;
// 3. 序列化失败直接上抛,防止写入半成品快照。
func marshalJSONOrDefault(v any, defaultJSON string) (string, error) {
if v == nil {
return defaultJSON, nil
}
raw, err := json.Marshal(v)
if err != nil {
return "", err
}
text := strings.TrimSpace(string(raw))
if text == "" || text == "null" {
return defaultJSON, nil
}
return text, nil
}
// unmarshalJSONOrDefault 统一处理“JSON 字符串 -> 结构体”反序列化。
//
// 设计目的:
// 1. 数据为空、null 时回落到默认值,避免上层到处判空;
// 2. 保留错误上抛,便于定位历史脏数据;
// 3. 保障读取到的快照字段始终有确定值语义。
func unmarshalJSONOrDefault[T any](raw string, target *T, defaultValue T) error {
clean := strings.TrimSpace(raw)
if clean == "" || clean == "null" {
*target = defaultValue
return nil
}
return json.Unmarshal([]byte(clean), target)
}
// normalizeStrategy 归一化快照中的 strategy 字段。
func normalizeStrategy(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "rapid":
return "rapid"
default:
return "steady"
}
}
// normalizeAdjustmentScope 归一化快照中的微调力度字段。
func normalizeAdjustmentScope(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "small":
return "small"
case "medium":
return "medium"
default:
return "large"
}
}

View File

@@ -0,0 +1,53 @@
package dao
import (
"context"
"errors"
agentmodel "github.com/LoveLosita/smartflow/backend/services/agent/model"
)
// AgentStateStoreAdapter 将 CacheDAO 适配为 agent 的 AgentStateStore 接口。
//
// 职责边界:
// 1. CacheDAO 的 LoadAgentState 使用 out-parameter 模式,需要适配到返回值模式;
// 2. CacheDAO 的 SaveAgentState 接受 any需要适配到 *AgentStateSnapshot
// 3. DeleteAgentState 签名已匹配,直接转发。
type AgentStateStoreAdapter struct {
cache *CacheDAO
}
// NewAgentStateStoreAdapter 创建适配器。
func NewAgentStateStoreAdapter(cache *CacheDAO) *AgentStateStoreAdapter {
return &AgentStateStoreAdapter{cache: cache}
}
// Save 序列化并保存 agent 状态快照。
func (a *AgentStateStoreAdapter) Save(ctx context.Context, conversationID string, snapshot *agentmodel.AgentStateSnapshot) error {
if a == nil || a.cache == nil {
return errors.New("agent state store adapter is not initialized")
}
return a.cache.SaveAgentState(ctx, conversationID, snapshot)
}
// Load 读取并反序列化 agent 状态快照。
func (a *AgentStateStoreAdapter) Load(ctx context.Context, conversationID string) (*agentmodel.AgentStateSnapshot, bool, error) {
if a == nil || a.cache == nil {
return nil, false, errors.New("agent state store adapter is not initialized")
}
var snapshot agentmodel.AgentStateSnapshot
ok, err := a.cache.LoadAgentState(ctx, conversationID, &snapshot)
if err != nil || !ok {
return nil, ok, err
}
return &snapshot, true, nil
}
// Delete 删除 agent 状态快照。
func (a *AgentStateStoreAdapter) Delete(ctx context.Context, conversationID string) error {
if a == nil || a.cache == nil {
return errors.New("agent state store adapter is not initialized")
}
return a.cache.DeleteAgentState(ctx, conversationID)
}

View File

@@ -0,0 +1,86 @@
package dao
import (
"context"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
)
// SaveConversationTimelineEvent 持久化单条会话时间线事件到 MySQL。
//
// 职责边界:
// 1. 只做单条写入,不负责 seq 分配;
// 2. 只保证字段标准化(去空格、空值置 nil不做业务语义修正
// 3. 返回 error 让上层决定是否中断当前链路。
func (a *AgentDAO) SaveConversationTimelineEvent(ctx context.Context, payload model.ChatTimelinePersistPayload) (int64, *time.Time, error) {
normalizedChatID := strings.TrimSpace(payload.ConversationID)
normalizedKind := strings.TrimSpace(payload.Kind)
normalizedRole := strings.TrimSpace(payload.Role)
normalizedContent := strings.TrimSpace(payload.Content)
normalizedPayloadJSON := strings.TrimSpace(payload.PayloadJSON)
var rolePtr *string
if normalizedRole != "" {
rolePtr = &normalizedRole
}
var contentPtr *string
if normalizedContent != "" {
contentPtr = &normalizedContent
}
var payloadPtr *string
if normalizedPayloadJSON != "" {
payloadPtr = &normalizedPayloadJSON
}
event := model.AgentTimelineEvent{
UserID: payload.UserID,
ChatID: normalizedChatID,
Seq: payload.Seq,
Kind: normalizedKind,
Role: rolePtr,
Content: contentPtr,
Payload: payloadPtr,
TokensConsumed: payload.TokensConsumed,
}
if err := a.db.WithContext(ctx).Create(&event).Error; err != nil {
return 0, nil, err
}
return event.ID, event.CreatedAt, nil
}
// ListConversationTimelineEvents 查询会话时间线,按 seq 正序返回。
func (a *AgentDAO) ListConversationTimelineEvents(ctx context.Context, userID int, chatID string) ([]model.AgentTimelineEvent, error) {
normalizedChatID := strings.TrimSpace(chatID)
var events []model.AgentTimelineEvent
err := a.db.WithContext(ctx).
Where("user_id = ? AND chat_id = ?", userID, normalizedChatID).
Order("seq ASC").
Order("id ASC").
Find(&events).Error
if err != nil {
return nil, err
}
return events, nil
}
// GetConversationTimelineMaxSeq 返回会话时间线当前最大 seq。
//
// 说明:
// 1. 该方法主要用于 Redis 顺序号不可用时的 DB 兜底;
// 2. 无记录时返回 0不视为错误
// 3. 上层需要自行 +1 后再写入新事件。
func (a *AgentDAO) GetConversationTimelineMaxSeq(ctx context.Context, userID int, chatID string) (int64, error) {
normalizedChatID := strings.TrimSpace(chatID)
var maxSeq int64
err := a.db.WithContext(ctx).
Model(&model.AgentTimelineEvent{}).
Where("user_id = ? AND chat_id = ?", userID, normalizedChatID).
Select("COALESCE(MAX(seq), 0)").
Scan(&maxSeq).Error
if err != nil {
return 0, err
}
return maxSeq, nil
}

View File

@@ -0,0 +1,64 @@
package dao
import (
"context"
"gorm.io/gorm"
)
// RepoManager 聚合所有 DAO供服务层做跨仓储事务编排。
type RepoManager struct {
db *gorm.DB
Schedule *ScheduleDAO
Task *TaskDAO
Course *CourseDAO
TaskClass *TaskClassDAO
Agent *AgentDAO
ActiveSchedule *ActiveScheduleDAO
ActiveScheduleSession *ActiveScheduleSessionDAO
}
func NewManager(db *gorm.DB) *RepoManager {
return &RepoManager{
db: db,
Schedule: NewScheduleDAO(db),
Task: NewTaskDAO(db),
Course: NewCourseDAO(db),
TaskClass: NewTaskClassDAO(db),
Agent: NewAgentDAO(db),
ActiveSchedule: NewActiveScheduleDAO(db),
ActiveScheduleSession: NewActiveScheduleSessionDAO(db),
}
}
// WithTx 基于外部事务句柄构造“同事务 RepoManager”。
//
// 职责边界:
// 1. 只做 DAO 依赖重绑定,不开启/提交/回滚事务;
// 2. 让服务层在一个 tx 内调用多个 DAO 方法;
// 3. 适用于 outbox 消费处理器这类“基础设施事务 + 业务事务合并”的场景。
func (m *RepoManager) WithTx(tx *gorm.DB) *RepoManager {
return &RepoManager{
db: tx,
Schedule: m.Schedule.WithTx(tx),
Task: m.Task.WithTx(tx),
TaskClass: m.TaskClass.WithTx(tx),
Course: m.Course.WithTx(tx),
Agent: m.Agent.WithTx(tx),
ActiveSchedule: m.ActiveSchedule.WithTx(tx),
ActiveScheduleSession: m.ActiveScheduleSession.WithTx(tx),
}
}
// Transaction 开启事务并把“同事务 RepoManager”传给回调。
//
// 使用约束:
// 1. 回调里应只使用 txM 下挂 DAO避免混入事务外句柄
// 2. 回调返回 error 会触发整体回滚;
// 3. 回调返回 nil 表示提交事务。
func (m *RepoManager) Transaction(ctx context.Context, fn func(txM *RepoManager) error) error {
return m.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
txM := m.WithTx(tx)
return fn(txM)
})
}

View File

@@ -0,0 +1,823 @@
package dao
import (
"context"
"encoding/json"
"errors"
"fmt"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/go-redis/redis/v8"
)
type CacheDAO struct {
client *redis.Client
}
func NewCacheDAO(client *redis.Client) *CacheDAO {
return &CacheDAO{client: client}
}
func (d *CacheDAO) schedulePreviewKey(userID int, conversationID string) string {
return fmt.Sprintf("smartflow:schedule_preview:u:%d:c:%s", userID, conversationID)
}
func (d *CacheDAO) conversationTimelineKey(userID int, conversationID string) string {
return fmt.Sprintf("smartflow:conversation_timeline:u:%d:c:%s", userID, conversationID)
}
func (d *CacheDAO) conversationTimelineSeqKey(userID int, conversationID string) string {
return fmt.Sprintf("smartflow:conversation_timeline_seq:u:%d:c:%s", userID, conversationID)
}
func (d *CacheDAO) AddTaskClassList(ctx context.Context, userID int, list *model.UserGetTaskClassesResponse) error {
// 1. 定义 Key使用 userID 隔离不同用户的数据。
key := fmt.Sprintf("smartflow:task_classes:%d", userID)
// 2. 序列化:将结构体转为 []byte。
data, err := json.Marshal(list)
if err != nil {
return err
}
// 3. 存储:设置 30 分钟过期,可按业务需要调整。
return d.client.Set(ctx, key, data, 30*time.Minute).Err()
}
func (d *CacheDAO) GetTaskClassList(ctx context.Context, userID int) (*model.UserGetTaskClassesResponse, error) {
key := fmt.Sprintf("smartflow:task_classes:%d", userID)
var resp model.UserGetTaskClassesResponse
// 1. 从 Redis 获取字符串。
val, err := d.client.Get(ctx, key).Result()
if err != nil {
// 注意:若是 redis.Nil则交给 Service 层处理回源查询逻辑。
return &resp, err
}
// 2. 反序列化:将 JSON 还原回结构体。
err = json.Unmarshal([]byte(val), &resp)
return &resp, err
}
func (d *CacheDAO) DeleteTaskClassList(ctx context.Context, userID int) error {
key := fmt.Sprintf("smartflow:task_classes:%d", userID)
return d.client.Del(ctx, key).Err()
}
func (d *CacheDAO) GetRecord(ctx context.Context, key string) (string, error) {
val, err := d.client.Get(ctx, key).Result()
if errors.Is(err, redis.Nil) {
return "", nil // 正常未命中
}
return val, err // 真正的 Redis 错误
}
func (d *CacheDAO) SaveRecord(ctx context.Context, key string, val string, ttl time.Duration) error {
return d.client.Set(ctx, key, val, ttl).Err()
}
func (d *CacheDAO) AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
return d.client.SetNX(ctx, key, "processing", ttl).Result()
}
func (d *CacheDAO) ReleaseLock(ctx context.Context, key string) error {
return d.client.Del(ctx, key).Err()
}
// GetUserTasksFromCache 读取用户任务缓存(内部模型版本)。
//
// 职责边界:
// 1. 负责从 Redis 读取 `[]model.Task`,供 Service 层做“读时派生优先级”;
// 2. 不负责把模型转换成对外 DTO该职责在 conv 层);
// 3. 不负责缓存回填和缓存失效(回填由 Service 控制,失效由 GORM cache_deleter 统一处理)。
//
// 输入输出语义:
// 1. 命中缓存时返回任务模型切片与 nil error
// 2. 未命中时返回 redis.Nil由上层决定是否回源 DB
// 3. 反序列化失败时返回 error避免把损坏缓存继续向后传播。
func (d *CacheDAO) GetUserTasksFromCache(ctx context.Context, userID int) ([]model.Task, error) {
key := fmt.Sprintf("smartflow:tasks:%d", userID)
var tasks []model.Task
val, err := d.client.Get(ctx, key).Result()
if err != nil {
return nil, err // 注意:若是 redis.Nil则交给 Service 层处理回源查询逻辑
}
err = json.Unmarshal([]byte(val), &tasks)
return tasks, err
}
// SetUserTasksToCache 写入用户任务缓存(内部模型版本)。
//
// 职责边界:
// 1. 负责把 DB 读取到的原始 `[]model.Task` 写入缓存;
// 2. 不负责对任务做“紧急性平移派生”,避免把派生结果写回缓存导致后续无法继续触发异步平移;
// 3. 不负责缓存删除,删除策略由 cache_deleter 在写库后触发。
//
// 步骤说明:
// 1. 先把模型序列化为 JSON确保 `urgency_threshold_at` 等字段完整保留;
// 2. 再写入固定 TTL 缓存,命中后可减少 DB 读取压力;
// 3. 若序列化失败立即返回 error避免写入半结构化垃圾数据。
func (d *CacheDAO) SetUserTasksToCache(ctx context.Context, userID int, tasks []model.Task) error {
key := fmt.Sprintf("smartflow:tasks:%d", userID)
data, err := json.Marshal(tasks)
if err != nil {
return err
}
return d.client.Set(ctx, key, data, 24*time.Hour).Err()
}
func (d *CacheDAO) DeleteUserTasksFromCache(ctx context.Context, userID int) error {
key := fmt.Sprintf("smartflow:tasks:%d", userID)
return d.client.Del(ctx, key).Err()
}
func (d *CacheDAO) GetUserTodayScheduleFromCache(ctx context.Context, userID int) ([]model.UserTodaySchedule, error) {
key := fmt.Sprintf("smartflow:today_schedule:%d", userID)
var schedules []model.UserTodaySchedule
val, err := d.client.Get(ctx, key).Result()
if err != nil {
return nil, err // 注意:若是 redis.Nil则交给 Service 层处理回源查询逻辑
}
err = json.Unmarshal([]byte(val), &schedules)
return schedules, err
}
func (d *CacheDAO) SetUserTodayScheduleToCache(ctx context.Context, userID int, schedules []model.UserTodaySchedule) error {
key := fmt.Sprintf("smartflow:today_schedule:%d", userID)
data, err := json.Marshal(schedules)
if err != nil {
return err
}
// 设置过期时间为“当天剩余时间”,保证每天自然刷新一次缓存。
return d.client.Set(ctx, key, data, time.Until(time.Date(time.Now().Year(), time.Now().Month(), time.Now().Day()+1, 0, 0, 0, 0, time.Now().Location()))).Err()
}
func (d *CacheDAO) DeleteUserTodayScheduleFromCache(ctx context.Context, userID int) error {
key := fmt.Sprintf("smartflow:today_schedule:%d", userID)
return d.client.Del(ctx, key).Err()
}
func (d *CacheDAO) GetUserWeeklyScheduleFromCache(ctx context.Context, userID int, week int) (*model.UserWeekSchedule, error) {
key := fmt.Sprintf("smartflow:weekly_schedule:%d:%d", userID, week)
var schedules model.UserWeekSchedule
val, err := d.client.Get(ctx, key).Result()
if err != nil {
return nil, err // 注意:若是 redis.Nil则交给 Service 层处理回源查询逻辑
}
err = json.Unmarshal([]byte(val), &schedules)
return &schedules, err
}
func (d *CacheDAO) SetUserWeeklyScheduleToCache(ctx context.Context, userID int, schedules *model.UserWeekSchedule) error {
key := fmt.Sprintf("smartflow:weekly_schedule:%d:%d", userID, schedules.Week)
data, err := json.Marshal(schedules)
if err != nil {
return err
}
// 设置过期时间为一天。
return d.client.Set(ctx, key, data, 24*time.Hour).Err()
}
func (d *CacheDAO) DeleteUserWeeklyScheduleFromCache(ctx context.Context, userID int, week int) error {
key := fmt.Sprintf("smartflow:weekly_schedule:%d:%d", userID, week)
return d.client.Del(ctx, key).Err()
}
func (d *CacheDAO) GetUserRecentCompletedSchedulesFromCache(ctx context.Context, userID, index, limit int) (*model.UserRecentCompletedScheduleResponse, error) {
key := fmt.Sprintf("smartflow:recent_completed_schedules:%d:%d:%d", userID, index, limit)
var resp model.UserRecentCompletedScheduleResponse
val, err := d.client.Get(ctx, key).Result()
if err != nil {
return &resp, err // 注意:若是 redis.Nil则交给 Service 层处理回源查询逻辑
}
err = json.Unmarshal([]byte(val), &resp)
return &resp, err
}
func (d *CacheDAO) SetUserRecentCompletedSchedulesToCache(ctx context.Context, userID, index, limit int, resp *model.UserRecentCompletedScheduleResponse) error {
key := fmt.Sprintf("smartflow:recent_completed_schedules:%d:%d:%d", userID, index, limit)
data, err := json.Marshal(resp)
if err != nil {
return err
}
// 设置过期时间为 30 分钟。
return d.client.Set(ctx, key, data, 30*time.Minute).Err()
}
func (d *CacheDAO) DeleteUserRecentCompletedSchedulesFromCache(ctx context.Context, userID int) error {
pattern := fmt.Sprintf("smartflow:recent_completed_schedules:%d:*", userID)
var cursor uint64
for {
keys, next, err := d.client.Scan(ctx, cursor, pattern, 500).Result()
if err != nil {
return err
}
if len(keys) > 0 {
// 使用 UNLINK() 异步删除,降低阻塞风险;若需要强一致删除可改用 Del()。
if err := d.client.Unlink(ctx, keys...).Err(); err != nil {
return err
}
}
cursor = next
if cursor == 0 {
break
}
}
return nil
}
func (d *CacheDAO) GetUserOngoingScheduleFromCache(ctx context.Context, userID int) (*model.OngoingSchedule, error) {
key := fmt.Sprintf("smartflow:ongoing_schedule:%d", userID)
var schedule model.OngoingSchedule
val, err := d.client.Get(ctx, key).Result()
if err != nil {
return &schedule, err // 注意:若是 redis.Nil则交给 Service 层处理回源查询逻辑
}
if val == "null" {
return nil, nil // 之前缓存过“当前没有正在进行的日程”,这里直接返回 nil
}
err = json.Unmarshal([]byte(val), &schedule)
return &schedule, err
}
func (d *CacheDAO) SetUserOngoingScheduleToCache(ctx context.Context, userID int, schedule *model.OngoingSchedule) error {
if schedule == nil {
// 如果当前没有正在进行的日程,则缓存空值并短暂过期,避免频繁回源查询。
key := fmt.Sprintf("smartflow:ongoing_schedule:%d", userID)
return d.client.Set(ctx, key, "null", 5*time.Minute).Err()
}
key := fmt.Sprintf("smartflow:ongoing_schedule:%d", userID)
data, err := json.Marshal(schedule)
if err != nil {
return err
}
// 设置过期时间为距离 endTime 的剩余时长;若已过期,则不再写入缓存。
ttl := time.Until(schedule.EndTime)
if ttl <= 0 {
return nil
}
return d.client.Set(ctx, key, data, ttl).Err()
}
func (d *CacheDAO) DeleteUserOngoingScheduleFromCache(ctx context.Context, userID int) error {
key := fmt.Sprintf("smartflow:ongoing_schedule:%d", userID)
return d.client.Del(ctx, key).Err()
}
// SetSchedulePlanPreviewToCache 写入“排程预览”缓存。
//
// 职责边界:
// 1. 负责按 user_id + conversation_id 写入结构化预览快照;
// 2. 负责 preview 入库前的基础参数校验,避免无效 key
// 3. 不负责 DB 回源,不负责业务重试策略。
//
// 步骤化说明:
// 1. 先校验 user_id / conversation_id / preview防止脏写
// 2. 再序列化 preview 为 JSON保证缓存结构稳定
// 3. 最后按固定 TTL 写入 Redis超时后自动失效。
func (d *CacheDAO) SetSchedulePlanPreviewToCache(ctx context.Context, userID int, conversationID string, preview *model.SchedulePlanPreviewCache) error {
if d == nil || d.client == nil {
return errors.New("cache dao is not initialized")
}
if userID <= 0 {
return fmt.Errorf("invalid user_id: %d", userID)
}
normalizedConversationID := strings.TrimSpace(conversationID)
if normalizedConversationID == "" {
return errors.New("conversation_id is empty")
}
if preview == nil {
return errors.New("schedule preview is nil")
}
data, err := json.Marshal(preview)
if err != nil {
return fmt.Errorf("marshal schedule preview failed: %w", err)
}
return d.client.Set(ctx, d.schedulePreviewKey(userID, normalizedConversationID), data, 1*time.Hour).Err()
}
// GetSchedulePlanPreviewFromCache 读取“排程预览”缓存。
//
// 输入输出语义:
// 1. 命中时返回 (*SchedulePlanPreviewCache, nil)
// 2. 未命中时返回 (nil, nil)
// 3. Redis 异常或反序列化失败时返回 error。
func (d *CacheDAO) GetSchedulePlanPreviewFromCache(ctx context.Context, userID int, conversationID string) (*model.SchedulePlanPreviewCache, error) {
if d == nil || d.client == nil {
return nil, errors.New("cache dao is not initialized")
}
if userID <= 0 {
return nil, fmt.Errorf("invalid user_id: %d", userID)
}
normalizedConversationID := strings.TrimSpace(conversationID)
if normalizedConversationID == "" {
return nil, errors.New("conversation_id is empty")
}
raw, err := d.client.Get(ctx, d.schedulePreviewKey(userID, normalizedConversationID)).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
var preview model.SchedulePlanPreviewCache
if err = json.Unmarshal([]byte(raw), &preview); err != nil {
return nil, fmt.Errorf("unmarshal schedule preview failed: %w", err)
}
return &preview, nil
}
// DeleteSchedulePlanPreviewFromCache 删除“排程预览”缓存。
//
// 说明:
// 1. 删除操作是幂等的key 不存在也视为成功;
// 2. 该方法用于新排程前清旧预览,或状态快照更新后触发失效。
func (d *CacheDAO) DeleteSchedulePlanPreviewFromCache(ctx context.Context, userID int, conversationID string) error {
if d == nil || d.client == nil {
return errors.New("cache dao is not initialized")
}
if userID <= 0 {
return fmt.Errorf("invalid user_id: %d", userID)
}
normalizedConversationID := strings.TrimSpace(conversationID)
if normalizedConversationID == "" {
return errors.New("conversation_id is empty")
}
return d.client.Del(ctx, d.schedulePreviewKey(userID, normalizedConversationID)).Err()
}
// IncrConversationTimelineSeq 原子递增并返回会话时间线 seq。
//
// 说明:
// 1. seq 只在同一 user_id + conversation_id 维度内递增;
// 2. 使用 Redis INCR 保证并发下不会拿到重复顺序号;
// 3. 该 key 也会设置 TTL避免长尾会话长期占用缓存。
func (d *CacheDAO) IncrConversationTimelineSeq(ctx context.Context, userID int, conversationID string) (int64, error) {
if d == nil || d.client == nil {
return 0, errors.New("cache dao is not initialized")
}
if userID <= 0 {
return 0, fmt.Errorf("invalid user_id: %d", userID)
}
normalizedConversationID := strings.TrimSpace(conversationID)
if normalizedConversationID == "" {
return 0, errors.New("conversation_id is empty")
}
key := d.conversationTimelineSeqKey(userID, normalizedConversationID)
pipe := d.client.Pipeline()
incrCmd := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, 24*time.Hour)
if _, err := pipe.Exec(ctx); err != nil {
return 0, err
}
return incrCmd.Val(), nil
}
// SetConversationTimelineSeq 强制设置会话时间线当前 seqDB 回填 Redis 兜底场景)。
func (d *CacheDAO) SetConversationTimelineSeq(ctx context.Context, userID int, conversationID string, seq int64) error {
if d == nil || d.client == nil {
return errors.New("cache dao is not initialized")
}
if userID <= 0 {
return fmt.Errorf("invalid user_id: %d", userID)
}
normalizedConversationID := strings.TrimSpace(conversationID)
if normalizedConversationID == "" {
return errors.New("conversation_id is empty")
}
if seq < 0 {
seq = 0
}
return d.client.Set(ctx, d.conversationTimelineSeqKey(userID, normalizedConversationID), seq, 24*time.Hour).Err()
}
// AppendConversationTimelineEventToCache 追加单条时间线缓存事件。
func (d *CacheDAO) AppendConversationTimelineEventToCache(
ctx context.Context,
userID int,
conversationID string,
item model.GetConversationTimelineItem,
) error {
if d == nil || d.client == nil {
return errors.New("cache dao is not initialized")
}
if userID <= 0 {
return fmt.Errorf("invalid user_id: %d", userID)
}
normalizedConversationID := strings.TrimSpace(conversationID)
if normalizedConversationID == "" {
return errors.New("conversation_id is empty")
}
data, err := json.Marshal(item)
if err != nil {
return fmt.Errorf("marshal conversation timeline item failed: %w", err)
}
key := d.conversationTimelineKey(userID, normalizedConversationID)
pipe := d.client.Pipeline()
pipe.RPush(ctx, key, data)
pipe.Expire(ctx, key, 24*time.Hour)
_, err = pipe.Exec(ctx)
return err
}
// SetConversationTimelineToCache 全量回填时间线缓存。
func (d *CacheDAO) SetConversationTimelineToCache(ctx context.Context, userID int, conversationID string, items []model.GetConversationTimelineItem) error {
if d == nil || d.client == nil {
return errors.New("cache dao is not initialized")
}
if userID <= 0 {
return fmt.Errorf("invalid user_id: %d", userID)
}
normalizedConversationID := strings.TrimSpace(conversationID)
if normalizedConversationID == "" {
return errors.New("conversation_id is empty")
}
key := d.conversationTimelineKey(userID, normalizedConversationID)
pipe := d.client.Pipeline()
pipe.Del(ctx, key)
if len(items) > 0 {
values := make([]interface{}, 0, len(items))
for _, item := range items {
data, err := json.Marshal(item)
if err != nil {
return fmt.Errorf("marshal conversation timeline item failed: %w", err)
}
values = append(values, data)
}
pipe.RPush(ctx, key, values...)
}
pipe.Expire(ctx, key, 24*time.Hour)
_, err := pipe.Exec(ctx)
return err
}
// GetConversationTimelineFromCache 读取时间线缓存(按 seq 正序)。
func (d *CacheDAO) GetConversationTimelineFromCache(ctx context.Context, userID int, conversationID string) ([]model.GetConversationTimelineItem, error) {
if d == nil || d.client == nil {
return nil, errors.New("cache dao is not initialized")
}
if userID <= 0 {
return nil, fmt.Errorf("invalid user_id: %d", userID)
}
normalizedConversationID := strings.TrimSpace(conversationID)
if normalizedConversationID == "" {
return nil, errors.New("conversation_id is empty")
}
rawItems, err := d.client.LRange(ctx, d.conversationTimelineKey(userID, normalizedConversationID), 0, -1).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
if len(rawItems) == 0 {
return nil, nil
}
items := make([]model.GetConversationTimelineItem, 0, len(rawItems))
for _, raw := range rawItems {
var item model.GetConversationTimelineItem
if err := json.Unmarshal([]byte(raw), &item); err != nil {
return nil, fmt.Errorf("unmarshal conversation timeline item failed: %w", err)
}
items = append(items, item)
}
return items, nil
}
// DeleteConversationTimelineFromCache 删除时间线缓存和 seq 缓存。
func (d *CacheDAO) DeleteConversationTimelineFromCache(ctx context.Context, userID int, conversationID string) error {
if d == nil || d.client == nil {
return errors.New("cache dao is not initialized")
}
if userID <= 0 {
return fmt.Errorf("invalid user_id: %d", userID)
}
normalizedConversationID := strings.TrimSpace(conversationID)
if normalizedConversationID == "" {
return errors.New("conversation_id is empty")
}
return d.client.Del(
ctx,
d.conversationTimelineKey(userID, normalizedConversationID),
d.conversationTimelineSeqKey(userID, normalizedConversationID),
).Err()
}
// agentStateKey 返回 agent 运行态快照的 Redis key。
//
// Key 设计:
// 1. 使用 smartflow:agent_state 前缀,与现有 key 命名空间隔离;
// 2. 使用 conversationID 作为唯一标识,因为 agent 状态是按会话维度持久化的。
const activeScheduleSessionCacheTTL = 2 * time.Hour
// activeScheduleSessionKey 生成 session_id 维度的主动调度会话缓存 key。
func (d *CacheDAO) activeScheduleSessionKey(sessionID string) string {
return fmt.Sprintf("smartflow:active_schedule_session:s:%s", strings.TrimSpace(sessionID))
}
// activeScheduleSessionConversationKey 生成 user_id + conversation_id 维度的主动调度会话缓存 key。
func (d *CacheDAO) activeScheduleSessionConversationKey(userID int, conversationID string) string {
return fmt.Sprintf("smartflow:active_schedule_session:u:%d:c:%s", userID, strings.TrimSpace(conversationID))
}
// SetActiveScheduleSessionToCache 同步写入主动调度会话缓存。
//
// 步骤化说明:
// 1. 先校验 snapshot 和主键,避免把无效会话写进 Redis
// 2. 再把同一份快照写入 session_id / conversation_id 两个维度的 key
// 3. 若 conversation_id 还没绑定,只写 session_id key避免生成空路由 key。
func (d *CacheDAO) SetActiveScheduleSessionToCache(ctx context.Context, snapshot *model.ActiveScheduleSessionSnapshot) error {
if d == nil || d.client == nil {
return errors.New("cache dao is not initialized")
}
if snapshot == nil {
return errors.New("active schedule session snapshot is nil")
}
sessionID := strings.TrimSpace(snapshot.SessionID)
if sessionID == "" {
return errors.New("session_id is empty")
}
data, err := json.Marshal(snapshot)
if err != nil {
return fmt.Errorf("marshal active schedule session cache failed: %w", err)
}
pipe := d.client.Pipeline()
pipe.Set(ctx, d.activeScheduleSessionKey(sessionID), data, activeScheduleSessionCacheTTL)
if conversationID := strings.TrimSpace(snapshot.ConversationID); conversationID != "" && snapshot.UserID > 0 {
pipe.Set(ctx, d.activeScheduleSessionConversationKey(snapshot.UserID, conversationID), data, activeScheduleSessionCacheTTL)
}
_, err = pipe.Exec(ctx)
return err
}
// GetActiveScheduleSessionFromCache 按 session_id 读取主动调度会话缓存。
func (d *CacheDAO) GetActiveScheduleSessionFromCache(ctx context.Context, sessionID string) (*model.ActiveScheduleSessionSnapshot, error) {
if d == nil || d.client == nil {
return nil, errors.New("cache dao is not initialized")
}
normalizedSessionID := strings.TrimSpace(sessionID)
if normalizedSessionID == "" {
return nil, errors.New("session_id is empty")
}
raw, err := d.client.Get(ctx, d.activeScheduleSessionKey(normalizedSessionID)).Result()
if errors.Is(err, redis.Nil) {
return nil, nil
}
if err != nil {
return nil, err
}
var snapshot model.ActiveScheduleSessionSnapshot
if err = json.Unmarshal([]byte(raw), &snapshot); err != nil {
return nil, fmt.Errorf("unmarshal active schedule session cache failed: %w", err)
}
return &snapshot, nil
}
// GetActiveScheduleSessionFromConversationCache 按 user_id + conversation_id 读取主动调度会话缓存。
func (d *CacheDAO) GetActiveScheduleSessionFromConversationCache(ctx context.Context, userID int, conversationID string) (*model.ActiveScheduleSessionSnapshot, error) {
if d == nil || d.client == nil {
return nil, errors.New("cache dao is not initialized")
}
if userID <= 0 {
return nil, fmt.Errorf("invalid user_id: %d", userID)
}
normalizedConversationID := strings.TrimSpace(conversationID)
if normalizedConversationID == "" {
return nil, errors.New("conversation_id is empty")
}
raw, err := d.client.Get(ctx, d.activeScheduleSessionConversationKey(userID, normalizedConversationID)).Result()
if errors.Is(err, redis.Nil) {
return nil, nil
}
if err != nil {
return nil, err
}
var snapshot model.ActiveScheduleSessionSnapshot
if err = json.Unmarshal([]byte(raw), &snapshot); err != nil {
return nil, fmt.Errorf("unmarshal active schedule session cache failed: %w", err)
}
return &snapshot, nil
}
// DeleteActiveScheduleSessionFromCache 删除主动调度会话缓存。
//
// 说明:
// 1. 会同时清理 session_id 和 conversation_id 两个维度,避免旧路由缓存残留;
// 2. conversation_id 为空时只清 session_id key
// 3. 删除操作本身幂等,即使 key 不存在也视为成功。
func (d *CacheDAO) DeleteActiveScheduleSessionFromCache(ctx context.Context, sessionID string, userID int, conversationID string) error {
if d == nil || d.client == nil {
return errors.New("cache dao is not initialized")
}
normalizedSessionID := strings.TrimSpace(sessionID)
if normalizedSessionID == "" {
return errors.New("session_id is empty")
}
keys := []string{d.activeScheduleSessionKey(normalizedSessionID)}
if userID > 0 {
if normalizedConversationID := strings.TrimSpace(conversationID); normalizedConversationID != "" {
keys = append(keys, d.activeScheduleSessionConversationKey(userID, normalizedConversationID))
}
}
return d.client.Del(ctx, keys...).Err()
}
func (d *CacheDAO) agentStateKey(conversationID string) string {
return fmt.Sprintf("smartflow:agent_state:%s", conversationID)
}
// SaveAgentState 序列化并保存 agent 运行态快照到 Redis。
//
// 职责边界:
// 1. 只负责 JSON 序列化 + Redis SET不做业务校验
// 2. TTL 默认 2h过期自动清理配合 MySQL outbox 异步持久化;
// 3. snapshot 为 nil 时直接返回,避免写入无效数据。
func (d *CacheDAO) SaveAgentState(ctx context.Context, conversationID string, snapshot any) error {
if d == nil || d.client == nil {
return errors.New("cache dao is not initialized")
}
normalizedID := strings.TrimSpace(conversationID)
if normalizedID == "" {
return errors.New("conversation_id is empty")
}
if snapshot == nil {
return nil
}
data, err := json.Marshal(snapshot)
if err != nil {
return fmt.Errorf("marshal agent state failed: %w", err)
}
return d.client.Set(ctx, d.agentStateKey(normalizedID), data, 2*time.Hour).Err()
}
// LoadAgentState 从 Redis 读取并反序列化 agent 运行态快照。
//
// 返回值语义:
// 1. (result, true, nil):命中快照,正常返回;
// 2. (nil, false, nil):未命中,不是错误,调用方应走新建对话路径;
// 3. (nil, false, error)Redis 或反序列化错误。
func (d *CacheDAO) LoadAgentState(ctx context.Context, conversationID string, result any) (bool, error) {
if d == nil || d.client == nil {
return false, errors.New("cache dao is not initialized")
}
normalizedID := strings.TrimSpace(conversationID)
if normalizedID == "" {
return false, errors.New("conversation_id is empty")
}
raw, err := d.client.Get(ctx, d.agentStateKey(normalizedID)).Result()
if errors.Is(err, redis.Nil) {
return false, nil
}
if err != nil {
return false, err
}
if err := json.Unmarshal([]byte(raw), result); err != nil {
return false, fmt.Errorf("unmarshal agent state failed: %w", err)
}
return true, nil
}
// DeleteAgentState 删除指定会话的 agent 运行态快照。
//
// 语义:
// 1. 删除操作是幂等的key 不存在也视为成功;
// 2. 典型调用时机Deliver 节点任务完成后清理。
func (d *CacheDAO) DeleteAgentState(ctx context.Context, conversationID string) error {
if d == nil || d.client == nil {
return errors.New("cache dao is not initialized")
}
normalizedID := strings.TrimSpace(conversationID)
if normalizedID == "" {
return errors.New("conversation_id is empty")
}
return d.client.Del(ctx, d.agentStateKey(normalizedID)).Err()
}
// --- 记忆预取缓存 ---
const (
memoryPrefetchTTL = 30 * time.Minute
)
// memoryPrefetchKey 生成用户+会话维度的记忆预取缓存 key。
//
// 1. 格式smartflow:memory_prefetch:u:{userID}:c:{chatID},与 conversationTimelineKey / schedulePreviewKey 命名风格一致;
// 2. chatID 为空时 key 为 smartflow:memory_prefetch:u:5:c:,仍然合法且唯一,不会与其他会话 key 冲突;
// 3. 加 chatID 隔离后,不同会话各自维护独立的预取缓存,避免会话间记忆上下文互相覆盖。
func (d *CacheDAO) memoryPrefetchKey(userID int, chatID string) string {
return fmt.Sprintf("smartflow:memory_prefetch:u:%d:c:%s", userID, chatID)
}
// GetMemoryPrefetchCache 读取用户记忆预取缓存。
//
// 输入输出语义:
// 1. 命中时返回 ItemDTO 切片与 nil error
// 2. 未命中时返回 nil, nil
// 3. Redis 异常或反序列化失败时返回 error。
func (d *CacheDAO) GetMemoryPrefetchCache(ctx context.Context, userID int, chatID string) ([]memorymodel.ItemDTO, error) {
if d == nil || d.client == nil {
return nil, errors.New("cache dao is not initialized")
}
if userID <= 0 {
return nil, nil
}
key := d.memoryPrefetchKey(userID, chatID)
raw, err := d.client.Get(ctx, key).Result()
if errors.Is(err, redis.Nil) {
return nil, nil
}
if err != nil {
return nil, err
}
var items []memorymodel.ItemDTO
if err = json.Unmarshal([]byte(raw), &items); err != nil {
return nil, fmt.Errorf("unmarshal memory prefetch cache failed: %w", err)
}
return items, nil
}
// SetMemoryPrefetchCache 写入用户记忆预取缓存。
//
// 职责边界:
// 1. 负责将检索后的记忆 DTO 写入 Redis供下一轮 Chat 节点即时消费;
// 2. TTL 30 分钟,靠自然过期淘汰,不需要显式 Invalidate
// 3. items 为空或 nil 时直接返回,避免写入无效数据。
func (d *CacheDAO) SetMemoryPrefetchCache(ctx context.Context, userID int, chatID string, items []memorymodel.ItemDTO) error {
if d == nil || d.client == nil {
return errors.New("cache dao is not initialized")
}
if userID <= 0 || len(items) == 0 {
return nil
}
data, err := json.Marshal(items)
if err != nil {
return fmt.Errorf("marshal memory prefetch cache failed: %w", err)
}
key := d.memoryPrefetchKey(userID, chatID)
return d.client.Set(ctx, key, data, memoryPrefetchTTL).Err()
}
// DeleteMemoryPrefetchCacheByUser 删除指定用户所有会话的记忆预取缓存。
//
// 步骤化说明:
// 1. 用 SCAN 遍历 smartflow:memory_prefetch:u:{userID}:c:* 匹配的所有 key
// 2. 用 UNLINK 异步删除,避免阻塞 Redis 主线程;
// 3. 复用 DeleteUserRecentCompletedSchedulesFromCache 的 SCAN+UNLINK 模式;
// 4. 该方法被 GORM cache deleter 和空检索清理两条链路共同调用,保证缓存一致性。
func (d *CacheDAO) DeleteMemoryPrefetchCacheByUser(ctx context.Context, userID int) error {
if d == nil || d.client == nil {
return errors.New("cache dao is not initialized")
}
if userID <= 0 {
return nil
}
pattern := fmt.Sprintf("smartflow:memory_prefetch:u:%d:c:*", userID)
var cursor uint64
for {
keys, next, err := d.client.Scan(ctx, cursor, pattern, 500).Result()
if err != nil {
return err
}
if len(keys) > 0 {
// 1. UNLINK 是 DEL 的异步版本,不会阻塞 Redis 主线程;
// 2. 即使 key 不存在也不会报错,幂等安全。
if err := d.client.Unlink(ctx, keys...).Err(); err != nil {
return err
}
}
cursor = next
if cursor == 0 {
break
}
}
return nil
}

View File

@@ -0,0 +1,50 @@
package dao
import (
"context"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"gorm.io/gorm"
)
type CourseDAO struct {
db *gorm.DB
}
// NewCourseDAO 创建ScheduleDAO实例
func NewCourseDAO(db *gorm.DB) *CourseDAO {
return &CourseDAO{
db: db,
}
}
func (r *CourseDAO) WithTx(tx *gorm.DB) *CourseDAO {
return &CourseDAO{db: tx}
}
func (r *CourseDAO) AddUserCoursesIntoSchedule(ctx context.Context, courses []model.Schedule) error {
if err := r.db.WithContext(ctx).Create(&courses).Error; err != nil {
return err
}
return nil
}
func (r *CourseDAO) AddUserCoursesIntoScheduleEvents(ctx context.Context, events []model.ScheduleEvent) ([]int, error) {
if err := r.db.WithContext(ctx).Create(&events).Error; err != nil {
return nil, err
}
ids := make([]int, 0, len(events))
for i := range events {
ids = append(ids, events[i].ID)
}
return ids, nil
}
// Transaction 在同一个数据库事务中执行传入的函数,供 service 层复用(自动提交/回滚)
// 规则fn 返回 nil \-\> 提交fn 返回 error 或发生 panic \-\> 回滚
// 说明gorm\.\(\\\*DB\)\.Transaction 会在 fn 返回 error 时回滚,并在发生 panic 时自动回滚后继续向上抛出 panic
func (r *CourseDAO) Transaction(fn func(txDAO *CourseDAO) error) error {
return r.db.Transaction(func(tx *gorm.DB) error {
return fn(NewCourseDAO(tx))
})
}

View File

@@ -0,0 +1,671 @@
package dao
import (
"context"
"errors"
"fmt"
"time"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/LoveLosita/smartflow/backend/shared/respond"
"gorm.io/gorm"
)
type ScheduleDAO struct {
db *gorm.DB
}
// NewScheduleDAO 创建TaskClassDAO实例
func NewScheduleDAO(db *gorm.DB) *ScheduleDAO {
return &ScheduleDAO{
db: db,
}
}
func (d *ScheduleDAO) WithTx(tx *gorm.DB) *ScheduleDAO {
return &ScheduleDAO{db: tx}
}
func (d *ScheduleDAO) AddSchedules(schedules []model.Schedule) ([]int, error) {
if err := d.db.Create(&schedules).Error; err != nil {
return nil, err
}
ids := make([]int, len(schedules))
for i, s := range schedules {
ids[i] = s.ID
}
return ids, nil
}
func (d *ScheduleDAO) EmbedTaskIntoSchedule(startSection, endSection, dayOfWeek, week, userID, taskID int) error {
// 仅更新指定:用户/周/星期/节次区间 的记录,将 embedded_task_id 精准写入 taskID
res := d.db.
Table("schedules").
Where("user_id = ? AND week = ? AND day_of_week = ? AND section BETWEEN ? AND ?", userID, week, dayOfWeek, startSection, endSection).
Update("embedded_task_id", taskID)
return res.Error
}
func (d *ScheduleDAO) GetCourseUserIDByID(ctx context.Context, courseScheduleEventID int) (int, error) {
type row struct {
UserID *int `gorm:"column:user_id"`
}
var r row
err := d.db.WithContext(ctx).
Table("schedule_events").
Select("user_id").
Where("id = ?", courseScheduleEventID).
First(&r).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return 0, respond.WrongCourseID
}
return 0, err
}
if r.UserID == nil {
return 0, respond.WrongCourseID
}
return *r.UserID, nil
}
// IsCourseEmbeddedByOtherTaskBlock 判断课程在给定节次区间内是否已被其他任务块嵌入(用于业务限制)
func (d *ScheduleDAO) IsCourseEmbeddedByOtherTaskBlock(ctx context.Context, courseID, startSection, endSection int) (bool, error) {
// 若区间非法,视为不冲突
if startSection <= 0 || endSection <= 0 || startSection > endSection {
return false, nil
}
var cnt int64
err := d.db.WithContext(ctx).
Table("schedules").
Where("id = ?", courseID).
Where("section BETWEEN ? AND ?", startSection, endSection).
Where("embedded_task_id IS NOT NULL AND embedded_task_id <> 0").
Count(&cnt).Error
if err != nil {
return false, err
}
return cnt > 0, nil
}
func (d *ScheduleDAO) HasUserScheduleConflict(ctx context.Context, userID, week, dayOfWeek int, sections []int) (bool, error) {
// 无节次则视为无冲突
if len(sections) == 0 {
return false, nil
}
// 统计同一用户、同一周、同一天、且节次有交集的排程数量
// 约定表字段user_id, week, day_of_week, section
var cnt int64
err := d.db.WithContext(ctx).
Table("schedules").
Where("user_id = ? AND week = ? AND day_of_week = ?", userID, week, dayOfWeek).
Where("section IN ?", sections).
Count(&cnt).Error
if err != nil {
return false, err
}
return cnt > 0, nil
}
func (d *ScheduleDAO) IsCourseTimeMatch(ctx context.Context, courseScheduleEventID, week, dayOfWeek, startSection, endSection int) (bool, error) {
// 区间非法直接不匹配
if startSection <= 0 || endSection <= 0 || startSection > endSection {
return false, nil
}
// 核对该课程事件在指定 周\+星期 下,是否存在覆盖整个节次区间的排程记录
// 说明此处按你当前表结构的用法schedule\_events 存事件schedules 存节次明细)来写:
// schedules 里通过 schedule\_event\_id 关联到 schedule\_events.id
var cnt int64
err := d.db.WithContext(ctx).
Table("schedules").
Where("event_id = ?", courseScheduleEventID).
Where("week = ? AND day_of_week = ?", week, dayOfWeek).
Where("section BETWEEN ? AND ?", startSection, endSection).
Count(&cnt).Error
if err != nil {
return false, err
}
// 需要区间内的每一节都存在记录才算匹配
return cnt == int64(endSection-startSection+1), nil
}
func (d *ScheduleDAO) AddScheduleEvent(scheduleEvent *model.ScheduleEvent) (int, error) {
if err := d.db.Create(&scheduleEvent).Error; err != nil {
return 0, err
}
return scheduleEvent.ID, nil
}
// CheckScheduleConflict 检查给定的 Schedule 切片中是否存在课程的冲突(即同一用户、同一周、同一天、且节次有交集的记录,并且只管课程,不管其它任务类型)
func (d *ScheduleDAO) CheckScheduleConflict(ctx context.Context, schedules []model.Schedule) (bool, error) {
if len(schedules) == 0 {
return false, nil
}
// 聚合:同一 user/week/day 的节次去重后一次性查库
type key struct {
UserID int
Week int
DayOfWeek int
}
groups := make(map[key]map[int]struct{})
for _, s := range schedules {
// 基础字段不合法直接跳过(按不冲突处理)
if s.UserID <= 0 || s.Week <= 0 || s.DayOfWeek <= 0 || s.Section <= 0 {
continue
}
k := key{UserID: s.UserID, Week: s.Week, DayOfWeek: s.DayOfWeek}
if _, ok := groups[k]; !ok {
groups[k] = make(map[int]struct{})
}
groups[k][s.Section] = struct{}{}
}
for k, set := range groups {
if len(set) == 0 {
continue
}
sections := make([]int, 0, len(set))
for sec := range set {
sections = append(sections, sec)
}
// 仅判断“课程type=course”是否冲突
// schedules.event_id -> schedule_events.id再用 schedule_events.type 过滤
var cnt int64
err := d.db.WithContext(ctx).
Table("schedules s").
Joins("JOIN schedule_events e ON e.id = s.event_id").
Where("s.user_id = ? AND s.week = ? AND s.day_of_week = ?", k.UserID, k.Week, k.DayOfWeek).
Where("s.section IN ?", sections).
Where("e.type = ?", "course").
Count(&cnt).Error
if err != nil {
return false, err
}
if cnt > 0 {
return true, nil
}
}
return false, nil
}
func (d *ScheduleDAO) GetNonCourseScheduleConflicts(ctx context.Context, newSchedules []model.Schedule) ([]model.Schedule, error) {
if len(newSchedules) == 0 {
return nil, nil
}
// 1. 构建指纹图:用于快速比对坐标
userID := newSchedules[0].UserID
weeksMap := make(map[int]bool)
newSlotsFingerprints := make(map[string]bool)
for _, s := range newSchedules {
weeksMap[s.Week] = true
key := fmt.Sprintf("%d-%d-%d", s.Week, s.DayOfWeek, s.Section)
newSlotsFingerprints[key] = true
}
weeks := make([]int, 0, len(weeksMap))
for w := range weeksMap {
weeks = append(weeks, w)
}
// 2. 第一步:定义一个临时小结构体,精准捞取坐标和 EventID
type simpleSlot struct {
EventID int
Week int
DayOfWeek int
Section int
}
var candidates []simpleSlot
// 💡 这里的逻辑:只查索引覆盖到的字段,速度极快
err := d.db.WithContext(ctx).
Table("schedules").
Select("schedules.event_id, schedules.week, schedules.day_of_week, schedules.section").
Joins("JOIN schedule_events ON schedule_events.id = schedules.event_id").
Where("schedules.user_id = ? AND schedules.week IN ? AND schedule_events.type != ?", userID, weeks, "course").
Scan(&candidates).Error
if err != nil {
return nil, err
}
// 3. 筛选出真正碰撞的 EventID
eventIDMap := make(map[int]bool)
for _, s := range candidates {
key := fmt.Sprintf("%d-%d-%d", s.Week, s.DayOfWeek, s.Section)
if newSlotsFingerprints[key] {
eventIDMap[s.EventID] = true
}
}
if len(eventIDMap) == 0 {
return nil, nil
}
// 4. 第二步:“抄全家”——根据碰撞到的 ID 捞出这些任务的所有原子槽位
var ids []int
for id := range eventIDMap {
ids = append(ids, id)
}
var fullConflicts []model.Schedule
// 💡 关键:这里必须 Preload("Event"),这样 DTO 才有名称显示
err = d.db.WithContext(ctx).
Preload("Event").
Where("event_id IN ?", ids).
Find(&fullConflicts).Error
return fullConflicts, err
}
func (d *ScheduleDAO) GetUserTodaySchedule(ctx context.Context, userID, week, dayOfWeek int) ([]model.Schedule, error) {
var schedules []model.Schedule
// 1. Preload("Event"): 拿到课程/任务的基础信息(名、地、型)
// 2. Preload("EmbeddedTask"): 拿到“水课”里嵌入的具体任务详情
err := d.db.WithContext(ctx).
Preload("Event").
Preload("EmbeddedTask").
Where("user_id = ? AND week = ? AND day_of_week = ?", userID, week, dayOfWeek).
Order("section ASC").
Find(&schedules).Error
if err != nil {
return nil, err
}
return schedules, nil
}
func (d *ScheduleDAO) GetUserWeeklySchedule(ctx context.Context, userID, week int) ([]model.Schedule, error) {
var schedules []model.Schedule
err := d.db.WithContext(ctx).
Preload("Event").
Preload("EmbeddedTask").
Where("user_id = ? AND week = ?", userID, week).
Order("day_of_week ASC, section ASC").
Find(&schedules).Error
if err != nil {
return nil, err
}
return schedules, nil
}
func (d *ScheduleDAO) DeleteScheduleEventAndSchedule(ctx context.Context, eventID int, userID int) error {
return d.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 先查出要删除的 schedules让 GORM 在 Delete 时能带上模型字段(供钩子读取 UserID/Week
var schedules []model.Schedule
if err := tx.
Where("event_id = ? AND user_id = ?", eventID, userID).
Find(&schedules).Error; err != nil {
return err
}
// 显式删子表 schedules触发 schedules 的 GORM Delete 回调/插件)
if len(schedules) > 0 {
if err := tx.Delete(&schedules).Error; err != nil {
return err
}
}
// 再删父表 schedule_events同样触发回调/插件)
res := tx.Where("id = ? AND user_id = ?", eventID, userID).
Delete(&model.ScheduleEvent{})
if res.Error != nil {
return res.Error
}
if res.RowsAffected == 0 {
return respond.WrongScheduleEventID
}
return nil
})
}
func (d *ScheduleDAO) GetScheduleTypeByEventID(ctx context.Context, eventID, userID int) (string, error) {
type row struct {
Type *string `gorm:"column:type"`
}
var r row
err := d.db.WithContext(ctx).
Table("schedule_events").
Select("type").
Where("id = ? AND user_id=?", eventID, userID).
First(&r).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", respond.WrongScheduleEventID // 事件不存在或不属于该用户,统一返回错误
}
return "", err
}
if r.Type == nil {
return "", respond.WrongScheduleEventID
}
return *r.Type, nil
}
func (d *ScheduleDAO) GetScheduleEmbeddedTaskID(ctx context.Context, eventID int) (int, error) {
// embedded_task_id 存在于 schedules 表中(按 event_id 聚合取一个非空值)
// 若该事件没有任何嵌入任务,则返回 0, nil
type row struct {
EmbeddedTaskID *int `gorm:"column:embedded_task_id"`
}
var r row
err := d.db.WithContext(ctx).
Table("schedules").
Select("embedded_task_id").
Where("event_id = ?", eventID).
Where("embedded_task_id IS NOT NULL AND embedded_task_id <> 0").
Order("id ASC").
Limit(1).
Scan(&r).Error
if err != nil {
return 0, err
}
if r.EmbeddedTaskID == nil { // 没有任何嵌入任务
return 0, nil
}
return *r.EmbeddedTaskID, nil
}
func (d *ScheduleDAO) IfScheduleEventIDExists(ctx context.Context, eventID int) (bool, error) {
var count int64
err := d.db.WithContext(ctx).
Table("schedule_events").
Where("id = ?", eventID).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
func (d *ScheduleDAO) SetScheduleEmbeddedTaskIDToNull(ctx context.Context, eventID int) (int, error) {
// 先取出该事件当前嵌入的任务 id若没有嵌入则返回对应业务错误
embeddedTaskID, err := d.GetScheduleEmbeddedTaskID(ctx, eventID)
if err != nil {
return 0, err
}
if embeddedTaskID == 0 {
return 0, respond.TargetScheduleNotHaveEmbeddedTask
}
// 将 schedules 表中指定 event_id 的 embedded_task_id 字段置空(用于解除嵌入关系)
res := d.db.WithContext(ctx).
Table("schedules").
Where("event_id = ?", eventID).
Where("embedded_task_id IS NOT NULL AND embedded_task_id <> 0").
Update("embedded_task_id", nil)
if res.Error != nil {
return 0, res.Error
}
if res.RowsAffected == 0 {
return 0, respond.TargetScheduleNotHaveEmbeddedTask
}
return embeddedTaskID, nil
}
func (d *ScheduleDAO) FindEmbeddedTaskIDAndDeleteIt(ctx context.Context, taskID int) (int, error) {
// 1. 先找到 schedules 表中 embedded_task_id = taskID 的记录,获取对应的 event_id。
// 1.1 该 taskID 可能是“嵌入课程”的任务块,也可能是“独立任务日程”的任务块;
// 1.2 两者撤销策略不同:课程只清 embedded_task_id独立任务需要删除 schedules 后再删 event。
type row struct {
EventID *int `gorm:"column:event_id"`
}
var r row
err := d.db.WithContext(ctx).
Table("schedules").
Select("event_id").
Where("embedded_task_id = ?", taskID).
Order("id ASC").
Limit(1).
Scan(&r).Error
if err != nil {
return 0, err
}
if r.EventID == nil {
return 0, respond.TargetTaskNotEmbeddedInAnySchedule
}
eventID := *r.EventID
var event model.ScheduleEvent
if err := d.db.WithContext(ctx).
Where("id = ?", eventID).
First(&event).Error; err != nil {
return 0, err
}
if event.Type == "task" && event.RelID != nil && *event.RelID == taskID {
// 2. 独立任务日程schedules.event_id 是外键,必须先删原子槽位再删事件。
if err := d.db.WithContext(ctx).
Table("schedules").
Where("event_id = ?", eventID).
Delete(&model.Schedule{}).Error; err != nil {
return 0, err
}
res := d.db.WithContext(ctx).
Table("schedule_events").
Where("id = ?", eventID).
Delete(&model.ScheduleEvent{})
if res.Error != nil {
return 0, res.Error
}
if res.RowsAffected == 0 {
return 0, respond.TargetTaskNotEmbeddedInAnySchedule
}
return eventID, nil
}
// 3. 嵌入课程:保留课程事件与课程槽位,只清空 embedded_task_id。
clearRes := d.db.WithContext(ctx).
Table("schedules").
Where("embedded_task_id = ?", taskID).
Update("embedded_task_id", nil)
if clearRes.Error != nil {
return 0, clearRes.Error
}
if clearRes.RowsAffected == 0 {
return 0, respond.TargetTaskNotEmbeddedInAnySchedule
}
return eventID, nil
}
func (d *ScheduleDAO) DeleteScheduleEventByTaskItemID(ctx context.Context, taskItemID int) error {
// 1. 先找 type=task 且 rel_id=taskItemID 的正式事件;若前一步已经删除则保持幂等成功。
var eventIDs []int
if err := d.db.WithContext(ctx).
Table("schedule_events").
Where("type = ? AND rel_id = ?", "task", taskItemID).
Pluck("id", &eventIDs).Error; err != nil {
return err
}
if len(eventIDs) == 0 {
return nil
}
// 2. schedules.event_id 指向 schedule_events.id删除顺序必须先子表后父表。
if err := d.db.WithContext(ctx).
Table("schedules").
Where("event_id IN ?", eventIDs).
Delete(&model.Schedule{}).Error; err != nil {
return err
}
return d.db.WithContext(ctx).
Table("schedule_events").
Where("id IN ?", eventIDs).
Delete(&model.ScheduleEvent{}).Error
}
func (d *ScheduleDAO) GetUserRecentCompletedSchedules(ctx context.Context, nowTime time.Time, userID int, index, limit int) ([]model.Schedule, error) {
var schedules []model.Schedule
err := d.db.WithContext(ctx).
Preload("Event").
Preload("EmbeddedTask").
Joins("JOIN schedule_events ON schedule_events.id = schedules.event_id").
// 修改后的核心逻辑:
// 1. 用户匹配 & 已结束
// 2. 满足 (事件本身是任务) OR (虽然是课程但嵌入了任务)
Where("schedules.user_id = ? AND schedule_events.end_time < ? AND (schedule_events.type = ? OR schedules.embedded_task_id IS NOT NULL)",
userID, nowTime, "task").
Order("schedule_events.end_time DESC"). // 命中索引
Offset(index).
Limit(limit).
Find(&schedules).Error
if err != nil {
return nil, err
}
return schedules, nil
}
func (d *ScheduleDAO) GetScheduleEventWeekByID(ctx context.Context, eventID int) (int, error) {
type row struct {
Week *int `gorm:"column:week"`
}
var r row
err := d.db.WithContext(ctx).
Table("schedules").
Select("week").
Where("event_id = ?", eventID).
Order("id ASC").
Limit(1).
Scan(&r).Error
if err != nil {
return 0, err
}
if r.Week == nil {
return 0, respond.WrongScheduleEventID
}
return *r.Week, nil
}
func (d *ScheduleDAO) GetUserOngoingSchedule(ctx context.Context, userID int, nowTime time.Time) ([]model.Schedule, error) {
var schedules []model.Schedule
err := d.db.WithContext(ctx).
Preload("Event").
Preload("EmbeddedTask").
Joins("JOIN schedule_events ON schedule_events.id = schedules.event_id").
Where("schedules.user_id = ? AND schedule_events.start_time <= ? AND schedule_events.end_time >= ?",
userID, nowTime, nowTime).
Or("schedules.user_id = ? AND schedule_events.start_time > ?",
userID, nowTime).
Order("schedule_events.start_time ASC"). // 命中索引
Find(&schedules).Error
if err != nil {
return nil, err
}
return schedules, nil
}
func (d *ScheduleDAO) RevocateSchedulesByEventID(ctx context.Context, eventID int) error {
// 将 schedules 表中指定 event_id 的 embedded_task_id 字段置空(用于撤销嵌入关系)
res := d.db.WithContext(ctx).
Table("schedules").
Where("event_id = ?", eventID).
Update("status", "interrupted")
if res.RowsAffected == 0 {
return respond.WrongScheduleEventID
}
return res.Error
}
func (d *ScheduleDAO) GetRelIDByScheduleEventID(ctx context.Context, eventID int) (int, error) {
type row struct {
RelID *int `gorm:"column:rel_id"`
}
var r row
err := d.db.WithContext(ctx).
Table("schedule_events").
Select("rel_id").
Where("id = ?", eventID).
First(&r).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return 0, respond.WrongScheduleEventID
}
return 0, err
}
if r.RelID == nil {
return 0, nil
}
return *r.RelID, nil
}
func (d *ScheduleDAO) GetUserSchedulesByTimeRange(ctx context.Context, userID int, startTime, endTime time.Time) ([]model.Schedule, error) {
var schedules []model.Schedule
err := d.db.WithContext(ctx).
Preload("Event").
Preload("EmbeddedTask").
Joins("JOIN schedule_events ON schedule_events.id = schedules.event_id").
Where("schedules.user_id = ? AND schedule_events.start_time >= ? AND schedule_events.end_time <= ?",
userID, startTime, endTime).
Order("schedule_events.start_time ASC"). // 命中索引
Find(&schedules).Error
if err != nil {
return nil, err
}
return schedules, nil
}
func (d *ScheduleDAO) BatchEmbedTaskIntoSchedule(ctx context.Context, eventIDs, taskItemIDs []int) error {
if len(eventIDs) == 0 {
return nil
}
if len(eventIDs) != len(taskItemIDs) {
return fmt.Errorf("eventIDs length != taskItemIDs length")
}
db := d.db.WithContext(ctx)
for i, eventID := range eventIDs {
taskItemID := taskItemIDs[i]
// 1) 校验该 event 是否为 course
var typ string
if err := db.
Table("schedule_events").
Select("type").
Where("id = ?", eventID).
Scan(&typ).Error; err != nil {
return err
}
if typ != "course" {
continue
}
// 2) 一 event 对多 schedules批量写入 embedded_task_id
if err := db.
Table("schedules").
Where("event_id = ?", eventID).
Update("embedded_task_id", taskItemID).Error; err != nil {
return err
}
}
return nil
}
func (d *ScheduleDAO) InsertScheduleEvents(ctx context.Context, events []model.ScheduleEvent) ([]int, error) {
if len(events) == 0 {
return nil, nil
}
if err := d.db.WithContext(ctx).Create(&events).Error; err != nil {
return nil, err
}
ids := make([]int, len(events))
for i, e := range events {
ids[i] = e.ID
}
return ids, nil
}

View File

@@ -0,0 +1,346 @@
package dao
import (
"context"
"errors"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/LoveLosita/smartflow/backend/shared/respond"
"gorm.io/gorm"
)
type TaskClassDAO struct {
// 这是一个口袋,用来装数据库连接实例
db *gorm.DB
}
// NewTaskClassDAO 创建TaskClassDAO实例
// NewTaskClassDAO 接收一个 *gorm.DB并把它塞进结构体的口袋里
func NewTaskClassDAO(db *gorm.DB) *TaskClassDAO {
return &TaskClassDAO{
db: db,
}
}
func (dao *TaskClassDAO) WithTx(tx *gorm.DB) *TaskClassDAO {
return &TaskClassDAO{
db: tx,
}
}
// AddOrUpdateTaskClass 为指定用户添加/更新任务类(防越权:更新时限定 user_id
func (dao *TaskClassDAO) AddOrUpdateTaskClass(userID int, taskClass *model.TaskClass) (int, error) {
// 不信任入参里的 UserID强制使用当前登录用户
taskClass.UserID = &userID
// 新增ID == 0 直接插入
if taskClass.ID == 0 {
if err := dao.db.Create(taskClass).Error; err != nil {
return 0, err
}
return taskClass.ID, nil
}
// 更新:必须同时匹配 id + user_id否则不会更新任何行避免覆盖他人数据
tx := dao.db.Model(&model.TaskClass{}).
Where("id = ? AND user_id = ?", taskClass.ID, userID).
Updates(taskClass)
if tx.Error != nil {
return 0, tx.Error
}
if tx.RowsAffected == 0 {
// 未匹配到记录:要么不存在,要么不属于该用户
return 0, respond.UserTaskClassForbidden
}
return taskClass.ID, nil
}
func (dao *TaskClassDAO) AddOrUpdateTaskClassItems(userID int, items []model.TaskClassItem) error {
if len(items) == 0 {
return nil
}
// 1) 校验这些 items 关联的 task_classcategory_id都属于当前用户
categoryIDSet := make(map[int]struct{}, len(items))
var categoryIDs []int
for _, it := range items {
if *it.CategoryID == 0 {
return gorm.ErrRecordNotFound
}
if _, ok := categoryIDSet[*it.CategoryID]; !ok {
categoryIDSet[*it.CategoryID] = struct{}{}
categoryIDs = append(categoryIDs, *it.CategoryID)
}
}
var count int64
if err := dao.db.Model(&model.TaskClass{}).
Where("id IN ? AND user_id = ?", categoryIDs, userID).
Count(&count).Error; err != nil {
return err
}
if count != int64(len(categoryIDs)) {
return respond.UserTaskClassForbidden
}
// 2) 新增与更新分开处理:新增不受影响;更新时限定 category_id防越权
var toCreate []model.TaskClassItem
for _, it := range items {
if it.ID == 0 {
toCreate = append(toCreate, it)
continue
}
tx := dao.db.Model(&model.TaskClassItem{}).
Where("id = ? AND category_id IN ?", it.ID, categoryIDs).
Updates(map[string]any{
"category_id": it.CategoryID,
})
if tx.Error != nil {
return tx.Error
}
if tx.RowsAffected == 0 {
return respond.UserTaskClassForbidden
}
}
if len(toCreate) > 0 {
if err := dao.db.Create(&toCreate).Error; err != nil {
return err
}
}
return nil
}
// Transaction 在一个事务中执行传入的函数,供 service 层复用(自动提交/回滚)
// 规则fn 返回 nil -> commitfn 返回 error 或 panic -> rollback
func (dao *TaskClassDAO) Transaction(fn func(txDAO *TaskClassDAO) error) error {
return dao.db.Transaction(func(tx *gorm.DB) error {
return fn(NewTaskClassDAO(tx))
})
}
func (dao *TaskClassDAO) GetUserTaskClasses(userID int) ([]model.TaskClass, error) {
var taskClasses []model.TaskClass
err := dao.db.Where("user_id = ?", userID).Find(&taskClasses).Error
if err != nil {
return nil, err
}
return taskClasses, nil
}
// GetCompleteTaskClassByID 带着 ID 和 UserID 去取,防越权
func (dao *TaskClassDAO) GetCompleteTaskClassByID(ctx context.Context, id int, userID int) (*model.TaskClass, error) {
var taskClass model.TaskClass
// 1. 使用 Preload("Items") 自动执行两条 SQL 并组装
// SQL A: SELECT * FROM task_classes WHERE id = ? AND user_id = ?
// SQL B: SELECT * FROM task_class_items WHERE category_id = (SQL A 的 ID)
err := dao.db.WithContext(ctx).
Preload("Items").
Where("id = ? AND user_id = ?", id, userID).
First(&taskClass).Error
if err != nil {
return nil, err
}
return &taskClass, nil
}
// GetCompleteTaskClassesByIDs 批量获取“完整任务类”(含 Items
//
// 职责边界:
// 1. 负责按 user_id + ids 过滤,保证数据归属安全;
// 2. 负责预加载 Items供智能粗排直接使用
// 3. 不负责排序策略,返回结果顺序由 service 层决定;
// 4. 若存在任一 id 不存在或不属于该用户,返回 WrongTaskClassID。
func (dao *TaskClassDAO) GetCompleteTaskClassesByIDs(ctx context.Context, userID int, ids []int) ([]model.TaskClass, error) {
if len(ids) == 0 {
return []model.TaskClass{}, nil
}
// 1. 先做去重与合法值过滤,避免无效 ID 放大数据库压力。
uniqueIDs := make([]int, 0, len(ids))
seen := make(map[int]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, exists := seen[id]; exists {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
}
if len(uniqueIDs) == 0 {
return nil, respond.WrongTaskClassID
}
// 2. 批量查询并预加载任务项。
var taskClasses []model.TaskClass
err := dao.db.WithContext(ctx).
Preload("Items").
Where("user_id = ? AND id IN ?", userID, uniqueIDs).
Find(&taskClasses).Error
if err != nil {
return nil, err
}
// 3. 数量校验:少一条都视为“存在非法/越权 ID”统一按业务错误返回。
if len(taskClasses) != len(uniqueIDs) {
return nil, respond.WrongTaskClassID
}
return taskClasses, nil
}
func (dao *TaskClassDAO) GetTaskClassItemByID(ctx context.Context, id int) (*model.TaskClassItem, error) {
var item model.TaskClassItem
err := dao.db.WithContext(ctx).
Where("id = ?", id).
First(&item).Error
if err != nil {
return nil, err
}
return &item, nil
}
func (dao *TaskClassDAO) GetTaskClassIDByTaskItemID(ctx context.Context, itemID int) (int, error) {
var item model.TaskClassItem
res := dao.db.WithContext(ctx).
Select("category_id").
Where("id = ?", itemID).
First(&item)
if res.Error != nil {
if errors.Is(res.Error, gorm.ErrRecordNotFound) {
return 0, respond.TaskClassItemNotFound
}
return 0, res.Error
}
return *item.CategoryID, nil
}
func (dao *TaskClassDAO) GetTaskClassUserIDByID(ctx context.Context, taskClassID int) (int, error) {
var taskClass model.TaskClass
err := dao.db.WithContext(ctx).
Select("user_id").
Where("id = ?", taskClassID).
First(&taskClass).Error
if err != nil {
return 0, err
}
return *taskClass.UserID, nil
}
func (dao *TaskClassDAO) UpdateTaskClassItemEmbeddedTime(ctx context.Context, taskID int, embeddedTime *model.TargetTime) error {
err := dao.db.WithContext(ctx).
Model(&model.TaskClassItem{}).
Where("id = ?", taskID).
Update("embedded_time", embeddedTime).Error
return err
}
func (dao *TaskClassDAO) DeleteTaskClassItemEmbeddedTime(ctx context.Context, taskID int) error {
err := dao.db.WithContext(ctx).
Model(&model.TaskClassItem{}).
Where("id = ?", taskID).
Update("embedded_time", nil).Error
return err
}
func (dao *TaskClassDAO) IfTaskClassItemArranged(ctx context.Context, taskID int) (bool, error) {
var item model.TaskClassItem
err := dao.db.WithContext(ctx).
Select("embedded_time").
Where("id = ?", taskID).
First(&item).Error
if err != nil {
return false, err
}
return item.EmbeddedTime != nil, nil
}
func (dao *TaskClassDAO) BatchCheckIfTaskClassItemsArranged(ctx context.Context, itemIDs []int) (bool, error) {
if len(itemIDs) == 0 {
return false, nil
}
var count int64
err := dao.db.WithContext(ctx).
Model(&model.TaskClassItem{}).
Where("id IN ? AND embedded_time IS NOT NULL", itemIDs).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
func (dao *TaskClassDAO) DeleteTaskClassItemByID(ctx context.Context, id int) error {
err := dao.db.WithContext(ctx).
Where("id = ?", id).
Delete(&model.TaskClassItem{}).Error
return err
}
func (dao *TaskClassDAO) DeleteTaskClassByID(ctx context.Context, id int) error {
res := dao.db.WithContext(ctx).
Where("id = ?", id).
Delete(&model.TaskClass{})
if res.Error != nil {
return res.Error
}
if res.RowsAffected == 0 {
return respond.WrongTaskClassID
}
return nil
}
func (dao *TaskClassDAO) BatchUpdateTaskClassItemEmbeddedTime(ctx context.Context, itemIDs []int, updates []*model.TargetTime) error {
if len(itemIDs) == 0 {
return nil
}
if len(itemIDs) != len(updates) {
return errors.New("itemIDs length mismatch updates length")
}
// 单条 SQL 批量更新UPDATE ... SET embedded_time = CASE id WHEN ? THEN ? ... END WHERE id IN (?)
caseSQL := "CASE id"
args := make([]any, 0, len(itemIDs)*2)
for i, id := range itemIDs {
caseSQL += " WHEN ? THEN ?"
args = append(args, id, updates[i])
}
caseSQL += " END"
res := dao.db.WithContext(ctx).
Model(&model.TaskClassItem{}).
Where("id IN ?", itemIDs).
Update("embedded_time", gorm.Expr(caseSQL, args...))
return res.Error
}
func (dao *TaskClassDAO) ValidateTaskItemIDsBelongToTaskClass(ctx context.Context, taskClassID int, itemIDs []int) (bool, error) {
if len(itemIDs) == 0 {
return true, nil
}
var count int64
err := dao.db.WithContext(ctx).
Model(&model.TaskClassItem{}).
Where("id IN ? AND category_id = ?", itemIDs, taskClassID).
Count(&count).Error
if err != nil {
return false, err
}
return count == int64(len(itemIDs)), nil
}
func (dao *TaskClassDAO) GetTaskClassItemsByIDs(ctx context.Context, itemIDs []int) ([]model.TaskClassItem, error) {
var items []model.TaskClassItem
err := dao.db.WithContext(ctx).
Where("id IN ?", itemIDs).
Find(&items).Error
if err != nil {
return nil, err
}
return items, nil
}

View File

@@ -0,0 +1,341 @@
package dao
import (
"context"
"errors"
"time"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
"github.com/LoveLosita/smartflow/backend/shared/respond"
"gorm.io/gorm"
)
type TaskDAO struct {
// 这是一个口袋,用来装数据库连接实例
db *gorm.DB
}
// NewTaskDAO 创建TaskDAO实例
// NewTaskDAO 接收一个 *gorm.DB并把它塞进结构体的口袋里
func NewTaskDAO(db *gorm.DB) *TaskDAO {
return &TaskDAO{
db: db,
}
}
func (r *TaskDAO) WithTx(tx *gorm.DB) *TaskDAO {
return &TaskDAO{db: tx}
}
// AddTask 为指定用户添加任务
func (dao *TaskDAO) AddTask(req *model.Task) (*model.Task, error) {
if err := dao.db.Create(req).Error; err != nil {
return nil, err
}
return req, nil
}
func (dao *TaskDAO) GetTasksByUserID(userID int) ([]model.Task, error) {
var tasks []model.Task
if err := dao.db.Where("user_id = ?", userID).Find(&tasks).Error; err != nil {
return nil, err
}
if len(tasks) == 0 { // 如果没有任务,返回自定义错误
return nil, respond.UserTasksEmpty
}
return tasks, nil
}
// CompleteTaskByID 将指定任务标记为"已完成"。
//
// 职责边界:
// 1. 只负责"当前用户 + 指定 task_id"的完成状态更新;
// 2. 不负责幂等中间件(由路由层统一挂载);
// 3. 不负责业务层响应包装(由 Service 层处理)。
//
// 返回语义:
// 1. 第一个返回值 *model.Task返回更新后的任务快照至少含 ID/UserID/IsCompleted
// 2. 第二个返回值 bool
// 2.1 true任务原本就已完成本次属于幂等命中
// 2.2 false本次从未完成成功更新为已完成
// 3. error
// 3.1 gorm.ErrRecordNotFound任务不存在或不属于当前用户
// 3.2 其他 error数据库异常。
func (dao *TaskDAO) CompleteTaskByID(ctx context.Context, userID int, taskID int) (*model.Task, bool, error) {
// 1. 基础兜底:非法参数直接返回"记录不存在"语义,避免下游误写。
if userID <= 0 || taskID <= 0 {
return nil, false, gorm.ErrRecordNotFound
}
// 2. 先查询目标任务,明确区分"已完成"与"不存在"。
var target model.Task
findErr := dao.db.WithContext(ctx).
Where("id = ? AND user_id = ?", taskID, userID).
First(&target).Error
if findErr != nil {
return nil, false, findErr
}
// 3. 若任务已完成,直接按幂等成功返回,不再写库。
if target.IsCompleted {
return &target, true, nil
}
// 4. 若任务未完成,执行状态更新。
//
// 4.1 使用 Model(&model.Task{UserID:userID}) 的目的:
// 让 cache_deleter 在 GORM Update 回调里拿到 user_id从而正确删除任务缓存。
// 4.2 更新条件继续限定 user_id + id避免误更新其他用户数据。
updateResult := dao.db.WithContext(ctx).
Model(&model.Task{UserID: userID}).
Where("id = ? AND user_id = ?", taskID, userID).
Update("is_completed", true)
if updateResult.Error != nil {
return nil, false, updateResult.Error
}
// 5. 极端并发兜底:
// 5.1 若 RowsAffected=0可能是并发请求已先一步更新
// 5.2 此时二次读取任务状态,若已完成则按幂等成功返回,否则视为不存在/异常。
if updateResult.RowsAffected == 0 {
var check model.Task
checkErr := dao.db.WithContext(ctx).
Where("id = ? AND user_id = ?", taskID, userID).
First(&check).Error
if checkErr != nil {
return nil, false, checkErr
}
if check.IsCompleted {
return &check, true, nil
}
return nil, false, errors.New("任务状态更新失败")
}
// 6. 返回更新后的快照给 Service 层组装响应。
target.IsCompleted = true
return &target, false, nil
}
// UndoCompleteTaskByID 将指定任务从"已完成"恢复为"未完成"。
//
// 职责边界:
// 1. 只负责当前用户(user_id)下指定 task_id 的状态恢复;
// 2. 若任务本就未完成,按业务要求返回明确错误,不做幂等成功;
// 3. 不负责响应文案拼装(由 Service 层处理)。
//
// 返回语义:
// 1. *model.Task恢复后的任务快照
// 2. error
// 2.1 gorm.ErrRecordNotFound任务不存在或不属于当前用户
// 2.2 respond.TaskNotCompleted任务当前不是"已完成"状态,不能执行取消勾选;
// 2.3 其他 error数据库异常。
func (dao *TaskDAO) UndoCompleteTaskByID(ctx context.Context, userID int, taskID int) (*model.Task, error) {
// 1. 参数兜底:非法 user/task 参数统一按"记录不存在"处理,避免误写。
if userID <= 0 || taskID <= 0 {
return nil, gorm.ErrRecordNotFound
}
// 2. 先读取目标任务,明确区分"不存在"和"状态不允许恢复"。
var target model.Task
findErr := dao.db.WithContext(ctx).
Where("id = ? AND user_id = ?", taskID, userID).
First(&target).Error
if findErr != nil {
return nil, findErr
}
// 3. 严格业务约束:若任务当前未完成,直接返回业务错误。
// 3.1 这是本接口和"标记完成"接口的关键差异:这里不做幂等成功。
if !target.IsCompleted {
return nil, respond.TaskNotCompleted
}
// 4. 执行状态恢复is_completed=true -> false
//
// 4.1 使用 Model(&model.Task{UserID:userID}) 的目的是让 cache_deleter 拿到 user_id
// 从而在回调中正确删除该用户任务缓存。
updateResult := dao.db.WithContext(ctx).
Model(&model.Task{UserID: userID}).
Where("id = ? AND user_id = ?", taskID, userID).
Update("is_completed", false)
if updateResult.Error != nil {
return nil, updateResult.Error
}
// 5. 并发兜底:
// 5.1 若 RowsAffected=0说明可能被并发请求先一步恢复
// 5.2 重新读取当前状态,若已是未完成则按业务规则返回"任务未完成"错误。
if updateResult.RowsAffected == 0 {
var check model.Task
checkErr := dao.db.WithContext(ctx).
Where("id = ? AND user_id = ?", taskID, userID).
First(&check).Error
if checkErr != nil {
return nil, checkErr
}
if !check.IsCompleted {
return nil, respond.TaskNotCompleted
}
return nil, errors.New("取消任务完成状态失败")
}
// 6. 回填恢复后状态并返回。
target.IsCompleted = false
return &target, nil
}
// PromoteTaskUrgencyByIDs 批量执行"任务紧急性平移"。
//
// 职责边界:
// 1. 只负责把满足条件的任务从"不紧急象限"平移到"紧急象限"
// 1.1 priority=2 -> 1重要不紧急 -> 重要且紧急);
// 1.2 priority=4 -> 3不简单不重要 -> 简单不重要);
// 2. 只更新本次指定 user_id + task_ids 范围内的数据;
// 3. 不负责事件发布、重试去重和缓存策略(由 Service/Outbox 负责)。
//
// 幂等与一致性说明:
// 1. SQL 条件会限制 `is_completed=0`、`urgency_threshold_at<=now`、`priority IN (2,4)`
// 2. 同一批任务重复调用时,已经平移过的记录不会再次更新(幂等);
// 3. 使用 `Model(&model.Task{UserID:userID})` 是为了让 GORM 回调拿到 user_id从而触发 cache_deleter 删除任务缓存。
func (dao *TaskDAO) PromoteTaskUrgencyByIDs(ctx context.Context, userID int, taskIDs []int, now time.Time) (int64, error) {
// 1. 基础兜底:非法 user 或空任务列表直接无操作返回。
if userID <= 0 || len(taskIDs) == 0 {
return 0, nil
}
// 2. 去重并过滤非正数 ID避免无效 where in 条件放大 SQL 噪音。
validTaskIDs := compactPositiveIntIDs(taskIDs)
if len(validTaskIDs) == 0 {
return 0, nil
}
// 3. 条件更新:只更新"已到紧急分界线且仍处于非紧急象限"的任务。
result := dao.db.WithContext(ctx).
Model(&model.Task{UserID: userID}).
Where("user_id = ?", userID).
Where("id IN ?", validTaskIDs).
Where("is_completed = ?", false).
Where("urgency_threshold_at IS NOT NULL AND urgency_threshold_at <= ?", now).
Where("priority IN ?", []int{2, 4}).
Update("priority", gorm.Expr("CASE WHEN priority = 2 THEN 1 WHEN priority = 4 THEN 3 ELSE priority END"))
if result.Error != nil {
return 0, result.Error
}
return result.RowsAffected, nil
}
// UpdateTaskByID 按 task_id + user_id 更新指定字段。
//
// 职责边界:
// 1. 只负责按 updates map 执行 SET 子句更新;
// 2. 不负责业务规则(如优先级范围校验),由 Service 层处理;
// 3. 使用 Model(&model.Task{UserID: userID}) 让 cache_deleter 回调拿到 user_id。
//
// 返回语义:
// 1. *model.Task更新后的完整任务快照
// 2. error
// 2.1 gorm.ErrRecordNotFound任务不存在或不属于当前用户
// 2.2 其他 error数据库异常。
func (dao *TaskDAO) UpdateTaskByID(ctx context.Context, userID int, taskID int, updates map[string]interface{}) (*model.Task, error) {
// 1. 参数兜底:非法参数直接返回"记录不存在"语义。
if userID <= 0 || taskID <= 0 {
return nil, gorm.ErrRecordNotFound
}
// 2. 先查询目标任务,确认存在且归属当前用户。
var target model.Task
findErr := dao.db.WithContext(ctx).
Where("id = ? AND user_id = ?", taskID, userID).
First(&target).Error
if findErr != nil {
return nil, findErr
}
// 3. 执行部分字段更新。
// 3.1 使用 Model(&model.Task{UserID: userID}) 触发 cache_deleter。
// 3.2 限定 id + user_id 条件,避免误更新。
updateResult := dao.db.WithContext(ctx).
Model(&model.Task{UserID: userID}).
Where("id = ? AND user_id = ?", taskID, userID).
Updates(updates)
if updateResult.Error != nil {
return nil, updateResult.Error
}
// 4. 更新后重新读取,保证返回完整且一致的快照。
var updated model.Task
if err := dao.db.WithContext(ctx).
Where("id = ? AND user_id = ?", taskID, userID).
First(&updated).Error; err != nil {
return nil, err
}
return &updated, nil
}
// DeleteTaskByID 永久删除指定任务(硬删除)。
//
// 职责边界:
// 1. 只负责删除 user_id + task_id 对应的记录;
// 2. 使用 Model(&model.Task{UserID: userID}) 触发 cache_deleter 删除用户任务缓存;
// 3. 不负责级联清理日程tasks 与 schedule_events 无直接外键关联)。
//
// 返回语义:
// 1. *model.Task被删除的任务快照用于响应前端
// 2. error
// 2.1 gorm.ErrRecordNotFound任务不存在或不属于当前用户
// 2.2 其他 error数据库异常。
func (dao *TaskDAO) DeleteTaskByID(ctx context.Context, userID int, taskID int) (*model.Task, error) {
// 1. 参数兜底。
if userID <= 0 || taskID <= 0 {
return nil, gorm.ErrRecordNotFound
}
// 2. 先查询目标任务,确认存在且归属当前用户,同时获取快照用于响应。
var target model.Task
findErr := dao.db.WithContext(ctx).
Where("id = ? AND user_id = ?", taskID, userID).
First(&target).Error
if findErr != nil {
return nil, findErr
}
// 3. 执行硬删除。
// 3.1 使用 Model(&model.Task{UserID: userID}) 触发 cache_deleter。
deleteResult := dao.db.WithContext(ctx).
Model(&model.Task{UserID: userID}).
Where("id = ? AND user_id = ?", taskID, userID).
Delete(&model.Task{})
if deleteResult.Error != nil {
return nil, deleteResult.Error
}
// 4. 并发兜底RowsAffected=0 说明被并发请求先一步删除。
if deleteResult.RowsAffected == 0 {
return nil, gorm.ErrRecordNotFound
}
return &target, nil
}
// compactPositiveIntIDs 对 int 切片做"去重 + 过滤非正数"。
//
// 说明:
// 1. 该函数是 DAO 内部参数清洗工具,不参与任何业务判定;
// 2. 返回结果不保证稳定顺序,对当前 SQL where in 场景无影响。
func compactPositiveIntIDs(ids []int) []int {
seen := make(map[int]struct{}, len(ids))
result := make([]int, 0, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, exists := seen[id]; exists {
continue
}
seen[id] = struct{}{}
result = append(result, id)
}
return result
}

View File

@@ -0,0 +1,85 @@
package eventsvc
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
sharedevents "github.com/LoveLosita/smartflow/backend/shared/events"
kafkabus "github.com/LoveLosita/smartflow/backend/shared/infra/kafka"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
"gorm.io/gorm"
)
// ActiveScheduleTriggeredProcessor 描述 active_schedule.triggered worker 真正执行业务所需的最小能力。
//
// 职责边界:
// 1. ProcessTriggeredInTx 负责事务内的 trigger -> preview -> notification 编排;
// 2. MarkTriggerFailedBestEffort 负责事务外的失败回写,避免 outbox retry 前完全没有业务态可查;
// 3. 接口本身不限定具体实现,便于迁移期由 active_scheduler 模块独立演进。
type ActiveScheduleTriggeredProcessor interface {
ProcessTriggeredInTx(ctx context.Context, tx *gorm.DB, payload sharedevents.ActiveScheduleTriggeredPayload) error
MarkTriggerFailedBestEffort(ctx context.Context, triggerID string, err error)
}
// RegisterActiveScheduleTriggeredHandler 注册 active_schedule.triggered outbox handler。
//
// 步骤化说明:
// 1. 先做 envelope -> contract DTO 解析与版本校验,明显坏消息直接标记 dead
// 2. 再通过 ConsumeAndMarkConsumed 把“业务落库 + consumed 推进”收敛在同一事务里;
// 3. 若事务返回 error则 best-effort 回写 trigger failed并把错误交给 outbox 做 retry
// 4. 这里不直接 import active_scheduler 的具体实现,避免 service/events 和业务编排层互相反向耦合。
func RegisterActiveScheduleTriggeredHandler(
bus OutboxBus,
outboxRepo *outboxinfra.Repository,
processor ActiveScheduleTriggeredProcessor,
) error {
if bus == nil {
return errors.New("event bus is nil")
}
if outboxRepo == nil {
return errors.New("outbox repository is nil")
}
if processor == nil {
return errors.New("active schedule triggered processor is nil")
}
eventOutboxRepo, err := scopedOutboxRepoForEvent(outboxRepo, sharedevents.ActiveScheduleTriggeredEventType)
if err != nil {
return err
}
handler := func(ctx context.Context, envelope kafkabus.Envelope) error {
if !isAllowedTriggeredEventVersion(envelope.EventVersion) {
_ = eventOutboxRepo.MarkDead(ctx, envelope.OutboxID, fmt.Sprintf("active_schedule.triggered 版本不受支持: %s", envelope.EventVersion))
return nil
}
var payload sharedevents.ActiveScheduleTriggeredPayload
if unmarshalErr := json.Unmarshal(envelope.Payload, &payload); unmarshalErr != nil {
_ = eventOutboxRepo.MarkDead(ctx, envelope.OutboxID, "解析 active_schedule.triggered 载荷失败: "+unmarshalErr.Error())
return nil
}
if validateErr := payload.Validate(); validateErr != nil {
_ = eventOutboxRepo.MarkDead(ctx, envelope.OutboxID, "active_schedule.triggered 载荷非法: "+validateErr.Error())
return nil
}
err := eventOutboxRepo.ConsumeAndMarkConsumed(ctx, envelope.OutboxID, func(tx *gorm.DB) error {
return processor.ProcessTriggeredInTx(ctx, tx, payload)
})
if err != nil {
processor.MarkTriggerFailedBestEffort(ctx, payload.TriggerID, err)
return err
}
return nil
}
return bus.RegisterEventHandler(sharedevents.ActiveScheduleTriggeredEventType, handler)
}
func isAllowedTriggeredEventVersion(version string) bool {
version = strings.TrimSpace(version)
return version == "" || version == sharedevents.ActiveScheduleTriggeredEventVersion
}

View File

@@ -0,0 +1,130 @@
package eventsvc
import (
"context"
"encoding/json"
"errors"
"log"
agentmodel "github.com/LoveLosita/smartflow/backend/services/agent/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/dao"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
kafkabus "github.com/LoveLosita/smartflow/backend/shared/infra/kafka"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
const (
// EventTypeAgentStateSnapshotPersist 是"agent 状态快照持久化"的业务事件类型。
EventTypeAgentStateSnapshotPersist = "agent.state.snapshot.persist"
)
// AgentStateSnapshotPayload 是 outbox 事件的业务载荷。
type AgentStateSnapshotPayload struct {
ConversationID string `json:"conversation_id"`
UserID int `json:"user_id"`
Phase string `json:"phase"`
SnapshotJSON string `json:"snapshot_json"`
}
// RegisterAgentStateSnapshotHandler 注册"agent 状态快照持久化"消费者处理器。
//
// 职责边界:
// 1. 只负责快照写入 agent_state_snapshot_records 表;
// 2. 使用 upsert 语义,同一 conversation_id 只保留最新快照;
// 3. 通过 outbox 通用消费事务保证"业务写入 + consumed 推进"原子一致。
func RegisterAgentStateSnapshotHandler(
bus OutboxBus,
outboxRepo *outboxinfra.Repository,
repoManager *dao.RepoManager,
) error {
if bus == nil {
return errors.New("event bus is nil")
}
if outboxRepo == nil {
return errors.New("outbox repository is nil")
}
if repoManager == nil {
return errors.New("repo manager is nil")
}
eventOutboxRepo, err := scopedOutboxRepoForEvent(outboxRepo, EventTypeAgentStateSnapshotPersist)
if err != nil {
return err
}
handler := func(ctx context.Context, envelope kafkabus.Envelope) error {
var payload AgentStateSnapshotPayload
if unmarshalErr := json.Unmarshal(envelope.Payload, &payload); unmarshalErr != nil {
_ = eventOutboxRepo.MarkDead(ctx, envelope.OutboxID, "解析快照载荷失败: "+unmarshalErr.Error())
return nil
}
return eventOutboxRepo.ConsumeAndMarkConsumed(ctx, envelope.OutboxID, func(tx *gorm.DB) error {
record := model.AgentStateSnapshotRecord{
ConversationID: payload.ConversationID,
UserID: payload.UserID,
Phase: payload.Phase,
SnapshotJSON: payload.SnapshotJSON,
}
return tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "conversation_id"}},
DoUpdates: clause.AssignmentColumns([]string{"user_id", "phase", "snapshot_json", "updated_at"}),
}).Create(&record).Error
})
}
return bus.RegisterEventHandler(EventTypeAgentStateSnapshotPersist, handler)
}
// PublishAgentStateSnapshot 发布"agent 状态快照持久化"事件到 outbox。
//
// 设计说明:
// 1. 将快照 JSON 序列化后通过 outbox 异步写入 MySQL
// 2. publisher 为 nil 时静默降级Kafka 未启用场景);
// 3. 发布失败只记日志,不中断主流程。
func PublishAgentStateSnapshot(
ctx context.Context,
publisher outboxinfra.EventPublisher,
snapshot *agentmodel.AgentStateSnapshot,
conversationID string,
userID int,
) {
if publisher == nil {
return
}
if snapshot == nil {
return
}
snapshotJSON, err := json.Marshal(snapshot)
if err != nil {
log.Printf("[WARN] 序列化 agent 状态快照失败 chat=%s: %v", conversationID, err)
return
}
phase := ""
if snapshot.RuntimeState != nil {
cs := snapshot.RuntimeState.EnsureCommonState()
if cs != nil {
phase = string(cs.Phase)
}
}
payload := AgentStateSnapshotPayload{
ConversationID: conversationID,
UserID: userID,
Phase: phase,
SnapshotJSON: string(snapshotJSON),
}
if err := publisher.Publish(ctx, outboxinfra.PublishRequest{
EventType: EventTypeAgentStateSnapshotPersist,
EventVersion: outboxinfra.DefaultEventVersion,
MessageKey: conversationID,
AggregateID: conversationID,
Payload: payload,
}); err != nil {
log.Printf("[WARN] 发布 agent 状态快照事件失败 chat=%s: %v", conversationID, err)
}
}

View File

@@ -0,0 +1,330 @@
package eventsvc
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"strings"
"github.com/LoveLosita/smartflow/backend/services/runtime/dao"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
kafkabus "github.com/LoveLosita/smartflow/backend/shared/infra/kafka"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
"gorm.io/gorm"
)
const EventTypeAgentTimelinePersistRequested = "agent.timeline.persist.requested"
// RegisterAgentTimelinePersistHandler 注册“会话时间线持久化”消费者处理器。
//
// 职责边界:
// 1. 只负责 timeline 事件,不处理 chat_history 等其他业务消息;
// 2. 只负责注册 handler不负责总线启停
// 3. 通过 outbox 通用消费事务,把“时间线写库 + consumed 推进”放进同一事务;
// 4. 若遇到 seq 唯一键冲突,会先判定是否属于重放幂等,再决定是否补新 seq 并回填 Redis。
func RegisterAgentTimelinePersistHandler(
bus OutboxBus,
outboxRepo *outboxinfra.Repository,
agentRepo *dao.AgentDAO,
cacheDAO *dao.CacheDAO,
) error {
// 1. 依赖校验:缺少任一关键依赖都无法安全消费消息。
if bus == nil {
return errors.New("event bus is nil")
}
if outboxRepo == nil {
return errors.New("outbox repository is nil")
}
if agentRepo == nil {
return errors.New("agent repo is nil")
}
eventOutboxRepo, err := scopedOutboxRepoForEvent(outboxRepo, EventTypeAgentTimelinePersistRequested)
if err != nil {
return err
}
handler := func(ctx context.Context, envelope kafkabus.Envelope) error {
var payload model.ChatTimelinePersistPayload
if unmarshalErr := json.Unmarshal(envelope.Payload, &payload); unmarshalErr != nil {
// 1. payload 无法反序列化属于不可恢复错误,直接标 dead避免无意义重试。
_ = eventOutboxRepo.MarkDead(ctx, envelope.OutboxID, "解析时间线持久化载荷失败: "+unmarshalErr.Error())
return nil
}
payload = payload.Normalize()
if !payload.HasValidIdentity() {
// 2. 这里只校验“能否唯一定位一条 timeline 记录”的最小字段集合。
// 3. content / payload_json 是否为空由事件类型自行决定,不在这里一刀切限制。
_ = eventOutboxRepo.MarkDead(ctx, envelope.OutboxID, "时间线持久化载荷非法: user_id/conversation_id/seq/kind 非法")
return nil
}
refreshCache := false
finalSeq := payload.Seq
// 4. 统一走 outbox 消费事务入口,保证“业务写入成功 -> consumed”原子一致。
err := eventOutboxRepo.ConsumeAndMarkConsumed(ctx, envelope.OutboxID, func(tx *gorm.DB) error {
finalPayload, repaired, persistErr := persistConversationTimelineEventInTx(ctx, tx, agentRepo.WithTx(tx), payload)
if persistErr != nil {
return persistErr
}
refreshCache = repaired
finalSeq = finalPayload.Seq
return nil
})
if err != nil {
return err
}
// 5. 只有发生“seq 冲突且补了新 seq”时才需要重建 Redis timeline。
// 5.1 原因:主链路已经先写过 Redis常规成功无需重复回写。
// 5.2 若发生补 seq不重建会留下旧 seq 的缓存残影,刷新后顺序会错乱。
// 5.3 缓存重建失败只记日志,不能反向把已 consumed 的 outbox 回滚。
if refreshCache {
if refreshErr := rebuildConversationTimelineCache(ctx, agentRepo, cacheDAO, payload.UserID, payload.ConversationID, finalSeq); refreshErr != nil {
log.Printf("重建时间线缓存失败 user=%d chat=%s seq=%d err=%v", payload.UserID, payload.ConversationID, finalSeq, refreshErr)
}
}
return nil
}
return bus.RegisterEventHandler(EventTypeAgentTimelinePersistRequested, handler)
}
// PublishAgentTimelinePersistRequested 发布“会话时间线持久化请求”事件。
//
// 设计目的:
// 1. 让业务层只传 DTO不重复拼事件元数据
// 2. 统一以 conversation_id 作为 MessageKey / AggregateID尽量降低同会话乱序概率
// 3. 发布失败显式返回 error由调用方决定是否中断主链路。
func PublishAgentTimelinePersistRequested(
ctx context.Context,
publisher outboxinfra.EventPublisher,
payload model.ChatTimelinePersistPayload,
) error {
if publisher == nil {
return errors.New("event publisher is nil")
}
payload = payload.Normalize()
if !payload.HasValidIdentity() {
return errors.New("invalid timeline persist payload")
}
return publisher.Publish(ctx, outboxinfra.PublishRequest{
EventType: EventTypeAgentTimelinePersistRequested,
EventVersion: outboxinfra.DefaultEventVersion,
MessageKey: payload.ConversationID,
AggregateID: payload.ConversationID,
Payload: payload,
})
}
// persistConversationTimelineEventInTx 负责在单个事务里完成 timeline 事件写库。
//
// 步骤化说明:
// 1. 先按 payload 原始 seq 尝试写入;
// 2. 若命中 seq 唯一键冲突,先查询同 seq 记录,判断是否属于“重放同一事件”;
// 3. 若不是重放,而是 Redis seq 漂移导致的新旧事件撞 seq则用 max(seq)+1 重新分配;
// 4. 最多修复 3 次,避免异常数据把消费者拖进无限循环。
func persistConversationTimelineEventInTx(
ctx context.Context,
tx *gorm.DB,
agentRepo *dao.AgentDAO,
payload model.ChatTimelinePersistPayload,
) (model.ChatTimelinePersistPayload, bool, error) {
if tx == nil {
return payload, false, errors.New("transaction is nil")
}
if agentRepo == nil {
return payload, false, errors.New("agent repo is nil")
}
working := payload.Normalize()
repaired := false
for attempt := 0; attempt < 3; attempt++ {
if _, _, err := agentRepo.SaveConversationTimelineEvent(ctx, working); err == nil {
return working, repaired, nil
} else if !model.IsTimelineSeqConflictError(err) {
return working, repaired, err
}
// 1. 先判断是否属于“同一条事件被重复消费”。
// 2. 若库里已有记录且字段完全一致,说明前一次其实已经成功落库,本次可视为幂等成功。
// 3. 若字段不一致,再进入“补新 seq”分支避免把真正的新事件吞掉。
existing, findErr := findConversationTimelineEventBySeq(ctx, tx, working.UserID, working.ConversationID, working.Seq)
if findErr == nil && working.MatchesStoredEvent(existing) {
return working, repaired, nil
}
if findErr != nil && !errors.Is(findErr, gorm.ErrRecordNotFound) {
return working, repaired, findErr
}
maxSeq, maxErr := loadConversationTimelineMaxSeq(ctx, tx, working.UserID, working.ConversationID)
if maxErr != nil {
return working, repaired, maxErr
}
working.Seq = maxSeq + 1
repaired = true
}
return working, repaired, fmt.Errorf("timeline seq repair exceeded limit user=%d chat=%s", working.UserID, working.ConversationID)
}
func findConversationTimelineEventBySeq(
ctx context.Context,
tx *gorm.DB,
userID int,
conversationID string,
seq int64,
) (model.AgentTimelineEvent, error) {
var event model.AgentTimelineEvent
err := tx.WithContext(ctx).
Where("user_id = ? AND chat_id = ? AND seq = ?", userID, strings.TrimSpace(conversationID), seq).
Take(&event).Error
return event, err
}
func loadConversationTimelineMaxSeq(
ctx context.Context,
tx *gorm.DB,
userID int,
conversationID string,
) (int64, error) {
var maxSeq int64
err := tx.WithContext(ctx).
Model(&model.AgentTimelineEvent{}).
Where("user_id = ? AND chat_id = ?", userID, strings.TrimSpace(conversationID)).
Select("COALESCE(MAX(seq), 0)").
Scan(&maxSeq).Error
if err != nil {
return 0, err
}
return maxSeq, nil
}
// rebuildConversationTimelineCache 在“补新 seq”后重建 Redis timeline 缓存。
//
// 说明:
// 1. 这里只在缓存存在时执行;未接 Redis 的环境直接跳过即可;
// 2. 需要整表重建而不是只 append 一条,因为旧缓存里已经存在错误 seq 的事件;
// 3. 这里不抽到 agent/sv 复用,是因为 events 不能反向依赖 service否则会形成循环依赖。
func rebuildConversationTimelineCache(
ctx context.Context,
agentRepo *dao.AgentDAO,
cacheDAO *dao.CacheDAO,
userID int,
conversationID string,
finalSeq int64,
) error {
if cacheDAO == nil || agentRepo == nil {
return nil
}
events, err := agentRepo.ListConversationTimelineEvents(ctx, userID, conversationID)
if err != nil {
return err
}
items := buildConversationTimelineCacheItems(events)
if err = cacheDAO.SetConversationTimelineToCache(ctx, userID, conversationID, items); err != nil {
return err
}
if len(items) > 0 {
finalSeq = items[len(items)-1].Seq
}
return cacheDAO.SetConversationTimelineSeq(ctx, userID, conversationID, finalSeq)
}
func buildConversationTimelineCacheItems(events []model.AgentTimelineEvent) []model.GetConversationTimelineItem {
if len(events) == 0 {
return make([]model.GetConversationTimelineItem, 0)
}
items := make([]model.GetConversationTimelineItem, 0, len(events))
for _, event := range events {
item := model.GetConversationTimelineItem{
ID: event.ID,
Seq: event.Seq,
Kind: strings.TrimSpace(event.Kind),
TokensConsumed: event.TokensConsumed,
CreatedAt: event.CreatedAt,
}
if event.Role != nil {
item.Role = strings.TrimSpace(*event.Role)
}
if event.Content != nil {
item.Content = strings.TrimSpace(*event.Content)
}
if event.Payload != nil {
var payload map[string]any
if err := json.Unmarshal([]byte(strings.TrimSpace(*event.Payload)), &payload); err == nil && len(payload) > 0 {
item.Payload = payload
}
}
items = append(items, item)
}
return normalizeConversationTimelineCacheItems(items)
}
func normalizeConversationTimelineCacheItems(items []model.GetConversationTimelineItem) []model.GetConversationTimelineItem {
if len(items) == 0 {
return make([]model.GetConversationTimelineItem, 0)
}
normalized := make([]model.GetConversationTimelineItem, 0, len(items))
for _, item := range items {
role := strings.ToLower(strings.TrimSpace(item.Role))
kind := canonicalizeConversationTimelineKind(item.Kind, role)
if kind == "" {
switch role {
case "user":
kind = model.AgentTimelineKindUserText
case "assistant":
kind = model.AgentTimelineKindAssistantText
}
}
if role == "" {
switch kind {
case model.AgentTimelineKindUserText:
role = "user"
case model.AgentTimelineKindAssistantText:
role = "assistant"
}
}
item.Kind = kind
item.Role = role
normalized = append(normalized, item)
}
return normalized
}
func canonicalizeConversationTimelineKind(kind string, role string) string {
normalizedKind := strings.ToLower(strings.TrimSpace(kind))
normalizedRole := strings.ToLower(strings.TrimSpace(role))
switch normalizedKind {
case model.AgentTimelineKindUserText,
model.AgentTimelineKindAssistantText,
model.AgentTimelineKindToolCall,
model.AgentTimelineKindToolResult,
model.AgentTimelineKindConfirmRequest,
model.AgentTimelineKindBusinessCard,
model.AgentTimelineKindScheduleCompleted,
model.AgentTimelineKindThinkingSummary:
return normalizedKind
case "text", "message", "query":
if normalizedRole == "user" {
return model.AgentTimelineKindUserText
}
if normalizedRole == "assistant" {
return model.AgentTimelineKindAssistantText
}
}
return normalizedKind
}

View File

@@ -0,0 +1,115 @@
package eventsvc
import (
"context"
"encoding/json"
"errors"
"strconv"
"strings"
"github.com/LoveLosita/smartflow/backend/services/runtime/dao"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
contracts "github.com/LoveLosita/smartflow/backend/shared/contracts/userauth"
kafkabus "github.com/LoveLosita/smartflow/backend/shared/infra/kafka"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
"github.com/LoveLosita/smartflow/backend/shared/ports"
"gorm.io/gorm"
)
const (
// EventTypeChatHistoryPersistRequested 是聊天消息持久化请求的业务事件类型。
EventTypeChatHistoryPersistRequested = "chat.history.persist.requested"
)
// RegisterChatHistoryPersistHandler 注册“聊天消息持久化”消费者。
// 职责边界:
// 1. 只处理聊天历史事件,不处理其它业务事件;
// 2. 只负责注册,不负责总线启动;
// 3. 先写本地 chat 相关表,再调用 userauth 调整 token 额度;
// 4. 当前版本仅注册新路由键,不再注册旧兼容键。
func RegisterChatHistoryPersistHandler(
bus OutboxBus,
outboxRepo *outboxinfra.Repository,
repoManager *dao.RepoManager,
adjuster ports.TokenUsageAdjuster,
) error {
if bus == nil {
return errors.New("event bus is nil")
}
if outboxRepo == nil {
return errors.New("outbox repository is nil")
}
if repoManager == nil {
return errors.New("repo manager is nil")
}
eventOutboxRepo, err := scopedOutboxRepoForEvent(outboxRepo, EventTypeChatHistoryPersistRequested)
if err != nil {
return err
}
handler := func(ctx context.Context, envelope kafkabus.Envelope) error {
var payload model.ChatHistoryPersistPayload
if unmarshalErr := json.Unmarshal(envelope.Payload, &payload); unmarshalErr != nil {
_ = eventOutboxRepo.MarkDead(ctx, envelope.OutboxID, "解析聊天持久化载荷失败: "+unmarshalErr.Error())
return nil
}
eventID := strings.TrimSpace(envelope.EventID)
if eventID == "" {
eventID = strconv.FormatInt(envelope.OutboxID, 10)
}
if err := eventOutboxRepo.ConsumeInTx(ctx, envelope.OutboxID, func(tx *gorm.DB) error {
txM := repoManager.WithTx(tx)
return txM.Agent.SaveChatHistoryInTx(
ctx,
payload.UserID,
payload.ConversationID,
payload.Role,
payload.Message,
payload.ReasoningContent,
payload.ReasoningDurationSeconds,
payload.TokensConsumed,
eventID,
)
}); err != nil {
return err
}
if payload.TokensConsumed > 0 {
if adjuster == nil {
return errors.New("userauth token adjuster is nil")
}
if _, err := adjuster.AdjustTokenUsage(ctx, contracts.AdjustTokenUsageRequest{
EventID: eventID,
UserID: payload.UserID,
TokenDelta: payload.TokensConsumed,
}); err != nil {
return err
}
}
return eventOutboxRepo.MarkConsumed(ctx, envelope.OutboxID)
}
return bus.RegisterEventHandler(EventTypeChatHistoryPersistRequested, handler)
}
// PublishChatHistoryPersistRequested 发布“聊天消息持久化请求”事件。
func PublishChatHistoryPersistRequested(
ctx context.Context,
publisher outboxinfra.EventPublisher,
payload model.ChatHistoryPersistPayload,
) error {
if publisher == nil {
return errors.New("event publisher is nil")
}
return publisher.Publish(ctx, outboxinfra.PublishRequest{
EventType: EventTypeChatHistoryPersistRequested,
EventVersion: outboxinfra.DefaultEventVersion,
MessageKey: payload.ConversationID,
AggregateID: payload.ConversationID,
Payload: payload,
})
}

View File

@@ -0,0 +1,126 @@
package eventsvc
import (
"context"
"encoding/json"
"errors"
"strconv"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/services/runtime/dao"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
contracts "github.com/LoveLosita/smartflow/backend/shared/contracts/userauth"
kafkabus "github.com/LoveLosita/smartflow/backend/shared/infra/kafka"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
"github.com/LoveLosita/smartflow/backend/shared/ports"
"gorm.io/gorm"
)
const (
// EventTypeChatTokenUsageAdjustRequested 是“会话 token 额度调整”事件类型。
// 命名约束:
// 1. 只表达业务语义,不泄露 outbox/kafka 实现细节;
// 2. 作为稳定路由键长期保留,后续演进优先通过 event_version。
EventTypeChatTokenUsageAdjustRequested = "chat.token.usage.adjust.requested"
)
// RegisterChatTokenUsageAdjustHandler 注册“会话 token 额度调整”消费者。
// 职责边界:
// 1. 只处理 token 调整事件,不处理聊天正文落库;
// 2. 先写本地账本,再调用 userauth 侧做额度同步;
// 3. 非法载荷直接标记 dead避免无意义重试。
func RegisterChatTokenUsageAdjustHandler(
bus OutboxBus,
outboxRepo *outboxinfra.Repository,
repoManager *dao.RepoManager,
adjuster ports.TokenUsageAdjuster,
) error {
if bus == nil {
return errors.New("event bus is nil")
}
if outboxRepo == nil {
return errors.New("outbox repository is nil")
}
if repoManager == nil {
return errors.New("repo manager is nil")
}
eventOutboxRepo, err := scopedOutboxRepoForEvent(outboxRepo, EventTypeChatTokenUsageAdjustRequested)
if err != nil {
return err
}
handler := func(ctx context.Context, envelope kafkabus.Envelope) error {
var payload model.ChatTokenUsageAdjustPayload
if unmarshalErr := json.Unmarshal(envelope.Payload, &payload); unmarshalErr != nil {
_ = eventOutboxRepo.MarkDead(ctx, envelope.OutboxID, "解析会话 token 调整载荷失败: "+unmarshalErr.Error())
return nil
}
if payload.UserID <= 0 || payload.TokensDelta <= 0 || payload.ConversationID == "" {
_ = eventOutboxRepo.MarkDead(ctx, envelope.OutboxID, "会话 token 调整载荷无效: user_id/conversation_id/tokens_delta 非法")
return nil
}
eventID := strings.TrimSpace(envelope.EventID)
if eventID == "" {
eventID = strconv.FormatInt(envelope.OutboxID, 10)
}
if err := eventOutboxRepo.ConsumeInTx(ctx, envelope.OutboxID, func(tx *gorm.DB) error {
txM := repoManager.WithTx(tx)
return txM.Agent.AdjustTokenUsageInTx(ctx, payload.UserID, payload.ConversationID, payload.TokensDelta, eventID)
}); err != nil {
return err
}
if adjuster == nil {
return errors.New("userauth token adjuster is nil")
}
if _, err := adjuster.AdjustTokenUsage(ctx, contracts.AdjustTokenUsageRequest{
EventID: eventID,
UserID: payload.UserID,
TokenDelta: payload.TokensDelta,
}); err != nil {
return err
}
return eventOutboxRepo.MarkConsumed(ctx, envelope.OutboxID)
}
return bus.RegisterEventHandler(EventTypeChatTokenUsageAdjustRequested, handler)
}
// PublishChatTokenUsageAdjustRequested 发布“会话 token 额度调整”事件。
// 1. 这里只保证 outbox 写入成功,不等待消费结果;
// 2. 业务层只关心 DTO不关心 outbox/Kafka 细节。
func PublishChatTokenUsageAdjustRequested(
ctx context.Context,
publisher outboxinfra.EventPublisher,
payload model.ChatTokenUsageAdjustPayload,
) error {
if publisher == nil {
return errors.New("event publisher is nil")
}
if payload.UserID <= 0 {
return errors.New("invalid user_id")
}
if payload.TokensDelta <= 0 {
return errors.New("invalid tokens_delta")
}
if payload.ConversationID == "" {
return errors.New("invalid conversation_id")
}
if payload.TriggeredAt.IsZero() {
payload.TriggeredAt = time.Now()
}
return publisher.Publish(ctx, outboxinfra.PublishRequest{
EventType: EventTypeChatTokenUsageAdjustRequested,
EventVersion: outboxinfra.DefaultEventVersion,
MessageKey: payload.ConversationID,
AggregateID: strconv.Itoa(payload.UserID) + ":" + payload.ConversationID,
Payload: payload,
})
}

View File

@@ -0,0 +1,185 @@
package eventsvc
import (
"errors"
"github.com/LoveLosita/smartflow/backend/services/memory"
"github.com/LoveLosita/smartflow/backend/services/runtime/dao"
sharedevents "github.com/LoveLosita/smartflow/backend/shared/events"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
"github.com/LoveLosita/smartflow/backend/shared/ports"
)
// RegisterCoreOutboxHandlers 注册单体残留内仍由 agent 边界消费的 outbox handler。
//
// 职责边界:
// 1. 只负责聚合注册当前单体残留内仍归 agent 进程消费的 handler
// 2. 不负责创建 eventBus/outboxRepo/DAO也不负责启动或关闭事件总线。
// 3. 不改变单个 Register* 函数的职责;具体 payload 解析、幂等消费和业务落库仍由各自 handler 负责。
// 4. memory.extract.requested 已在阶段 6 CP1 迁往 cmd/memory这里只登记其路由不再注册消费 handler。
func RegisterCoreOutboxHandlers(
eventBus OutboxBus,
outboxRepo *outboxinfra.Repository,
repoManager *dao.RepoManager,
agentRepo *dao.AgentDAO,
cacheRepo *dao.CacheDAO,
memoryModule *memory.Module,
adjuster ports.TokenUsageAdjuster,
) error {
if err := validateCoreOutboxHandlerDeps(eventBus, outboxRepo, repoManager, agentRepo, cacheRepo); err != nil {
return err
}
if err := RegisterMemoryExtractRoute(); err != nil {
return err
}
return registerOutboxHandlerRoutes(coreOutboxHandlerRoutes(eventBus, outboxRepo, repoManager, agentRepo, cacheRepo, memoryModule, adjuster))
}
// RegisterAllOutboxHandlers 注册当前阶段所有 outbox handler。
//
// 职责边界:
// 1. 只负责把当前单体残留域的 core / active_scheduler 路由一次性接线;
// 2. 不负责创建依赖,也不负责启动事件总线;
// 3. notification 已独立到 cmd/notification自有 outbox consumer 不再由单体注册;
// 4. 供当前启动流程在“总线启动前”统一完成显式路由注册。
func RegisterAllOutboxHandlers(
eventBus OutboxBus,
outboxRepo *outboxinfra.Repository,
repoManager *dao.RepoManager,
agentRepo *dao.AgentDAO,
cacheRepo *dao.CacheDAO,
memoryModule *memory.Module,
activeTriggerWorkflow ActiveScheduleTriggeredProcessor,
adjuster ports.TokenUsageAdjuster,
) error {
if err := validateAllOutboxHandlerDeps(eventBus, outboxRepo, repoManager, agentRepo, cacheRepo, memoryModule, activeTriggerWorkflow); err != nil {
return err
}
return registerOutboxHandlerRoutes(allOutboxHandlerRoutes(
eventBus,
outboxRepo,
repoManager,
agentRepo,
cacheRepo,
memoryModule,
activeTriggerWorkflow,
adjuster,
))
}
// validateCoreOutboxHandlerDeps 校验核心 outbox handler 聚合注册所需依赖。
//
// 职责边界:
// 1. 只做 nil 校验不做数据库、Redis、Kafka 连通性探测,避免注册函数承担启动健康检查职责。
// 2. 返回 error 表示依赖缺失;返回 nil 表示可以安全进入逐项注册流程。
func validateCoreOutboxHandlerDeps(
eventBus OutboxBus,
outboxRepo *outboxinfra.Repository,
repoManager *dao.RepoManager,
agentRepo *dao.AgentDAO,
cacheRepo *dao.CacheDAO,
) error {
if eventBus == nil {
return errors.New("event bus is nil")
}
if outboxRepo == nil {
return errors.New("outbox repository is nil")
}
if repoManager == nil {
return errors.New("repo manager is nil")
}
if agentRepo == nil {
return errors.New("agent repo is nil")
}
if cacheRepo == nil {
return errors.New("cache repo is nil")
}
return nil
}
// validateAllOutboxHandlerDeps 在核心依赖基础上,额外校验 active_scheduler 相关依赖。
func validateAllOutboxHandlerDeps(
eventBus OutboxBus,
outboxRepo *outboxinfra.Repository,
repoManager *dao.RepoManager,
agentRepo *dao.AgentDAO,
cacheRepo *dao.CacheDAO,
memoryModule *memory.Module,
activeTriggerWorkflow ActiveScheduleTriggeredProcessor,
) error {
if err := validateCoreOutboxHandlerDeps(eventBus, outboxRepo, repoManager, agentRepo, cacheRepo); err != nil {
return err
}
if activeTriggerWorkflow == nil {
return errors.New("active schedule triggered processor is nil")
}
return nil
}
// coreOutboxHandlerRoutes 只描述 core 阶段的 outbox 路由。
func coreOutboxHandlerRoutes(
eventBus OutboxBus,
outboxRepo *outboxinfra.Repository,
repoManager *dao.RepoManager,
agentRepo *dao.AgentDAO,
cacheRepo *dao.CacheDAO,
memoryModule *memory.Module,
adjuster ports.TokenUsageAdjuster,
) []outboxHandlerRoute {
return []outboxHandlerRoute{
{
EventType: EventTypeChatHistoryPersistRequested,
Service: outboxHandlerServiceAgent,
Register: func() error {
return RegisterChatHistoryPersistHandler(eventBus, outboxRepo, repoManager, adjuster)
},
},
{
EventType: EventTypeChatTokenUsageAdjustRequested,
Service: outboxHandlerServiceAgent,
Register: func() error {
return RegisterChatTokenUsageAdjustHandler(eventBus, outboxRepo, repoManager, adjuster)
},
},
{
EventType: EventTypeAgentStateSnapshotPersist,
Service: outboxHandlerServiceAgent,
Register: func() error {
return RegisterAgentStateSnapshotHandler(eventBus, outboxRepo, repoManager)
},
},
{
EventType: EventTypeAgentTimelinePersistRequested,
Service: outboxHandlerServiceAgent,
Register: func() error {
return RegisterAgentTimelinePersistHandler(eventBus, outboxRepo, agentRepo, cacheRepo)
},
},
}
}
// allOutboxHandlerRoutes 把当前阶段所有 outbox 路由一次性展开,供启动入口统一接线。
func allOutboxHandlerRoutes(
eventBus OutboxBus,
outboxRepo *outboxinfra.Repository,
repoManager *dao.RepoManager,
agentRepo *dao.AgentDAO,
cacheRepo *dao.CacheDAO,
memoryModule *memory.Module,
activeTriggerWorkflow ActiveScheduleTriggeredProcessor,
adjuster ports.TokenUsageAdjuster,
) []outboxHandlerRoute {
routes := coreOutboxHandlerRoutes(eventBus, outboxRepo, repoManager, agentRepo, cacheRepo, memoryModule, adjuster)
routes = append(routes,
outboxHandlerRoute{
EventType: sharedevents.ActiveScheduleTriggeredEventType,
Service: outboxHandlerServiceActiveScheduler,
Register: func() error {
return RegisterActiveScheduleTriggeredHandler(eventBus, outboxRepo, activeTriggerWorkflow)
},
},
)
return routes
}

View File

@@ -0,0 +1,262 @@
package eventsvc
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/services/memory"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
kafkabus "github.com/LoveLosita/smartflow/backend/shared/infra/kafka"
outboxinfra "github.com/LoveLosita/smartflow/backend/shared/infra/outbox"
"github.com/spf13/viper"
"gorm.io/gorm"
)
const (
// EventTypeMemoryExtractRequested 是“记忆抽取请求”事件类型。
EventTypeMemoryExtractRequested = "memory.extract.requested"
maxMemorySourceTextLength = 1500
)
// RegisterMemoryExtractRoute 只登记 memory.extract.requested 的服务归属。
//
// 职责边界:
// 1. 只保证发布侧能把事件写入 memory_outbox_messages
// 2. 不注册消费 handler消费边界在阶段 6 CP1 起归 cmd/memory
// 3. 重复调用按 outbox 路由注册的幂等语义处理。
func RegisterMemoryExtractRoute() error {
return outboxinfra.RegisterEventService(EventTypeMemoryExtractRequested, outboxinfra.ServiceMemory)
}
// RegisterMemoryExtractRequestedHandler 注册“记忆抽取请求”消费者。
//
// 职责边界:
// 1. 只负责把事件转为 memory_jobs 任务;
// 2. 不在消费回调里执行 LLM 重计算;
// 3. 通过 memory.Module.WithTx(tx) 复用同一套接入门面,保证事务边界仍由 outbox 掌控。
func RegisterMemoryExtractRequestedHandler(
bus OutboxBus,
outboxRepo *outboxinfra.Repository,
memoryModule *memory.Module,
) error {
if bus == nil {
return errors.New("event bus is nil")
}
if outboxRepo == nil {
return errors.New("outbox repository is nil")
}
if memoryModule == nil {
return errors.New("memory module is nil")
}
eventOutboxRepo, err := scopedOutboxRepoForEvent(outboxRepo, EventTypeMemoryExtractRequested)
if err != nil {
return err
}
handler := func(ctx context.Context, envelope kafkabus.Envelope) error {
var payload model.MemoryExtractRequestedPayload
if unmarshalErr := json.Unmarshal(envelope.Payload, &payload); unmarshalErr != nil {
_ = eventOutboxRepo.MarkDead(ctx, envelope.OutboxID, "解析记忆抽取载荷失败: "+unmarshalErr.Error())
return nil
}
if validateErr := validateMemoryExtractPayload(payload); validateErr != nil {
_ = eventOutboxRepo.MarkDead(ctx, envelope.OutboxID, "记忆抽取载荷非法: "+validateErr.Error())
return nil
}
return eventOutboxRepo.ConsumeAndMarkConsumed(ctx, envelope.OutboxID, func(tx *gorm.DB) error {
jobPayload := memorymodel.ExtractJobPayload{
UserID: payload.UserID,
ConversationID: strings.TrimSpace(payload.ConversationID),
AssistantID: strings.TrimSpace(payload.AssistantID),
RunID: strings.TrimSpace(payload.RunID),
SourceMessageID: payload.SourceMessageID,
SourceRole: strings.TrimSpace(payload.SourceRole),
SourceText: strings.TrimSpace(payload.SourceText),
OccurredAt: payload.OccurredAt,
TraceID: strings.TrimSpace(payload.TraceID),
IdempotencyKey: strings.TrimSpace(payload.IdempotencyKey),
}
return memoryModule.WithTx(tx).EnqueueExtract(ctx, jobPayload, envelope.EventID)
})
}
return bus.RegisterEventHandler(EventTypeMemoryExtractRequested, handler)
}
// EnqueueMemoryExtractRequestedInTx 在事务内写入 memory.extract.requested outbox 消息。
//
// 设计目的:
// 1. 让“聊天消息已落库”和“记忆抽取事件已入队”同事务提交;
// 2. 任意一步失败都整体回滚,避免出现链路断点。
func EnqueueMemoryExtractRequestedInTx(
ctx context.Context,
outboxRepo *outboxinfra.Repository,
maxRetry int,
chatPayload model.ChatHistoryPersistPayload,
) error {
if !isMemoryWriteEnabled() {
return nil
}
if outboxRepo == nil {
return errors.New("outbox repository is nil")
}
memoryPayload, shouldEnqueue := buildMemoryExtractPayloadFromChat(chatPayload)
if !shouldEnqueue {
return nil
}
payloadJSON, err := json.Marshal(memoryPayload)
if err != nil {
return err
}
if maxRetry <= 0 {
maxRetry = 20
}
outboxPayload := outboxinfra.OutboxEventPayload{
EventType: EventTypeMemoryExtractRequested,
EventVersion: outboxinfra.DefaultEventVersion,
AggregateID: strings.TrimSpace(chatPayload.ConversationID),
Payload: payloadJSON,
}
// 1. 这里只传 eventType 与消息键服务归属、outbox 表和 Kafka topic 统一交给仓库路由层解析。
// 2. 这样聊天持久化链路不会继续感知 memory 服务的物理 topic避免拆服务时出现双写口径。
_, err = outboxRepo.CreateMessage(
ctx,
EventTypeMemoryExtractRequested,
strings.TrimSpace(chatPayload.ConversationID),
outboxPayload,
maxRetry,
)
return err
}
// PublishMemoryExtractFromGraph 在 graph 完成后直接发布记忆抽取事件。
//
// 设计目的:
// 1. 绕过 chat-persist 链路,由 agent service 在 graph 完成后按需调用;
// 2. 内部完成 source text 截断、幂等 key 生成、memory 开关检查;
// 3. 发布失败只记日志,不阻断主链路。
func PublishMemoryExtractFromGraph(
ctx context.Context,
publisher outboxinfra.EventPublisher,
userID int,
conversationID string,
sourceText string,
) error {
if !isMemoryWriteEnabled() {
return nil
}
if publisher == nil {
return errors.New("event publisher is nil")
}
sourceText = strings.TrimSpace(sourceText)
if sourceText == "" || userID <= 0 || strings.TrimSpace(conversationID) == "" {
return nil
}
truncated := truncateByRune(sourceText, maxMemorySourceTextLength)
now := time.Now()
payload := model.MemoryExtractRequestedPayload{
UserID: userID,
ConversationID: strings.TrimSpace(conversationID),
SourceRole: "user",
SourceText: truncated,
OccurredAt: now,
IdempotencyKey: buildMemoryExtractIdempotencyKey(userID, conversationID, truncated),
}
return publisher.Publish(ctx, outboxinfra.PublishRequest{
EventType: EventTypeMemoryExtractRequested,
EventVersion: outboxinfra.DefaultEventVersion,
MessageKey: payload.ConversationID,
AggregateID: payload.ConversationID,
Payload: payload,
})
}
func buildMemoryExtractPayloadFromChat(chatPayload model.ChatHistoryPersistPayload) (model.MemoryExtractRequestedPayload, bool) {
role := strings.ToLower(strings.TrimSpace(chatPayload.Role))
if role != "user" {
return model.MemoryExtractRequestedPayload{}, false
}
sourceText := strings.TrimSpace(chatPayload.Message)
if sourceText == "" {
return model.MemoryExtractRequestedPayload{}, false
}
truncatedSourceText := truncateByRune(sourceText, maxMemorySourceTextLength)
now := time.Now()
return model.MemoryExtractRequestedPayload{
UserID: chatPayload.UserID,
ConversationID: strings.TrimSpace(chatPayload.ConversationID),
// Day1 先保留 assistant_id/run_id 空值,后续从主链路上下文补齐。
AssistantID: "",
RunID: "",
SourceMessageID: 0,
SourceRole: role,
SourceText: truncatedSourceText,
OccurredAt: now,
TraceID: "",
IdempotencyKey: buildMemoryExtractIdempotencyKey(chatPayload.UserID, chatPayload.ConversationID, truncatedSourceText),
}, true
}
func validateMemoryExtractPayload(payload model.MemoryExtractRequestedPayload) error {
if payload.UserID <= 0 {
return errors.New("user_id is invalid")
}
if strings.TrimSpace(payload.ConversationID) == "" {
return errors.New("conversation_id is empty")
}
if strings.TrimSpace(payload.SourceRole) == "" {
return errors.New("source_role is empty")
}
if strings.TrimSpace(payload.SourceText) == "" {
return errors.New("source_text is empty")
}
if strings.TrimSpace(payload.IdempotencyKey) == "" {
return errors.New("idempotency_key is empty")
}
return nil
}
func buildMemoryExtractIdempotencyKey(userID int, conversationID, sourceText string) string {
raw := fmt.Sprintf("%d|%s|%s", userID, strings.TrimSpace(conversationID), strings.TrimSpace(sourceText))
sum := sha256.Sum256([]byte(raw))
return "memory_extract_" + strconv.Itoa(userID) + "_" + hex.EncodeToString(sum[:8])
}
func truncateByRune(raw string, max int) string {
if max <= 0 {
return ""
}
runes := []rune(raw)
if len(runes) <= max {
return raw
}
return string(runes[:max])
}
func isMemoryWriteEnabled() bool {
if !viper.IsSet("memory.enabled") {
return true
}
return viper.GetBool("memory.enabled")
}

Some files were not shown because too many files have changed in this diff Show More