Version: 0.9.76.dev.260505

后端:
1.阶段 6 agent / memory 服务化收口
- 新增 cmd/agent 独立进程入口,承载 agent zrpc server、agent outbox relay / consumer 和运行时依赖初始化
- 补齐 services/agent/rpc 的 Chat stream 与 conversation meta/list/timeline、schedule-preview、context-stats、schedule-state unary RPC
- 新增 gateway/client/agent 与 shared/contracts/agent,将 /api/v1/agent chat 和非 chat 门面切到 agent zrpc
- 收缩 gateway 本地 AgentService 装配,双 RPC 开关开启时不再初始化本地 agent 编排、LLM、RAG 和 memory reader fallback
- 将 backend/memory 物理迁入 services/memory,私有实现收入 internal,保留 module/model/observe 作为 memory 服务门面
- 调整 memory outbox、memory reader 和 agent 记忆渲染链路的 import 与服务边界,cmd/memory 独占 memory worker / consumer
- 关闭 gateway 侧 agent outbox worker 所有权,agent relay / consumer 由 cmd/agent 独占,gateway 仅保留 HTTP/SSE 门面与迁移期开关回退
- 更新阶段 6 文档,记录 agent / memory 当前切流点、smoke 结果,以及 backend/client 与 gateway/shared 的目录收口口径
This commit is contained in:
Losita
2026-05-05 19:31:39 +08:00
parent d7184b776b
commit 2a96f4c6f9
72 changed files with 2775 additions and 291 deletions

View File

@@ -0,0 +1,149 @@
package service
import (
"strings"
"github.com/LoveLosita/smartflow/backend/model"
memoryutils "github.com/LoveLosita/smartflow/backend/services/memory/internal/utils"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
)
func toItemDTO(item model.MemoryItem) memorymodel.ItemDTO {
return memorymodel.ItemDTO{
ID: item.ID,
UserID: item.UserID,
ConversationID: strValue(item.ConversationID),
AssistantID: strValue(item.AssistantID),
RunID: strValue(item.RunID),
MemoryType: item.MemoryType,
Title: item.Title,
Content: item.Content,
ContentHash: fallbackContentHash(item.MemoryType, item.Content, strValue(item.ContentHash)),
Confidence: item.Confidence,
Importance: item.Importance,
SensitivityLevel: item.SensitivityLevel,
IsExplicit: item.IsExplicit,
Status: item.Status,
TTLAt: item.TTLAt,
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
}
}
func toItemDTOs(items []model.MemoryItem) []memorymodel.ItemDTO {
if len(items) == 0 {
return nil
}
result := make([]memorymodel.ItemDTO, 0, len(items))
for _, item := range items {
result = append(result, toItemDTO(item))
}
return result
}
func toUserSettingDTO(setting model.MemoryUserSetting) memorymodel.UserSettingDTO {
return memorymodel.UserSettingDTO{
UserID: setting.UserID,
MemoryEnabled: setting.MemoryEnabled,
ImplicitMemoryEnabled: setting.ImplicitMemoryEnabled,
SensitiveMemoryEnabled: setting.SensitiveMemoryEnabled,
UpdatedAt: setting.UpdatedAt,
}
}
func normalizeMemoryTypes(raw []string) []string {
if len(raw) == 0 {
return nil
}
result := make([]string, 0, len(raw))
seen := make(map[string]struct{}, len(raw))
for _, item := range raw {
normalized := memorymodel.NormalizeMemoryType(item)
if normalized == "" {
continue
}
if _, exists := seen[normalized]; exists {
continue
}
seen[normalized] = struct{}{}
result = append(result, normalized)
}
return result
}
func normalizeManageStatuses(raw []string) []string {
if len(raw) == 0 {
return []string{
model.MemoryItemStatusActive,
model.MemoryItemStatusArchived,
}
}
result := make([]string, 0, len(raw))
seen := make(map[string]struct{}, len(raw))
for _, item := range raw {
status := strings.ToLower(strings.TrimSpace(item))
if status != model.MemoryItemStatusActive &&
status != model.MemoryItemStatusArchived &&
status != model.MemoryItemStatusDeleted {
continue
}
if _, exists := seen[status]; exists {
continue
}
seen[status] = struct{}{}
result = append(result, status)
}
if len(result) == 0 {
return []string{
model.MemoryItemStatusActive,
model.MemoryItemStatusArchived,
}
}
return result
}
func normalizeLimit(limit, defaultValue, maxValue int) int {
if limit <= 0 {
limit = defaultValue
}
if maxValue > 0 && limit > maxValue {
return maxValue
}
return limit
}
func strValue(v *string) string {
if v == nil {
return ""
}
return strings.TrimSpace(*v)
}
// fallbackContentHash 返回条目可用于服务级去重的内容哈希。
//
// 说明:
// 1. 优先复用库内已落表的 content_hash避免同一条数据多套算法口径不一致
// 2. 若历史数据或 RAG metadata 没带 hash则按“类型 + 规范化内容”补算;
// 3. 若类型非法或正文为空,则返回空字符串,让上游继续走文本兜底去重。
func fallbackContentHash(memoryType, content, currentHash string) string {
currentHash = strings.TrimSpace(currentHash)
if currentHash != "" {
return currentHash
}
normalizedType := memorymodel.NormalizeMemoryType(memoryType)
normalizedContent := normalizeContentForHash(content)
if normalizedType == "" || normalizedContent == "" {
return ""
}
return memoryutils.HashContent(normalizedType, normalizedContent)
}
func normalizeContentForHash(content string) string {
content = strings.TrimSpace(content)
if content == "" {
return ""
}
return strings.ToLower(strings.Join(strings.Fields(content), " "))
}

View File

@@ -0,0 +1,91 @@
package service
import (
"time"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
"github.com/spf13/viper"
)
// LoadConfigFromViper 读取记忆模块配置并做默认值兜底。
//
// 默认策略:
// 1. temperature/top_p 使用低随机参数,提升可复现性;
// 2. Day1 先提供参数位,不强制所有参数立即生效;
// 3. 轮询与重试参数给出保守默认值,避免对主链路造成压力。
func LoadConfigFromViper() memorymodel.Config {
cfg := memorymodel.Config{
Enabled: viper.GetBool("memory.enabled"),
RAGEnabled: viper.GetBool("memory.rag.enabled"),
ReadMode: memorymodel.NormalizeReadMode(viper.GetString("memory.read.mode")),
InjectRenderMode: memorymodel.NormalizeInjectRenderMode(viper.GetString("memory.inject.renderMode")),
ExtractPrompt: viper.GetString("memory.prompt.extract"),
DecisionPrompt: viper.GetString("memory.prompt.decision"),
Threshold: viper.GetFloat64("memory.threshold"),
EnableReranker: viper.GetBool("memory.enableReranker"),
LLMTemperature: viper.GetFloat64("memory.llm.temperature"),
LLMTopP: viper.GetFloat64("memory.llm.topP"),
JobMaxRetry: viper.GetInt("memory.job.maxRetry"),
WorkerPollEvery: viper.GetDuration("memory.worker.pollEvery"),
WorkerClaimBatch: viper.GetInt("memory.worker.claimBatch"),
ReadConstraintLimit: viper.GetInt("memory.read.constraintLimit"),
ReadPreferenceLimit: viper.GetInt("memory.read.preferenceLimit"),
ReadFactLimit: viper.GetInt("memory.read.factLimit"),
// 决策层配置:默认关闭,灰度开启后才会生效。
DecisionEnabled: viper.GetBool("memory.decision.enabled"),
DecisionCandidateTopK: viper.GetInt("memory.decision.candidateTopK"),
DecisionCandidateMinScore: viper.GetFloat64("memory.decision.candidateMinScore"),
DecisionFallbackMode: viper.GetString("memory.decision.fallbackMode"),
WriteMode: viper.GetString("memory.write.mode"),
WriteMinConfidence: viper.GetFloat64("memory.write.minConfidence"),
LLMThinking: viper.GetBool("agent.thinking.memory"),
}
if cfg.Threshold <= 0 {
cfg.Threshold = 0.55
}
if cfg.LLMTemperature <= 0 {
cfg.LLMTemperature = 0.1
}
if cfg.LLMTopP <= 0 {
cfg.LLMTopP = 0.2
}
if cfg.JobMaxRetry <= 0 {
cfg.JobMaxRetry = 6
}
if cfg.WorkerPollEvery <= 0 {
cfg.WorkerPollEvery = 2 * time.Second
}
if cfg.WorkerClaimBatch <= 0 {
cfg.WorkerClaimBatch = 1
}
cfg.ReadConstraintLimit = cfg.EffectiveReadConstraintLimit()
cfg.ReadPreferenceLimit = cfg.EffectiveReadPreferenceLimit()
cfg.ReadFactLimit = cfg.EffectiveReadFactLimit()
cfg.ReadMode = cfg.EffectiveReadMode()
cfg.InjectRenderMode = cfg.EffectiveInjectRenderMode()
// 决策层配置默认值兜底。
// 说明:
// 1. TopK 和 MinScore 是 Milvus 召回参数,需要保守默认值避免召回过多噪声候选;
// 2. FallbackMode 默认退回旧路径新增,保证决策流程异常时不丢数据;
// 3. WriteMode 由 DecisionEnabled 隐式决定,这里不做强制联动。
if cfg.DecisionCandidateTopK <= 0 {
cfg.DecisionCandidateTopK = 5
}
if cfg.DecisionCandidateMinScore <= 0 {
cfg.DecisionCandidateMinScore = 0.6
}
if cfg.DecisionFallbackMode == "" {
cfg.DecisionFallbackMode = "legacy_add"
}
if cfg.WriteMode == "" {
cfg.WriteMode = "legacy"
}
if cfg.WriteMinConfidence <= 0 {
cfg.WriteMinConfidence = 0.5
}
return cfg
}

