Version: 0.9.65.dev.260503
后端: 1. 阶段 1.5/1.6 收口 llm-service / rag-service,统一模型出口与检索基础设施入口,清退 backend/infra/llm 与 backend/infra/rag 旧实现; 2. 同步更新相关调用链与微服务迁移计划文档
This commit is contained in:
98
backend/services/rag/retrieve/vector_retriever.go
Normal file
98
backend/services/rag/retrieve/vector_retriever.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package retrieve
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
// VectorRetriever 是通用检索器(embed + vector search)。
|
||||
type VectorRetriever struct {
|
||||
embedder core.Embedder
|
||||
store core.VectorStore
|
||||
}
|
||||
|
||||
func NewVectorRetriever(embedder core.Embedder, store core.VectorStore) *VectorRetriever {
|
||||
return &VectorRetriever{
|
||||
embedder: embedder,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *VectorRetriever) Retrieve(ctx context.Context, req core.RetrieveRequest) (result []core.ScoredChunk, err error) {
|
||||
defer func() {
|
||||
recovered := recover()
|
||||
if recovered == nil {
|
||||
return
|
||||
}
|
||||
err = fmt.Errorf("vector retriever panic recovered: %v", recovered)
|
||||
}()
|
||||
|
||||
if r == nil || r.embedder == nil || r.store == nil {
|
||||
return nil, core.ErrNilDependency
|
||||
}
|
||||
query := strings.TrimSpace(req.Query)
|
||||
if query == "" {
|
||||
return nil, core.ErrInvalidQuery
|
||||
}
|
||||
topK := req.TopK
|
||||
if topK <= 0 {
|
||||
topK = 8
|
||||
}
|
||||
action := strings.TrimSpace(req.Action)
|
||||
if action == "" {
|
||||
action = "search"
|
||||
}
|
||||
|
||||
vectors, err := r.embedder.Embed(ctx, []string{query}, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vectors) != 1 {
|
||||
return nil, fmt.Errorf("embedding query length mismatch: %d", len(vectors))
|
||||
}
|
||||
|
||||
rows, err := r.store.Search(ctx, core.VectorSearchRequest{
|
||||
QueryVector: vectors[0],
|
||||
TopK: topK,
|
||||
Filter: req.Filter,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = make([]core.ScoredChunk, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
if row.Score < req.Threshold {
|
||||
continue
|
||||
}
|
||||
result = append(result, core.ScoredChunk{
|
||||
ChunkID: row.Row.ID,
|
||||
DocumentID: asString(row.Row.Metadata["document_id"]),
|
||||
Text: row.Row.Text,
|
||||
Score: row.Score,
|
||||
Metadata: cloneMap(row.Row.Metadata),
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
Reference in New Issue
Block a user