Version: 0.9.65.dev.260503
后端: 1. 阶段 1.5/1.6 收口 llm-service / rag-service,统一模型出口与检索基础设施入口,清退 backend/infra/llm 与 backend/infra/rag 旧实现; 2. 同步更新相关调用链与微服务迁移计划文档
This commit is contained in:
118
backend/services/rag/api.go
Normal file
118
backend/services/rag/api.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Runtime 是 RAG service 对业务侧暴露的唯一稳定方法面。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责承接 memory/web 两类语料的统一入库与检索入口;
|
||||
// 2. 负责屏蔽底层 Pipeline / Store / Embedder / Reranker 的装配细节;
|
||||
// 3. 不负责 provider 搜索、HTML 抓取、prompt 注入等业务语义。
|
||||
type Runtime interface {
|
||||
IngestMemory(ctx context.Context, req MemoryIngestRequest) (*IngestResult, error)
|
||||
RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (*RetrieveResult, error)
|
||||
DeleteMemory(ctx context.Context, documentIDs []string) error
|
||||
|
||||
IngestWeb(ctx context.Context, req WebIngestRequest) (*IngestResult, error)
|
||||
RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (*RetrieveResult, error)
|
||||
}
|
||||
|
||||
// IngestResult 描述一次统一入库执行摘要。
|
||||
type IngestResult struct {
|
||||
DocumentCount int
|
||||
ChunkCount int
|
||||
DocumentIDs []string
|
||||
}
|
||||
|
||||
// RetrieveHit 是对业务侧暴露的统一命中项。
|
||||
type RetrieveHit struct {
|
||||
ChunkID string
|
||||
DocumentID string
|
||||
Text string
|
||||
Score float64
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// RetrieveResult 描述一次检索执行摘要。
|
||||
type RetrieveResult struct {
|
||||
Items []RetrieveHit
|
||||
RawCount int
|
||||
FallbackUsed bool
|
||||
FallbackReason string
|
||||
}
|
||||
|
||||
// MemoryIngestItem 是 memory 语料入库项。
|
||||
type MemoryIngestItem struct {
|
||||
MemoryID int64
|
||||
UserID int
|
||||
ConversationID string
|
||||
AssistantID string
|
||||
RunID string
|
||||
MemoryType string
|
||||
Title string
|
||||
Content string
|
||||
Confidence float64
|
||||
Importance float64
|
||||
SensitivityLevel int
|
||||
IsExplicit bool
|
||||
Status string
|
||||
TTLAt *time.Time
|
||||
CreatedAt *time.Time
|
||||
}
|
||||
|
||||
// MemoryIngestRequest 描述一次记忆向量入库请求。
|
||||
type MemoryIngestRequest struct {
|
||||
TraceID string
|
||||
Action string
|
||||
Items []MemoryIngestItem
|
||||
}
|
||||
|
||||
// MemoryRetrieveRequest 描述一次记忆检索请求。
|
||||
type MemoryRetrieveRequest struct {
|
||||
TraceID string
|
||||
Query string
|
||||
TopK int
|
||||
Threshold float64
|
||||
Action string
|
||||
UserID int
|
||||
ConversationID string
|
||||
AssistantID string
|
||||
RunID string
|
||||
MemoryTypes []string
|
||||
}
|
||||
|
||||
// WebIngestItem 是网页语料入库项。
|
||||
type WebIngestItem struct {
|
||||
URL string
|
||||
Title string
|
||||
Content string
|
||||
Snippet string
|
||||
Domain string
|
||||
QueryID string
|
||||
SessionID string
|
||||
PublishedAt *time.Time
|
||||
FetchedAt *time.Time
|
||||
SourceRank int
|
||||
}
|
||||
|
||||
// WebIngestRequest 描述一次网页语料入库请求。
|
||||
type WebIngestRequest struct {
|
||||
TraceID string
|
||||
Action string
|
||||
Items []WebIngestItem
|
||||
}
|
||||
|
||||
// WebRetrieveRequest 描述一次网页检索请求。
|
||||
type WebRetrieveRequest struct {
|
||||
TraceID string
|
||||
Query string
|
||||
TopK int
|
||||
Threshold float64
|
||||
Action string
|
||||
QueryID string
|
||||
SessionID string
|
||||
Domain string
|
||||
}
|
||||
85
backend/services/rag/chunk/text_chunker.go
Normal file
85
backend/services/rag/chunk/text_chunker.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package chunk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
// TextChunker 是默认文本切块器。
|
||||
type TextChunker struct{}
|
||||
|
||||
func NewTextChunker() *TextChunker {
|
||||
return &TextChunker{}
|
||||
}
|
||||
|
||||
// Chunk 对文本执行固定窗口切块。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先做空白归一,避免无效块进入向量库;
|
||||
// 2. 再按 chunk_size/overlap 滑窗切割;
|
||||
// 3. 每块继承原文 metadata,并补充 chunk 序号。
|
||||
func (c *TextChunker) Chunk(_ context.Context, doc core.SourceDocument, opt core.ChunkOption) ([]core.Chunk, error) {
|
||||
if strings.TrimSpace(doc.ID) == "" {
|
||||
return nil, fmt.Errorf("empty document id")
|
||||
}
|
||||
text := strings.TrimSpace(doc.Text)
|
||||
if text == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if opt.ChunkSize <= 0 {
|
||||
opt.ChunkSize = 400
|
||||
}
|
||||
if opt.ChunkOverlap < 0 {
|
||||
opt.ChunkOverlap = 0
|
||||
}
|
||||
if opt.ChunkOverlap >= opt.ChunkSize {
|
||||
opt.ChunkOverlap = opt.ChunkSize / 5
|
||||
}
|
||||
|
||||
runes := []rune(text)
|
||||
step := opt.ChunkSize - opt.ChunkOverlap
|
||||
if step <= 0 {
|
||||
step = opt.ChunkSize
|
||||
}
|
||||
|
||||
result := make([]core.Chunk, 0, len(runes)/step+1)
|
||||
order := 0
|
||||
for start := 0; start < len(runes); start += step {
|
||||
end := start + opt.ChunkSize
|
||||
if end > len(runes) {
|
||||
end = len(runes)
|
||||
}
|
||||
chunkText := strings.TrimSpace(string(runes[start:end]))
|
||||
if chunkText == "" {
|
||||
continue
|
||||
}
|
||||
metadata := cloneMap(doc.Metadata)
|
||||
metadata["chunk_order"] = order
|
||||
result = append(result, core.Chunk{
|
||||
ID: fmt.Sprintf("%s#%d", doc.ID, order),
|
||||
DocumentID: doc.ID,
|
||||
Text: chunkText,
|
||||
Order: order,
|
||||
Metadata: metadata,
|
||||
})
|
||||
order++
|
||||
if end == len(runes) {
|
||||
break
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
113
backend/services/rag/config/config.go
Normal file
113
backend/services/rag/config/config.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package config
|
||||
|
||||
import "github.com/spf13/viper"
|
||||
|
||||
// Config 是 RAG Core 运行配置。
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
Store string
|
||||
TopK int
|
||||
|
||||
Threshold float64
|
||||
|
||||
EmbedProvider string
|
||||
EmbedModel string
|
||||
EmbedBaseURL string
|
||||
EmbedTimeoutMS int
|
||||
EmbedDimension int
|
||||
|
||||
RerankerEnabled bool
|
||||
RerankerProvider string
|
||||
RerankerTimeoutMS int
|
||||
|
||||
ChunkSize int
|
||||
ChunkOverlap int
|
||||
|
||||
RetrieveTimeoutMS int
|
||||
|
||||
MilvusAddress string
|
||||
MilvusToken string
|
||||
MilvusDBName string
|
||||
MilvusCollectionName string
|
||||
MilvusMetricType string
|
||||
MilvusRequestTimeoutMS int
|
||||
}
|
||||
|
||||
// LoadFromViper 读取 rag 配置并补默认值。
|
||||
func LoadFromViper() Config {
|
||||
cfg := Config{
|
||||
Enabled: viper.GetBool("rag.enabled"),
|
||||
Store: viper.GetString("rag.store"),
|
||||
TopK: viper.GetInt("rag.topK"),
|
||||
Threshold: viper.GetFloat64("rag.threshold"),
|
||||
EmbedProvider: viper.GetString("rag.embed.provider"),
|
||||
EmbedModel: viper.GetString("rag.embed.model"),
|
||||
EmbedBaseURL: viper.GetString("rag.embed.baseURL"),
|
||||
EmbedTimeoutMS: viper.GetInt("rag.embed.timeoutMs"),
|
||||
EmbedDimension: viper.GetInt("rag.embed.dimension"),
|
||||
RerankerEnabled: viper.GetBool("rag.reranker.enabled"),
|
||||
RerankerProvider: viper.GetString("rag.reranker.provider"),
|
||||
RerankerTimeoutMS: viper.GetInt("rag.reranker.timeoutMs"),
|
||||
ChunkSize: viper.GetInt("rag.ingest.chunkSize"),
|
||||
ChunkOverlap: viper.GetInt("rag.ingest.chunkOverlap"),
|
||||
RetrieveTimeoutMS: viper.GetInt("rag.retrieve.timeoutMs"),
|
||||
MilvusAddress: viper.GetString("rag.milvus.address"),
|
||||
MilvusToken: viper.GetString("rag.milvus.token"),
|
||||
MilvusDBName: viper.GetString("rag.milvus.dbName"),
|
||||
MilvusCollectionName: viper.GetString("rag.milvus.collectionName"),
|
||||
MilvusMetricType: viper.GetString("rag.milvus.metricType"),
|
||||
MilvusRequestTimeoutMS: viper.GetInt("rag.milvus.requestTimeoutMs"),
|
||||
}
|
||||
if cfg.Store == "" {
|
||||
cfg.Store = "inmemory"
|
||||
}
|
||||
if cfg.TopK <= 0 {
|
||||
cfg.TopK = 8
|
||||
}
|
||||
if cfg.Threshold < 0 {
|
||||
cfg.Threshold = 0
|
||||
}
|
||||
if cfg.EmbedProvider == "" {
|
||||
cfg.EmbedProvider = "mock"
|
||||
}
|
||||
if cfg.EmbedBaseURL == "" {
|
||||
cfg.EmbedBaseURL = viper.GetString("agent.baseURL")
|
||||
}
|
||||
if cfg.EmbedTimeoutMS <= 0 {
|
||||
cfg.EmbedTimeoutMS = 1200
|
||||
}
|
||||
if cfg.EmbedDimension <= 0 {
|
||||
cfg.EmbedDimension = 1024
|
||||
}
|
||||
if cfg.RerankerProvider == "" {
|
||||
cfg.RerankerProvider = "noop"
|
||||
}
|
||||
if cfg.RerankerTimeoutMS <= 0 {
|
||||
cfg.RerankerTimeoutMS = 1200
|
||||
}
|
||||
if cfg.ChunkSize <= 0 {
|
||||
cfg.ChunkSize = 400
|
||||
}
|
||||
if cfg.ChunkOverlap < 0 {
|
||||
cfg.ChunkOverlap = 80
|
||||
}
|
||||
if cfg.RetrieveTimeoutMS <= 0 {
|
||||
cfg.RetrieveTimeoutMS = 1500
|
||||
}
|
||||
if cfg.MilvusAddress == "" {
|
||||
cfg.MilvusAddress = "http://localhost:19530"
|
||||
}
|
||||
if cfg.MilvusToken == "" {
|
||||
cfg.MilvusToken = "root:Milvus"
|
||||
}
|
||||
if cfg.MilvusCollectionName == "" {
|
||||
cfg.MilvusCollectionName = "smartflow_rag_chunks"
|
||||
}
|
||||
if cfg.MilvusMetricType == "" {
|
||||
cfg.MilvusMetricType = "COSINE"
|
||||
}
|
||||
if cfg.MilvusRequestTimeoutMS <= 0 {
|
||||
cfg.MilvusRequestTimeoutMS = 1500
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
17
backend/services/rag/core/errors.go
Normal file
17
backend/services/rag/core/errors.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package core
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrInvalidQuery 表示检索请求缺少有效 query。
|
||||
ErrInvalidQuery = errors.New("invalid query")
|
||||
// ErrInvalidTopK 表示 topK 非法。
|
||||
ErrInvalidTopK = errors.New("invalid top_k")
|
||||
// ErrNilDependency 表示 pipeline 关键依赖未注入。
|
||||
ErrNilDependency = errors.New("nil dependency")
|
||||
)
|
||||
|
||||
const (
|
||||
// FallbackReasonRerankFailed 表示 rerank 失败后降级。
|
||||
FallbackReasonRerankFailed = "RERANK_FAILED"
|
||||
)
|
||||
38
backend/services/rag/core/interfaces.go
Normal file
38
backend/services/rag/core/interfaces.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package core
|
||||
|
||||
import "context"
|
||||
|
||||
// Chunker 负责文本切块。
|
||||
type Chunker interface {
|
||||
Chunk(ctx context.Context, doc SourceDocument, opt ChunkOption) ([]Chunk, error)
|
||||
}
|
||||
|
||||
// Embedder 负责向量化。
|
||||
type Embedder interface {
|
||||
Embed(ctx context.Context, texts []string, action string) ([][]float32, error)
|
||||
}
|
||||
|
||||
// Retriever 负责召回候选。
|
||||
type Retriever interface {
|
||||
Retrieve(ctx context.Context, req RetrieveRequest) ([]ScoredChunk, error)
|
||||
}
|
||||
|
||||
// Reranker 负责重排候选。
|
||||
type Reranker interface {
|
||||
Rerank(ctx context.Context, query string, candidates []ScoredChunk, topK int) ([]ScoredChunk, error)
|
||||
}
|
||||
|
||||
// VectorStore 负责向量库读写。
|
||||
type VectorStore interface {
|
||||
Upsert(ctx context.Context, rows []VectorRow) error
|
||||
Search(ctx context.Context, req VectorSearchRequest) ([]ScoredVectorRow, error)
|
||||
Delete(ctx context.Context, ids []string) error
|
||||
Get(ctx context.Context, ids []string) ([]VectorRow, error)
|
||||
}
|
||||
|
||||
// CorpusAdapter 负责把业务语料映射成统一文档/过滤条件。
|
||||
type CorpusAdapter interface {
|
||||
Name() string
|
||||
BuildIngestDocuments(ctx context.Context, input any) ([]SourceDocument, error)
|
||||
BuildRetrieveFilter(ctx context.Context, req any) (map[string]any, error)
|
||||
}
|
||||
190
backend/services/rag/core/observer.go
Normal file
190
backend/services/rag/core/observer.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ObserveLevel 表示观测事件等级。
|
||||
type ObserveLevel string
|
||||
|
||||
const (
|
||||
ObserveLevelInfo ObserveLevel = "info"
|
||||
ObserveLevelWarn ObserveLevel = "warn"
|
||||
ObserveLevelError ObserveLevel = "error"
|
||||
)
|
||||
|
||||
// ObserveEvent 描述一次统一观测事件。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只承载 RAG service 的结构化运行信息;
|
||||
// 2. 不绑定具体日志系统、指标系统或 tracing 实现;
|
||||
// 3. 字段内容应尽量稳定,便于后续统一接入全局观测平台。
|
||||
type ObserveEvent struct {
|
||||
Level ObserveLevel
|
||||
Component string
|
||||
Operation string
|
||||
Fields map[string]any
|
||||
}
|
||||
|
||||
// Observer 是 RAG service 的最小观测接口。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责消费结构化事件;
|
||||
// 2. 不负责决定业务逻辑是否继续执行;
|
||||
// 3. 任一实现都不应反向影响主链路稳定性。
|
||||
type Observer interface {
|
||||
Observe(ctx context.Context, event ObserveEvent)
|
||||
}
|
||||
|
||||
// ObserverFunc 允许用函数快速适配 Observer。
|
||||
type ObserverFunc func(ctx context.Context, event ObserveEvent)
|
||||
|
||||
func (f ObserverFunc) Observe(ctx context.Context, event ObserveEvent) {
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
f(ctx, event)
|
||||
}
|
||||
|
||||
// NewNopObserver 返回空实现,适合在未接入统一观测平台时兜底。
|
||||
func NewNopObserver() Observer {
|
||||
return ObserverFunc(func(context.Context, ObserveEvent) {})
|
||||
}
|
||||
|
||||
// NewLoggerObserver 返回标准日志适配器。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 当前项目尚未建立统一日志平台时,先把结构化字段稳定打印出来;
|
||||
// 2. 后续若项目引入统一 logger/metrics/tracing,只需替换该 Observer 注入实现;
|
||||
// 3. 该适配器默认保持单行输出,减少和现有日志风格的割裂感。
|
||||
func NewLoggerObserver(logger *log.Logger) Observer {
|
||||
if logger == nil {
|
||||
logger = log.Default()
|
||||
}
|
||||
return &loggerObserver{logger: logger}
|
||||
}
|
||||
|
||||
type loggerObserver struct {
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func (o *loggerObserver) Observe(ctx context.Context, event ObserveEvent) {
|
||||
if o == nil || o.logger == nil {
|
||||
return
|
||||
}
|
||||
|
||||
level := strings.TrimSpace(string(event.Level))
|
||||
if level == "" {
|
||||
level = string(ObserveLevelInfo)
|
||||
}
|
||||
component := strings.TrimSpace(event.Component)
|
||||
if component == "" {
|
||||
component = "unknown"
|
||||
}
|
||||
operation := strings.TrimSpace(event.Operation)
|
||||
if operation == "" {
|
||||
operation = "unknown"
|
||||
}
|
||||
|
||||
fields := ObserveFieldsFromContext(ctx)
|
||||
for key, value := range event.Fields {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" || !shouldKeepObserveField(value) {
|
||||
continue
|
||||
}
|
||||
fields[key] = value
|
||||
}
|
||||
|
||||
parts := []string{
|
||||
"rag",
|
||||
fmt.Sprintf("level=%s", level),
|
||||
fmt.Sprintf("component=%s", component),
|
||||
fmt.Sprintf("operation=%s", operation),
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(fields))
|
||||
for key := range fields {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, key := range keys {
|
||||
parts = append(parts, fmt.Sprintf("%s=%v", key, fields[key]))
|
||||
}
|
||||
|
||||
o.logger.Print(strings.Join(parts, " "))
|
||||
}
|
||||
|
||||
type observeFieldsContextKey struct{}
|
||||
|
||||
// WithObserveFields 把通用观测字段挂入上下文,便于下游组件复用。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先读取已有上下文字段,保证 Runtime / Pipeline / Store 能逐层补充信息;
|
||||
// 2. 后写字段覆盖同名旧值,确保下游拿到的是最新语义;
|
||||
// 3. 仅保存“有意义”的字段,避免日志长期堆积大量空值。
|
||||
func WithObserveFields(ctx context.Context, fields map[string]any) context.Context {
|
||||
if len(fields) == 0 {
|
||||
return ctx
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
merged := ObserveFieldsFromContext(ctx)
|
||||
for key, value := range fields {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" || !shouldKeepObserveField(value) {
|
||||
continue
|
||||
}
|
||||
merged[key] = value
|
||||
}
|
||||
if len(merged) == 0 {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, observeFieldsContextKey{}, merged)
|
||||
}
|
||||
|
||||
// ObserveFieldsFromContext 提取上下文中已经累积的观测字段。
|
||||
func ObserveFieldsFromContext(ctx context.Context) map[string]any {
|
||||
if ctx == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
raw, ok := ctx.Value(observeFieldsContextKey{}).(map[string]any)
|
||||
if !ok || len(raw) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
result := make(map[string]any, len(raw))
|
||||
for key, value := range raw {
|
||||
result[key] = value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ClassifyErrorCode 统一把常见错误压缩为稳定错误码,便于后续接入全局观测平台。
|
||||
func ClassifyErrorCode(err error) string {
|
||||
switch {
|
||||
case err == nil:
|
||||
return ""
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
return "DEADLINE_EXCEEDED"
|
||||
case errors.Is(err, context.Canceled):
|
||||
return "CANCELED"
|
||||
default:
|
||||
return "RAG_ERROR"
|
||||
}
|
||||
}
|
||||
|
||||
func shouldKeepObserveField(value any) bool {
|
||||
if value == nil {
|
||||
return false
|
||||
}
|
||||
if text, ok := value.(string); ok {
|
||||
return strings.TrimSpace(text) != ""
|
||||
}
|
||||
return true
|
||||
}
|
||||
366
backend/services/rag/core/pipeline.go
Normal file
366
backend/services/rag/core/pipeline.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTopK = 8
|
||||
defaultThreshold = 0
|
||||
defaultChunkSize = 400
|
||||
defaultChunkOvLap = 80
|
||||
)
|
||||
|
||||
// Pipeline 是 RAG Core 编排器。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责统一 chunk/embed/retrieve/rerank 流程;
|
||||
// 2. 负责失败降级语义;
|
||||
// 3. 不承载任何具体业务语义(由 CorpusAdapter 提供)。
|
||||
type Pipeline struct {
|
||||
chunker Chunker
|
||||
embedder Embedder
|
||||
store VectorStore
|
||||
reranker Reranker
|
||||
logger *log.Logger
|
||||
observer Observer
|
||||
}
|
||||
|
||||
func NewPipeline(chunker Chunker, embedder Embedder, store VectorStore, reranker Reranker) *Pipeline {
|
||||
return &Pipeline{
|
||||
chunker: chunker,
|
||||
embedder: embedder,
|
||||
store: store,
|
||||
reranker: reranker,
|
||||
logger: log.Default(),
|
||||
observer: NewNopObserver(),
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogger 设置 Pipeline 使用的日志器。
|
||||
func (p *Pipeline) SetLogger(logger *log.Logger) {
|
||||
if p == nil || logger == nil {
|
||||
return
|
||||
}
|
||||
p.logger = logger
|
||||
}
|
||||
|
||||
// SetObserver 设置 Pipeline 使用的统一观测器。
|
||||
func (p *Pipeline) SetObserver(observer Observer) {
|
||||
if p == nil || observer == nil {
|
||||
return
|
||||
}
|
||||
p.observer = observer
|
||||
}
|
||||
|
||||
// Ingest 执行统一入库流程。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先由 CorpusAdapter 生成统一文档,确保不同语料入口一致;
|
||||
// 2. 再统一切块与向量化,避免业务侧重复实现;
|
||||
// 3. 最后一次性 Upsert,失败直接返回,交由上层决定是否重试。
|
||||
func (p *Pipeline) Ingest(
|
||||
ctx context.Context,
|
||||
corpus CorpusAdapter,
|
||||
input any,
|
||||
opt IngestOption,
|
||||
) (result *IngestResult, err error) {
|
||||
defer p.recoverExecutionPanic(ctx, "ingest", &err)
|
||||
|
||||
if p == nil || p.chunker == nil || p.embedder == nil || p.store == nil {
|
||||
return nil, ErrNilDependency
|
||||
}
|
||||
if corpus == nil {
|
||||
return nil, errors.New("nil corpus adapter")
|
||||
}
|
||||
|
||||
docs, err := corpus.BuildIngestDocuments(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p.IngestDocuments(ctx, corpus.Name(), docs, opt)
|
||||
}
|
||||
|
||||
// IngestDocuments 执行“已标准化文档”的统一入库流程。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责处理已经完成 CorpusAdapter 映射的标准文档;
|
||||
// 2. 负责统一切块、向量化与 Upsert;
|
||||
// 3. 不负责再做业务输入解析,避免 Runtime 为拿到 document_id 重复 build 文档。
|
||||
func (p *Pipeline) IngestDocuments(
|
||||
ctx context.Context,
|
||||
corpusName string,
|
||||
docs []SourceDocument,
|
||||
opt IngestOption,
|
||||
) (result *IngestResult, err error) {
|
||||
defer p.recoverExecutionPanic(ctx, "ingest_documents", &err)
|
||||
|
||||
if p == nil || p.chunker == nil || p.embedder == nil || p.store == nil {
|
||||
return nil, ErrNilDependency
|
||||
}
|
||||
if len(docs) == 0 {
|
||||
return &IngestResult{DocumentCount: 0, ChunkCount: 0}, nil
|
||||
}
|
||||
|
||||
chunkOpt := normalizeChunkOption(opt.Chunk)
|
||||
chunks := make([]Chunk, 0, len(docs)*2)
|
||||
for _, doc := range docs {
|
||||
// 1. 对每个文档独立切块,失败直接中断,避免写入半成品。
|
||||
docChunks, chunkErr := p.chunker.Chunk(ctx, doc, chunkOpt)
|
||||
if chunkErr != nil {
|
||||
return nil, chunkErr
|
||||
}
|
||||
chunks = append(chunks, docChunks...)
|
||||
}
|
||||
if len(chunks) == 0 {
|
||||
return &IngestResult{DocumentCount: len(docs), ChunkCount: 0}, nil
|
||||
}
|
||||
|
||||
texts := make([]string, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
texts = append(texts, chunk.Text)
|
||||
}
|
||||
|
||||
action := strings.TrimSpace(opt.Action)
|
||||
if action == "" {
|
||||
action = "add"
|
||||
}
|
||||
vectors, err := p.embedder.Embed(ctx, texts, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vectors) != len(chunks) {
|
||||
return nil, fmt.Errorf("embedding result length mismatch: chunks=%d vectors=%d", len(chunks), len(vectors))
|
||||
}
|
||||
|
||||
rows := make([]VectorRow, 0, len(chunks))
|
||||
now := time.Now()
|
||||
for i, chunk := range chunks {
|
||||
metadata := cloneMap(chunk.Metadata)
|
||||
metadata["corpus"] = corpusName
|
||||
metadata["document_id"] = chunk.DocumentID
|
||||
metadata["chunk_order"] = chunk.Order
|
||||
rows = append(rows, VectorRow{
|
||||
ID: chunk.ID,
|
||||
Vector: vectors[i],
|
||||
Text: chunk.Text,
|
||||
Metadata: metadata,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
})
|
||||
}
|
||||
|
||||
if err = p.store.Upsert(ctx, rows); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &IngestResult{
|
||||
DocumentCount: len(docs),
|
||||
ChunkCount: len(chunks),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Retrieve 执行统一检索流程。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先做 query 向量化与向量检索;
|
||||
// 2. 再执行阈值过滤,减少低质量候选;
|
||||
// 3. 最后可选 rerank,若失败则降级回原排序并打日志。
|
||||
func (p *Pipeline) Retrieve(
|
||||
ctx context.Context,
|
||||
corpus CorpusAdapter,
|
||||
req RetrieveRequest,
|
||||
) (result *RetrieveResult, err error) {
|
||||
defer p.recoverExecutionPanic(ctx, "retrieve", &err)
|
||||
|
||||
if p == nil || p.embedder == nil || p.store == nil {
|
||||
return nil, ErrNilDependency
|
||||
}
|
||||
query := strings.TrimSpace(req.Query)
|
||||
if query == "" {
|
||||
return nil, ErrInvalidQuery
|
||||
}
|
||||
|
||||
topK := req.TopK
|
||||
if topK <= 0 {
|
||||
topK = defaultTopK
|
||||
}
|
||||
threshold := req.Threshold
|
||||
if threshold < 0 {
|
||||
threshold = defaultThreshold
|
||||
}
|
||||
|
||||
filter := cloneMap(req.Filter)
|
||||
if corpus != nil {
|
||||
// 1. 先拼接 corpus 过滤条件,避免跨语料串召回。
|
||||
corpusFilter, err := corpus.BuildRetrieveFilter(ctx, req.CorpusInput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filter = mergeMap(filter, corpusFilter)
|
||||
filter["corpus"] = corpus.Name()
|
||||
}
|
||||
|
||||
action := strings.TrimSpace(req.Action)
|
||||
if action == "" {
|
||||
action = "search"
|
||||
}
|
||||
|
||||
vectors, err := p.embedder.Embed(ctx, []string{query}, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vectors) != 1 {
|
||||
return nil, fmt.Errorf("embedding query length mismatch: %d", len(vectors))
|
||||
}
|
||||
|
||||
scoredRows, err := p.store.Search(ctx, VectorSearchRequest{
|
||||
QueryVector: vectors[0],
|
||||
TopK: topK,
|
||||
Filter: filter,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawCount := len(scoredRows)
|
||||
candidates := make([]ScoredChunk, 0, len(scoredRows))
|
||||
for _, row := range scoredRows {
|
||||
if row.Score < threshold {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, ScoredChunk{
|
||||
ChunkID: row.Row.ID,
|
||||
DocumentID: asString(row.Row.Metadata["document_id"]),
|
||||
Text: row.Row.Text,
|
||||
Score: row.Score,
|
||||
Metadata: cloneMap(row.Row.Metadata),
|
||||
})
|
||||
}
|
||||
|
||||
result = &RetrieveResult{
|
||||
Items: candidates,
|
||||
RawCount: rawCount,
|
||||
FallbackUsed: false,
|
||||
}
|
||||
if len(candidates) == 0 || p.reranker == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
reranked, rerankErr := p.reranker.Rerank(ctx, query, candidates, topK)
|
||||
if rerankErr != nil {
|
||||
// 2. rerank 异常不终止主流程,统一降级为原排序。
|
||||
result.FallbackUsed = true
|
||||
result.FallbackReason = FallbackReasonRerankFailed
|
||||
if p.observer != nil {
|
||||
p.observer.Observe(ctx, ObserveEvent{
|
||||
Level: ObserveLevelWarn,
|
||||
Component: "pipeline",
|
||||
Operation: "rerank_fallback",
|
||||
Fields: map[string]any{
|
||||
"status": "fallback",
|
||||
"fallback_reason": FallbackReasonRerankFailed,
|
||||
"candidate_count": len(candidates),
|
||||
"top_k": topK,
|
||||
"error": rerankErr,
|
||||
"error_code": ClassifyErrorCode(rerankErr),
|
||||
},
|
||||
})
|
||||
} else if p.logger != nil {
|
||||
p.logger.Printf("rag rerank fallback: reason=%s err=%v", FallbackReasonRerankFailed, rerankErr)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
result.Items = reranked
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Delete 删除指定 ID 的向量。
|
||||
func (p *Pipeline) Delete(ctx context.Context, ids []string) error {
|
||||
if p == nil || p.store == nil {
|
||||
return nil
|
||||
}
|
||||
return p.store.Delete(ctx, ids)
|
||||
}
|
||||
|
||||
func (p *Pipeline) recoverExecutionPanic(ctx context.Context, operation string, errPtr *error) {
|
||||
recovered := recover()
|
||||
if recovered == nil || errPtr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
panicErr := fmt.Errorf("rag pipeline panic recovered: operation=%s panic=%v", operation, recovered)
|
||||
*errPtr = panicErr
|
||||
|
||||
// 1. Pipeline 是 chunk/embed/store/rerank 的统一编排边界,第三方依赖异常不应直接杀掉上层请求。
|
||||
// 2. 这里统一 recover 后继续走 error 语义,让 runtime/service 决定降级、回退或记日志。
|
||||
// 3. stack 只写观测层,不塞进返回值,避免把超长堆栈直接暴露给上层业务错误文案。
|
||||
if p != nil && p.observer != nil {
|
||||
p.observer.Observe(ctx, ObserveEvent{
|
||||
Level: ObserveLevelError,
|
||||
Component: "pipeline",
|
||||
Operation: operation + "_panic_recovered",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"panic": fmt.Sprintf("%v", recovered),
|
||||
"panic_type": fmt.Sprintf("%T", recovered),
|
||||
"error": panicErr,
|
||||
"error_code": ClassifyErrorCode(panicErr),
|
||||
"stack": string(debug.Stack()),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
if p != nil && p.logger != nil {
|
||||
p.logger.Printf("rag pipeline panic recovered: operation=%s panic=%v stack=%s", operation, recovered, string(debug.Stack()))
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeChunkOption(opt ChunkOption) ChunkOption {
|
||||
if opt.ChunkSize <= 0 {
|
||||
opt.ChunkSize = defaultChunkSize
|
||||
}
|
||||
if opt.ChunkOverlap < 0 {
|
||||
opt.ChunkOverlap = 0
|
||||
}
|
||||
if opt.ChunkOverlap >= opt.ChunkSize {
|
||||
opt.ChunkOverlap = defaultChunkOvLap
|
||||
if opt.ChunkOverlap >= opt.ChunkSize {
|
||||
opt.ChunkOverlap = opt.ChunkSize / 5
|
||||
}
|
||||
}
|
||||
return opt
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for key, value := range src {
|
||||
dst[key] = value
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func mergeMap(base map[string]any, ext map[string]any) map[string]any {
|
||||
if base == nil {
|
||||
base = map[string]any{}
|
||||
}
|
||||
for key, value := range ext {
|
||||
base[key] = value
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
94
backend/services/rag/core/types.go
Normal file
94
backend/services/rag/core/types.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package core
|
||||
|
||||
import "time"
|
||||
|
||||
// SourceDocument 是统一语料文档模型。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只描述“可被切块与索引”的原始文档;
|
||||
// 2. 不承载业务流程状态。
|
||||
type SourceDocument struct {
|
||||
ID string
|
||||
Text string
|
||||
Title string
|
||||
Metadata map[string]any
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// Chunk 是标准切块结果。
|
||||
type Chunk struct {
|
||||
ID string
|
||||
DocumentID string
|
||||
Text string
|
||||
Order int
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// ChunkOption 控制切块参数。
|
||||
type ChunkOption struct {
|
||||
ChunkSize int
|
||||
ChunkOverlap int
|
||||
}
|
||||
|
||||
// IngestOption 控制入库参数。
|
||||
type IngestOption struct {
|
||||
Chunk ChunkOption
|
||||
// Action 用于 embedding 分型(add/update/search)。
|
||||
Action string
|
||||
}
|
||||
|
||||
// IngestResult 描述一次入库执行摘要。
|
||||
type IngestResult struct {
|
||||
DocumentCount int
|
||||
ChunkCount int
|
||||
}
|
||||
|
||||
// RetrieveRequest 是统一检索请求。
|
||||
type RetrieveRequest struct {
|
||||
Query string
|
||||
TopK int
|
||||
Threshold float64
|
||||
Action string
|
||||
Filter map[string]any
|
||||
CorpusInput any
|
||||
}
|
||||
|
||||
// ScoredChunk 是统一召回结果。
|
||||
type ScoredChunk struct {
|
||||
ChunkID string
|
||||
DocumentID string
|
||||
Text string
|
||||
Score float64
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// RetrieveResult 是检索链路执行摘要。
|
||||
type RetrieveResult struct {
|
||||
Items []ScoredChunk
|
||||
RawCount int
|
||||
FallbackUsed bool
|
||||
FallbackReason string
|
||||
}
|
||||
|
||||
// VectorRow 是向量存储标准行。
|
||||
type VectorRow struct {
|
||||
ID string
|
||||
Vector []float32
|
||||
Text string
|
||||
Metadata map[string]any
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// VectorSearchRequest 是向量检索请求。
|
||||
type VectorSearchRequest struct {
|
||||
QueryVector []float32
|
||||
TopK int
|
||||
Filter map[string]any
|
||||
}
|
||||
|
||||
// ScoredVectorRow 是向量检索结果。
|
||||
type ScoredVectorRow struct {
|
||||
Row VectorRow
|
||||
Score float64
|
||||
}
|
||||
13
backend/services/rag/corpus/common.go
Normal file
13
backend/services/rag/corpus/common.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package corpus
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func hashLikeText(text string) string {
|
||||
normalized := strings.TrimSpace(strings.ToLower(text))
|
||||
sum := sha256.Sum256([]byte(normalized))
|
||||
return hex.EncodeToString(sum[:8])
|
||||
}
|
||||
158
backend/services/rag/corpus/memory_corpus.go
Normal file
158
backend/services/rag/corpus/memory_corpus.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package corpus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
const memoryCorpusName = "memory"
|
||||
|
||||
// MemoryIngestItem 是记忆语料入库项。
|
||||
type MemoryIngestItem struct {
|
||||
MemoryID int64
|
||||
UserID int
|
||||
ConversationID string
|
||||
AssistantID string
|
||||
RunID string
|
||||
MemoryType string
|
||||
Title string
|
||||
Content string
|
||||
Confidence float64
|
||||
Importance float64
|
||||
SensitivityLevel int
|
||||
IsExplicit bool
|
||||
Status string
|
||||
TTLAt *time.Time
|
||||
CreatedAt *time.Time
|
||||
}
|
||||
|
||||
// MemoryRetrieveInput 是记忆检索过滤输入。
|
||||
type MemoryRetrieveInput struct {
|
||||
UserID int
|
||||
ConversationID string
|
||||
AssistantID string
|
||||
RunID string
|
||||
MemoryType string
|
||||
}
|
||||
|
||||
// MemoryCorpus 是记忆语料适配器。
|
||||
type MemoryCorpus struct{}
|
||||
|
||||
func NewMemoryCorpus() *MemoryCorpus {
|
||||
return &MemoryCorpus{}
|
||||
}
|
||||
|
||||
func (c *MemoryCorpus) Name() string {
|
||||
return memoryCorpusName
|
||||
}
|
||||
|
||||
func (c *MemoryCorpus) BuildIngestDocuments(_ context.Context, input any) ([]core.SourceDocument, error) {
|
||||
items, err := toMemoryItems(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]core.SourceDocument, 0, len(items))
|
||||
for _, item := range items {
|
||||
if item.UserID <= 0 {
|
||||
return nil, errors.New("memory ingest item user_id is invalid")
|
||||
}
|
||||
text := strings.TrimSpace(item.Content)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
docID := fmt.Sprintf("memory:%d", item.MemoryID)
|
||||
if item.MemoryID <= 0 {
|
||||
docID = fmt.Sprintf("memory:uid:%d:%s", item.UserID, hashLikeText(text))
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"user_id": item.UserID,
|
||||
"conversation_id": strings.TrimSpace(item.ConversationID),
|
||||
"assistant_id": strings.TrimSpace(item.AssistantID),
|
||||
"run_id": strings.TrimSpace(item.RunID),
|
||||
"memory_type": strings.TrimSpace(strings.ToLower(item.MemoryType)),
|
||||
"title": strings.TrimSpace(item.Title),
|
||||
"confidence": item.Confidence,
|
||||
"importance": item.Importance,
|
||||
"sensitivity_level": item.SensitivityLevel,
|
||||
"is_explicit": item.IsExplicit,
|
||||
"status": strings.TrimSpace(item.Status),
|
||||
}
|
||||
if item.TTLAt != nil {
|
||||
metadata["ttl_at"] = item.TTLAt.Format(time.RFC3339)
|
||||
}
|
||||
createdAt := time.Now()
|
||||
if item.CreatedAt != nil {
|
||||
createdAt = *item.CreatedAt
|
||||
}
|
||||
result = append(result, core.SourceDocument{
|
||||
ID: docID,
|
||||
Text: text,
|
||||
Title: strings.TrimSpace(item.Title),
|
||||
Metadata: metadata,
|
||||
CreatedAt: createdAt,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *MemoryCorpus) BuildRetrieveFilter(_ context.Context, req any) (map[string]any, error) {
|
||||
input, ok := req.(MemoryRetrieveInput)
|
||||
if !ok {
|
||||
if ptr, isPtr := req.(*MemoryRetrieveInput); isPtr && ptr != nil {
|
||||
input = *ptr
|
||||
} else if req == nil {
|
||||
return nil, errors.New("memory retrieve input is nil")
|
||||
} else {
|
||||
return nil, errors.New("invalid memory retrieve input")
|
||||
}
|
||||
}
|
||||
if input.UserID <= 0 {
|
||||
return nil, errors.New("memory retrieve user_id is invalid")
|
||||
}
|
||||
filter := map[string]any{
|
||||
"user_id": input.UserID,
|
||||
}
|
||||
if v := strings.TrimSpace(input.ConversationID); v != "" {
|
||||
filter["conversation_id"] = v
|
||||
}
|
||||
if v := strings.TrimSpace(input.AssistantID); v != "" {
|
||||
filter["assistant_id"] = v
|
||||
}
|
||||
if v := strings.TrimSpace(input.RunID); v != "" {
|
||||
filter["run_id"] = v
|
||||
}
|
||||
if v := strings.TrimSpace(strings.ToLower(input.MemoryType)); v != "" {
|
||||
filter["memory_type"] = v
|
||||
}
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func toMemoryItems(input any) ([]MemoryIngestItem, error) {
|
||||
switch value := input.(type) {
|
||||
case MemoryIngestItem:
|
||||
return []MemoryIngestItem{value}, nil
|
||||
case *MemoryIngestItem:
|
||||
if value == nil {
|
||||
return nil, errors.New("memory ingest item is nil")
|
||||
}
|
||||
return []MemoryIngestItem{*value}, nil
|
||||
case []MemoryIngestItem:
|
||||
return value, nil
|
||||
case []*MemoryIngestItem:
|
||||
items := make([]MemoryIngestItem, 0, len(value))
|
||||
for _, ptr := range value {
|
||||
if ptr == nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, *ptr)
|
||||
}
|
||||
return items, nil
|
||||
default:
|
||||
return nil, errors.New("invalid memory ingest input")
|
||||
}
|
||||
}
|
||||
163
backend/services/rag/corpus/web_corpus.go
Normal file
163
backend/services/rag/corpus/web_corpus.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package corpus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
const webCorpusName = "web"
|
||||
|
||||
// WebIngestItem 是网页语料入库项。
|
||||
type WebIngestItem struct {
|
||||
URL string
|
||||
Title string
|
||||
Content string
|
||||
Snippet string
|
||||
Domain string
|
||||
QueryID string
|
||||
SessionID string
|
||||
PublishedAt *time.Time
|
||||
FetchedAt *time.Time
|
||||
SourceRank int
|
||||
}
|
||||
|
||||
// WebRetrieveInput 是网页检索过滤输入。
|
||||
type WebRetrieveInput struct {
|
||||
QueryID string
|
||||
SessionID string
|
||||
Domain string
|
||||
}
|
||||
|
||||
// WebCorpus 是网页语料适配器。
|
||||
type WebCorpus struct{}
|
||||
|
||||
func NewWebCorpus() *WebCorpus {
|
||||
return &WebCorpus{}
|
||||
}
|
||||
|
||||
func (c *WebCorpus) Name() string {
|
||||
return webCorpusName
|
||||
}
|
||||
|
||||
func (c *WebCorpus) BuildIngestDocuments(_ context.Context, input any) ([]core.SourceDocument, error) {
|
||||
items, err := toWebItems(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]core.SourceDocument, 0, len(items))
|
||||
for _, item := range items {
|
||||
url := strings.TrimSpace(item.URL)
|
||||
if url == "" {
|
||||
return nil, errors.New("web ingest item url is empty")
|
||||
}
|
||||
|
||||
mainText := buildWebText(item)
|
||||
if strings.TrimSpace(mainText) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
docID := fmt.Sprintf("web:%s", hashLikeText(url+"|"+mainText))
|
||||
metadata := map[string]any{
|
||||
"url": url,
|
||||
"domain": strings.TrimSpace(item.Domain),
|
||||
"query_id": strings.TrimSpace(item.QueryID),
|
||||
"session_id": strings.TrimSpace(item.SessionID),
|
||||
"source_rank": item.SourceRank,
|
||||
}
|
||||
if item.PublishedAt != nil {
|
||||
metadata["published_at"] = item.PublishedAt.Format(time.RFC3339)
|
||||
}
|
||||
if item.FetchedAt != nil {
|
||||
metadata["fetched_at"] = item.FetchedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
createdAt := time.Now()
|
||||
if item.FetchedAt != nil {
|
||||
createdAt = *item.FetchedAt
|
||||
}
|
||||
result = append(result, core.SourceDocument{
|
||||
ID: docID,
|
||||
Text: mainText,
|
||||
Title: strings.TrimSpace(item.Title),
|
||||
Metadata: metadata,
|
||||
CreatedAt: createdAt,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *WebCorpus) BuildRetrieveFilter(_ context.Context, req any) (map[string]any, error) {
|
||||
input, ok := req.(WebRetrieveInput)
|
||||
if !ok {
|
||||
if ptr, isPtr := req.(*WebRetrieveInput); isPtr && ptr != nil {
|
||||
input = *ptr
|
||||
} else if req == nil {
|
||||
return nil, errors.New("web retrieve input is nil")
|
||||
} else {
|
||||
return nil, errors.New("invalid web retrieve input")
|
||||
}
|
||||
}
|
||||
|
||||
// 1. query_id/session_id 至少要有一个,避免跨问题串数据。
|
||||
queryID := strings.TrimSpace(input.QueryID)
|
||||
sessionID := strings.TrimSpace(input.SessionID)
|
||||
if queryID == "" && sessionID == "" {
|
||||
return nil, errors.New("web retrieve filter requires query_id or session_id")
|
||||
}
|
||||
|
||||
filter := map[string]any{}
|
||||
if queryID != "" {
|
||||
filter["query_id"] = queryID
|
||||
}
|
||||
if sessionID != "" {
|
||||
filter["session_id"] = sessionID
|
||||
}
|
||||
if domain := strings.TrimSpace(input.Domain); domain != "" {
|
||||
filter["domain"] = domain
|
||||
}
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func toWebItems(input any) ([]WebIngestItem, error) {
|
||||
switch value := input.(type) {
|
||||
case WebIngestItem:
|
||||
return []WebIngestItem{value}, nil
|
||||
case *WebIngestItem:
|
||||
if value == nil {
|
||||
return nil, errors.New("web ingest item is nil")
|
||||
}
|
||||
return []WebIngestItem{*value}, nil
|
||||
case []WebIngestItem:
|
||||
return value, nil
|
||||
case []*WebIngestItem:
|
||||
items := make([]WebIngestItem, 0, len(value))
|
||||
for _, ptr := range value {
|
||||
if ptr == nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, *ptr)
|
||||
}
|
||||
return items, nil
|
||||
default:
|
||||
return nil, errors.New("invalid web ingest input")
|
||||
}
|
||||
}
|
||||
|
||||
func buildWebText(item WebIngestItem) string {
|
||||
parts := make([]string, 0, 3)
|
||||
if title := strings.TrimSpace(item.Title); title != "" {
|
||||
parts = append(parts, title)
|
||||
}
|
||||
if snippet := strings.TrimSpace(item.Snippet); snippet != "" {
|
||||
parts = append(parts, snippet)
|
||||
}
|
||||
if content := strings.TrimSpace(item.Content); content != "" {
|
||||
parts = append(parts, content)
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
208
backend/services/rag/embed/eino_embedder.go
Normal file
208
backend/services/rag/embed/eino_embedder.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
openaiembedding "github.com/cloudwego/eino-ext/libs/acl/openai"
|
||||
einoembedding "github.com/cloudwego/eino/components/embedding"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||||
arkmodel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
// EinoConfig 描述 Eino embedding 运行参数。
|
||||
type EinoConfig struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
Model string
|
||||
TimeoutMS int
|
||||
Dimension int
|
||||
}
|
||||
|
||||
// EinoEmbedder 是基于 Eino 的 embedding 适配器。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 对 infra/rag 暴露统一 []float32 结果,屏蔽底层 SDK 的实现差异。
|
||||
// 2. 文本 embedding 继续走当前稳定的 OpenAI 兼容链路,避免无关模型受影响。
|
||||
// 3. 多模态 embedding 模型单独走 Ark 原生 `/embeddings/multimodal`,解决 vision 模型与标准 `/embeddings` 不兼容的问题。
|
||||
type EinoEmbedder struct {
|
||||
textClient einoembedding.Embedder
|
||||
multimodalClient *arkruntime.Client
|
||||
model string
|
||||
timeout time.Duration
|
||||
dimension int
|
||||
}
|
||||
|
||||
func NewEinoEmbedder(ctx context.Context, cfg EinoConfig) (*EinoEmbedder, error) {
|
||||
if strings.TrimSpace(cfg.APIKey) == "" {
|
||||
return nil, errors.New("eino embedder api key is empty")
|
||||
}
|
||||
if strings.TrimSpace(cfg.Model) == "" {
|
||||
return nil, errors.New("eino embedder model is empty")
|
||||
}
|
||||
|
||||
timeout := 1200 * time.Millisecond
|
||||
if cfg.TimeoutMS > 0 {
|
||||
timeout = time.Duration(cfg.TimeoutMS) * time.Millisecond
|
||||
}
|
||||
|
||||
baseURL := normalizeEmbeddingBaseURL(cfg.BaseURL)
|
||||
model := strings.TrimSpace(cfg.Model)
|
||||
httpClient := &http.Client{Timeout: timeout}
|
||||
|
||||
// 1. `doubao-embedding-vision-*` 这类模型不支持标准 `/embeddings`。
|
||||
// 2. 这里直接切到 Ark 原生多模态 embedding API,避免再依赖错误 endpoint 拼接。
|
||||
// 3. 之所以仍保留文本链路,是为了不影响普通 text embedding 模型的既有行为。
|
||||
if isMultimodalEmbeddingModel(model) {
|
||||
arkOptions := []arkruntime.ConfigOption{
|
||||
arkruntime.WithHTTPClient(httpClient),
|
||||
}
|
||||
if baseURL != "" {
|
||||
arkOptions = append(arkOptions, arkruntime.WithBaseUrl(baseURL))
|
||||
}
|
||||
|
||||
return &EinoEmbedder{
|
||||
multimodalClient: arkruntime.NewClientWithApiKey(
|
||||
strings.TrimSpace(cfg.APIKey),
|
||||
arkOptions...,
|
||||
),
|
||||
model: model,
|
||||
timeout: timeout,
|
||||
dimension: cfg.Dimension,
|
||||
}, nil
|
||||
}
|
||||
|
||||
clientCfg := &openaiembedding.EmbeddingConfig{
|
||||
APIKey: strings.TrimSpace(cfg.APIKey),
|
||||
BaseURL: baseURL,
|
||||
Model: model,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
if cfg.Dimension > 0 {
|
||||
clientCfg.Dimensions = &cfg.Dimension
|
||||
}
|
||||
|
||||
client, err := openaiembedding.NewEmbeddingClient(ctx, clientCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &EinoEmbedder{
|
||||
textClient: client,
|
||||
model: model,
|
||||
timeout: timeout,
|
||||
dimension: cfg.Dimension,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *EinoEmbedder) Embed(ctx context.Context, texts []string, _ string) (result [][]float32, err error) {
|
||||
if e == nil {
|
||||
return nil, errors.New("eino embedder is not initialized")
|
||||
}
|
||||
if len(texts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
callCtx := ctx
|
||||
cancel := func() {}
|
||||
if e.timeout > 0 {
|
||||
callCtx, cancel = context.WithTimeout(ctx, e.timeout)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
// 1. 第三方 SDK 一旦 panic,不应该穿透到 RAG 主链路。
|
||||
// 2. 这里统一在模型调用边界 recover,并转成 error 交给上层做降级。
|
||||
// 3. 这样 memory 主写链路和 agent 主回复链路都不会因为向量同步失败被直接打崩。
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
err = fmt.Errorf("eino embedder panic recovered: %v", recovered)
|
||||
result = nil
|
||||
}
|
||||
}()
|
||||
|
||||
if e.multimodalClient != nil {
|
||||
return e.embedTextsWithMultimodalAPI(callCtx, texts)
|
||||
}
|
||||
if e.textClient == nil {
|
||||
return nil, errors.New("eino embedder client is not initialized")
|
||||
}
|
||||
|
||||
vectors, err := e.textClient.EmbedStrings(callCtx, texts, einoembedding.WithModel(e.model))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = make([][]float32, 0, len(vectors))
|
||||
for _, vector := range vectors {
|
||||
converted := make([]float32, len(vector))
|
||||
for i, value := range vector {
|
||||
converted[i] = float32(value)
|
||||
}
|
||||
result = append(result, converted)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (e *EinoEmbedder) embedTextsWithMultimodalAPI(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
if e.multimodalClient == nil {
|
||||
return nil, errors.New("eino multimodal embedder client is not initialized")
|
||||
}
|
||||
|
||||
vectors := make([][]float32, 0, len(texts))
|
||||
for _, text := range texts {
|
||||
text := text
|
||||
req := arkmodel.MultiModalEmbeddingRequest{
|
||||
Model: e.model,
|
||||
Input: []arkmodel.MultimodalEmbeddingInput{
|
||||
{
|
||||
Type: arkmodel.MultiModalEmbeddingInputTypeText,
|
||||
Text: &text,
|
||||
},
|
||||
},
|
||||
}
|
||||
if e.dimension > 0 {
|
||||
req.Dimensions = &e.dimension
|
||||
}
|
||||
|
||||
// 1. Ark 的多模态 embedding 请求体是“单条内容由多个 part 组成”。
|
||||
// 2. 当前 RAG 这里只传文本,因此每段文本单独发一次,避免把多段文本错误拼成同一个 multimodal sample。
|
||||
// 3. 一旦后续真的要做批量多模态 embedding,再单独扩展 batch 接口,而不是在这里偷改语义。
|
||||
resp, err := e.multimodalClient.CreateMultiModalEmbeddings(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
converted := make([]float32, len(resp.Data.Embedding))
|
||||
copy(converted, resp.Data.Embedding)
|
||||
vectors = append(vectors, converted)
|
||||
}
|
||||
return vectors, nil
|
||||
}
|
||||
|
||||
func isMultimodalEmbeddingModel(model string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "doubao-embedding-vision-")
|
||||
}
|
||||
|
||||
func normalizeEmbeddingBaseURL(raw string) string {
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(raw), "/")
|
||||
if baseURL == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
lowerBaseURL := strings.ToLower(baseURL)
|
||||
|
||||
// 1. 配置里应填写 Ark 服务根路径,而不是具体 embedding endpoint。
|
||||
// 2. 这里兼容两类常见误配:`/embeddings` 和 `/embeddings/multimodal`。
|
||||
// 3. 统一回退到 `/api/v3` 根路径后,再由对应 SDK 自己追加正确后缀,避免最终 URL 重复拼接。
|
||||
if strings.HasSuffix(lowerBaseURL, "/embeddings/multimodal") {
|
||||
return strings.TrimSuffix(baseURL, baseURL[len(baseURL)-len("/embeddings/multimodal"):])
|
||||
}
|
||||
if strings.HasSuffix(lowerBaseURL, "/embeddings") {
|
||||
return strings.TrimSuffix(baseURL, baseURL[len(baseURL)-len("/embeddings"):])
|
||||
}
|
||||
return baseURL
|
||||
}
|
||||
46
backend/services/rag/embed/mock_embedder.go
Normal file
46
backend/services/rag/embed/mock_embedder.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const defaultDim = 16
|
||||
|
||||
// MockEmbedder 是本地可运行的占位向量化实现。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 该实现用于开发阶段打通链路,不代表真实语义向量质量;
|
||||
// 2. 后续可替换为 Eino embedding 实现,接口保持不变。
|
||||
type MockEmbedder struct {
|
||||
dim int
|
||||
}
|
||||
|
||||
func NewMockEmbedder(dim int) *MockEmbedder {
|
||||
if dim <= 0 {
|
||||
dim = defaultDim
|
||||
}
|
||||
return &MockEmbedder{dim: dim}
|
||||
}
|
||||
|
||||
func (e *MockEmbedder) Embed(_ context.Context, texts []string, _ string) ([][]float32, error) {
|
||||
result := make([][]float32, 0, len(texts))
|
||||
for _, text := range texts {
|
||||
result = append(result, e.embedOne(text))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (e *MockEmbedder) embedOne(text string) []float32 {
|
||||
normalized := strings.TrimSpace(strings.ToLower(text))
|
||||
sum := sha256.Sum256([]byte(normalized))
|
||||
vec := make([]float32, e.dim)
|
||||
for i := 0; i < e.dim; i++ {
|
||||
offset := (i * 4) % len(sum)
|
||||
v := binary.BigEndian.Uint32(sum[offset : offset+4])
|
||||
vec[i] = float32(v%1000) / 1000
|
||||
}
|
||||
return vec
|
||||
}
|
||||
142
backend/services/rag/factory.go
Normal file
142
backend/services/rag/factory.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
ragchunk "github.com/LoveLosita/smartflow/backend/services/rag/chunk"
|
||||
ragconfig "github.com/LoveLosita/smartflow/backend/services/rag/config"
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
ragembed "github.com/LoveLosita/smartflow/backend/services/rag/embed"
|
||||
ragrerank "github.com/LoveLosita/smartflow/backend/services/rag/rerank"
|
||||
ragstore "github.com/LoveLosita/smartflow/backend/services/rag/store"
|
||||
)
|
||||
|
||||
// FactoryDeps 描述 Runtime 工厂所需的可选依赖。
|
||||
//
|
||||
// 说明:
|
||||
// 1. Logger 仅作为“当前项目尚无统一日志系统”时的默认落点;
|
||||
// 2. Observer 是正式的统一观测插槽,后续可替换为项目级 logger/metrics/tracing 适配器;
|
||||
// 3. 业务侧仍然只拿 Runtime,不直接碰底层装配细节。
|
||||
type FactoryDeps struct {
|
||||
Logger *log.Logger
|
||||
Observer Observer
|
||||
}
|
||||
|
||||
// NewRuntimeFromConfig 按配置统一组装 RAG Runtime。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 所有底层实现选择都收口到这里,业务侧不再自行 new store/embedder/reranker;
|
||||
// 2. 即使后续引入更多 provider,也应优先扩展本工厂,而不是把选择逻辑扩散到业务模块;
|
||||
// 3. 观测能力也在此统一注入,避免 runtime/store/pipeline 各自偷偷打印日志。
|
||||
func NewRuntimeFromConfig(ctx context.Context, cfg ragconfig.Config, deps FactoryDeps) (Runtime, error) {
|
||||
logger, observer := normalizeFactoryDeps(deps)
|
||||
|
||||
embedder, err := buildEmbedder(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
store, err := buildStore(cfg, logger, observer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reranker, err := buildReranker(cfg, observer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pipeline := core.NewPipeline(ragchunk.NewTextChunker(), embedder, store, reranker)
|
||||
pipeline.SetLogger(logger)
|
||||
pipeline.SetObserver(observer)
|
||||
return newRuntime(cfg, pipeline, observer), nil
|
||||
}
|
||||
|
||||
func normalizeFactoryDeps(deps FactoryDeps) (*log.Logger, Observer) {
|
||||
logger := deps.Logger
|
||||
if logger == nil {
|
||||
logger = log.Default()
|
||||
}
|
||||
observer := deps.Observer
|
||||
if observer == nil {
|
||||
observer = NewLoggerObserver(logger)
|
||||
}
|
||||
return logger, observer
|
||||
}
|
||||
|
||||
func buildEmbedder(ctx context.Context, cfg ragconfig.Config) (core.Embedder, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.EmbedProvider)) {
|
||||
case "", "mock":
|
||||
return ragembed.NewMockEmbedder(cfg.EmbedDimension), nil
|
||||
case "eino":
|
||||
// 1. RAG embedding 与普通 LLM 链路保持同一套密钥来源,统一直接读取 ARK_API_KEY;
|
||||
// 2. 这样可以避免再维护一层 “env 名称配置 -> 再读环境变量” 的间接映射,减少配置分叉;
|
||||
// 3. 若后续真的需要多套 embedding 凭据,再显式设计独立字段,而不是继续隐式透传 env 名称。
|
||||
apiKey := strings.TrimSpace(os.Getenv("ARK_API_KEY"))
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("rag embed api key is empty: env=%s", "ARK_API_KEY")
|
||||
}
|
||||
return ragembed.NewEinoEmbedder(ctx, ragembed.EinoConfig{
|
||||
APIKey: apiKey,
|
||||
BaseURL: cfg.EmbedBaseURL,
|
||||
Model: cfg.EmbedModel,
|
||||
TimeoutMS: cfg.EmbedTimeoutMS,
|
||||
Dimension: cfg.EmbedDimension,
|
||||
})
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported rag embed provider: %s", cfg.EmbedProvider)
|
||||
}
|
||||
}
|
||||
|
||||
func buildStore(cfg ragconfig.Config, logger *log.Logger, observer Observer) (core.VectorStore, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.Store)) {
|
||||
case "", "inmemory":
|
||||
return ragstore.NewInMemoryVectorStore(), nil
|
||||
case "milvus":
|
||||
return ragstore.NewMilvusStore(ragstore.MilvusConfig{
|
||||
Address: cfg.MilvusAddress,
|
||||
Token: cfg.MilvusToken,
|
||||
DBName: cfg.MilvusDBName,
|
||||
CollectionName: cfg.MilvusCollectionName,
|
||||
RequestTimeoutMS: cfg.MilvusRequestTimeoutMS,
|
||||
Dimension: cfg.EmbedDimension,
|
||||
MetricType: cfg.MilvusMetricType,
|
||||
Logger: logger,
|
||||
Observer: observer,
|
||||
})
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported rag store: %s", cfg.Store)
|
||||
}
|
||||
}
|
||||
|
||||
func buildReranker(cfg ragconfig.Config, observer Observer) (core.Reranker, error) {
|
||||
if !cfg.RerankerEnabled {
|
||||
return ragrerank.NewNoopReranker(), nil
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.RerankerProvider)) {
|
||||
case "", "noop":
|
||||
return ragrerank.NewNoopReranker(), nil
|
||||
case "eino":
|
||||
if observer != nil {
|
||||
observer.Observe(context.Background(), ObserveEvent{
|
||||
Level: ObserveLevelWarn,
|
||||
Component: "factory",
|
||||
Operation: "reranker_fallback",
|
||||
Fields: map[string]any{
|
||||
"provider": "eino",
|
||||
"status": "fallback",
|
||||
"fallback_target": "noop",
|
||||
"reason": "reranker_not_implemented",
|
||||
},
|
||||
})
|
||||
}
|
||||
return ragrerank.NewNoopReranker(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported rag reranker provider: %s", cfg.RerankerProvider)
|
||||
}
|
||||
}
|
||||
32
backend/services/rag/observe.go
Normal file
32
backend/services/rag/observe.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
// ObserveLevel 对外暴露统一观测等级别名,避免启动层直接依赖 core 细节。
|
||||
type ObserveLevel = core.ObserveLevel
|
||||
|
||||
const (
|
||||
ObserveLevelInfo = core.ObserveLevelInfo
|
||||
ObserveLevelWarn = core.ObserveLevelWarn
|
||||
ObserveLevelError = core.ObserveLevelError
|
||||
)
|
||||
|
||||
// ObserveEvent 对外暴露统一观测事件别名。
|
||||
type ObserveEvent = core.ObserveEvent
|
||||
|
||||
// Observer 对外暴露统一观测接口别名。
|
||||
type Observer = core.Observer
|
||||
|
||||
// NewNopObserver 返回空实现。
|
||||
func NewNopObserver() Observer {
|
||||
return core.NewNopObserver()
|
||||
}
|
||||
|
||||
// NewLoggerObserver 返回标准日志适配器。
|
||||
func NewLoggerObserver(logger *log.Logger) Observer {
|
||||
return core.NewLoggerObserver(logger)
|
||||
}
|
||||
23
backend/services/rag/rag.go
Normal file
23
backend/services/rag/rag.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/chunk"
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/embed"
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/rerank"
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/store"
|
||||
)
|
||||
|
||||
// NewDefaultPipeline 构造默认可运行的 RAG Pipeline。
|
||||
//
|
||||
// 当前策略:
|
||||
// 1. 默认使用本地 MockEmbedder + InMemoryStore,保证零外部依赖可运行;
|
||||
// 2. 后续切 Milvus / Eino 时仅替换依赖,不改业务调用方式。
|
||||
func NewDefaultPipeline() *core.Pipeline {
|
||||
return core.NewPipeline(
|
||||
chunk.NewTextChunker(),
|
||||
embed.NewMockEmbedder(16),
|
||||
store.NewInMemoryVectorStore(),
|
||||
rerank.NewNoopReranker(),
|
||||
)
|
||||
}
|
||||
19
backend/services/rag/rerank/eino_reranker.go
Normal file
19
backend/services/rag/rerank/eino_reranker.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
// EinoReranker 是 Eino 重排器占位实现。
|
||||
type EinoReranker struct{}
|
||||
|
||||
func NewEinoReranker() *EinoReranker {
|
||||
return &EinoReranker{}
|
||||
}
|
||||
|
||||
func (r *EinoReranker) Rerank(_ context.Context, _ string, _ []core.ScoredChunk, _ int) ([]core.ScoredChunk, error) {
|
||||
return nil, errors.New("eino reranker is not implemented yet")
|
||||
}
|
||||
30
backend/services/rag/rerank/noop_reranker.go
Normal file
30
backend/services/rag/rerank/noop_reranker.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
// NoopReranker 是默认重排器(仅按原 score 排序)。
|
||||
type NoopReranker struct{}
|
||||
|
||||
func NewNoopReranker() *NoopReranker {
|
||||
return &NoopReranker{}
|
||||
}
|
||||
|
||||
func (r *NoopReranker) Rerank(_ context.Context, _ string, candidates []core.ScoredChunk, topK int) ([]core.ScoredChunk, error) {
|
||||
if len(candidates) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
sorted := make([]core.ScoredChunk, len(candidates))
|
||||
copy(sorted, candidates)
|
||||
sort.SliceStable(sorted, func(i, j int) bool {
|
||||
return sorted[i].Score > sorted[j].Score
|
||||
})
|
||||
if topK <= 0 || topK >= len(sorted) {
|
||||
return sorted, nil
|
||||
}
|
||||
return sorted[:topK], nil
|
||||
}
|
||||
98
backend/services/rag/retrieve/vector_retriever.go
Normal file
98
backend/services/rag/retrieve/vector_retriever.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package retrieve
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
// VectorRetriever 是通用检索器(embed + vector search)。
|
||||
type VectorRetriever struct {
|
||||
embedder core.Embedder
|
||||
store core.VectorStore
|
||||
}
|
||||
|
||||
func NewVectorRetriever(embedder core.Embedder, store core.VectorStore) *VectorRetriever {
|
||||
return &VectorRetriever{
|
||||
embedder: embedder,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *VectorRetriever) Retrieve(ctx context.Context, req core.RetrieveRequest) (result []core.ScoredChunk, err error) {
|
||||
defer func() {
|
||||
recovered := recover()
|
||||
if recovered == nil {
|
||||
return
|
||||
}
|
||||
err = fmt.Errorf("vector retriever panic recovered: %v", recovered)
|
||||
}()
|
||||
|
||||
if r == nil || r.embedder == nil || r.store == nil {
|
||||
return nil, core.ErrNilDependency
|
||||
}
|
||||
query := strings.TrimSpace(req.Query)
|
||||
if query == "" {
|
||||
return nil, core.ErrInvalidQuery
|
||||
}
|
||||
topK := req.TopK
|
||||
if topK <= 0 {
|
||||
topK = 8
|
||||
}
|
||||
action := strings.TrimSpace(req.Action)
|
||||
if action == "" {
|
||||
action = "search"
|
||||
}
|
||||
|
||||
vectors, err := r.embedder.Embed(ctx, []string{query}, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vectors) != 1 {
|
||||
return nil, fmt.Errorf("embedding query length mismatch: %d", len(vectors))
|
||||
}
|
||||
|
||||
rows, err := r.store.Search(ctx, core.VectorSearchRequest{
|
||||
QueryVector: vectors[0],
|
||||
TopK: topK,
|
||||
Filter: req.Filter,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = make([]core.ScoredChunk, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
if row.Score < req.Threshold {
|
||||
continue
|
||||
}
|
||||
result = append(result, core.ScoredChunk{
|
||||
ChunkID: row.Row.ID,
|
||||
DocumentID: asString(row.Row.Metadata["document_id"]),
|
||||
Text: row.Row.Text,
|
||||
Score: row.Score,
|
||||
Metadata: cloneMap(row.Row.Metadata),
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
434
backend/services/rag/runtime.go
Normal file
434
backend/services/rag/runtime.go
Normal file
@@ -0,0 +1,434 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
ragconfig "github.com/LoveLosita/smartflow/backend/services/rag/config"
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/corpus"
|
||||
)
|
||||
|
||||
type runtime struct {
|
||||
cfg ragconfig.Config
|
||||
pipeline *core.Pipeline
|
||||
memoryCorpus *corpus.MemoryCorpus
|
||||
webCorpus *corpus.WebCorpus
|
||||
observer Observer
|
||||
}
|
||||
|
||||
func newRuntime(cfg ragconfig.Config, pipeline *core.Pipeline, observer Observer) Runtime {
|
||||
if observer == nil {
|
||||
observer = NewNopObserver()
|
||||
}
|
||||
return &runtime{
|
||||
cfg: cfg,
|
||||
pipeline: pipeline,
|
||||
memoryCorpus: corpus.NewMemoryCorpus(),
|
||||
webCorpus: corpus.NewWebCorpus(),
|
||||
observer: observer,
|
||||
}
|
||||
}
|
||||
|
||||
// IngestMemory 统一承接记忆语料入库。
|
||||
func (r *runtime) IngestMemory(ctx context.Context, req MemoryIngestRequest) (result *IngestResult, err error) {
|
||||
defer r.recoverPublicPanic(ctx, req.TraceID, "memory", normalizeAction(req.Action, "add"), "ingest", &err)
|
||||
|
||||
items := make([]corpus.MemoryIngestItem, 0, len(req.Items))
|
||||
for _, item := range req.Items {
|
||||
items = append(items, corpus.MemoryIngestItem{
|
||||
MemoryID: item.MemoryID,
|
||||
UserID: item.UserID,
|
||||
ConversationID: item.ConversationID,
|
||||
AssistantID: item.AssistantID,
|
||||
RunID: item.RunID,
|
||||
MemoryType: item.MemoryType,
|
||||
Title: item.Title,
|
||||
Content: item.Content,
|
||||
Confidence: item.Confidence,
|
||||
Importance: item.Importance,
|
||||
SensitivityLevel: item.SensitivityLevel,
|
||||
IsExplicit: item.IsExplicit,
|
||||
Status: item.Status,
|
||||
TTLAt: item.TTLAt,
|
||||
CreatedAt: item.CreatedAt,
|
||||
})
|
||||
}
|
||||
return r.ingestWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, items, req.Action)
|
||||
}
|
||||
|
||||
// RetrieveMemory 统一承接记忆语料检索。
|
||||
func (r *runtime) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (result *RetrieveResult, err error) {
|
||||
defer r.recoverPublicPanic(ctx, req.TraceID, "memory", normalizeAction(req.Action, "search"), "retrieve", &err)
|
||||
|
||||
corpusInput := corpus.MemoryRetrieveInput{
|
||||
UserID: req.UserID,
|
||||
ConversationID: req.ConversationID,
|
||||
AssistantID: req.AssistantID,
|
||||
RunID: req.RunID,
|
||||
}
|
||||
if len(req.MemoryTypes) == 1 {
|
||||
corpusInput.MemoryType = req.MemoryTypes[0]
|
||||
}
|
||||
|
||||
result, err = r.retrieveWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, core.RetrieveRequest{
|
||||
Query: req.Query,
|
||||
TopK: normalizeTopK(req.TopK, r.cfg.TopK),
|
||||
Threshold: normalizeThreshold(req.Threshold, r.cfg.Threshold),
|
||||
Action: normalizeAction(req.Action, "search"),
|
||||
CorpusInput: corpusInput,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(req.MemoryTypes) <= 1 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 1. 当前底层过滤仍以等值条件为主,先保持 Runtime 做多类型二次筛选;
|
||||
// 2. 这样可以避免把 “memory_type in (...)” 的实现细节扩散到所有 Store;
|
||||
// 3. 等后续底层过滤能力统一后,再考虑把该逻辑继续下沉。
|
||||
allowed := make(map[string]struct{}, len(req.MemoryTypes))
|
||||
for _, item := range req.MemoryTypes {
|
||||
value := strings.TrimSpace(strings.ToLower(item))
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
allowed[value] = struct{}{}
|
||||
}
|
||||
|
||||
filtered := make([]RetrieveHit, 0, len(result.Items))
|
||||
for _, item := range result.Items {
|
||||
memoryType := strings.TrimSpace(strings.ToLower(asString(item.Metadata["memory_type"])))
|
||||
if len(allowed) > 0 {
|
||||
if _, ok := allowed[memoryType]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
result.Items = filtered
|
||||
if req.TopK > 0 && len(result.Items) > req.TopK {
|
||||
result.Items = result.Items[:req.TopK]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteMemory 删除记忆语料中的指定向量。
|
||||
func (r *runtime) DeleteMemory(ctx context.Context, documentIDs []string) (err error) {
|
||||
defer r.recoverPublicPanic(ctx, "", "memory", "delete", "delete", &err)
|
||||
|
||||
if r == nil || r.pipeline == nil || len(documentIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return r.pipeline.Delete(ctx, documentIDs)
|
||||
}
|
||||
|
||||
// IngestWeb 统一承接网页语料入库。
|
||||
func (r *runtime) IngestWeb(ctx context.Context, req WebIngestRequest) (result *IngestResult, err error) {
|
||||
defer r.recoverPublicPanic(ctx, req.TraceID, "web", normalizeAction(req.Action, "add"), "ingest", &err)
|
||||
|
||||
items := make([]corpus.WebIngestItem, 0, len(req.Items))
|
||||
for _, item := range req.Items {
|
||||
items = append(items, corpus.WebIngestItem{
|
||||
URL: item.URL,
|
||||
Title: item.Title,
|
||||
Content: item.Content,
|
||||
Snippet: item.Snippet,
|
||||
Domain: item.Domain,
|
||||
QueryID: item.QueryID,
|
||||
SessionID: item.SessionID,
|
||||
PublishedAt: item.PublishedAt,
|
||||
FetchedAt: item.FetchedAt,
|
||||
SourceRank: item.SourceRank,
|
||||
})
|
||||
}
|
||||
return r.ingestWithCorpus(ctx, req.TraceID, "web", r.webCorpus, items, req.Action)
|
||||
}
|
||||
|
||||
// RetrieveWeb 统一承接网页语料检索。
|
||||
func (r *runtime) RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (result *RetrieveResult, err error) {
|
||||
defer r.recoverPublicPanic(ctx, req.TraceID, "web", normalizeAction(req.Action, "search"), "retrieve", &err)
|
||||
|
||||
return r.retrieveWithCorpus(ctx, req.TraceID, "web", r.webCorpus, core.RetrieveRequest{
|
||||
Query: req.Query,
|
||||
TopK: normalizeTopK(req.TopK, r.cfg.TopK),
|
||||
Threshold: normalizeThreshold(req.Threshold, r.cfg.Threshold),
|
||||
Action: normalizeAction(req.Action, "search"),
|
||||
CorpusInput: corpus.WebRetrieveInput{
|
||||
QueryID: req.QueryID,
|
||||
SessionID: req.SessionID,
|
||||
Domain: req.Domain,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (r *runtime) ingestWithCorpus(
|
||||
ctx context.Context,
|
||||
traceID string,
|
||||
corpusName string,
|
||||
adapter core.CorpusAdapter,
|
||||
input any,
|
||||
action string,
|
||||
) (*IngestResult, error) {
|
||||
start := time.Now()
|
||||
if r == nil || r.pipeline == nil || adapter == nil {
|
||||
return nil, core.ErrNilDependency
|
||||
}
|
||||
|
||||
action = normalizeAction(action, "add")
|
||||
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
|
||||
|
||||
docs, err := adapter.BuildIngestDocuments(observeCtx, input)
|
||||
if err != nil {
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelError,
|
||||
Component: "runtime",
|
||||
Operation: "ingest",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"phase": "build_documents",
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
"input_count": estimateInputCount(input),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
docIDs := make([]string, 0, len(docs))
|
||||
for _, doc := range docs {
|
||||
docIDs = append(docIDs, doc.ID)
|
||||
}
|
||||
|
||||
result, err := r.pipeline.IngestDocuments(observeCtx, adapter.Name(), docs, core.IngestOption{
|
||||
Chunk: core.ChunkOption{
|
||||
ChunkSize: r.cfg.ChunkSize,
|
||||
ChunkOverlap: r.cfg.ChunkOverlap,
|
||||
},
|
||||
Action: action,
|
||||
})
|
||||
if err != nil {
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelError,
|
||||
Component: "runtime",
|
||||
Operation: "ingest",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"document_count": len(docs),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelInfo,
|
||||
Component: "runtime",
|
||||
Operation: "ingest",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"document_count": result.DocumentCount,
|
||||
"chunk_count": result.ChunkCount,
|
||||
},
|
||||
})
|
||||
return &IngestResult{
|
||||
DocumentCount: result.DocumentCount,
|
||||
ChunkCount: result.ChunkCount,
|
||||
DocumentIDs: docIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *runtime) retrieveWithCorpus(
|
||||
ctx context.Context,
|
||||
traceID string,
|
||||
corpusName string,
|
||||
adapter core.CorpusAdapter,
|
||||
req core.RetrieveRequest,
|
||||
) (*RetrieveResult, error) {
|
||||
start := time.Now()
|
||||
if r == nil || r.pipeline == nil || adapter == nil {
|
||||
return nil, core.ErrNilDependency
|
||||
}
|
||||
|
||||
action := normalizeAction(req.Action, "search")
|
||||
req.Action = action
|
||||
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
|
||||
|
||||
timeoutCtx := observeCtx
|
||||
cancel := func() {}
|
||||
if r.cfg.RetrieveTimeoutMS > 0 {
|
||||
timeoutCtx, cancel = context.WithTimeout(observeCtx, time.Duration(r.cfg.RetrieveTimeoutMS)*time.Millisecond)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
result, err := r.pipeline.Retrieve(timeoutCtx, adapter, req)
|
||||
if err != nil {
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelError,
|
||||
Component: "runtime",
|
||||
Operation: "retrieve",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"query_len": len(strings.TrimSpace(req.Query)),
|
||||
"top_k": req.TopK,
|
||||
"threshold": req.Threshold,
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
items := make([]RetrieveHit, 0, len(result.Items))
|
||||
for _, item := range result.Items {
|
||||
items = append(items, RetrieveHit{
|
||||
ChunkID: item.ChunkID,
|
||||
DocumentID: item.DocumentID,
|
||||
Text: item.Text,
|
||||
Score: item.Score,
|
||||
Metadata: cloneMap(item.Metadata),
|
||||
})
|
||||
}
|
||||
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelInfo,
|
||||
Component: "runtime",
|
||||
Operation: "retrieve",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"query_len": len(strings.TrimSpace(req.Query)),
|
||||
"top_k": req.TopK,
|
||||
"threshold": req.Threshold,
|
||||
"raw_count": result.RawCount,
|
||||
"hit_count": len(result.Items),
|
||||
"fallback_used": result.FallbackUsed,
|
||||
"fallback_reason": result.FallbackReason,
|
||||
},
|
||||
})
|
||||
return &RetrieveResult{
|
||||
Items: items,
|
||||
RawCount: result.RawCount,
|
||||
FallbackUsed: result.FallbackUsed,
|
||||
FallbackReason: result.FallbackReason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *runtime) observe(ctx context.Context, event ObserveEvent) {
|
||||
if r == nil || r.observer == nil {
|
||||
return
|
||||
}
|
||||
r.observer.Observe(ctx, event)
|
||||
}
|
||||
|
||||
func (r *runtime) recoverPublicPanic(
|
||||
ctx context.Context,
|
||||
traceID string,
|
||||
corpusName string,
|
||||
action string,
|
||||
operation string,
|
||||
errPtr *error,
|
||||
) {
|
||||
recovered := recover()
|
||||
if recovered == nil || errPtr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 1. runtime 是 RAG service 对业务侧暴露的最终方法面,任何下层 panic 都不应再穿透到业务协程。
|
||||
// 2. 这里统一把 panic 转成 error,并补一条结构化观测,方便继续排查是哪一层依赖失控。
|
||||
// 3. 保留 stack 是为了在“进程不崩”的前提下仍能定位根因,避免只剩一句 recovered 无法复盘。
|
||||
panicErr := fmt.Errorf("rag runtime panic recovered: corpus=%s operation=%s panic=%v", corpusName, operation, recovered)
|
||||
*errPtr = panicErr
|
||||
|
||||
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelError,
|
||||
Component: "runtime",
|
||||
Operation: operation + "_panic_recovered",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"panic": fmt.Sprintf("%v", recovered),
|
||||
"panic_type": fmt.Sprintf("%T", recovered),
|
||||
"error": panicErr,
|
||||
"error_code": core.ClassifyErrorCode(panicErr),
|
||||
"stack": string(debug.Stack()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func newObserveContext(ctx context.Context, traceID string, corpusName string, action string) context.Context {
|
||||
fields := map[string]any{
|
||||
"corpus": corpusName,
|
||||
"action": action,
|
||||
}
|
||||
if traceID = strings.TrimSpace(traceID); traceID != "" {
|
||||
fields["trace_id"] = traceID
|
||||
}
|
||||
return core.WithObserveFields(ctx, fields)
|
||||
}
|
||||
|
||||
func estimateInputCount(input any) int {
|
||||
switch value := input.(type) {
|
||||
case []corpus.MemoryIngestItem:
|
||||
return len(value)
|
||||
case []corpus.WebIngestItem:
|
||||
return len(value)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeAction(action string, fallback string) string {
|
||||
action = strings.TrimSpace(action)
|
||||
if action == "" {
|
||||
return fallback
|
||||
}
|
||||
return action
|
||||
}
|
||||
|
||||
func normalizeTopK(topK int, fallback int) int {
|
||||
if topK > 0 {
|
||||
return topK
|
||||
}
|
||||
if fallback > 0 {
|
||||
return fallback
|
||||
}
|
||||
return 8
|
||||
}
|
||||
|
||||
func normalizeThreshold(threshold float64, fallback float64) float64 {
|
||||
if threshold >= 0 {
|
||||
return threshold
|
||||
}
|
||||
if fallback >= 0 {
|
||||
return fallback
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for key, value := range src {
|
||||
dst[key] = value
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
111
backend/services/rag/service.go
Normal file
111
backend/services/rag/service.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
ragconfig "github.com/LoveLosita/smartflow/backend/services/rag/config"
|
||||
)
|
||||
|
||||
// Options 描述 rag-service 需要持有的底层运行时。
|
||||
type Options struct {
|
||||
Runtime Runtime
|
||||
}
|
||||
|
||||
// Service 是 rag-service 对外暴露的统一入口。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责持有运行时,并把 memory / web 两条能力线统一收口到服务层。
|
||||
// 2. 负责在服务入口内完成基于配置的运行时装配。
|
||||
// 3. 不直接承载 chunk / embed / store 的实现细节,这些细节下沉到服务树内部子包。
|
||||
type Service struct {
|
||||
runtime Runtime
|
||||
}
|
||||
|
||||
// New 使用调用方传入的运行时构造服务。
|
||||
func New(opts Options) *Service {
|
||||
return &Service{runtime: opts.Runtime}
|
||||
}
|
||||
|
||||
// NewFromConfig 基于服务树内的配置与工厂能力构造自给自足的 RAG 服务。
|
||||
func NewFromConfig(ctx context.Context, cfg ragconfig.Config, deps FactoryDeps) (*Service, error) {
|
||||
if !cfg.Enabled {
|
||||
return New(Options{}), nil
|
||||
}
|
||||
runtime, err := NewRuntimeFromConfig(ctx, cfg, deps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewWithRuntime(runtime), nil
|
||||
}
|
||||
|
||||
// Runtime 返回当前服务持有的运行时。
|
||||
func (s *Service) Runtime() Runtime {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.runtime
|
||||
}
|
||||
|
||||
// IngestMemory 写入记忆语料。
|
||||
func (s *Service) IngestMemory(ctx context.Context, req MemoryIngestRequest) (*IngestResult, error) {
|
||||
if s == nil || s.runtime == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.runtime.IngestMemory(ctx, req)
|
||||
}
|
||||
|
||||
// RetrieveMemory 检索记忆语料。
|
||||
func (s *Service) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (*RetrieveResult, error) {
|
||||
if s == nil || s.runtime == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.runtime.RetrieveMemory(ctx, req)
|
||||
}
|
||||
|
||||
// DeleteMemory 删除指定记忆文档。
|
||||
func (s *Service) DeleteMemory(ctx context.Context, documentIDs []string) error {
|
||||
if s == nil || s.runtime == nil {
|
||||
return nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return s.runtime.DeleteMemory(ctx, documentIDs)
|
||||
}
|
||||
|
||||
// IngestWeb 写入网页语料。
|
||||
func (s *Service) IngestWeb(ctx context.Context, req WebIngestRequest) (*IngestResult, error) {
|
||||
if s == nil || s.runtime == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.runtime.IngestWeb(ctx, req)
|
||||
}
|
||||
|
||||
// RetrieveWeb 检索网页语料。
|
||||
func (s *Service) RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (*RetrieveResult, error) {
|
||||
if s == nil || s.runtime == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.runtime.RetrieveWeb(ctx, req)
|
||||
}
|
||||
|
||||
// EnsureRuntime 返回一个可继续向下传递的运行时引用。
|
||||
func (s *Service) EnsureRuntime() Runtime {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.runtime
|
||||
}
|
||||
|
||||
// SetRuntime 允许在装配阶段延迟注入运行时。
|
||||
func (s *Service) SetRuntime(runtime Runtime) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.runtime = runtime
|
||||
}
|
||||
|
||||
// NewWithRuntime 用显式运行时构造服务。
|
||||
func NewWithRuntime(runtime Runtime) *Service {
|
||||
return New(Options{Runtime: runtime})
|
||||
}
|
||||
182
backend/services/rag/store/inmemory_store.go
Normal file
182
backend/services/rag/store/inmemory_store.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
// InMemoryVectorStore 是本地开发用向量存储实现。
|
||||
//
|
||||
// 注意:
|
||||
// 1. 仅用于开发调试,不建议生产使用;
|
||||
// 2. 真实环境可替换为 MilvusStore,接口保持一致。
|
||||
type InMemoryVectorStore struct {
|
||||
mu sync.RWMutex
|
||||
rows map[string]core.VectorRow
|
||||
}
|
||||
|
||||
func NewInMemoryVectorStore() *InMemoryVectorStore {
|
||||
return &InMemoryVectorStore{
|
||||
rows: make(map[string]core.VectorRow),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *InMemoryVectorStore) Upsert(_ context.Context, rows []core.VectorRow) error {
|
||||
if s == nil {
|
||||
return errors.New("inmemory vector store is nil")
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return nil
|
||||
}
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.rows == nil {
|
||||
s.rows = make(map[string]core.VectorRow)
|
||||
}
|
||||
for _, row := range rows {
|
||||
current, exists := s.rows[row.ID]
|
||||
if exists {
|
||||
row.CreatedAt = current.CreatedAt
|
||||
row.UpdatedAt = now
|
||||
} else {
|
||||
if row.CreatedAt.IsZero() {
|
||||
row.CreatedAt = now
|
||||
}
|
||||
row.UpdatedAt = now
|
||||
}
|
||||
s.rows[row.ID] = row
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InMemoryVectorStore) Search(_ context.Context, req core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("inmemory vector store is nil")
|
||||
}
|
||||
topK := req.TopK
|
||||
if topK <= 0 {
|
||||
topK = 8
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
result := make([]core.ScoredVectorRow, 0, len(s.rows))
|
||||
for _, row := range s.rows {
|
||||
if !matchMetadataFilter(row.Metadata, req.Filter) {
|
||||
continue
|
||||
}
|
||||
score := cosineSimilarity(req.QueryVector, row.Vector)
|
||||
result = append(result, core.ScoredVectorRow{
|
||||
Row: row,
|
||||
Score: score,
|
||||
})
|
||||
}
|
||||
sort.SliceStable(result, func(i, j int) bool {
|
||||
return result[i].Score > result[j].Score
|
||||
})
|
||||
if len(result) <= topK {
|
||||
return result, nil
|
||||
}
|
||||
return result[:topK], nil
|
||||
}
|
||||
|
||||
func (s *InMemoryVectorStore) Delete(_ context.Context, ids []string) error {
|
||||
if s == nil {
|
||||
return errors.New("inmemory vector store is nil")
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, id := range ids {
|
||||
delete(s.rows, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InMemoryVectorStore) Get(_ context.Context, ids []string) ([]core.VectorRow, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("inmemory vector store is nil")
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
result := make([]core.VectorRow, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
row, exists := s.rows[id]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
result = append(result, row)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func cosineSimilarity(a, b []float32) float64 {
|
||||
if len(a) == 0 || len(b) == 0 {
|
||||
return 0
|
||||
}
|
||||
n := len(a)
|
||||
if len(b) < n {
|
||||
n = len(b)
|
||||
}
|
||||
if n == 0 {
|
||||
return 0
|
||||
}
|
||||
var dot, normA, normB float64
|
||||
for i := 0; i < n; i++ {
|
||||
av := float64(a[i])
|
||||
bv := float64(b[i])
|
||||
dot += av * bv
|
||||
normA += av * av
|
||||
normB += bv * bv
|
||||
}
|
||||
if normA == 0 || normB == 0 {
|
||||
return 0
|
||||
}
|
||||
return dot / (math.Sqrt(normA) * math.Sqrt(normB))
|
||||
}
|
||||
|
||||
func matchMetadataFilter(metadata map[string]any, filter map[string]any) bool {
|
||||
if len(filter) == 0 {
|
||||
return true
|
||||
}
|
||||
for key, wanted := range filter {
|
||||
got, exists := metadata[key]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
if !equalAny(got, wanted) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func equalAny(left any, right any) bool {
|
||||
return toString(left) == toString(right)
|
||||
}
|
||||
|
||||
func toString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return fmtAny(v)
|
||||
}
|
||||
|
||||
func fmtAny(v any) string {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
927
backend/services/rag/store/milvus_store.go
Normal file
927
backend/services/rag/store/milvus_store.go
Normal file
@@ -0,0 +1,927 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
// MilvusConfig 描述 Milvus REST 存储配置。
|
||||
type MilvusConfig struct {
|
||||
// Address 应指向 Milvus REST 入口。
|
||||
// 当前项目联调验证使用 19530;9091 仅用于 health/metrics,不承载本文实现所走的 REST API。
|
||||
Address string
|
||||
Token string
|
||||
DBName string
|
||||
CollectionName string
|
||||
RequestTimeoutMS int
|
||||
Dimension int
|
||||
MetricType string
|
||||
Logger *log.Logger
|
||||
Observer core.Observer
|
||||
}
|
||||
|
||||
// MilvusStore 是基于 Milvus REST API 的向量存储实现。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 本实现优先保证“项目内可接入、可管理、可灰度”,不强依赖额外 SDK;
|
||||
// 2. 通过固定字段 + metadata JSON 的方式兼顾过滤能力与元数据完整性;
|
||||
// 3. collection 在首次写入时自动创建,避免启动期额外初始化脚本。
|
||||
type MilvusStore struct {
|
||||
cfg MilvusConfig
|
||||
client *http.Client
|
||||
observer core.Observer
|
||||
mu sync.Mutex
|
||||
ensured bool
|
||||
}
|
||||
|
||||
const (
|
||||
milvusPrimaryField = "id"
|
||||
milvusVectorField = "vector"
|
||||
milvusTextField = "text"
|
||||
milvusMetadataField = "metadata"
|
||||
milvusCorpusField = "corpus"
|
||||
milvusDocumentField = "document_id"
|
||||
milvusUserIDField = "user_id"
|
||||
milvusAssistantField = "assistant_id"
|
||||
milvusConvField = "conversation_id"
|
||||
milvusRunField = "run_id"
|
||||
milvusMemoryType = "memory_type"
|
||||
milvusQueryIDField = "query_id"
|
||||
milvusSessionField = "session_id"
|
||||
milvusDomainField = "domain"
|
||||
milvusChunkOrder = "chunk_order"
|
||||
milvusUpdatedAtField = "updated_at"
|
||||
)
|
||||
|
||||
var milvusFilterFieldMap = map[string]string{
|
||||
"corpus": milvusCorpusField,
|
||||
"document_id": milvusDocumentField,
|
||||
"user_id": milvusUserIDField,
|
||||
"assistant_id": milvusAssistantField,
|
||||
"conversation_id": milvusConvField,
|
||||
"run_id": milvusRunField,
|
||||
"memory_type": milvusMemoryType,
|
||||
"query_id": milvusQueryIDField,
|
||||
"session_id": milvusSessionField,
|
||||
"domain": milvusDomainField,
|
||||
"chunk_order": milvusChunkOrder,
|
||||
}
|
||||
|
||||
func NewMilvusStore(cfg MilvusConfig) (*MilvusStore, error) {
|
||||
cfg.Address = strings.TrimRight(strings.TrimSpace(cfg.Address), "/")
|
||||
if cfg.Address == "" {
|
||||
return nil, errors.New("milvus address is empty")
|
||||
}
|
||||
if cfg.CollectionName == "" {
|
||||
cfg.CollectionName = "smartflow_rag_chunks"
|
||||
}
|
||||
if cfg.MetricType == "" {
|
||||
cfg.MetricType = "COSINE"
|
||||
}
|
||||
if cfg.RequestTimeoutMS <= 0 {
|
||||
cfg.RequestTimeoutMS = 1500
|
||||
}
|
||||
if cfg.Logger == nil {
|
||||
cfg.Logger = log.Default()
|
||||
}
|
||||
if cfg.Observer == nil {
|
||||
cfg.Observer = core.NewLoggerObserver(cfg.Logger)
|
||||
}
|
||||
|
||||
return &MilvusStore{
|
||||
cfg: cfg,
|
||||
client: &http.Client{Timeout: time.Duration(cfg.RequestTimeoutMS) * time.Millisecond},
|
||||
observer: cfg.Observer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Upsert(ctx context.Context, rows []core.VectorRow) error {
|
||||
if err := s.ensureReady(); err != nil {
|
||||
return err
|
||||
}
|
||||
start := time.Now()
|
||||
if len(rows) == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := s.ensureCollection(ctx, len(rows[0].Vector)); err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "upsert",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"row_count": len(rows),
|
||||
"vector_dim": len(rows[0].Vector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
data := make([]map[string]any, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
item := mapRowToMilvusEntity(row)
|
||||
data = append(data, item)
|
||||
}
|
||||
|
||||
_, err := s.postJSON(ctx, "/v2/vectordb/entities/upsert", map[string]any{
|
||||
"collectionName": s.cfg.CollectionName,
|
||||
"data": data,
|
||||
"dbName": blankToNil(s.cfg.DBName),
|
||||
})
|
||||
if err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "upsert",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"row_count": len(rows),
|
||||
"vector_dim": len(rows[0].Vector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "upsert",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"row_count": len(rows),
|
||||
"vector_dim": len(rows[0].Vector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Search(ctx context.Context, req core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
|
||||
if err := s.ensureReady(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
start := time.Now()
|
||||
if len(req.QueryVector) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if err := s.ensureCollection(ctx, len(req.QueryVector)); err != nil {
|
||||
if isMilvusCollectionMissing(err) {
|
||||
return nil, nil
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filterExpr, err := buildMilvusFilter(req.Filter)
|
||||
if err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body := map[string]any{
|
||||
"collectionName": s.cfg.CollectionName,
|
||||
"data": [][]float32{req.QueryVector},
|
||||
"annsField": milvusVectorField,
|
||||
"limit": normalizeMilvusTopK(req.TopK),
|
||||
"outputFields": milvusOutputFields(false),
|
||||
}
|
||||
if filterExpr != "" {
|
||||
body["filter"] = filterExpr
|
||||
}
|
||||
if s.cfg.DBName != "" {
|
||||
body["dbName"] = s.cfg.DBName
|
||||
}
|
||||
|
||||
respBody, err := s.postJSON(ctx, "/v2/vectordb/entities/search", body)
|
||||
if err != nil {
|
||||
if isMilvusCollectionMissing(err) {
|
||||
return nil, nil
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp milvusSearchResponse
|
||||
if err = json.Unmarshal(respBody, &resp); err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
if resp.Code != 0 && resp.Code != 200 {
|
||||
err = fmt.Errorf("milvus search failed: code=%d message=%s", resp.Code, resp.Message)
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]core.ScoredVectorRow, 0, len(resp.Data))
|
||||
for _, item := range resp.Data {
|
||||
row, score := item.toVectorRow()
|
||||
result = append(result, core.ScoredVectorRow{
|
||||
Row: row,
|
||||
Score: score,
|
||||
})
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"result_count": len(result),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Delete(ctx context.Context, ids []string) error {
|
||||
if err := s.ensureReady(); err != nil {
|
||||
return err
|
||||
}
|
||||
start := time.Now()
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
filter := fmt.Sprintf(`%s in [%s]`, milvusPrimaryField, joinQuotedStrings(ids))
|
||||
_, err := s.postJSON(ctx, "/v2/vectordb/entities/delete", map[string]any{
|
||||
"collectionName": s.cfg.CollectionName,
|
||||
"filter": filter,
|
||||
"dbName": blankToNil(s.cfg.DBName),
|
||||
})
|
||||
if isMilvusCollectionMissing(err) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "delete",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"id_count": len(ids),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "delete",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"id_count": len(ids),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Get(ctx context.Context, ids []string) ([]core.VectorRow, error) {
|
||||
if err := s.ensureReady(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
start := time.Now()
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
respBody, err := s.postJSON(ctx, "/v2/vectordb/entities/get", map[string]any{
|
||||
"collectionName": s.cfg.CollectionName,
|
||||
"id": ids,
|
||||
"outputFields": milvusOutputFields(true),
|
||||
"dbName": blankToNil(s.cfg.DBName),
|
||||
})
|
||||
if err != nil {
|
||||
if isMilvusCollectionMissing(err) {
|
||||
return nil, nil
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "get",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"id_count": len(ids),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp milvusGetResponse
|
||||
if err = json.Unmarshal(respBody, &resp); err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "get",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"id_count": len(ids),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
if resp.Code != 0 && resp.Code != 200 {
|
||||
err = fmt.Errorf("milvus get failed: code=%d message=%s", resp.Code, resp.Message)
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "get",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"id_count": len(ids),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows := make([]core.VectorRow, 0, len(resp.Data))
|
||||
for _, item := range resp.Data {
|
||||
rows = append(rows, mapMilvusRow(item, true))
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "get",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"id_count": len(ids),
|
||||
"row_count": len(rows),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) ensureCollection(ctx context.Context, dimension int) error {
|
||||
if err := s.ensureReady(); err != nil {
|
||||
return err
|
||||
}
|
||||
start := time.Now()
|
||||
if dimension <= 0 {
|
||||
dimension = s.cfg.Dimension
|
||||
}
|
||||
if dimension <= 0 {
|
||||
return errors.New("milvus vector dimension is invalid")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.ensured {
|
||||
return nil
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"collectionName": s.cfg.CollectionName,
|
||||
"schema": map[string]any{
|
||||
"autoId": false,
|
||||
"enabledDynamicField": false,
|
||||
"fields": []map[string]any{
|
||||
buildVarcharField(milvusPrimaryField, true, 256),
|
||||
buildVectorField(milvusVectorField, dimension),
|
||||
buildVarcharField(milvusTextField, false, 65535),
|
||||
{"fieldName": milvusMetadataField, "dataType": "JSON"},
|
||||
buildVarcharField(milvusCorpusField, false, 64),
|
||||
buildVarcharField(milvusDocumentField, false, 256),
|
||||
{"fieldName": milvusUserIDField, "dataType": "Int64"},
|
||||
buildVarcharField(milvusAssistantField, false, 128),
|
||||
buildVarcharField(milvusConvField, false, 128),
|
||||
buildVarcharField(milvusRunField, false, 128),
|
||||
buildVarcharField(milvusMemoryType, false, 64),
|
||||
buildVarcharField(milvusQueryIDField, false, 128),
|
||||
buildVarcharField(milvusSessionField, false, 128),
|
||||
buildVarcharField(milvusDomainField, false, 128),
|
||||
{"fieldName": milvusChunkOrder, "dataType": "Int64"},
|
||||
{"fieldName": milvusUpdatedAtField, "dataType": "Int64"},
|
||||
},
|
||||
},
|
||||
"indexParams": []map[string]any{
|
||||
{
|
||||
"fieldName": milvusVectorField,
|
||||
"indexName": milvusVectorField,
|
||||
"metricType": s.cfg.MetricType,
|
||||
"indexType": "AUTOINDEX",
|
||||
},
|
||||
},
|
||||
}
|
||||
if s.cfg.DBName != "" {
|
||||
payload["dbName"] = s.cfg.DBName
|
||||
}
|
||||
|
||||
_, err := s.postJSON(ctx, "/v2/vectordb/collections/create", payload)
|
||||
if err != nil {
|
||||
if isMilvusAlreadyExists(err) {
|
||||
s.ensured = true
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "ensure_collection",
|
||||
Fields: map[string]any{
|
||||
"status": "already_exists",
|
||||
"vector_dim": dimension,
|
||||
"metric_type": s.cfg.MetricType,
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "ensure_collection",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"vector_dim": dimension,
|
||||
"metric_type": s.cfg.MetricType,
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
s.ensured = true
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "ensure_collection",
|
||||
Fields: map[string]any{
|
||||
"status": "created",
|
||||
"vector_dim": dimension,
|
||||
"metric_type": s.cfg.MetricType,
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) postJSON(ctx context.Context, path string, payload map[string]any) ([]byte, error) {
|
||||
if err := s.ensureReady(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.cfg.Address+path, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if token := strings.TrimSpace(s.cfg.Token); token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, readErr
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("milvus http failed: status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var basic milvusBasicResponse
|
||||
if jsonErr := json.Unmarshal(respBody, &basic); jsonErr == nil {
|
||||
if basic.Code != 0 && basic.Code != 200 {
|
||||
return nil, fmt.Errorf("milvus api failed: code=%d message=%s", basic.Code, basic.Message)
|
||||
}
|
||||
}
|
||||
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) ensureReady() error {
|
||||
if s == nil || s.client == nil {
|
||||
return errors.New("milvus store is not initialized")
|
||||
}
|
||||
if strings.TrimSpace(s.cfg.Address) == "" {
|
||||
return errors.New("milvus address is empty")
|
||||
}
|
||||
if strings.TrimSpace(s.cfg.CollectionName) == "" {
|
||||
return errors.New("milvus collection name is empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) observe(ctx context.Context, event core.ObserveEvent) {
|
||||
if s == nil || s.observer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
fields := cloneMap(event.Fields)
|
||||
fields["store"] = "milvus"
|
||||
fields["collection"] = s.cfg.CollectionName
|
||||
if strings.TrimSpace(s.cfg.DBName) != "" {
|
||||
fields["db_name"] = s.cfg.DBName
|
||||
}
|
||||
|
||||
s.observer.Observe(ctx, core.ObserveEvent{
|
||||
Level: event.Level,
|
||||
Component: event.Component,
|
||||
Operation: event.Operation,
|
||||
Fields: fields,
|
||||
})
|
||||
}
|
||||
|
||||
func mapRowToMilvusEntity(row core.VectorRow) map[string]any {
|
||||
metadata := cloneMap(row.Metadata)
|
||||
entity := map[string]any{
|
||||
milvusPrimaryField: row.ID,
|
||||
milvusVectorField: row.Vector,
|
||||
milvusTextField: row.Text,
|
||||
milvusMetadataField: metadata,
|
||||
milvusCorpusField: asString(metadata["corpus"]),
|
||||
milvusDocumentField: asString(metadata["document_id"]),
|
||||
milvusUpdatedAtField: func() int64 {
|
||||
if row.UpdatedAt.IsZero() {
|
||||
return time.Now().UnixMilli()
|
||||
}
|
||||
return row.UpdatedAt.UnixMilli()
|
||||
}(),
|
||||
}
|
||||
assignMilvusScalar(entity, milvusUserIDField, metadata["user_id"])
|
||||
assignMilvusScalar(entity, milvusAssistantField, metadata["assistant_id"])
|
||||
assignMilvusScalar(entity, milvusConvField, metadata["conversation_id"])
|
||||
assignMilvusScalar(entity, milvusRunField, metadata["run_id"])
|
||||
assignMilvusScalar(entity, milvusMemoryType, metadata["memory_type"])
|
||||
assignMilvusScalar(entity, milvusQueryIDField, metadata["query_id"])
|
||||
assignMilvusScalar(entity, milvusSessionField, metadata["session_id"])
|
||||
assignMilvusScalar(entity, milvusDomainField, metadata["domain"])
|
||||
assignMilvusScalar(entity, milvusChunkOrder, metadata["chunk_order"])
|
||||
return entity
|
||||
}
|
||||
|
||||
func assignMilvusScalar(target map[string]any, field string, value any) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
switch field {
|
||||
case milvusUserIDField, milvusChunkOrder:
|
||||
if parsed, ok := toInt64(value); ok {
|
||||
target[field] = parsed
|
||||
}
|
||||
default:
|
||||
if text := asString(value); text != "" {
|
||||
target[field] = text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildMilvusFilter(filter map[string]any) (string, error) {
|
||||
if len(filter) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(filter))
|
||||
for key, value := range filter {
|
||||
field, ok := milvusFilterFieldMap[key]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unsupported milvus filter key: %s", key)
|
||||
}
|
||||
switch field {
|
||||
case milvusUserIDField, milvusChunkOrder:
|
||||
parsed, parseOK := toInt64(value)
|
||||
if !parseOK {
|
||||
return "", fmt.Errorf("milvus filter key=%s expects integer", key)
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("%s == %d", field, parsed))
|
||||
default:
|
||||
text := escapeMilvusString(asString(value))
|
||||
parts = append(parts, fmt.Sprintf(`%s == "%s"`, field, text))
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " and "), nil
|
||||
}
|
||||
|
||||
func buildVarcharField(name string, isPrimary bool, maxLength int) map[string]any {
|
||||
field := map[string]any{
|
||||
"fieldName": name,
|
||||
"dataType": "VarChar",
|
||||
"elementTypeParams": map[string]any{"max_length": maxLength},
|
||||
}
|
||||
if isPrimary {
|
||||
field["isPrimary"] = true
|
||||
}
|
||||
return field
|
||||
}
|
||||
|
||||
func buildVectorField(name string, dimension int) map[string]any {
|
||||
return map[string]any{
|
||||
"fieldName": name,
|
||||
"dataType": "FloatVector",
|
||||
"elementTypeParams": map[string]any{"dim": dimension},
|
||||
}
|
||||
}
|
||||
|
||||
func milvusOutputFields(includeVector bool) []string {
|
||||
fields := []string{
|
||||
milvusTextField,
|
||||
milvusMetadataField,
|
||||
milvusCorpusField,
|
||||
milvusDocumentField,
|
||||
milvusUserIDField,
|
||||
milvusAssistantField,
|
||||
milvusConvField,
|
||||
milvusRunField,
|
||||
milvusMemoryType,
|
||||
milvusQueryIDField,
|
||||
milvusSessionField,
|
||||
milvusDomainField,
|
||||
milvusChunkOrder,
|
||||
milvusUpdatedAtField,
|
||||
}
|
||||
if includeVector {
|
||||
fields = append(fields, milvusVectorField)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func normalizeMilvusTopK(topK int) int {
|
||||
if topK <= 0 {
|
||||
return 8
|
||||
}
|
||||
return topK
|
||||
}
|
||||
|
||||
func blankToNil(v string) any {
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func escapeMilvusString(v string) string {
|
||||
v = strings.ReplaceAll(v, `\`, `\\`)
|
||||
return strings.ReplaceAll(v, `"`, `\"`)
|
||||
}
|
||||
|
||||
func joinQuotedStrings(values []string) string {
|
||||
parts := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
parts = append(parts, fmt.Sprintf(`"%s"`, escapeMilvusString(value)))
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for key, value := range src {
|
||||
dst[key] = value
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
|
||||
func toInt64(v any) (int64, bool) {
|
||||
switch value := v.(type) {
|
||||
case int:
|
||||
return int64(value), true
|
||||
case int32:
|
||||
return int64(value), true
|
||||
case int64:
|
||||
return value, true
|
||||
case float64:
|
||||
return int64(value), true
|
||||
case json.Number:
|
||||
parsed, err := value.Int64()
|
||||
return parsed, err == nil
|
||||
case string:
|
||||
parsed, err := strconv.ParseInt(strings.TrimSpace(value), 10, 64)
|
||||
return parsed, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func isMilvusAlreadyExists(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
text := strings.ToLower(err.Error())
|
||||
return strings.Contains(text, "already exist") ||
|
||||
strings.Contains(text, "already exists") ||
|
||||
strings.Contains(text, "duplicate collection")
|
||||
}
|
||||
|
||||
func isMilvusCollectionMissing(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
text := strings.ToLower(err.Error())
|
||||
return strings.Contains(text, "can't find collection") || strings.Contains(text, "collection not found")
|
||||
}
|
||||
|
||||
type milvusBasicResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type milvusSearchResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data []milvusSearchItem `json:"data"`
|
||||
}
|
||||
|
||||
type milvusSearchItem map[string]any
|
||||
|
||||
func (m milvusSearchItem) toVectorRow() (core.VectorRow, float64) {
|
||||
row := mapMilvusRow(map[string]any(m), false)
|
||||
score := 0.0
|
||||
if value, ok := m["distance"].(float64); ok {
|
||||
score = value
|
||||
}
|
||||
return row, score
|
||||
}
|
||||
|
||||
type milvusGetResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data []map[string]any `json:"data"`
|
||||
}
|
||||
|
||||
func mapMilvusRow(raw map[string]any, includeVector bool) core.VectorRow {
|
||||
metadata := cloneMap(readMetadataMap(raw[milvusMetadataField]))
|
||||
assignMetadataIfPresent(metadata, "corpus", raw[milvusCorpusField])
|
||||
assignMetadataIfPresent(metadata, "document_id", raw[milvusDocumentField])
|
||||
assignMetadataIfPresent(metadata, "user_id", raw[milvusUserIDField])
|
||||
assignMetadataIfPresent(metadata, "assistant_id", raw[milvusAssistantField])
|
||||
assignMetadataIfPresent(metadata, "conversation_id", raw[milvusConvField])
|
||||
assignMetadataIfPresent(metadata, "run_id", raw[milvusRunField])
|
||||
assignMetadataIfPresent(metadata, "memory_type", raw[milvusMemoryType])
|
||||
assignMetadataIfPresent(metadata, "query_id", raw[milvusQueryIDField])
|
||||
assignMetadataIfPresent(metadata, "session_id", raw[milvusSessionField])
|
||||
assignMetadataIfPresent(metadata, "domain", raw[milvusDomainField])
|
||||
assignMetadataIfPresent(metadata, "chunk_order", raw[milvusChunkOrder])
|
||||
|
||||
row := core.VectorRow{
|
||||
ID: asString(raw[milvusPrimaryField]),
|
||||
Text: asString(raw[milvusTextField]),
|
||||
Metadata: metadata,
|
||||
}
|
||||
if row.ID == "" {
|
||||
row.ID = asString(raw["id"])
|
||||
}
|
||||
if includeVector {
|
||||
row.Vector = readFloat32Vector(raw[milvusVectorField])
|
||||
}
|
||||
return row
|
||||
}
|
||||
|
||||
func readMetadataMap(value any) map[string]any {
|
||||
switch data := value.(type) {
|
||||
case map[string]any:
|
||||
return data
|
||||
default:
|
||||
return map[string]any{}
|
||||
}
|
||||
}
|
||||
|
||||
func readFloat32Vector(value any) []float32 {
|
||||
switch vector := value.(type) {
|
||||
case []float32:
|
||||
return vector
|
||||
case []any:
|
||||
result := make([]float32, 0, len(vector))
|
||||
for _, item := range vector {
|
||||
switch number := item.(type) {
|
||||
case float64:
|
||||
result = append(result, float32(number))
|
||||
case float32:
|
||||
result = append(result, number)
|
||||
}
|
||||
}
|
||||
return result
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func assignMetadataIfPresent(target map[string]any, key string, value any) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(typed) == "" {
|
||||
return
|
||||
}
|
||||
target[key] = strings.TrimSpace(typed)
|
||||
default:
|
||||
target[key] = typed
|
||||
}
|
||||
}
|
||||
9
backend/services/rag/store/vector_store.go
Normal file
9
backend/services/rag/store/vector_store.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package store
|
||||
|
||||
import "github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
|
||||
// EnsureCompile 用于静态校验实现是否满足接口。
|
||||
func EnsureCompile() {
|
||||
var _ core.VectorStore = (*InMemoryVectorStore)(nil)
|
||||
var _ core.VectorStore = (*MilvusStore)(nil)
|
||||
}
|
||||
Reference in New Issue
Block a user