View File

@@ -0,0 +1,33 @@
package service
import (
"context"
"errors"
memoryrepo "github.com/LoveLosita/smartflow/backend/services/memory/internal/repo"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
)
// EnqueueService 是 Day1 的“任务入队门面”。
//
// 职责边界:
// 1. 只负责把抽取请求入 memory_jobs
// 2. 不负责执行抽取、不负责写 memory_items。
type EnqueueService struct {
jobRepo *memoryrepo.JobRepo
}
func NewEnqueueService(jobRepo *memoryrepo.JobRepo) *EnqueueService {
return &EnqueueService{jobRepo: jobRepo}
}
func (s *EnqueueService) EnqueueExtractJob(
ctx context.Context,
payload memorymodel.ExtractJobPayload,
sourceEventID string,
) error {
if s == nil || s.jobRepo == nil {
return errors.New("memory enqueue service is nil")
}
return s.jobRepo.CreatePendingExtractJob(ctx, payload, sourceEventID)
}

View File

@@ -0,0 +1,659 @@
package service
import (
"context"
"errors"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/respond"
memoryrepo "github.com/LoveLosita/smartflow/backend/services/memory/internal/repo"
memoryutils "github.com/LoveLosita/smartflow/backend/services/memory/internal/utils"
memoryvectorsync "github.com/LoveLosita/smartflow/backend/services/memory/internal/vectorsync"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
memoryobserve "github.com/LoveLosita/smartflow/backend/services/memory/observe"
"gorm.io/gorm"
)
const (
defaultManageListLimit = 20
maxManageListLimit = 100
defaultManualConfidence = 0.95
defaultManualImportance = 0.90
)
// ManageService 负责 memory 模块内部的管理面能力。
//
// 职责边界:
// 1. 负责“列出记忆 / 删除记忆 / 读取与更新用户开关”这类维护动作;
// 2. 负责把用户主动管理行为补充进 memory_audit_logs
// 3. 不负责 prompt 注入、不负责向量召回,也不负责后台抽取任务执行。
type ManageService struct {
db *gorm.DB
itemRepo *memoryrepo.ItemRepo
auditRepo *memoryrepo.AuditRepo
settingsRepo *memoryrepo.SettingsRepo
vectorSyncer *memoryvectorsync.Syncer
observer memoryobserve.Observer
metrics memoryobserve.MetricsRecorder
}
func NewManageService(
db *gorm.DB,
itemRepo *memoryrepo.ItemRepo,
auditRepo *memoryrepo.AuditRepo,
settingsRepo *memoryrepo.SettingsRepo,
vectorSyncer *memoryvectorsync.Syncer,
observer memoryobserve.Observer,
metrics memoryobserve.MetricsRecorder,
) *ManageService {
if observer == nil {
observer = memoryobserve.NewNopObserver()
}
if metrics == nil {
metrics = memoryobserve.NewNopMetrics()
}
return &ManageService{
db: db,
itemRepo: itemRepo,
auditRepo: auditRepo,
settingsRepo: settingsRepo,
vectorSyncer: vectorSyncer,
observer: observer,
metrics: metrics,
}
}
// ListItems 列出某个用户当前可管理的记忆条目。
//
// 说明:
// 1. 这里面向“管理视角”,不会按用户开关再做二次过滤;
// 2. 即便用户暂时关闭 memory总览页仍需要看见已有记忆便于手动删除或核对
// 3. 默认只返回 active/archived除非显式传入 deleted。
func (s *ManageService) ListItems(ctx context.Context, req memorymodel.ListItemsRequest) ([]memorymodel.ItemDTO, error) {
if s == nil || s.itemRepo == nil {
return nil, errors.New("memory manage service is nil")
}
if req.UserID <= 0 {
return nil, nil
}
conversationID := strings.TrimSpace(req.ConversationID)
query := memorymodel.ItemQuery{
UserID: req.UserID,
ConversationID: conversationID,
Statuses: normalizeManageStatuses(req.Statuses),
MemoryTypes: normalizeMemoryTypes(req.MemoryTypes),
IncludeGlobal: conversationID != "",
OnlyUnexpired: false,
Limit: normalizeLimit(req.Limit, defaultManageListLimit, maxManageListLimit),
}
items, err := s.itemRepo.FindByQuery(ctx, query)
if err != nil {
return nil, err
}
return toItemDTOs(items), nil
}
// GetItem 返回“当前用户自己的某条记忆”详情。
func (s *ManageService) GetItem(ctx context.Context, req model.MemoryGetItemRequest) (*memorymodel.ItemDTO, error) {
if s == nil || s.itemRepo == nil {
return nil, errors.New("memory manage service is nil")
}
if req.UserID <= 0 {
return nil, respond.WrongUserID
}
if req.MemoryID <= 0 {
return nil, respond.WrongParamType
}
item, err := s.itemRepo.GetByIDForUser(ctx, req.UserID, req.MemoryID)
if err != nil {
return nil, translateManageError(err)
}
dto := toItemDTO(*item)
return &dto, nil
}
// CreateItem 手动新增一条用户记忆,并补审计与向量同步桥接。
func (s *ManageService) CreateItem(ctx context.Context, req model.MemoryCreateItemRequest) (*memorymodel.ItemDTO, error) {
if s == nil || s.db == nil || s.itemRepo == nil || s.auditRepo == nil {
return nil, errors.New("memory manage service is not initialized")
}
if req.UserID <= 0 {
return nil, respond.WrongUserID
}
fields, err := buildCreateItemFields(req)
if err != nil {
s.recordManageAction(ctx, "create", req.UserID, 0, fields.MemoryType, false, err)
return nil, err
}
var createdItem model.MemoryItem
err = s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
itemRepo := s.itemRepo.WithTx(tx)
auditRepo := s.auditRepo.WithTx(tx)
created, createErr := itemRepo.Create(ctx, fields)
if createErr != nil {
return createErr
}
createdItem = *created
audit := memoryutils.BuildItemAuditLog(
createdItem.ID,
createdItem.UserID,
memoryutils.AuditOperationCreate,
memoryutils.NormalizeOperatorType(req.OperatorType),
normalizeManageReason(req.Reason, "用户手动新增记忆"),
nil,
&createdItem,
)
return auditRepo.Create(ctx, audit)
})
if err != nil {
err = translateManageError(err)
s.recordManageAction(ctx, "create", req.UserID, 0, fields.MemoryType, false, err)
return nil, err
}
s.vectorSyncer.Upsert(ctx, "", []model.MemoryItem{createdItem})
s.recordManageAction(ctx, "create", req.UserID, createdItem.ID, createdItem.MemoryType, true, nil)
dto := toItemDTO(createdItem)
return &dto, nil
}
// UpdateItem 手动修改一条用户记忆,并补审计与向量重同步桥接。
func (s *ManageService) UpdateItem(ctx context.Context, req model.MemoryUpdateItemRequest) (*memorymodel.ItemDTO, error) {
if s == nil || s.db == nil || s.itemRepo == nil || s.auditRepo == nil {
return nil, errors.New("memory manage service is not initialized")
}
if req.UserID <= 0 {
return nil, respond.WrongUserID
}
if req.MemoryID <= 0 {
return nil, respond.WrongParamType
}
var updatedItem model.MemoryItem
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
itemRepo := s.itemRepo.WithTx(tx)
auditRepo := s.auditRepo.WithTx(tx)
current, getErr := itemRepo.GetByIDForUser(ctx, req.UserID, req.MemoryID)
if getErr != nil {
return getErr
}
fields, afterItem, buildErr := buildUpdateItemFields(req, *current)
if buildErr != nil {
return buildErr
}
now := time.Now()
afterItem.UpdatedAt = &now
afterItem.VectorStatus = "pending"
if updateErr := itemRepo.UpdateManagedFieldsByIDAt(ctx, req.UserID, req.MemoryID, fields, now); updateErr != nil {
return updateErr
}
audit := memoryutils.BuildItemAuditLog(
current.ID,
current.UserID,
memoryutils.AuditOperationUpdate,
memoryutils.NormalizeOperatorType(req.OperatorType),
normalizeManageReason(req.Reason, "用户手动修改记忆"),
current,
&afterItem,
)
if auditErr := auditRepo.Create(ctx, audit); auditErr != nil {
return auditErr
}
updatedItem = afterItem
return nil
})
if err != nil {
err = translateManageError(err)
s.recordManageAction(ctx, "update", req.UserID, req.MemoryID, resolveUpdateMemoryType(req), false, err)
return nil, err
}
s.vectorSyncer.Upsert(ctx, "", []model.MemoryItem{updatedItem})
s.recordManageAction(ctx, "update", req.UserID, updatedItem.ID, updatedItem.MemoryType, true, nil)
dto := toItemDTO(updatedItem)
return &dto, nil
}
// DeleteItem 软删除一条记忆,并补写审计日志。
//
// 步骤化说明:
// 1. 先在事务里读取当前条目快照,确保审计前镜像和实际删除对象一致;
// 2. 若该条目已是 deleted则直接按幂等语义返回避免重复写多条删除审计
// 3. 状态更新成功后再写 audit log保证“有删除就有审计”失败时整笔事务回滚。
func (s *ManageService) DeleteItem(ctx context.Context, req model.MemoryDeleteItemRequest) (*memorymodel.ItemDTO, error) {
if s == nil || s.db == nil || s.itemRepo == nil || s.auditRepo == nil {
return nil, errors.New("memory manage service is not initialized")
}
if req.UserID <= 0 {
return nil, respond.WrongUserID
}
if req.MemoryID <= 0 {
return nil, respond.WrongParamType
}
now := time.Now()
operatorType := memoryutils.NormalizeOperatorType(req.OperatorType)
reason := normalizeDeleteReason(req.Reason)
var deletedItem model.MemoryItem
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
itemRepo := s.itemRepo.WithTx(tx)
auditRepo := s.auditRepo.WithTx(tx)
current, err := itemRepo.GetByIDForUser(ctx, req.UserID, req.MemoryID)
if err != nil {
return err
}
if current.Status == model.MemoryItemStatusDeleted {
deletedItem = *current
return nil
}
before := *current
after := before
after.Status = model.MemoryItemStatusDeleted
after.UpdatedAt = &now
after.VectorStatus = "pending"
if err = itemRepo.SoftDeleteByID(ctx, req.UserID, req.MemoryID); err != nil {
return err
}
audit := memoryutils.BuildItemAuditLog(
req.MemoryID,
req.UserID,
memoryutils.AuditOperationDelete,
operatorType,
reason,
&before,
&after,
)
if err = auditRepo.Create(ctx, audit); err != nil {
return err
}
deletedItem = after
return nil
})
if err != nil {
err = translateManageError(err)
s.recordManageAction(ctx, "delete", req.UserID, req.MemoryID, "", false, err)
return nil, err
}
if deletedItem.ID <= 0 {
return nil, nil
}
if deletedItem.Status == model.MemoryItemStatusDeleted {
s.vectorSyncer.Delete(ctx, "", []int64{deletedItem.ID})
}
s.recordManageAction(ctx, "delete", req.UserID, deletedItem.ID, deletedItem.MemoryType, true, nil)
result := toItemDTO(deletedItem)
return &result, nil
}
// RestoreItem 把 archived/deleted 记忆恢复为 active并补审计与向量同步桥接。
func (s *ManageService) RestoreItem(ctx context.Context, req model.MemoryRestoreItemRequest) (*memorymodel.ItemDTO, error) {
if s == nil || s.db == nil || s.itemRepo == nil || s.auditRepo == nil {
return nil, errors.New("memory manage service is not initialized")
}
if req.UserID <= 0 {
return nil, respond.WrongUserID
}
if req.MemoryID <= 0 {
return nil, respond.WrongParamType
}
var restoredItem model.MemoryItem
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
itemRepo := s.itemRepo.WithTx(tx)
auditRepo := s.auditRepo.WithTx(tx)
current, getErr := itemRepo.GetByIDForUser(ctx, req.UserID, req.MemoryID)
if getErr != nil {
return getErr
}
if current.Status == model.MemoryItemStatusActive {
restoredItem = *current
return nil
}
now := time.Now()
before := *current
after := before
after.Status = model.MemoryItemStatusActive
after.UpdatedAt = &now
after.VectorStatus = "pending"
if restoreErr := itemRepo.RestoreByIDAt(ctx, req.UserID, req.MemoryID, now); restoreErr != nil {
return restoreErr
}
audit := memoryutils.BuildItemAuditLog(
before.ID,
before.UserID,
memoryutils.AuditOperationRestore,
memoryutils.NormalizeOperatorType(req.OperatorType),
normalizeManageReason(req.Reason, "用户恢复记忆"),
&before,
&after,
)
if auditErr := auditRepo.Create(ctx, audit); auditErr != nil {
return auditErr
}
restoredItem = after
return nil
})
if err != nil {
err = translateManageError(err)
s.recordManageAction(ctx, "restore", req.UserID, req.MemoryID, "", false, err)
return nil, err
}
s.vectorSyncer.Upsert(ctx, "", []model.MemoryItem{restoredItem})
s.recordManageAction(ctx, "restore", req.UserID, restoredItem.ID, restoredItem.MemoryType, true, nil)
dto := toItemDTO(restoredItem)
return &dto, nil
}
// GetUserSetting 返回用户当前生效的记忆开关。
//
// 返回语义:
// 1. 若数据库中还没有记录,返回系统默认开关,而不是 nil
// 2. 这样前端/上层调用方始终拿到完整结构,避免再做一层判空补默认值;
// 3. 这里只读 settings不附带修改动作。
func (s *ManageService) GetUserSetting(ctx context.Context, userID int) (memorymodel.UserSettingDTO, error) {
if s == nil || s.settingsRepo == nil {
return memorymodel.UserSettingDTO{}, errors.New("memory manage service is nil")
}
if userID <= 0 {
return memorymodel.UserSettingDTO{}, nil
}
setting, err := s.settingsRepo.GetByUserID(ctx, userID)
if err != nil {
return memorymodel.UserSettingDTO{}, err
}
return toUserSettingDTO(memoryutils.EffectiveUserSetting(setting, userID)), nil
}
// UpsertUserSetting 写入用户记忆开关。
//
// 说明:
// 1. 当前阶段先直接覆盖三类开关,不做 patch 语义;
// 2. 这样便于前端把整块设置表单一次性提交,接口语义更稳定;
// 3. 若后续需要记录设置变更审计,再单独扩展 setting audit而不是复用 item audit。
func (s *ManageService) UpsertUserSetting(ctx context.Context, req memorymodel.UpdateUserSettingRequest) (memorymodel.UserSettingDTO, error) {
if s == nil || s.settingsRepo == nil {
return memorymodel.UserSettingDTO{}, errors.New("memory manage service is nil")
}
if req.UserID <= 0 {
return memorymodel.UserSettingDTO{}, nil
}
now := time.Now()
setting := model.MemoryUserSetting{
UserID: req.UserID,
MemoryEnabled: req.MemoryEnabled,
ImplicitMemoryEnabled: req.ImplicitMemoryEnabled,
SensitiveMemoryEnabled: req.SensitiveMemoryEnabled,
UpdatedAt: &now,
}
if err := s.settingsRepo.Upsert(ctx, setting); err != nil {
return memorymodel.UserSettingDTO{}, err
}
return toUserSettingDTO(setting), nil
}
func normalizeDeleteReason(reason string) string {
reason = strings.TrimSpace(reason)
if reason == "" {
return "用户删除记忆"
}
return reason
}
func normalizeManageReason(reason string, fallback string) string {
reason = strings.TrimSpace(reason)
if reason == "" {
return fallback
}
return reason
}
func translateManageError(err error) error {
switch {
case err == nil:
return nil
case errors.Is(err, gorm.ErrRecordNotFound):
return respond.MemoryItemNotFound
default:
return err
}
}
func buildCreateItemFields(req model.MemoryCreateItemRequest) (memorymodel.CreateItemFields, error) {
memoryType, err := normalizeManagedMemoryType(req.MemoryType)
if err != nil {
return memorymodel.CreateItemFields{}, err
}
content, normalizedContent, err := normalizeManagedContent(req.Content)
if err != nil {
return memorymodel.CreateItemFields{}, err
}
title := normalizeManagedTitle(req.Title, content)
return memorymodel.CreateItemFields{
UserID: req.UserID,
ConversationID: strings.TrimSpace(req.ConversationID),
AssistantID: strings.TrimSpace(req.AssistantID),
RunID: strings.TrimSpace(req.RunID),
MemoryType: memoryType,
Title: title,
Content: content,
NormalizedContent: normalizedContent,
ContentHash: memoryutils.HashContent(memoryType, normalizedContent),
Confidence: normalizeManageScore(req.Confidence, defaultManualConfidence),
Importance: normalizeManageScore(req.Importance, defaultManualImportance),
SensitivityLevel: normalizeManageSensitivity(req.SensitivityLevel, 0),
IsExplicit: normalizeManageBool(req.IsExplicit, true),
Status: model.MemoryItemStatusActive,
TTLAt: req.TTLAt,
VectorStatus: "pending",
}, nil
}
func buildUpdateItemFields(
req model.MemoryUpdateItemRequest,
current model.MemoryItem,
) (memorymodel.UpdateItemFields, model.MemoryItem, error) {
memoryType := current.MemoryType
if req.MemoryType != nil {
normalizedType, err := normalizeManagedMemoryType(*req.MemoryType)
if err != nil {
return memorymodel.UpdateItemFields{}, model.MemoryItem{}, err
}
memoryType = normalizedType
}
content := current.Content
if req.Content != nil {
normalizedContentValue, _, err := normalizeManagedContent(*req.Content)
if err != nil {
return memorymodel.UpdateItemFields{}, model.MemoryItem{}, err
}
content = normalizedContentValue
}
normalizedContent := normalizeContentForHash(content)
if normalizedContent == "" {
return memorymodel.UpdateItemFields{}, model.MemoryItem{}, respond.MemoryInvalidContent
}
title := current.Title
if req.Title != nil {
title = normalizeManagedTitle(*req.Title, content)
}
ttlAt := current.TTLAt
if req.ClearTTL {
ttlAt = nil
} else if req.TTLAt != nil {
ttlAt = req.TTLAt
}
fields := memorymodel.UpdateItemFields{
MemoryType: memoryType,
Title: title,
Content: content,
NormalizedContent: normalizedContent,
ContentHash: memoryutils.HashContent(memoryType, normalizedContent),
Confidence: normalizeManageScore(req.Confidence, current.Confidence),
Importance: normalizeManageScore(req.Importance, current.Importance),
SensitivityLevel: normalizeManageSensitivity(req.SensitivityLevel, current.SensitivityLevel),
IsExplicit: normalizeManageBool(req.IsExplicit, current.IsExplicit),
TTLAt: ttlAt,
}
after := current
after.MemoryType = fields.MemoryType
after.Title = fields.Title
after.Content = fields.Content
after.NormalizedContent = strPtr(fields.NormalizedContent)
after.ContentHash = strPtr(fields.ContentHash)
after.Confidence = fields.Confidence
after.Importance = fields.Importance
after.SensitivityLevel = fields.SensitivityLevel
after.IsExplicit = fields.IsExplicit
after.TTLAt = fields.TTLAt
return fields, after, nil
}
func normalizeManagedMemoryType(raw string) (string, error) {
normalized := memorymodel.NormalizeMemoryType(raw)
if normalized == "" {
return "", respond.MemoryInvalidType
}
return normalized, nil
}
func normalizeManagedContent(raw string) (string, string, error) {
content := strings.TrimSpace(raw)
if content == "" {
return "", "", respond.MemoryInvalidContent
}
normalized := normalizeContentForHash(content)
if normalized == "" {
return "", "", respond.MemoryInvalidContent
}
return content, normalized, nil
}
func normalizeManagedTitle(raw string, content string) string {
title := strings.TrimSpace(raw)
if title != "" {
return title
}
content = strings.TrimSpace(content)
if content == "" {
return "未命名记忆"
}
runes := []rune(content)
if len(runes) > 24 {
return string(runes[:24])
}
return content
}
func normalizeManageScore(value *float64, defaultValue float64) float64 {
if value == nil {
return clamp01(defaultValue)
}
return clamp01(*value)
}
func normalizeManageSensitivity(value *int, defaultValue int) int {
if value == nil {
return defaultValue
}
if *value < 0 {
return defaultValue
}
return *value
}
func normalizeManageBool(value *bool, defaultValue bool) bool {
if value == nil {
return defaultValue
}
return *value
}
func resolveUpdateMemoryType(req model.MemoryUpdateItemRequest) string {
if req.MemoryType == nil {
return ""
}
return strings.TrimSpace(*req.MemoryType)
}
func strPtr(value string) *string {
value = strings.TrimSpace(value)
if value == "" {
return nil
}
result := value
return &result
}
func (s *ManageService) recordManageAction(
ctx context.Context,
operation string,
userID int,
memoryID int64,
memoryType string,
success bool,
err error,
) {
if s == nil {
return
}
status := "success"
level := memoryobserve.LevelInfo
if !success || err != nil {
status = "error"
level = memoryobserve.LevelWarn
}
s.metrics.AddCounter(memoryobserve.MetricManageTotal, 1, map[string]string{
"operation": strings.TrimSpace(operation),
"status": status,
})
s.observer.Observe(ctx, memoryobserve.Event{
Level: level,
Component: memoryobserve.ComponentManage,
Operation: memoryobserve.OperationManage,
Fields: map[string]any{
"user_id": userID,
"memory_id": memoryID,
"action": strings.TrimSpace(operation),
"memory_type": strings.TrimSpace(memoryType),
"success": success && err == nil,
"error": err,
"error_code": memoryobserve.ClassifyError(err),
},
})
}

