Version: 0.9.13.dev.260410
后端: 1. Memory Day1 链路打通(chat_history -> outbox -> memory_jobs) - 更新 service/events/chat_history_persist.go:聊天消息落库同事务追加 memory.extract.requested 事件(仅 user 消息,失败回滚后由 outbox 重试) - 新建 service/events/memory_extract_requested.go:消费 memory.extract.requested 并幂等入队 memory_jobs,补齐 payload 校验、文本截断与 idempotency key - 更新 cmd/start.go:注册 RegisterMemoryExtractRequestedHandler 2. Memory 模块骨架落地(先跑通状态机,再接入真实抽取) - 新建 memory/model、repo、service、orchestrator、worker、utils 目录与 Day1 mock 抽取执行链 - 新建 model/memory.go:补齐 memory_items / memory_jobs / memory_audit_logs / memory_user_settings 与事件 payload 模型 - 更新 inits/mysql.go:接入 4 张 memory 相关表 AutoMigrate 3. RAG 复用基础设施预埋(依赖可替换) - 新建 infra/rag:core pipeline + chunk/embed/retrieve/rerank/store/corpus/config 分层实现 - 默认接入 MockEmbedder + InMemoryStore,预留 Milvus / Eino 适配实现 - 新增 infra/rag/RAG复用接口实施计划.md 4. 本地依赖与交接文档同步 - 更新 docker-compose.yml:新增 etcd / minio / milvus / attu 服务与数据卷 - 删除 newAgent/HANDOFF_工具研究与运行态重置.md、newAgent/阶段3_上下文瘦身设计.md - 新增 newAgent/HANDOFF_WebSearch两阶段实施计划.md、memory/HANDOFF-RAG复用后续实施计划.md、memory/README.md 前端:无 仓库:无
This commit is contained in:
191
backend/infra/rag/RAG复用接口实施计划.md
Normal file
191
backend/infra/rag/RAG复用接口实施计划.md
Normal file
@@ -0,0 +1,191 @@
|
||||
# RAG 复用接口实施计划(Memory + WebSearch 统一底座)
|
||||
|
||||
## 1. 目标与原则
|
||||
|
||||
1. 在 `backend/infra/rag` 抽离共享 RAG Core,统一 `chunk/embed/retrieve/rerank` 能力。
|
||||
2. 先接入 `MemoryCorpus` 与 `WebCorpus` 两个适配器,避免后续重复造轮子。
|
||||
3. 保持“并行迁移”策略:新老链路并存,先接入、再灰度、再切流、最后删除旧实现。
|
||||
4. 不阻塞现有主链路;任何 RAG 子能力失败都必须可降级。
|
||||
|
||||
## 2. 本轮范围与非目标
|
||||
|
||||
### 2.1 本轮范围
|
||||
|
||||
1. 定义 RAG Core 接口、标准数据结构、错误码和回退语义。
|
||||
2. 提供 `MemoryCorpus` 与 `WebCorpus` 适配层设计。
|
||||
3. 给出分阶段落地步骤、验收标准、风险控制。
|
||||
|
||||
### 2.2 本轮非目标
|
||||
|
||||
1. 不在本轮实现完整生产级向量检索细节(Milvus 连接器可先占位)。
|
||||
2. 不在本轮统一改造所有调用方,只做首批接入点。
|
||||
3. 不在本轮引入多 Provider 工厂(先保证单 Provider 可替换)。
|
||||
|
||||
## 3. 目录与模块规划
|
||||
|
||||
建议目录(先建骨架,逐轮填实):
|
||||
|
||||
```text
|
||||
backend/infra/rag/
|
||||
core/
|
||||
types.go
|
||||
interfaces.go
|
||||
pipeline.go
|
||||
errors.go
|
||||
chunk/
|
||||
text_chunker.go
|
||||
embed/
|
||||
eino_embedder.go
|
||||
retrieve/
|
||||
vector_retriever.go
|
||||
rerank/
|
||||
eino_reranker.go
|
||||
store/
|
||||
vector_store.go
|
||||
milvus_store.go
|
||||
corpus/
|
||||
memory_corpus.go
|
||||
web_corpus.go
|
||||
config/
|
||||
config.go
|
||||
```
|
||||
|
||||
## 4. 核心接口设计(建议签名)
|
||||
|
||||
```go
|
||||
type Chunker interface {
|
||||
Chunk(ctx context.Context, doc SourceDocument, opt ChunkOption) ([]Chunk, error)
|
||||
}
|
||||
|
||||
type Embedder interface {
|
||||
Embed(ctx context.Context, texts []string, action string) ([][]float32, error)
|
||||
}
|
||||
|
||||
type Retriever interface {
|
||||
Retrieve(ctx context.Context, req RetrieveRequest) ([]ScoredChunk, error)
|
||||
}
|
||||
|
||||
type Reranker interface {
|
||||
Rerank(ctx context.Context, query string, candidates []ScoredChunk, topK int) ([]ScoredChunk, error)
|
||||
}
|
||||
|
||||
type VectorStore interface {
|
||||
Upsert(ctx context.Context, rows []VectorRow) error
|
||||
Search(ctx context.Context, req VectorSearchRequest) ([]ScoredVectorRow, error)
|
||||
Delete(ctx context.Context, ids []string) error
|
||||
Get(ctx context.Context, ids []string) ([]VectorRow, error)
|
||||
}
|
||||
|
||||
type CorpusAdapter interface {
|
||||
Name() string
|
||||
BuildIngestDocuments(ctx context.Context, input any) ([]SourceDocument, error)
|
||||
BuildRetrieveFilter(ctx context.Context, req any) (map[string]any, error)
|
||||
}
|
||||
```
|
||||
|
||||
## 5. 统一流程约定
|
||||
|
||||
### 5.1 Ingest 流程
|
||||
|
||||
1. `CorpusAdapter.BuildIngestDocuments` 生成标准文档。
|
||||
2. `Chunker.Chunk` 切块(固定 chunk_size + overlap)。
|
||||
3. `Embedder.Embed(action=add/update)` 生成向量。
|
||||
4. `VectorStore.Upsert` 写入。
|
||||
5. 任一步失败按“可补偿”记录状态,不影响主业务成功返回。
|
||||
|
||||
### 5.2 Retrieve 流程
|
||||
|
||||
1. `CorpusAdapter.BuildRetrieveFilter` 构建过滤条件。
|
||||
2. `Embedder.Embed(action=search)` 向量化 query。
|
||||
3. `VectorStore.Search` 召回候选。
|
||||
4. `threshold` 过滤。
|
||||
5. 可选 `Reranker` 重排;失败则 fallback 到原排序并记录原因码。
|
||||
|
||||
## 6. 两类 Corpus 适配器设计
|
||||
|
||||
### 6.1 MemoryCorpus
|
||||
|
||||
1. 数据源:`memory_items`(结构化记忆事实)。
|
||||
2. 强约束过滤:`user_id + assistant_id + conversation_id`。
|
||||
3. 元数据:`memory_type/confidence/sensitivity_level/ttl_at/source_event_id`。
|
||||
4. 注入优先级:`constraint/preference` 高于 `fact/todo_hint`。
|
||||
|
||||
### 6.2 WebCorpus
|
||||
|
||||
1. 数据源:websearch 抓取结果(`url/title/snippet/content`)。
|
||||
2. 强约束过滤:`query_id/session_id`,避免跨问题污染。
|
||||
3. 元数据:`domain/published_at/fetched_at/language/source_rank`。
|
||||
4. 检索策略:先向量召回,再结合域名可信度做轻量加权。
|
||||
|
||||
## 7. 与 Eino 的集成方式
|
||||
|
||||
1. `embed/eino_embedder.go`:封装 Eino embedding 调用。
|
||||
2. `rerank/eino_reranker.go`:封装 Eino 重排调用。
|
||||
3. 统一配置入口:`rag.enabled/top_k/threshold/reranker_enabled/timeout`。
|
||||
4. 统一日志字段:`trace_id/corpus/action/fallback_reason/latency_ms/hit_count`。
|
||||
|
||||
## 8. 分阶段实施(建议 4 轮)
|
||||
|
||||
### Round 1:基础骨架(不切流)
|
||||
|
||||
1. 建 `infra/rag` 目录与接口、类型、错误码。
|
||||
2. 提供 `NoopReranker`、`MockEmbedder` 兜底实现。
|
||||
3. 验收:编译通过,主链路行为不变。
|
||||
|
||||
### Round 2:MemoryCorpus 接入(灰度)
|
||||
|
||||
1. 把记忆检索从“模块内直连”改为调用 RAG Core。
|
||||
2. 保留旧路径开关 `memory.rag.enabled`,默认关闭。
|
||||
3. 验收:开启开关后功能等价,失败可自动降级旧链路。
|
||||
|
||||
### Round 3:WebCorpus 接入(灰度)
|
||||
|
||||
1. websearch 召回改走 RAG Core。
|
||||
2. 加入 `web.rag.enabled` 灰度开关。
|
||||
3. 验收:检索可复用同一 pipeline,质量不低于旧实现。
|
||||
|
||||
### Round 4:统一切流与清理
|
||||
|
||||
1. 默认开启 RAG Core,旧链路保留一段观察窗口。
|
||||
2. 指标稳定后删除旧实现。
|
||||
3. 验收:两条业务链路均通过统一接口,文档与监控齐全。
|
||||
|
||||
## 9. 配置建议
|
||||
|
||||
```yaml
|
||||
rag:
|
||||
enabled: true
|
||||
topK: 8
|
||||
threshold: 0.55
|
||||
reranker:
|
||||
enabled: true
|
||||
timeoutMs: 1200
|
||||
ingest:
|
||||
chunkSize: 400
|
||||
chunkOverlap: 80
|
||||
retrieve:
|
||||
timeoutMs: 1500
|
||||
```
|
||||
|
||||
## 10. 验收标准(DoD)
|
||||
|
||||
1. 同一套 Core 能同时服务 Memory 与 WebSearch。
|
||||
2. `rerank` 异常时可观测地降级,不影响主功能可用性。
|
||||
3. 支持按 corpus 维度查看命中率、耗时、降级率。
|
||||
4. 新老链路可开关切换,回滚路径明确。
|
||||
|
||||
## 11. 风险与应对
|
||||
|
||||
1. 风险:一次性切流影响面大。
|
||||
应对:按 corpus 分轮灰度,先 Memory 后 Web。
|
||||
2. 风险:向量检索延迟波动。
|
||||
应对:超时控制 + fallback + 本地缓存热点 query。
|
||||
3. 风险:跨域检索串数据。
|
||||
应对:强制 filter 校验,不满足维度直接拒绝检索。
|
||||
|
||||
## 12. 下一步执行清单(紧接实现)
|
||||
|
||||
1. 先补 `core/interfaces.go + core/types.go + core/pipeline.go`。
|
||||
2. 再补 `corpus/memory_corpus.go`(首个适配器)。
|
||||
3. 然后给 websearch 接 `corpus/web_corpus.go` 占位适配器。
|
||||
4. 最后补 `store/milvus_store.go` 与配置接线(当前 docker compose 已准备 Milvus 依赖)。
|
||||
85
backend/infra/rag/chunk/text_chunker.go
Normal file
85
backend/infra/rag/chunk/text_chunker.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package chunk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
|
||||
)
|
||||
|
||||
// TextChunker 是默认文本切块器。
|
||||
type TextChunker struct{}
|
||||
|
||||
func NewTextChunker() *TextChunker {
|
||||
return &TextChunker{}
|
||||
}
|
||||
|
||||
// Chunk 对文本执行固定窗口切块。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先做空白归一,避免无效块进入向量库;
|
||||
// 2. 再按 chunk_size/overlap 滑窗切割;
|
||||
// 3. 每块继承原文 metadata,并补充 chunk 序号。
|
||||
func (c *TextChunker) Chunk(_ context.Context, doc core.SourceDocument, opt core.ChunkOption) ([]core.Chunk, error) {
|
||||
if strings.TrimSpace(doc.ID) == "" {
|
||||
return nil, fmt.Errorf("empty document id")
|
||||
}
|
||||
text := strings.TrimSpace(doc.Text)
|
||||
if text == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if opt.ChunkSize <= 0 {
|
||||
opt.ChunkSize = 400
|
||||
}
|
||||
if opt.ChunkOverlap < 0 {
|
||||
opt.ChunkOverlap = 0
|
||||
}
|
||||
if opt.ChunkOverlap >= opt.ChunkSize {
|
||||
opt.ChunkOverlap = opt.ChunkSize / 5
|
||||
}
|
||||
|
||||
runes := []rune(text)
|
||||
step := opt.ChunkSize - opt.ChunkOverlap
|
||||
if step <= 0 {
|
||||
step = opt.ChunkSize
|
||||
}
|
||||
|
||||
result := make([]core.Chunk, 0, len(runes)/step+1)
|
||||
order := 0
|
||||
for start := 0; start < len(runes); start += step {
|
||||
end := start + opt.ChunkSize
|
||||
if end > len(runes) {
|
||||
end = len(runes)
|
||||
}
|
||||
chunkText := strings.TrimSpace(string(runes[start:end]))
|
||||
if chunkText == "" {
|
||||
continue
|
||||
}
|
||||
metadata := cloneMap(doc.Metadata)
|
||||
metadata["chunk_order"] = order
|
||||
result = append(result, core.Chunk{
|
||||
ID: fmt.Sprintf("%s#%d", doc.ID, order),
|
||||
DocumentID: doc.ID,
|
||||
Text: chunkText,
|
||||
Order: order,
|
||||
Metadata: metadata,
|
||||
})
|
||||
order++
|
||||
if end == len(runes) {
|
||||
break
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
52
backend/infra/rag/config/config.go
Normal file
52
backend/infra/rag/config/config.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package config
|
||||
|
||||
import "github.com/spf13/viper"
|
||||
|
||||
// Config 是 RAG Core 运行配置。
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
TopK int
|
||||
|
||||
Threshold float64
|
||||
|
||||
RerankerEnabled bool
|
||||
RerankerTimeoutMS int
|
||||
|
||||
ChunkSize int
|
||||
ChunkOverlap int
|
||||
|
||||
RetrieveTimeoutMS int
|
||||
}
|
||||
|
||||
// LoadFromViper 读取 rag 配置并补默认值。
|
||||
func LoadFromViper() Config {
|
||||
cfg := Config{
|
||||
Enabled: viper.GetBool("rag.enabled"),
|
||||
TopK: viper.GetInt("rag.topK"),
|
||||
Threshold: viper.GetFloat64("rag.threshold"),
|
||||
RerankerEnabled: viper.GetBool("rag.reranker.enabled"),
|
||||
RerankerTimeoutMS: viper.GetInt("rag.reranker.timeoutMs"),
|
||||
ChunkSize: viper.GetInt("rag.ingest.chunkSize"),
|
||||
ChunkOverlap: viper.GetInt("rag.ingest.chunkOverlap"),
|
||||
RetrieveTimeoutMS: viper.GetInt("rag.retrieve.timeoutMs"),
|
||||
}
|
||||
if cfg.TopK <= 0 {
|
||||
cfg.TopK = 8
|
||||
}
|
||||
if cfg.Threshold < 0 {
|
||||
cfg.Threshold = 0
|
||||
}
|
||||
if cfg.RerankerTimeoutMS <= 0 {
|
||||
cfg.RerankerTimeoutMS = 1200
|
||||
}
|
||||
if cfg.ChunkSize <= 0 {
|
||||
cfg.ChunkSize = 400
|
||||
}
|
||||
if cfg.ChunkOverlap < 0 {
|
||||
cfg.ChunkOverlap = 80
|
||||
}
|
||||
if cfg.RetrieveTimeoutMS <= 0 {
|
||||
cfg.RetrieveTimeoutMS = 1500
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
17
backend/infra/rag/core/errors.go
Normal file
17
backend/infra/rag/core/errors.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package core
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrInvalidQuery 表示检索请求缺少有效 query。
|
||||
ErrInvalidQuery = errors.New("invalid query")
|
||||
// ErrInvalidTopK 表示 topK 非法。
|
||||
ErrInvalidTopK = errors.New("invalid top_k")
|
||||
// ErrNilDependency 表示 pipeline 关键依赖未注入。
|
||||
ErrNilDependency = errors.New("nil dependency")
|
||||
)
|
||||
|
||||
const (
|
||||
// FallbackReasonRerankFailed 表示 rerank 失败后降级。
|
||||
FallbackReasonRerankFailed = "RERANK_FAILED"
|
||||
)
|
||||
38
backend/infra/rag/core/interfaces.go
Normal file
38
backend/infra/rag/core/interfaces.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package core
|
||||
|
||||
import "context"
|
||||
|
||||
// Chunker 负责文本切块。
|
||||
type Chunker interface {
|
||||
Chunk(ctx context.Context, doc SourceDocument, opt ChunkOption) ([]Chunk, error)
|
||||
}
|
||||
|
||||
// Embedder 负责向量化。
|
||||
type Embedder interface {
|
||||
Embed(ctx context.Context, texts []string, action string) ([][]float32, error)
|
||||
}
|
||||
|
||||
// Retriever 负责召回候选。
|
||||
type Retriever interface {
|
||||
Retrieve(ctx context.Context, req RetrieveRequest) ([]ScoredChunk, error)
|
||||
}
|
||||
|
||||
// Reranker 负责重排候选。
|
||||
type Reranker interface {
|
||||
Rerank(ctx context.Context, query string, candidates []ScoredChunk, topK int) ([]ScoredChunk, error)
|
||||
}
|
||||
|
||||
// VectorStore 负责向量库读写。
|
||||
type VectorStore interface {
|
||||
Upsert(ctx context.Context, rows []VectorRow) error
|
||||
Search(ctx context.Context, req VectorSearchRequest) ([]ScoredVectorRow, error)
|
||||
Delete(ctx context.Context, ids []string) error
|
||||
Get(ctx context.Context, ids []string) ([]VectorRow, error)
|
||||
}
|
||||
|
||||
// CorpusAdapter 负责把业务语料映射成统一文档/过滤条件。
|
||||
type CorpusAdapter interface {
|
||||
Name() string
|
||||
BuildIngestDocuments(ctx context.Context, input any) ([]SourceDocument, error)
|
||||
BuildRetrieveFilter(ctx context.Context, req any) (map[string]any, error)
|
||||
}
|
||||
266
backend/infra/rag/core/pipeline.go
Normal file
266
backend/infra/rag/core/pipeline.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTopK = 8
|
||||
defaultThreshold = 0
|
||||
defaultChunkSize = 400
|
||||
defaultChunkOvLap = 80
|
||||
)
|
||||
|
||||
// Pipeline 是 RAG Core 编排器。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责统一 chunk/embed/retrieve/rerank 流程;
|
||||
// 2. 负责失败降级语义;
|
||||
// 3. 不承载任何具体业务语义(由 CorpusAdapter 提供)。
|
||||
type Pipeline struct {
|
||||
chunker Chunker
|
||||
embedder Embedder
|
||||
store VectorStore
|
||||
reranker Reranker
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func NewPipeline(chunker Chunker, embedder Embedder, store VectorStore, reranker Reranker) *Pipeline {
|
||||
return &Pipeline{
|
||||
chunker: chunker,
|
||||
embedder: embedder,
|
||||
store: store,
|
||||
reranker: reranker,
|
||||
logger: log.Default(),
|
||||
}
|
||||
}
|
||||
|
||||
// Ingest 执行统一入库流程。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先由 CorpusAdapter 生成统一文档,确保不同语料入口一致;
|
||||
// 2. 再统一切块与向量化,避免业务侧重复实现;
|
||||
// 3. 最后一次性 Upsert,失败直接返回,交由上层决定是否重试。
|
||||
func (p *Pipeline) Ingest(
|
||||
ctx context.Context,
|
||||
corpus CorpusAdapter,
|
||||
input any,
|
||||
opt IngestOption,
|
||||
) (*IngestResult, error) {
|
||||
if p == nil || p.chunker == nil || p.embedder == nil || p.store == nil {
|
||||
return nil, ErrNilDependency
|
||||
}
|
||||
if corpus == nil {
|
||||
return nil, errors.New("nil corpus adapter")
|
||||
}
|
||||
|
||||
docs, err := corpus.BuildIngestDocuments(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(docs) == 0 {
|
||||
return &IngestResult{DocumentCount: 0, ChunkCount: 0}, nil
|
||||
}
|
||||
|
||||
chunkOpt := normalizeChunkOption(opt.Chunk)
|
||||
chunks := make([]Chunk, 0, len(docs)*2)
|
||||
for _, doc := range docs {
|
||||
// 1. 对每个文档独立切块,失败直接中断,避免写入半成品。
|
||||
docChunks, chunkErr := p.chunker.Chunk(ctx, doc, chunkOpt)
|
||||
if chunkErr != nil {
|
||||
return nil, chunkErr
|
||||
}
|
||||
chunks = append(chunks, docChunks...)
|
||||
}
|
||||
if len(chunks) == 0 {
|
||||
return &IngestResult{DocumentCount: len(docs), ChunkCount: 0}, nil
|
||||
}
|
||||
|
||||
texts := make([]string, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
texts = append(texts, chunk.Text)
|
||||
}
|
||||
|
||||
action := strings.TrimSpace(opt.Action)
|
||||
if action == "" {
|
||||
action = "add"
|
||||
}
|
||||
vectors, err := p.embedder.Embed(ctx, texts, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vectors) != len(chunks) {
|
||||
return nil, fmt.Errorf("embedding result length mismatch: chunks=%d vectors=%d", len(chunks), len(vectors))
|
||||
}
|
||||
|
||||
rows := make([]VectorRow, 0, len(chunks))
|
||||
now := time.Now()
|
||||
for i, chunk := range chunks {
|
||||
metadata := cloneMap(chunk.Metadata)
|
||||
metadata["corpus"] = corpus.Name()
|
||||
metadata["document_id"] = chunk.DocumentID
|
||||
metadata["chunk_order"] = chunk.Order
|
||||
rows = append(rows, VectorRow{
|
||||
ID: chunk.ID,
|
||||
Vector: vectors[i],
|
||||
Text: chunk.Text,
|
||||
Metadata: metadata,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
})
|
||||
}
|
||||
|
||||
if err = p.store.Upsert(ctx, rows); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &IngestResult{
|
||||
DocumentCount: len(docs),
|
||||
ChunkCount: len(chunks),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Retrieve 执行统一检索流程。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先做 query 向量化与向量检索;
|
||||
// 2. 再执行阈值过滤,减少低质量候选;
|
||||
// 3. 最后可选 rerank,若失败则降级回原排序并打日志。
|
||||
func (p *Pipeline) Retrieve(
|
||||
ctx context.Context,
|
||||
corpus CorpusAdapter,
|
||||
req RetrieveRequest,
|
||||
) (*RetrieveResult, error) {
|
||||
if p == nil || p.embedder == nil || p.store == nil {
|
||||
return nil, ErrNilDependency
|
||||
}
|
||||
query := strings.TrimSpace(req.Query)
|
||||
if query == "" {
|
||||
return nil, ErrInvalidQuery
|
||||
}
|
||||
|
||||
topK := req.TopK
|
||||
if topK <= 0 {
|
||||
topK = defaultTopK
|
||||
}
|
||||
threshold := req.Threshold
|
||||
if threshold < 0 {
|
||||
threshold = defaultThreshold
|
||||
}
|
||||
|
||||
filter := cloneMap(req.Filter)
|
||||
if corpus != nil {
|
||||
// 1. 先拼接 corpus 过滤条件,避免跨语料串召回。
|
||||
corpusFilter, err := corpus.BuildRetrieveFilter(ctx, req.CorpusInput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filter = mergeMap(filter, corpusFilter)
|
||||
filter["corpus"] = corpus.Name()
|
||||
}
|
||||
|
||||
action := strings.TrimSpace(req.Action)
|
||||
if action == "" {
|
||||
action = "search"
|
||||
}
|
||||
|
||||
vectors, err := p.embedder.Embed(ctx, []string{query}, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vectors) != 1 {
|
||||
return nil, fmt.Errorf("embedding query length mismatch: %d", len(vectors))
|
||||
}
|
||||
|
||||
scoredRows, err := p.store.Search(ctx, VectorSearchRequest{
|
||||
QueryVector: vectors[0],
|
||||
TopK: topK,
|
||||
Filter: filter,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawCount := len(scoredRows)
|
||||
candidates := make([]ScoredChunk, 0, len(scoredRows))
|
||||
for _, row := range scoredRows {
|
||||
if row.Score < threshold {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, ScoredChunk{
|
||||
ChunkID: row.Row.ID,
|
||||
DocumentID: asString(row.Row.Metadata["document_id"]),
|
||||
Text: row.Row.Text,
|
||||
Score: row.Score,
|
||||
Metadata: cloneMap(row.Row.Metadata),
|
||||
})
|
||||
}
|
||||
|
||||
result := &RetrieveResult{
|
||||
Items: candidates,
|
||||
RawCount: rawCount,
|
||||
FallbackUsed: false,
|
||||
}
|
||||
if len(candidates) == 0 || p.reranker == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
reranked, rerankErr := p.reranker.Rerank(ctx, query, candidates, topK)
|
||||
if rerankErr != nil {
|
||||
// 2. rerank 异常不终止主流程,统一降级为原排序。
|
||||
result.FallbackUsed = true
|
||||
result.FallbackReason = FallbackReasonRerankFailed
|
||||
p.logger.Printf("rag rerank fallback: reason=%s err=%v", FallbackReasonRerankFailed, rerankErr)
|
||||
return result, nil
|
||||
}
|
||||
result.Items = reranked
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func normalizeChunkOption(opt ChunkOption) ChunkOption {
|
||||
if opt.ChunkSize <= 0 {
|
||||
opt.ChunkSize = defaultChunkSize
|
||||
}
|
||||
if opt.ChunkOverlap < 0 {
|
||||
opt.ChunkOverlap = 0
|
||||
}
|
||||
if opt.ChunkOverlap >= opt.ChunkSize {
|
||||
opt.ChunkOverlap = defaultChunkOvLap
|
||||
if opt.ChunkOverlap >= opt.ChunkSize {
|
||||
opt.ChunkOverlap = opt.ChunkSize / 5
|
||||
}
|
||||
}
|
||||
return opt
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for key, value := range src {
|
||||
dst[key] = value
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func mergeMap(base map[string]any, ext map[string]any) map[string]any {
|
||||
if base == nil {
|
||||
base = map[string]any{}
|
||||
}
|
||||
for key, value := range ext {
|
||||
base[key] = value
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
94
backend/infra/rag/core/types.go
Normal file
94
backend/infra/rag/core/types.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package core
|
||||
|
||||
import "time"
|
||||
|
||||
// SourceDocument 是统一语料文档模型。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只描述“可被切块与索引”的原始文档;
|
||||
// 2. 不承载业务流程状态。
|
||||
type SourceDocument struct {
|
||||
ID string
|
||||
Text string
|
||||
Title string
|
||||
Metadata map[string]any
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// Chunk 是标准切块结果。
|
||||
type Chunk struct {
|
||||
ID string
|
||||
DocumentID string
|
||||
Text string
|
||||
Order int
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// ChunkOption 控制切块参数。
|
||||
type ChunkOption struct {
|
||||
ChunkSize int
|
||||
ChunkOverlap int
|
||||
}
|
||||
|
||||
// IngestOption 控制入库参数。
|
||||
type IngestOption struct {
|
||||
Chunk ChunkOption
|
||||
// Action 用于 embedding 分型(add/update/search)。
|
||||
Action string
|
||||
}
|
||||
|
||||
// IngestResult 描述一次入库执行摘要。
|
||||
type IngestResult struct {
|
||||
DocumentCount int
|
||||
ChunkCount int
|
||||
}
|
||||
|
||||
// RetrieveRequest 是统一检索请求。
|
||||
type RetrieveRequest struct {
|
||||
Query string
|
||||
TopK int
|
||||
Threshold float64
|
||||
Action string
|
||||
Filter map[string]any
|
||||
CorpusInput any
|
||||
}
|
||||
|
||||
// ScoredChunk 是统一召回结果。
|
||||
type ScoredChunk struct {
|
||||
ChunkID string
|
||||
DocumentID string
|
||||
Text string
|
||||
Score float64
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// RetrieveResult 是检索链路执行摘要。
|
||||
type RetrieveResult struct {
|
||||
Items []ScoredChunk
|
||||
RawCount int
|
||||
FallbackUsed bool
|
||||
FallbackReason string
|
||||
}
|
||||
|
||||
// VectorRow 是向量存储标准行。
|
||||
type VectorRow struct {
|
||||
ID string
|
||||
Vector []float32
|
||||
Text string
|
||||
Metadata map[string]any
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// VectorSearchRequest 是向量检索请求。
|
||||
type VectorSearchRequest struct {
|
||||
QueryVector []float32
|
||||
TopK int
|
||||
Filter map[string]any
|
||||
}
|
||||
|
||||
// ScoredVectorRow 是向量检索结果。
|
||||
type ScoredVectorRow struct {
|
||||
Row VectorRow
|
||||
Score float64
|
||||
}
|
||||
13
backend/infra/rag/corpus/common.go
Normal file
13
backend/infra/rag/corpus/common.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package corpus
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func hashLikeText(text string) string {
|
||||
normalized := strings.TrimSpace(strings.ToLower(text))
|
||||
sum := sha256.Sum256([]byte(normalized))
|
||||
return hex.EncodeToString(sum[:8])
|
||||
}
|
||||
149
backend/infra/rag/corpus/memory_corpus.go
Normal file
149
backend/infra/rag/corpus/memory_corpus.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package corpus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
|
||||
)
|
||||
|
||||
const memoryCorpusName = "memory"
|
||||
|
||||
// MemoryIngestItem 是记忆语料入库项。
|
||||
type MemoryIngestItem struct {
|
||||
MemoryID int64
|
||||
UserID int
|
||||
ConversationID string
|
||||
AssistantID string
|
||||
RunID string
|
||||
MemoryType string
|
||||
Title string
|
||||
Content string
|
||||
SensitivityLevel int
|
||||
TTLAt *time.Time
|
||||
CreatedAt *time.Time
|
||||
}
|
||||
|
||||
// MemoryRetrieveInput 是记忆检索过滤输入。
|
||||
type MemoryRetrieveInput struct {
|
||||
UserID int
|
||||
ConversationID string
|
||||
AssistantID string
|
||||
RunID string
|
||||
MemoryType string
|
||||
}
|
||||
|
||||
// MemoryCorpus 是记忆语料适配器。
|
||||
type MemoryCorpus struct{}
|
||||
|
||||
func NewMemoryCorpus() *MemoryCorpus {
|
||||
return &MemoryCorpus{}
|
||||
}
|
||||
|
||||
func (c *MemoryCorpus) Name() string {
|
||||
return memoryCorpusName
|
||||
}
|
||||
|
||||
func (c *MemoryCorpus) BuildIngestDocuments(_ context.Context, input any) ([]core.SourceDocument, error) {
|
||||
items, err := toMemoryItems(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]core.SourceDocument, 0, len(items))
|
||||
for _, item := range items {
|
||||
if item.UserID <= 0 {
|
||||
return nil, errors.New("memory ingest item user_id is invalid")
|
||||
}
|
||||
text := strings.TrimSpace(item.Content)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
docID := fmt.Sprintf("memory:%d", item.MemoryID)
|
||||
if item.MemoryID <= 0 {
|
||||
docID = fmt.Sprintf("memory:uid:%d:%s", item.UserID, hashLikeText(text))
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"user_id": item.UserID,
|
||||
"conversation_id": strings.TrimSpace(item.ConversationID),
|
||||
"assistant_id": strings.TrimSpace(item.AssistantID),
|
||||
"run_id": strings.TrimSpace(item.RunID),
|
||||
"memory_type": strings.TrimSpace(strings.ToLower(item.MemoryType)),
|
||||
"sensitivity_level": item.SensitivityLevel,
|
||||
}
|
||||
if item.TTLAt != nil {
|
||||
metadata["ttl_at"] = item.TTLAt.Format(time.RFC3339)
|
||||
}
|
||||
createdAt := time.Now()
|
||||
if item.CreatedAt != nil {
|
||||
createdAt = *item.CreatedAt
|
||||
}
|
||||
result = append(result, core.SourceDocument{
|
||||
ID: docID,
|
||||
Text: text,
|
||||
Title: strings.TrimSpace(item.Title),
|
||||
Metadata: metadata,
|
||||
CreatedAt: createdAt,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *MemoryCorpus) BuildRetrieveFilter(_ context.Context, req any) (map[string]any, error) {
|
||||
input, ok := req.(MemoryRetrieveInput)
|
||||
if !ok {
|
||||
if ptr, isPtr := req.(*MemoryRetrieveInput); isPtr && ptr != nil {
|
||||
input = *ptr
|
||||
} else if req == nil {
|
||||
return nil, errors.New("memory retrieve input is nil")
|
||||
} else {
|
||||
return nil, errors.New("invalid memory retrieve input")
|
||||
}
|
||||
}
|
||||
if input.UserID <= 0 {
|
||||
return nil, errors.New("memory retrieve user_id is invalid")
|
||||
}
|
||||
filter := map[string]any{
|
||||
"user_id": input.UserID,
|
||||
}
|
||||
if v := strings.TrimSpace(input.ConversationID); v != "" {
|
||||
filter["conversation_id"] = v
|
||||
}
|
||||
if v := strings.TrimSpace(input.AssistantID); v != "" {
|
||||
filter["assistant_id"] = v
|
||||
}
|
||||
if v := strings.TrimSpace(input.RunID); v != "" {
|
||||
filter["run_id"] = v
|
||||
}
|
||||
if v := strings.TrimSpace(strings.ToLower(input.MemoryType)); v != "" {
|
||||
filter["memory_type"] = v
|
||||
}
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func toMemoryItems(input any) ([]MemoryIngestItem, error) {
|
||||
switch value := input.(type) {
|
||||
case MemoryIngestItem:
|
||||
return []MemoryIngestItem{value}, nil
|
||||
case *MemoryIngestItem:
|
||||
if value == nil {
|
||||
return nil, errors.New("memory ingest item is nil")
|
||||
}
|
||||
return []MemoryIngestItem{*value}, nil
|
||||
case []MemoryIngestItem:
|
||||
return value, nil
|
||||
case []*MemoryIngestItem:
|
||||
items := make([]MemoryIngestItem, 0, len(value))
|
||||
for _, ptr := range value {
|
||||
if ptr == nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, *ptr)
|
||||
}
|
||||
return items, nil
|
||||
default:
|
||||
return nil, errors.New("invalid memory ingest input")
|
||||
}
|
||||
}
|
||||
163
backend/infra/rag/corpus/web_corpus.go
Normal file
163
backend/infra/rag/corpus/web_corpus.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package corpus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
|
||||
)
|
||||
|
||||
const webCorpusName = "web"
|
||||
|
||||
// WebIngestItem 是网页语料入库项。
|
||||
type WebIngestItem struct {
|
||||
URL string
|
||||
Title string
|
||||
Content string
|
||||
Snippet string
|
||||
Domain string
|
||||
QueryID string
|
||||
SessionID string
|
||||
PublishedAt *time.Time
|
||||
FetchedAt *time.Time
|
||||
SourceRank int
|
||||
}
|
||||
|
||||
// WebRetrieveInput 是网页检索过滤输入。
|
||||
type WebRetrieveInput struct {
|
||||
QueryID string
|
||||
SessionID string
|
||||
Domain string
|
||||
}
|
||||
|
||||
// WebCorpus 是网页语料适配器。
|
||||
type WebCorpus struct{}
|
||||
|
||||
func NewWebCorpus() *WebCorpus {
|
||||
return &WebCorpus{}
|
||||
}
|
||||
|
||||
func (c *WebCorpus) Name() string {
|
||||
return webCorpusName
|
||||
}
|
||||
|
||||
func (c *WebCorpus) BuildIngestDocuments(_ context.Context, input any) ([]core.SourceDocument, error) {
|
||||
items, err := toWebItems(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]core.SourceDocument, 0, len(items))
|
||||
for _, item := range items {
|
||||
url := strings.TrimSpace(item.URL)
|
||||
if url == "" {
|
||||
return nil, errors.New("web ingest item url is empty")
|
||||
}
|
||||
|
||||
mainText := buildWebText(item)
|
||||
if strings.TrimSpace(mainText) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
docID := fmt.Sprintf("web:%s", hashLikeText(url+"|"+mainText))
|
||||
metadata := map[string]any{
|
||||
"url": url,
|
||||
"domain": strings.TrimSpace(item.Domain),
|
||||
"query_id": strings.TrimSpace(item.QueryID),
|
||||
"session_id": strings.TrimSpace(item.SessionID),
|
||||
"source_rank": item.SourceRank,
|
||||
}
|
||||
if item.PublishedAt != nil {
|
||||
metadata["published_at"] = item.PublishedAt.Format(time.RFC3339)
|
||||
}
|
||||
if item.FetchedAt != nil {
|
||||
metadata["fetched_at"] = item.FetchedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
createdAt := time.Now()
|
||||
if item.FetchedAt != nil {
|
||||
createdAt = *item.FetchedAt
|
||||
}
|
||||
result = append(result, core.SourceDocument{
|
||||
ID: docID,
|
||||
Text: mainText,
|
||||
Title: strings.TrimSpace(item.Title),
|
||||
Metadata: metadata,
|
||||
CreatedAt: createdAt,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *WebCorpus) BuildRetrieveFilter(_ context.Context, req any) (map[string]any, error) {
|
||||
input, ok := req.(WebRetrieveInput)
|
||||
if !ok {
|
||||
if ptr, isPtr := req.(*WebRetrieveInput); isPtr && ptr != nil {
|
||||
input = *ptr
|
||||
} else if req == nil {
|
||||
return nil, errors.New("web retrieve input is nil")
|
||||
} else {
|
||||
return nil, errors.New("invalid web retrieve input")
|
||||
}
|
||||
}
|
||||
|
||||
// 1. query_id/session_id 至少要有一个,避免跨问题串数据。
|
||||
queryID := strings.TrimSpace(input.QueryID)
|
||||
sessionID := strings.TrimSpace(input.SessionID)
|
||||
if queryID == "" && sessionID == "" {
|
||||
return nil, errors.New("web retrieve filter requires query_id or session_id")
|
||||
}
|
||||
|
||||
filter := map[string]any{}
|
||||
if queryID != "" {
|
||||
filter["query_id"] = queryID
|
||||
}
|
||||
if sessionID != "" {
|
||||
filter["session_id"] = sessionID
|
||||
}
|
||||
if domain := strings.TrimSpace(input.Domain); domain != "" {
|
||||
filter["domain"] = domain
|
||||
}
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func toWebItems(input any) ([]WebIngestItem, error) {
|
||||
switch value := input.(type) {
|
||||
case WebIngestItem:
|
||||
return []WebIngestItem{value}, nil
|
||||
case *WebIngestItem:
|
||||
if value == nil {
|
||||
return nil, errors.New("web ingest item is nil")
|
||||
}
|
||||
return []WebIngestItem{*value}, nil
|
||||
case []WebIngestItem:
|
||||
return value, nil
|
||||
case []*WebIngestItem:
|
||||
items := make([]WebIngestItem, 0, len(value))
|
||||
for _, ptr := range value {
|
||||
if ptr == nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, *ptr)
|
||||
}
|
||||
return items, nil
|
||||
default:
|
||||
return nil, errors.New("invalid web ingest input")
|
||||
}
|
||||
}
|
||||
|
||||
func buildWebText(item WebIngestItem) string {
|
||||
parts := make([]string, 0, 3)
|
||||
if title := strings.TrimSpace(item.Title); title != "" {
|
||||
parts = append(parts, title)
|
||||
}
|
||||
if snippet := strings.TrimSpace(item.Snippet); snippet != "" {
|
||||
parts = append(parts, snippet)
|
||||
}
|
||||
if content := strings.TrimSpace(item.Content); content != "" {
|
||||
parts = append(parts, content)
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
21
backend/infra/rag/embed/eino_embedder.go
Normal file
21
backend/infra/rag/embed/eino_embedder.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// EinoEmbedder 是 Eino embedding 的占位实现。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 本轮先占位接口,避免过早耦合具体 Provider;
|
||||
// 2. 后续接入真实 embedding 时,只替换此文件内部实现。
|
||||
type EinoEmbedder struct{}
|
||||
|
||||
func NewEinoEmbedder() *EinoEmbedder {
|
||||
return &EinoEmbedder{}
|
||||
}
|
||||
|
||||
func (e *EinoEmbedder) Embed(_ context.Context, _ []string, _ string) ([][]float32, error) {
|
||||
return nil, errors.New("eino embedder is not implemented yet")
|
||||
}
|
||||
46
backend/infra/rag/embed/mock_embedder.go
Normal file
46
backend/infra/rag/embed/mock_embedder.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const defaultDim = 16
|
||||
|
||||
// MockEmbedder 是本地可运行的占位向量化实现。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 该实现用于开发阶段打通链路,不代表真实语义向量质量;
|
||||
// 2. 后续可替换为 Eino embedding 实现,接口保持不变。
|
||||
type MockEmbedder struct {
|
||||
dim int
|
||||
}
|
||||
|
||||
func NewMockEmbedder(dim int) *MockEmbedder {
|
||||
if dim <= 0 {
|
||||
dim = defaultDim
|
||||
}
|
||||
return &MockEmbedder{dim: dim}
|
||||
}
|
||||
|
||||
func (e *MockEmbedder) Embed(_ context.Context, texts []string, _ string) ([][]float32, error) {
|
||||
result := make([][]float32, 0, len(texts))
|
||||
for _, text := range texts {
|
||||
result = append(result, e.embedOne(text))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (e *MockEmbedder) embedOne(text string) []float32 {
|
||||
normalized := strings.TrimSpace(strings.ToLower(text))
|
||||
sum := sha256.Sum256([]byte(normalized))
|
||||
vec := make([]float32, e.dim)
|
||||
for i := 0; i < e.dim; i++ {
|
||||
offset := (i * 4) % len(sum)
|
||||
v := binary.BigEndian.Uint32(sum[offset : offset+4])
|
||||
vec[i] = float32(v%1000) / 1000
|
||||
}
|
||||
return vec
|
||||
}
|
||||
23
backend/infra/rag/rag.go
Normal file
23
backend/infra/rag/rag.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/chunk"
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/embed"
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/rerank"
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/store"
|
||||
)
|
||||
|
||||
// NewDefaultPipeline 构造默认可运行的 RAG Pipeline。
|
||||
//
|
||||
// 当前策略:
|
||||
// 1. 默认使用本地 MockEmbedder + InMemoryStore,保证零外部依赖可运行;
|
||||
// 2. 后续切 Milvus / Eino 时仅替换依赖,不改业务调用方式。
|
||||
func NewDefaultPipeline() *core.Pipeline {
|
||||
return core.NewPipeline(
|
||||
chunk.NewTextChunker(),
|
||||
embed.NewMockEmbedder(16),
|
||||
store.NewInMemoryVectorStore(),
|
||||
rerank.NewNoopReranker(),
|
||||
)
|
||||
}
|
||||
19
backend/infra/rag/rerank/eino_reranker.go
Normal file
19
backend/infra/rag/rerank/eino_reranker.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
|
||||
)
|
||||
|
||||
// EinoReranker 是 Eino 重排器占位实现。
|
||||
type EinoReranker struct{}
|
||||
|
||||
func NewEinoReranker() *EinoReranker {
|
||||
return &EinoReranker{}
|
||||
}
|
||||
|
||||
func (r *EinoReranker) Rerank(_ context.Context, _ string, _ []core.ScoredChunk, _ int) ([]core.ScoredChunk, error) {
|
||||
return nil, errors.New("eino reranker is not implemented yet")
|
||||
}
|
||||
30
backend/infra/rag/rerank/noop_reranker.go
Normal file
30
backend/infra/rag/rerank/noop_reranker.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
|
||||
)
|
||||
|
||||
// NoopReranker 是默认重排器(仅按原 score 排序)。
|
||||
type NoopReranker struct{}
|
||||
|
||||
func NewNoopReranker() *NoopReranker {
|
||||
return &NoopReranker{}
|
||||
}
|
||||
|
||||
func (r *NoopReranker) Rerank(_ context.Context, _ string, candidates []core.ScoredChunk, topK int) ([]core.ScoredChunk, error) {
|
||||
if len(candidates) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
sorted := make([]core.ScoredChunk, len(candidates))
|
||||
copy(sorted, candidates)
|
||||
sort.SliceStable(sorted, func(i, j int) bool {
|
||||
return sorted[i].Score > sorted[j].Score
|
||||
})
|
||||
if topK <= 0 || topK >= len(sorted) {
|
||||
return sorted, nil
|
||||
}
|
||||
return sorted[:topK], nil
|
||||
}
|
||||
90
backend/infra/rag/retrieve/vector_retriever.go
Normal file
90
backend/infra/rag/retrieve/vector_retriever.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package retrieve
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
|
||||
)
|
||||
|
||||
// VectorRetriever 是通用检索器(embed + vector search)。
|
||||
type VectorRetriever struct {
|
||||
embedder core.Embedder
|
||||
store core.VectorStore
|
||||
}
|
||||
|
||||
func NewVectorRetriever(embedder core.Embedder, store core.VectorStore) *VectorRetriever {
|
||||
return &VectorRetriever{
|
||||
embedder: embedder,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *VectorRetriever) Retrieve(ctx context.Context, req core.RetrieveRequest) ([]core.ScoredChunk, error) {
|
||||
if r == nil || r.embedder == nil || r.store == nil {
|
||||
return nil, core.ErrNilDependency
|
||||
}
|
||||
query := strings.TrimSpace(req.Query)
|
||||
if query == "" {
|
||||
return nil, core.ErrInvalidQuery
|
||||
}
|
||||
topK := req.TopK
|
||||
if topK <= 0 {
|
||||
topK = 8
|
||||
}
|
||||
action := strings.TrimSpace(req.Action)
|
||||
if action == "" {
|
||||
action = "search"
|
||||
}
|
||||
|
||||
vectors, err := r.embedder.Embed(ctx, []string{query}, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vectors) != 1 {
|
||||
return nil, fmt.Errorf("embedding query length mismatch: %d", len(vectors))
|
||||
}
|
||||
|
||||
rows, err := r.store.Search(ctx, core.VectorSearchRequest{
|
||||
QueryVector: vectors[0],
|
||||
TopK: topK,
|
||||
Filter: req.Filter,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]core.ScoredChunk, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
if row.Score < req.Threshold {
|
||||
continue
|
||||
}
|
||||
result = append(result, core.ScoredChunk{
|
||||
ChunkID: row.Row.ID,
|
||||
DocumentID: asString(row.Row.Metadata["document_id"]),
|
||||
Text: row.Row.Text,
|
||||
Score: row.Score,
|
||||
Metadata: cloneMap(row.Row.Metadata),
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
166
backend/infra/rag/store/inmemory_store.go
Normal file
166
backend/infra/rag/store/inmemory_store.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
|
||||
)
|
||||
|
||||
// InMemoryVectorStore 是本地开发用向量存储实现。
|
||||
//
|
||||
// 注意:
|
||||
// 1. 仅用于开发调试,不建议生产使用;
|
||||
// 2. 真实环境可替换为 MilvusStore,接口保持一致。
|
||||
type InMemoryVectorStore struct {
|
||||
mu sync.RWMutex
|
||||
rows map[string]core.VectorRow
|
||||
}
|
||||
|
||||
func NewInMemoryVectorStore() *InMemoryVectorStore {
|
||||
return &InMemoryVectorStore{
|
||||
rows: make(map[string]core.VectorRow),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *InMemoryVectorStore) Upsert(_ context.Context, rows []core.VectorRow) error {
|
||||
if len(rows) == 0 {
|
||||
return nil
|
||||
}
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, row := range rows {
|
||||
current, exists := s.rows[row.ID]
|
||||
if exists {
|
||||
row.CreatedAt = current.CreatedAt
|
||||
row.UpdatedAt = now
|
||||
} else {
|
||||
if row.CreatedAt.IsZero() {
|
||||
row.CreatedAt = now
|
||||
}
|
||||
row.UpdatedAt = now
|
||||
}
|
||||
s.rows[row.ID] = row
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InMemoryVectorStore) Search(_ context.Context, req core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
|
||||
topK := req.TopK
|
||||
if topK <= 0 {
|
||||
topK = 8
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
result := make([]core.ScoredVectorRow, 0, len(s.rows))
|
||||
for _, row := range s.rows {
|
||||
if !matchMetadataFilter(row.Metadata, req.Filter) {
|
||||
continue
|
||||
}
|
||||
score := cosineSimilarity(req.QueryVector, row.Vector)
|
||||
result = append(result, core.ScoredVectorRow{
|
||||
Row: row,
|
||||
Score: score,
|
||||
})
|
||||
}
|
||||
sort.SliceStable(result, func(i, j int) bool {
|
||||
return result[i].Score > result[j].Score
|
||||
})
|
||||
if len(result) <= topK {
|
||||
return result, nil
|
||||
}
|
||||
return result[:topK], nil
|
||||
}
|
||||
|
||||
func (s *InMemoryVectorStore) Delete(_ context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, id := range ids {
|
||||
delete(s.rows, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InMemoryVectorStore) Get(_ context.Context, ids []string) ([]core.VectorRow, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
result := make([]core.VectorRow, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
row, exists := s.rows[id]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
result = append(result, row)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func cosineSimilarity(a, b []float32) float64 {
|
||||
if len(a) == 0 || len(b) == 0 {
|
||||
return 0
|
||||
}
|
||||
n := len(a)
|
||||
if len(b) < n {
|
||||
n = len(b)
|
||||
}
|
||||
if n == 0 {
|
||||
return 0
|
||||
}
|
||||
var dot, normA, normB float64
|
||||
for i := 0; i < n; i++ {
|
||||
av := float64(a[i])
|
||||
bv := float64(b[i])
|
||||
dot += av * bv
|
||||
normA += av * av
|
||||
normB += bv * bv
|
||||
}
|
||||
if normA == 0 || normB == 0 {
|
||||
return 0
|
||||
}
|
||||
return dot / (math.Sqrt(normA) * math.Sqrt(normB))
|
||||
}
|
||||
|
||||
func matchMetadataFilter(metadata map[string]any, filter map[string]any) bool {
|
||||
if len(filter) == 0 {
|
||||
return true
|
||||
}
|
||||
for key, wanted := range filter {
|
||||
got, exists := metadata[key]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
if !equalAny(got, wanted) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func equalAny(left any, right any) bool {
|
||||
return toString(left) == toString(right)
|
||||
}
|
||||
|
||||
func toString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return fmtAny(v)
|
||||
}
|
||||
|
||||
func fmtAny(v any) string {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
35
backend/infra/rag/store/milvus_store.go
Normal file
35
backend/infra/rag/store/milvus_store.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
|
||||
)
|
||||
|
||||
// MilvusStore 是 Milvus 连接器占位实现。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 本轮先保留接口结构,便于后续平滑替换 InMemoryStore;
|
||||
// 2. 真实接入时需补充连接池、集合初始化、元数据过滤与错误转换。
|
||||
type MilvusStore struct{}
|
||||
|
||||
func NewMilvusStore() *MilvusStore {
|
||||
return &MilvusStore{}
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Upsert(_ context.Context, _ []core.VectorRow) error {
|
||||
return errors.New("milvus store is not implemented yet")
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Search(_ context.Context, _ core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
|
||||
return nil, errors.New("milvus store is not implemented yet")
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Delete(_ context.Context, _ []string) error {
|
||||
return errors.New("milvus store is not implemented yet")
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Get(_ context.Context, _ []string) ([]core.VectorRow, error) {
|
||||
return nil, errors.New("milvus store is not implemented yet")
|
||||
}
|
||||
8
backend/infra/rag/store/vector_store.go
Normal file
8
backend/infra/rag/store/vector_store.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package store
|
||||
|
||||
import "github.com/LoveLosita/smartflow/backend/infra/rag/core"
|
||||
|
||||
// EnsureCompile 用于静态校验实现是否满足接口。
|
||||
func EnsureCompile() {
|
||||
var _ core.VectorStore = (*InMemoryVectorStore)(nil)
|
||||
}
|
||||
Reference in New Issue
Block a user