View File

@@ -0,0 +1,83 @@
package service
import (
"time"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
ragservice "github.com/LoveLosita/smartflow/backend/services/rag"
)
// buildReadScopedItemQuery 构造读侧统一使用的 MySQL 查询条件。
//
// 职责边界:
// 1. 只负责把 RetrieveRequest 映射成“读侧作用域”查询参数;
// 2. 不负责真正查库,也不负责排序、裁剪或注入;
// 3. conversation_id 字段在这里刻意不参与过滤,仅保留在记忆记录元数据里供审计与溯源使用。
//
// 步骤化说明:
// 1. 读侧始终按 user_id 作为硬隔离边界,避免跨用户串记忆。
// 2. assistant_id / run_id 仍允许参与过滤,因为它们表达的是助手实例与执行轮次边界,而不是“是否跨对话召回”的问题。
// 3. conversation_id 明确置空,原因是聊天上下文窗口已经覆盖同对话信息;记忆读侧的价值主要在跨对话补充。
func buildReadScopedItemQuery(
req memorymodel.RetrieveRequest,
now time.Time,
statuses []string,
limit int,
) memorymodel.ItemQuery {
return memorymodel.ItemQuery{
UserID: req.UserID,
ConversationID: "",
AssistantID: req.AssistantID,
RunID: req.RunID,
Statuses: statuses,
MemoryTypes: normalizeRetrieveMemoryTypes(req.MemoryTypes),
IncludeGlobal: true,
OnlyUnexpired: true,
Limit: limit,
Now: now,
}
}
// buildReadScopedRAGRequest 构造读侧统一使用的 RAG 检索请求。
//
// 职责边界:
// 1. 只负责生成 memory 检索请求,不负责执行向量检索;
// 2. 不负责阈值外的重排、fallback 或去重;
// 3. conversation_id 字段同样只保留在文档 metadata 中,不再作为聊天读侧的硬过滤条件。
//
// 步骤化说明:
// 1. user_id 仍是唯一必须保留的硬过滤条件,确保召回范围限定在当前用户。
// 2. conversation_id 明确置空,避免旧对话记忆在进入相似度计算前就被 metadata filter 提前挡掉。
// 3. assistant_id / run_id 保持透传,方便后续若存在多助手场景时继续做更细粒度隔离。
func buildReadScopedRAGRequest(
req memorymodel.RetrieveRequest,
topK int,
threshold float64,
) ragservice.MemoryRetrieveRequest {
return ragservice.MemoryRetrieveRequest{
Query: req.Query,
TopK: topK,
Threshold: threshold,
Action: "search",
UserID: req.UserID,
ConversationID: "",
AssistantID: req.AssistantID,
RunID: req.RunID,
MemoryTypes: normalizeRetrieveMemoryTypes(req.MemoryTypes),
}
}
// shouldReturnSemanticRAGResult 判断当前是否可以直接采用 RAG 结果。
//
// 职责边界:
// 1. 只负责表达“RAG 是否足以短路后续 MySQL fallback”这一条业务规则
// 2. 不负责执行任何检索,也不负责日志记录;
// 3. 返回 false 不代表错误,只代表调用方应继续尝试数据库兜底。
//
// 步骤化说明:
// 1. RAG 报错时,一定不能短路,必须继续走 MySQL fallback。
// 2. RAG 0 命中时,同样不能短路;否则会把“成功执行但没有候选”误当成最终结果。
// 3. 只有“无报错且结果非空”时,才允许直接返回 RAG 结果。
func shouldReturnSemanticRAGResult(items []memorymodel.ItemDTO, err error) bool {
return err == nil && len(items) > 0
}

View File

@@ -0,0 +1,438 @@
package service
import (
"context"
"fmt"
"sort"
"strconv"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
memoryrepo "github.com/LoveLosita/smartflow/backend/services/memory/internal/repo"
memoryutils "github.com/LoveLosita/smartflow/backend/services/memory/internal/utils"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
memoryobserve "github.com/LoveLosita/smartflow/backend/services/memory/observe"
ragservice "github.com/LoveLosita/smartflow/backend/services/rag"
)
const (
defaultRetrieveLimit = 5
maxRetrieveLimit = 20
)
// ReadService 负责 memory 模块内部的读取、门控与轻量重排。
//
// 职责边界:
// 1. 负责把 memory_items 读出来并做用户设置过滤;
// 2. 负责最小可用的排序与截断,为后续 prompt 注入提供稳定入口;
// 3. 不直接依赖 agent不负责真正把记忆拼进 prompt。
type ReadService struct {
itemRepo *memoryrepo.ItemRepo
settingsRepo *memoryrepo.SettingsRepo
ragRuntime ragservice.Runtime
cfg memorymodel.Config
observer memoryobserve.Observer
metrics memoryobserve.MetricsRecorder
}
type retrieveTelemetry struct {
ReadMode string
QueryLen int
LegacyHitCount int
PinnedHitCount int
SemanticHitCount int
DedupDropCount int
FinalCount int
Degraded bool
RAGFallbackUsed bool
}
type semanticRetrieveTelemetry struct {
HitCount int
Degraded bool
RAGFallbackUsed bool
}
func NewReadService(
itemRepo *memoryrepo.ItemRepo,
settingsRepo *memoryrepo.SettingsRepo,
ragRuntime ragservice.Runtime,
cfg memorymodel.Config,
observer memoryobserve.Observer,
metrics memoryobserve.MetricsRecorder,
) *ReadService {
if observer == nil {
observer = memoryobserve.NewNopObserver()
}
if metrics == nil {
metrics = memoryobserve.NewNopMetrics()
}
return &ReadService{
itemRepo: itemRepo,
settingsRepo: settingsRepo,
ragRuntime: ragRuntime,
cfg: cfg,
observer: observer,
metrics: metrics,
}
}
// Retrieve 读取可供后续注入使用的候选记忆。
func (s *ReadService) Retrieve(ctx context.Context, req memorymodel.RetrieveRequest) ([]memorymodel.ItemDTO, error) {
if s == nil || s.itemRepo == nil || s.settingsRepo == nil {
return nil, nil
}
if req.UserID <= 0 {
return nil, nil
}
now := req.Now
if now.IsZero() {
now = time.Now()
}
telemetry := retrieveTelemetry{
ReadMode: s.cfg.EffectiveReadMode(),
QueryLen: len(strings.TrimSpace(req.Query)),
}
setting, err := s.settingsRepo.GetByUserID(ctx, req.UserID)
if err != nil {
s.recordRetrieve(ctx, req, telemetry, err)
return nil, err
}
effectiveSetting := memoryutils.EffectiveUserSetting(setting, req.UserID)
if !effectiveSetting.MemoryEnabled {
return nil, nil
}
limit := normalizeLimit(req.Limit, defaultRetrieveLimit, maxRetrieveLimit)
if s.cfg.EffectiveReadMode() == memorymodel.MemoryReadModeHybrid {
items, hybridTelemetry, hybridErr := s.HybridRetrieve(ctx, req, effectiveSetting, limit, now)
hybridTelemetry.ReadMode = memorymodel.MemoryReadModeHybrid
hybridTelemetry.QueryLen = telemetry.QueryLen
s.recordRetrieve(ctx, req, hybridTelemetry, hybridErr)
return items, hybridErr
}
if s.cfg.RAGEnabled && s.ragRuntime != nil && strings.TrimSpace(req.Query) != "" {
items, ragErr := s.retrieveByRAG(ctx, req, effectiveSetting, limit, now)
if ragErr == nil && len(items) > 0 {
telemetry.SemanticHitCount = len(items)
telemetry.FinalCount = len(items)
s.recordRetrieve(ctx, req, telemetry, nil)
return items, nil
}
telemetry.Degraded = true
telemetry.RAGFallbackUsed = true
}
items, legacyErr := s.retrieveByLegacy(ctx, req, limit, now, effectiveSetting)
telemetry.LegacyHitCount = len(items)
telemetry.FinalCount = len(items)
s.recordRetrieve(ctx, req, telemetry, legacyErr)
return items, legacyErr
}
func (s *ReadService) retrieveByLegacy(
ctx context.Context,
req memorymodel.RetrieveRequest,
limit int,
now time.Time,
effectiveSetting model.MemoryUserSetting,
) ([]memorymodel.ItemDTO, error) {
if !effectiveSetting.MemoryEnabled {
return nil, nil
}
query := buildReadScopedItemQuery(
req,
now,
[]string{model.MemoryItemStatusActive},
normalizeLimit(limit*3, limit*3, maxRetrieveLimit*3),
)
items, err := s.itemRepo.FindByQuery(ctx, query)
if err != nil {
return nil, err
}
items = memoryutils.FilterItemsBySetting(items, effectiveSetting)
if len(items) == 0 {
return nil, nil
}
sort.SliceStable(items, func(i, j int) bool {
left := scoreRetrievedItem(items[i], now)
right := scoreRetrievedItem(items[j], now)
if left == right {
return items[i].ID > items[j].ID
}
return left > right
})
if len(items) > limit {
items = items[:limit]
}
_ = s.itemRepo.TouchLastAccessAt(ctx, collectMemoryIDs(items), now)
return toItemDTOs(items), nil
}
func (s *ReadService) retrieveByRAG(
ctx context.Context,
req memorymodel.RetrieveRequest,
effectiveSetting model.MemoryUserSetting,
limit int,
now time.Time,
) ([]memorymodel.ItemDTO, error) {
if !effectiveSetting.MemoryEnabled {
return nil, nil
}
result, err := s.ragRuntime.RetrieveMemory(ctx, buildReadScopedRAGRequest(req, limit, s.cfg.Threshold))
if err != nil || result == nil || len(result.Items) == 0 {
return nil, err
}
items := make([]memorymodel.ItemDTO, 0, len(result.Items))
ids := make([]int64, 0, len(result.Items))
for _, hit := range result.Items {
dto, memoryID := buildMemoryDTOFromRetrieveHit(hit)
if !effectiveSetting.ImplicitMemoryEnabled && !dto.IsExplicit {
continue
}
if !effectiveSetting.SensitiveMemoryEnabled && dto.SensitivityLevel > 0 {
continue
}
if dto.ID <= 0 && memoryID > 0 {
dto.ID = memoryID
}
items = append(items, dto)
if dto.ID > 0 {
ids = append(ids, dto.ID)
}
}
if len(items) > limit {
items = items[:limit]
}
_ = s.itemRepo.TouchLastAccessAt(ctx, ids, now)
return items, nil
}
func normalizeRetrieveMemoryTypes(raw []string) []string {
normalized := normalizeMemoryTypes(raw)
if len(normalized) > 0 {
return normalized
}
return []string{
memorymodel.MemoryTypeConstraint,
memorymodel.MemoryTypePreference,
memorymodel.MemoryTypeFact,
}
}
func (s *ReadService) recordRetrieve(
ctx context.Context,
req memorymodel.RetrieveRequest,
telemetry retrieveTelemetry,
err error,
) {
if s == nil {
return
}
level := memoryobserve.LevelInfo
if err != nil {
level = memoryobserve.LevelWarn
}
s.observer.Observe(ctx, memoryobserve.Event{
Level: level,
Component: memoryobserve.ComponentRead,
Operation: memoryobserve.OperationRetrieve,
Fields: map[string]any{
"user_id": req.UserID,
"read_mode": telemetry.ReadMode,
"query_len": telemetry.QueryLen,
"legacy_hit_count": telemetry.LegacyHitCount,
"pinned_hit_count": telemetry.PinnedHitCount,
"semantic_hit_count": telemetry.SemanticHitCount,
"dedup_drop_count": telemetry.DedupDropCount,
"final_count": telemetry.FinalCount,
"degraded": telemetry.Degraded,
"rag_fallback_used": telemetry.RAGFallbackUsed,
"success": err == nil,
"error": err,
"error_code": memoryobserve.ClassifyError(err),
},
})
if telemetry.FinalCount > 0 {
s.metrics.AddCounter(memoryobserve.MetricRetrieveHitTotal, int64(telemetry.FinalCount), map[string]string{
"read_mode": strings.TrimSpace(telemetry.ReadMode),
})
}
if telemetry.DedupDropCount > 0 {
s.metrics.AddCounter(memoryobserve.MetricRetrieveDedupDropTotal, int64(telemetry.DedupDropCount), map[string]string{
"read_mode": strings.TrimSpace(telemetry.ReadMode),
})
}
if telemetry.RAGFallbackUsed {
s.metrics.AddCounter(memoryobserve.MetricRAGFallbackTotal, 1, map[string]string{
"read_mode": strings.TrimSpace(telemetry.ReadMode),
})
}
}
// scoreRetrievedItem 计算 legacy 读链路的确定性排序分数。
//
// 说明:
// 1. 这里只保留 importance / confidence / recency / explicit / type 这些稳定特征;
// 2. conversation_id 已不再参与读侧打分,因为同对话信息本就已经在上下文窗口内;
// 3. 若后续需要引入语义分或 reranker应在 DTO 层补齐对应字段后再统一并入。
func scoreRetrievedItem(item model.MemoryItem, now time.Time) float64 {
score := 0.35*clamp01(item.Importance) + 0.3*clamp01(item.Confidence) + 0.2*recencyScore(item, now)
if item.IsExplicit {
score += 0.1
}
switch item.MemoryType {
case memorymodel.MemoryTypeConstraint:
score += 0.12
case memorymodel.MemoryTypePreference:
score += 0.08
}
return score
}
func recencyScore(item model.MemoryItem, now time.Time) float64 {
base := item.UpdatedAt
if base == nil {
base = item.CreatedAt
}
if base == nil || now.Before(*base) {
return 0.5
}
age := now.Sub(*base)
switch {
case age <= 24*time.Hour:
return 1
case age <= 7*24*time.Hour:
return 0.85
case age <= 30*24*time.Hour:
return 0.65
case age <= 90*24*time.Hour:
return 0.45
default:
return 0.25
}
}
func clamp01(v float64) float64 {
if v < 0 {
return 0
}
if v > 1 {
return 1
}
return v
}
func collectMemoryIDs(items []model.MemoryItem) []int64 {
if len(items) == 0 {
return nil
}
ids := make([]int64, 0, len(items))
for _, item := range items {
if item.ID <= 0 {
continue
}
ids = append(ids, item.ID)
}
return ids
}
func buildMemoryDTOFromRetrieveHit(hit ragservice.RetrieveHit) (memorymodel.ItemDTO, int64) {
memoryID := parseMemoryIDFromDocumentID(hit.DocumentID)
metadata := hit.Metadata
content := strings.TrimSpace(hit.Text)
memoryType := readString(metadata["memory_type"])
dto := memorymodel.ItemDTO{
ID: memoryID,
UserID: int(readFloatLike(metadata["user_id"])),
ConversationID: readString(metadata["conversation_id"]),
AssistantID: readString(metadata["assistant_id"]),
RunID: readString(metadata["run_id"]),
MemoryType: memoryType,
Title: readString(metadata["title"]),
Content: content,
ContentHash: fallbackContentHash(memoryType, content, readString(metadata["content_hash"])),
Confidence: readFloatLike(metadata["confidence"]),
Importance: readFloatLike(metadata["importance"]),
SensitivityLevel: int(readFloatLike(metadata["sensitivity_level"])),
IsExplicit: readBoolLike(metadata["is_explicit"]),
Status: readString(metadata["status"]),
TTLAt: readTimeLike(metadata["ttl_at"]),
}
return dto, memoryID
}
func parseMemoryIDFromDocumentID(documentID string) int64 {
documentID = strings.TrimSpace(documentID)
if !strings.HasPrefix(documentID, "memory:") {
return 0
}
raw := strings.TrimPrefix(documentID, "memory:")
if strings.HasPrefix(raw, "uid:") {
return 0
}
parsed, err := strconv.ParseInt(raw, 10, 64)
if err != nil {
return 0
}
return parsed
}
func readString(v any) string {
if v == nil {
return ""
}
return strings.TrimSpace(fmt.Sprintf("%v", v))
}
func readFloatLike(v any) float64 {
switch value := v.(type) {
case float64:
return value
case float32:
return float64(value)
case int:
return float64(value)
case int64:
return float64(value)
case string:
parsed, err := strconv.ParseFloat(strings.TrimSpace(value), 64)
if err == nil {
return parsed
}
}
return 0
}
func readBoolLike(v any) bool {
switch value := v.(type) {
case bool:
return value
case string:
return strings.EqualFold(strings.TrimSpace(value), "true")
default:
return false
}
}
func readTimeLike(v any) *time.Time {
text := readString(v)
if text == "" {
return nil
}
parsed, err := time.Parse(time.RFC3339, text)
if err != nil {
return nil
}
return &parsed
}

View File

@@ -0,0 +1,341 @@
package service
import (
"context"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/model"
memoryutils "github.com/LoveLosita/smartflow/backend/services/memory/internal/utils"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
)
// HybridRetrieve 统一承接读取侧 RAG-first 召回链路。
//
// 步骤化说明:
// 1. 优先走 RAG 语义搜索,按 query 相关性召回候选记忆;
// 2. RAG 报错或 0 命中时回退 MySQL保证链路韧性
// 3. 召回结果做三级去重、排序与类型预算裁剪(总量不超过调用方 limit
// 4. 旧 legacy 链路完全保留,方便通过配置快速回滚。
func (s *ReadService) HybridRetrieve(
ctx context.Context,
req memorymodel.RetrieveRequest,
effectiveSetting model.MemoryUserSetting,
limit int,
now time.Time,
) ([]memorymodel.ItemDTO, retrieveTelemetry, error) {
telemetry := retrieveTelemetry{}
if s == nil || s.itemRepo == nil {
return nil, telemetry, nil
}
if !effectiveSetting.MemoryEnabled {
return nil, telemetry, nil
}
// RAG-first只走语义召回不再全量拉 MySQL pinned。
items, semanticTelemetry, err := s.retrieveSemanticCandidates(ctx, req, effectiveSetting, limit, now)
if err != nil {
return nil, telemetry, err
}
telemetry.SemanticHitCount = semanticTelemetry.HitCount
telemetry.Degraded = semanticTelemetry.Degraded
telemetry.RAGFallbackUsed = semanticTelemetry.RAGFallbackUsed
if len(items) == 0 {
return nil, telemetry, nil
}
beforeDedupCount := len(items)
items = dedupByID(items)
items = dedupByHash(items)
items = dedupByText(items)
telemetry.DedupDropCount = beforeDedupCount - len(items)
items = RankItems(items, now)
items = applyTypeBudget(items, s.cfg, limit)
if len(items) == 0 {
return nil, telemetry, nil
}
telemetry.FinalCount = len(items)
_ = s.itemRepo.TouchLastAccessAt(ctx, collectItemDTOIDs(items), now)
return items, telemetry, nil
}
func (s *ReadService) retrievePinnedCandidates(
ctx context.Context,
req memorymodel.RetrieveRequest,
effectiveSetting model.MemoryUserSetting,
now time.Time,
) ([]memorymodel.ItemDTO, error) {
query := buildReadScopedItemQuery(req, now, nil, 0)
items, err := s.itemRepo.FindPinnedByUser(ctx, query, s.cfg.EffectiveReadPreferenceLimit())
if err != nil {
return nil, err
}
items = memoryutils.FilterItemsBySetting(items, effectiveSetting)
return toItemDTOs(items), nil
}
func (s *ReadService) retrieveSemanticCandidates(
ctx context.Context,
req memorymodel.RetrieveRequest,
effectiveSetting model.MemoryUserSetting,
limit int,
now time.Time,
) ([]memorymodel.ItemDTO, semanticRetrieveTelemetry, error) {
telemetry := semanticRetrieveTelemetry{}
queryText := strings.TrimSpace(req.Query)
if queryText == "" {
return nil, telemetry, nil
}
candidateLimit := hybridSemanticTopK(s.cfg, limit)
if s.cfg.RAGEnabled && s.ragRuntime != nil {
items, err := s.retrieveSemanticCandidatesByRAG(ctx, req, effectiveSetting, candidateLimit, now)
if shouldReturnSemanticRAGResult(items, err) {
telemetry.HitCount = len(items)
return items, telemetry, nil
}
telemetry.Degraded = true
telemetry.RAGFallbackUsed = true
}
items, err := s.retrieveSemanticCandidatesByMySQL(ctx, req, effectiveSetting, candidateLimit, now)
telemetry.HitCount = len(items)
return items, telemetry, err
}
func (s *ReadService) retrieveSemanticCandidatesByRAG(
ctx context.Context,
req memorymodel.RetrieveRequest,
effectiveSetting model.MemoryUserSetting,
candidateLimit int,
now time.Time,
) ([]memorymodel.ItemDTO, error) {
result, err := s.ragRuntime.RetrieveMemory(ctx, buildReadScopedRAGRequest(req, candidateLimit, s.cfg.Threshold))
if err != nil {
return nil, err
}
if result == nil || len(result.Items) == 0 {
return nil, nil
}
items := make([]memorymodel.ItemDTO, 0, len(result.Items))
for _, hit := range result.Items {
dto, memoryID := buildMemoryDTOFromRetrieveHit(hit)
if !effectiveSetting.ImplicitMemoryEnabled && !dto.IsExplicit {
continue
}
if !effectiveSetting.SensitiveMemoryEnabled && dto.SensitivityLevel > 0 {
continue
}
if dto.ID <= 0 && memoryID > 0 {
dto.ID = memoryID
}
items = append(items, dto)
}
return items, nil
}
func (s *ReadService) retrieveSemanticCandidatesByMySQL(
ctx context.Context,
req memorymodel.RetrieveRequest,
effectiveSetting model.MemoryUserSetting,
candidateLimit int,
now time.Time,
) ([]memorymodel.ItemDTO, error) {
query := buildReadScopedItemQuery(
req,
now,
[]string{model.MemoryItemStatusActive},
normalizeLimit(candidateLimit, candidateLimit, maxRetrieveLimit),
)
items, err := s.itemRepo.FindByQuery(ctx, query)
if err != nil {
return nil, err
}
items = memoryutils.FilterItemsBySetting(items, effectiveSetting)
return toItemDTOs(items), nil
}
// dedupByID 按 memory_id 去重,后出现的结果覆盖先出现的结果。
func dedupByID(items []memorymodel.ItemDTO) []memorymodel.ItemDTO {
if len(items) == 0 {
return nil
}
seen := make(map[int64]struct{}, len(items))
result := make([]memorymodel.ItemDTO, 0, len(items))
for i := len(items) - 1; i >= 0; i-- {
item := items[i]
if item.ID <= 0 {
result = append(result, item)
continue
}
if _, exists := seen[item.ID]; exists {
continue
}
seen[item.ID] = struct{}{}
result = append(result, item)
}
reverseItemDTOs(result)
return result
}
// dedupByHash 按 content_hash 去重;缺失 hash 时跳过,保留 importance 更高的条目。
func dedupByHash(items []memorymodel.ItemDTO) []memorymodel.ItemDTO {
return dedupByKey(items, func(item memorymodel.ItemDTO) string {
return fallbackContentHash(item.MemoryType, item.Content, item.ContentHash)
})
}
// dedupByText 按“类型标签 + 文本”兜底去重,用于覆盖历史数据未带 hash 的场景。
func dedupByText(items []memorymodel.ItemDTO) []memorymodel.ItemDTO {
return dedupByKey(items, func(item memorymodel.ItemDTO) string {
text := strings.TrimSpace(item.Content)
if text == "" {
text = strings.TrimSpace(item.Title)
}
if text == "" {
return ""
}
return renderMemoryTypeLabelForDedup(item.MemoryType) + "::" + normalizeContentForHash(text)
})
}
func dedupByKey(items []memorymodel.ItemDTO, keyBuilder func(item memorymodel.ItemDTO) string) []memorymodel.ItemDTO {
if len(items) == 0 {
return nil
}
selectedIndex := make(map[string]int, len(items))
for index, item := range items {
key := strings.TrimSpace(keyBuilder(item))
if key == "" {
continue
}
if previous, exists := selectedIndex[key]; exists {
if preferCurrentItem(items[previous], item) {
selectedIndex[key] = index
}
continue
}
selectedIndex[key] = index
}
result := make([]memorymodel.ItemDTO, 0, len(items))
for index, item := range items {
key := strings.TrimSpace(keyBuilder(item))
if key == "" {
result = append(result, item)
continue
}
if selectedIndex[key] == index {
result = append(result, item)
}
}
return result
}
func preferCurrentItem(previous memorymodel.ItemDTO, current memorymodel.ItemDTO) bool {
if current.Importance != previous.Importance {
return current.Importance > previous.Importance
}
if current.Confidence != previous.Confidence {
return current.Confidence > previous.Confidence
}
return true
}
// applyTypeBudget 在排序结果上应用四类记忆预算,并以 callerLimit 作为总量硬上限。
//
// 说明:
// 1. 每种类型先保底自己的预算上限,避免 fact 抢掉 constraint 的位置;
// 2. 裁剪时保持当前排序顺序,不在这里重新打分;
// 3. 最终总量不超过 min(callerLimit, cfg.TotalReadBudget())。
func applyTypeBudget(items []memorymodel.ItemDTO, cfg memorymodel.Config, callerLimit int) []memorymodel.ItemDTO {
if len(items) == 0 {
return nil
}
hardCap := cfg.TotalReadBudget()
if callerLimit > 0 && callerLimit < hardCap {
hardCap = callerLimit
}
budgetByType := map[string]int{
memorymodel.MemoryTypeConstraint: cfg.EffectiveReadConstraintLimit(),
memorymodel.MemoryTypePreference: cfg.EffectiveReadPreferenceLimit(),
memorymodel.MemoryTypeFact: cfg.EffectiveReadFactLimit(),
}
usedByType := make(map[string]int, len(budgetByType))
result := make([]memorymodel.ItemDTO, 0, minInt(len(items), hardCap))
for _, item := range items {
if len(result) >= hardCap {
break
}
memoryType := resolveBudgetMemoryType(item.MemoryType)
if usedByType[memoryType] >= budgetByType[memoryType] {
continue
}
usedByType[memoryType]++
result = append(result, item)
}
return result
}
// hybridSemanticTopK 计算语义召回的候选集大小。
// 使用 callerLimit 的 2 倍作为 TopK保证去重后仍有足够结果填充预算。
func hybridSemanticTopK(cfg memorymodel.Config, limit int) int {
return limit * 2
}
func resolveBudgetMemoryType(memoryType string) string {
normalized := memorymodel.NormalizeMemoryType(memoryType)
if normalized == "" {
return memorymodel.MemoryTypeFact
}
return normalized
}
func renderMemoryTypeLabelForDedup(memoryType string) string {
switch memorymodel.NormalizeMemoryType(memoryType) {
case memorymodel.MemoryTypePreference:
return "偏好"
case memorymodel.MemoryTypeConstraint:
return "约束"
case memorymodel.MemoryTypeFact:
return "事实"
default:
return "记忆"
}
}
func collectItemDTOIDs(items []memorymodel.ItemDTO) []int64 {
if len(items) == 0 {
return nil
}
ids := make([]int64, 0, len(items))
for _, item := range items {
if item.ID <= 0 {
continue
}
ids = append(ids, item.ID)
}
return ids
}
func reverseItemDTOs(items []memorymodel.ItemDTO) {
for left, right := 0, len(items)-1; left < right; left, right = left+1, right-1 {
items[left], items[right] = items[right], items[left]
}
}
func minInt(left, right int) int {
if left < right {
return left
}
return right
}

View File

@@ -0,0 +1,76 @@
package service
import (
"sort"
"time"
memorymodel "github.com/LoveLosita/smartflow/backend/services/memory/model"
)
// RankItems 对读取结果做统一重排。
//
// 步骤化说明:
// 1. 先基于 importance / confidence / recency 构造基础分,保持和旧链路相近的排序直觉;
// 2. 再叠加“显式记忆 / 类型优先级”奖励,让 constraint 与 preference 更稳定地排在前面;
// 3. 同分按 ID 降序,保证排序在日志与测试里具备稳定性。
func RankItems(items []memorymodel.ItemDTO, now time.Time) []memorymodel.ItemDTO {
if len(items) == 0 {
return nil
}
ranked := make([]memorymodel.ItemDTO, len(items))
copy(ranked, items)
sort.SliceStable(ranked, func(i, j int) bool {
left := scoreRankedItem(ranked[i], now)
right := scoreRankedItem(ranked[j], now)
if left == right {
return ranked[i].ID > ranked[j].ID
}
return left > right
})
return ranked
}
// scoreRankedItem 计算 hybrid 读链路的统一重排分数。
//
// 说明:
// 1. 这里仍然只依赖条目自身属性,不引入 conversation_id 加分;
// 2. 原因是同对话内容本就已经存在于上下文窗口,记忆读侧应专注跨对话补充;
// 3. 类型加权仍然保留,用于确保 constraint / preference 的业务优先级稳定生效。
func scoreRankedItem(item memorymodel.ItemDTO, now time.Time) float64 {
score := 0.35*clamp01(item.Importance) + 0.3*clamp01(item.Confidence) + 0.2*recencyScoreDTO(item, now)
if item.IsExplicit {
score += 0.1
}
switch memorymodel.NormalizeMemoryType(item.MemoryType) {
case memorymodel.MemoryTypeConstraint:
score += 0.15
case memorymodel.MemoryTypePreference:
score += 0.10
}
return score
}
func recencyScoreDTO(item memorymodel.ItemDTO, now time.Time) float64 {
base := item.UpdatedAt
if base == nil {
base = item.CreatedAt
}
if base == nil || now.Before(*base) {
return 0.5
}
age := now.Sub(*base)
switch {
case age <= 24*time.Hour:
return 1
case age <= 7*24*time.Hour:
return 0.85
case age <= 30*24*time.Hour:
return 0.65
case age <= 90*24*time.Hour:
return 0.45
default:
return 0.25
}
}