Version: 0.9.16.dev.260413

后端:
1. RAG embedding 接入修正,并兼容 Ark 多模态 embedding 链路
   - 更新 backend/infra/rag/embed/eino_embedder.go:文本 embedding 继续走 Eino OpenAI 兼容链路;`doubao-embedding-vision-*` 模型切到 Ark 原生 `/embeddings/multimodal`
   - 增加 embedding baseURL 归一化:兼容把 `.../embeddings` 或 `.../embeddings/multimodal` 误填进配置的情况,统一回退到 `/api/v3`
   - 为第三方 embedding 调用增加 panic recover,避免向量检索/写入异常直接打崩主进程

2. RAG runtime / pipeline / store 稳定性加固,统一降级为 error 语义
   - 更新 backend/infra/rag/runtime.go:runtime 对外入口增加 panic recover 与观测打点
   - 更新 backend/infra/rag/core/pipeline.go:ingest / retrieve 编排边界增加 panic recover
   - 更新 backend/infra/rag/retrieve/vector_retriever.go:向量检索边界补充 panic recover
   - 更新 backend/infra/rag/store/milvus_store.go、backend/infra/rag/store/inmemory_store.go:补齐未初始化保护,避免 nil 依赖直接异常退出

3. RAG embedding 配置口径与普通 LLM 链路对齐
   - 更新 backend/infra/rag/factory.go:RAG embedding API Key 不再走 `apiKeyEnv` 间接映射,统一直接读取 `ARK_API_KEY`
   - 更新 backend/infra/rag/config/config.go:删除 `rag.embed.apiKeyEnv` 配置字段,收敛配置分叉
   - 更新 backend/config.example.yaml:示例配置切到当前联调口径,保持 `rag.enabled=true`、`memory.rag.enabled=true`,并对齐 Milvus / embed 配置

4. Memory + RAG 联调链路可运行态修正
   - 当前已验证 memory 抽取写库、RAG ingest 写入 Milvus、后续语义召回链路可继续联调
   - 检索失败场景已从“直接 panic”收敛为“记录日志并降级”,不再阻断主聊天链路

前端:无
仓库:无

undo:
1. 增删改查的 mysql 记忆去重没实现
2. 提取用户话为记忆的过滤机制不足,有点无脑
3. RAG 召回也有问题
This commit is contained in:
Losita
2026-04-13 23:18:59 +08:00
parent 070d4c3459
commit 863cba4e4e
9 changed files with 297 additions and 53 deletions

View File

@@ -49,15 +49,15 @@ time:
semesterEndDate: "2026-07-19" #学期结束日期,一定要设定为周日,确保最后一周完整
agent:
workerModel: "doubao-seed-1-6-lite-251015" # 智能体使用的Worker模型需根据实际情况调整
strategistModel: "deepseek-v3-2-251201" # 策略师使用的Worker模型需根据实际情况调整
workerModel: "doubao-seed-2-0-code-preview-260215" # 智能体使用的Worker模型需根据实际情况调整
strategistModel: "doubao-seed-2-0-code-preview-260215" # 策略师使用的Worker模型需根据实际情况调整
baseURL: "https://ark.cn-beijing.volces.com/api/v3" # Worker服务的基础URL需根据实际情况调整
dailyRefineConcurrency: 3 # 日内并发优化并发度,建议按模型配额调整
dailyRefineConcurrency: 7 # 日内并发优化并发度,建议按模型配额调整
weeklyAdjustBudget: 5 # 周级跨天配平额度上限,防止过度调整
rag:
enabled: false
store: "inmemory" # 可选inmemory / milvus
enabled: true
store: "milvus" # 可选inmemory / milvus
topK: 8
threshold: 0.55
retrieve:
@@ -66,16 +66,14 @@ rag:
chunkSize: 400
chunkOverlap: 80
embed:
provider: "mock" # 可选mock / eino
model: "" # 例如 Ark/OpenAI 兼容 embedding 模型名
baseURL: "https://ark.cn-beijing.volces.com/api/v3"
apiKeyEnv: "ARK_API_KEY"
provider: "eino" # 可选mock / eino
model: "doubao-embedding-vision-251215" # 例如 Ark/OpenAI 兼容 embedding 模型名
baseURL: "https://ark.cn-beijing.volces.com/api/v3" # 这里填服务根路径SDK 会自动拼接 /embeddingsAPI Key 统一从环境变量 ARK_API_KEY 读取
timeoutMs: 1200
dimension: 1024
reranker:
enabled: false
provider: "noop" # 当前默认 noop后续可扩展
timeoutMs: 1200
milvus:
address: "http://localhost:19530" # Milvus REST 入口,当前联调确认不要填 9091 健康检查口
token: "root:Milvus"
@@ -87,7 +85,7 @@ rag:
memory:
enabled: true
rag:
enabled: false
enabled: true
prompt:
extract: ""
decision: ""
@@ -103,7 +101,7 @@ memory:
claimBatch: 1
websearch:
provider: mock # 可选mock | bochamock 为空实现,跑通链路用)
provider: bocha # 可选mock | bochamock 为空实现,跑通链路用)
apiKey: "" # 搜索供应商 API Keybocha 模式必填,否则降级为 mock
timeout: 10s # 单次搜索请求超时
fetchTimeout: 15s # 单次 URL 抓取超时

View File

@@ -13,7 +13,6 @@ type Config struct {
EmbedProvider string
EmbedModel string
EmbedBaseURL string
EmbedAPIKeyEnv string
EmbedTimeoutMS int
EmbedDimension int
@@ -44,7 +43,6 @@ func LoadFromViper() Config {
EmbedProvider: viper.GetString("rag.embed.provider"),
EmbedModel: viper.GetString("rag.embed.model"),
EmbedBaseURL: viper.GetString("rag.embed.baseURL"),
EmbedAPIKeyEnv: viper.GetString("rag.embed.apiKeyEnv"),
EmbedTimeoutMS: viper.GetInt("rag.embed.timeoutMs"),
EmbedDimension: viper.GetInt("rag.embed.dimension"),
RerankerEnabled: viper.GetBool("rag.reranker.enabled"),
@@ -75,9 +73,6 @@ func LoadFromViper() Config {
if cfg.EmbedBaseURL == "" {
cfg.EmbedBaseURL = viper.GetString("agent.baseURL")
}
if cfg.EmbedAPIKeyEnv == "" {
cfg.EmbedAPIKeyEnv = "ARK_API_KEY"
}
if cfg.EmbedTimeoutMS <= 0 {
cfg.EmbedTimeoutMS = 1200
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log"
"runtime/debug"
"strings"
"time"
)
@@ -69,7 +70,9 @@ func (p *Pipeline) Ingest(
corpus CorpusAdapter,
input any,
opt IngestOption,
) (*IngestResult, error) {
) (result *IngestResult, err error) {
defer p.recoverExecutionPanic(ctx, "ingest", &err)
if p == nil || p.chunker == nil || p.embedder == nil || p.store == nil {
return nil, ErrNilDependency
}
@@ -95,7 +98,9 @@ func (p *Pipeline) IngestDocuments(
corpusName string,
docs []SourceDocument,
opt IngestOption,
) (*IngestResult, error) {
) (result *IngestResult, err error) {
defer p.recoverExecutionPanic(ctx, "ingest_documents", &err)
if p == nil || p.chunker == nil || p.embedder == nil || p.store == nil {
return nil, ErrNilDependency
}
@@ -170,7 +175,9 @@ func (p *Pipeline) Retrieve(
ctx context.Context,
corpus CorpusAdapter,
req RetrieveRequest,
) (*RetrieveResult, error) {
) (result *RetrieveResult, err error) {
defer p.recoverExecutionPanic(ctx, "retrieve", &err)
if p == nil || p.embedder == nil || p.store == nil {
return nil, ErrNilDependency
}
@@ -236,7 +243,7 @@ func (p *Pipeline) Retrieve(
})
}
result := &RetrieveResult{
result = &RetrieveResult{
Items: candidates,
RawCount: rawCount,
FallbackUsed: false,
@@ -273,6 +280,39 @@ func (p *Pipeline) Retrieve(
return result, nil
}
func (p *Pipeline) recoverExecutionPanic(ctx context.Context, operation string, errPtr *error) {
recovered := recover()
if recovered == nil || errPtr == nil {
return
}
panicErr := fmt.Errorf("rag pipeline panic recovered: operation=%s panic=%v", operation, recovered)
*errPtr = panicErr
// 1. Pipeline 是 chunk/embed/store/rerank 的统一编排边界,第三方依赖异常不应直接杀掉上层请求。
// 2. 这里统一 recover 后继续走 error 语义,让 runtime/service 决定降级、回退或记日志。
// 3. stack 只写观测层,不塞进返回值,避免把超长堆栈直接暴露给上层业务错误文案。
if p != nil && p.observer != nil {
p.observer.Observe(ctx, ObserveEvent{
Level: ObserveLevelError,
Component: "pipeline",
Operation: operation + "_panic_recovered",
Fields: map[string]any{
"status": "failed",
"panic": fmt.Sprintf("%v", recovered),
"panic_type": fmt.Sprintf("%T", recovered),
"error": panicErr,
"error_code": ClassifyErrorCode(panicErr),
"stack": string(debug.Stack()),
},
})
return
}
if p != nil && p.logger != nil {
p.logger.Printf("rag pipeline panic recovered: operation=%s panic=%v stack=%s", operation, recovered, string(debug.Stack()))
}
}
func normalizeChunkOption(opt ChunkOption) ChunkOption {
if opt.ChunkSize <= 0 {
opt.ChunkSize = defaultChunkSize

View File

@@ -3,11 +3,15 @@ package embed
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
openaiembedding "github.com/cloudwego/eino-ext/libs/acl/openai"
einoembedding "github.com/cloudwego/eino/components/embedding"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
arkmodel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
)
// EinoConfig 描述 Eino embedding 运行参数。
@@ -22,14 +26,15 @@ type EinoConfig struct {
// EinoEmbedder 是基于 Eino 的 embedding 适配器。
//
// 说明:
// 1. 对 infra/rag 暴露统一 []float32 结果,屏蔽 Eino/OpenAI 兼容实现细节;
// 2. 超时由该适配器自身收口,避免业务侧每次调用都手写超时控制;
// 3. 当前底层走 Eino Ext 的 OpenAI 兼容 embedding client便于接 Ark/OpenAI 兼容接口
// 1. 对 infra/rag 暴露统一 []float32 结果,屏蔽底层 SDK 的实现差异。
// 2. 文本 embedding 继续走当前稳定的 OpenAI 兼容链路,避免无关模型受影响。
// 3. 多模态 embedding 模型单独走 Ark 原生 `/embeddings/multimodal`,解决 vision 模型与标准 `/embeddings` 不兼容的问题
type EinoEmbedder struct {
client einoembedding.Embedder
model string
timeout time.Duration
dimension int
textClient einoembedding.Embedder
multimodalClient *arkruntime.Client
model string
timeout time.Duration
dimension int
}
func NewEinoEmbedder(ctx context.Context, cfg EinoConfig) (*EinoEmbedder, error) {
@@ -40,10 +45,42 @@ func NewEinoEmbedder(ctx context.Context, cfg EinoConfig) (*EinoEmbedder, error)
return nil, errors.New("eino embedder model is empty")
}
timeout := 1200 * time.Millisecond
if cfg.TimeoutMS > 0 {
timeout = time.Duration(cfg.TimeoutMS) * time.Millisecond
}
baseURL := normalizeEmbeddingBaseURL(cfg.BaseURL)
model := strings.TrimSpace(cfg.Model)
httpClient := &http.Client{Timeout: timeout}
// 1. `doubao-embedding-vision-*` 这类模型不支持标准 `/embeddings`。
// 2. 这里直接切到 Ark 原生多模态 embedding API避免再依赖错误 endpoint 拼接。
// 3. 之所以仍保留文本链路,是为了不影响普通 text embedding 模型的既有行为。
if isMultimodalEmbeddingModel(model) {
arkOptions := []arkruntime.ConfigOption{
arkruntime.WithHTTPClient(httpClient),
}
if baseURL != "" {
arkOptions = append(arkOptions, arkruntime.WithBaseUrl(baseURL))
}
return &EinoEmbedder{
multimodalClient: arkruntime.NewClientWithApiKey(
strings.TrimSpace(cfg.APIKey),
arkOptions...,
),
model: model,
timeout: timeout,
dimension: cfg.Dimension,
}, nil
}
clientCfg := &openaiembedding.EmbeddingConfig{
APIKey: strings.TrimSpace(cfg.APIKey),
BaseURL: strings.TrimSpace(cfg.BaseURL),
Model: strings.TrimSpace(cfg.Model),
APIKey: strings.TrimSpace(cfg.APIKey),
BaseURL: baseURL,
Model: model,
HTTPClient: httpClient,
}
if cfg.Dimension > 0 {
clientCfg.Dimensions = &cfg.Dimension
@@ -54,21 +91,16 @@ func NewEinoEmbedder(ctx context.Context, cfg EinoConfig) (*EinoEmbedder, error)
return nil, err
}
timeout := 1200 * time.Millisecond
if cfg.TimeoutMS > 0 {
timeout = time.Duration(cfg.TimeoutMS) * time.Millisecond
}
return &EinoEmbedder{
client: client,
model: strings.TrimSpace(cfg.Model),
timeout: timeout,
dimension: cfg.Dimension,
textClient: client,
model: model,
timeout: timeout,
dimension: cfg.Dimension,
}, nil
}
func (e *EinoEmbedder) Embed(ctx context.Context, texts []string, _ string) ([][]float32, error) {
if e == nil || e.client == nil {
func (e *EinoEmbedder) Embed(ctx context.Context, texts []string, _ string) (result [][]float32, err error) {
if e == nil {
return nil, errors.New("eino embedder is not initialized")
}
if len(texts) == 0 {
@@ -82,12 +114,29 @@ func (e *EinoEmbedder) Embed(ctx context.Context, texts []string, _ string) ([][
}
defer cancel()
vectors, err := e.client.EmbedStrings(callCtx, texts, einoembedding.WithModel(e.model))
// 1. 第三方 SDK 一旦 panic不应该穿透到 RAG 主链路。
// 2. 这里统一在模型调用边界 recover并转成 error 交给上层做降级。
// 3. 这样 memory 主写链路和 agent 主回复链路都不会因为向量同步失败被直接打崩。
defer func() {
if recovered := recover(); recovered != nil {
err = fmt.Errorf("eino embedder panic recovered: %v", recovered)
result = nil
}
}()
if e.multimodalClient != nil {
return e.embedTextsWithMultimodalAPI(callCtx, texts)
}
if e.textClient == nil {
return nil, errors.New("eino embedder client is not initialized")
}
vectors, err := e.textClient.EmbedStrings(callCtx, texts, einoembedding.WithModel(e.model))
if err != nil {
return nil, err
}
result := make([][]float32, 0, len(vectors))
result = make([][]float32, 0, len(vectors))
for _, vector := range vectors {
converted := make([]float32, len(vector))
for i, value := range vector {
@@ -97,3 +146,63 @@ func (e *EinoEmbedder) Embed(ctx context.Context, texts []string, _ string) ([][
}
return result, nil
}
func (e *EinoEmbedder) embedTextsWithMultimodalAPI(ctx context.Context, texts []string) ([][]float32, error) {
if e.multimodalClient == nil {
return nil, errors.New("eino multimodal embedder client is not initialized")
}
vectors := make([][]float32, 0, len(texts))
for _, text := range texts {
text := text
req := arkmodel.MultiModalEmbeddingRequest{
Model: e.model,
Input: []arkmodel.MultimodalEmbeddingInput{
{
Type: arkmodel.MultiModalEmbeddingInputTypeText,
Text: &text,
},
},
}
if e.dimension > 0 {
req.Dimensions = &e.dimension
}
// 1. Ark 的多模态 embedding 请求体是“单条内容由多个 part 组成”。
// 2. 当前 RAG 这里只传文本,因此每段文本单独发一次,避免把多段文本错误拼成同一个 multimodal sample。
// 3. 一旦后续真的要做批量多模态 embedding再单独扩展 batch 接口,而不是在这里偷改语义。
resp, err := e.multimodalClient.CreateMultiModalEmbeddings(ctx, req)
if err != nil {
return nil, err
}
converted := make([]float32, len(resp.Data.Embedding))
copy(converted, resp.Data.Embedding)
vectors = append(vectors, converted)
}
return vectors, nil
}
func isMultimodalEmbeddingModel(model string) bool {
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "doubao-embedding-vision-")
}
func normalizeEmbeddingBaseURL(raw string) string {
baseURL := strings.TrimRight(strings.TrimSpace(raw), "/")
if baseURL == "" {
return ""
}
lowerBaseURL := strings.ToLower(baseURL)
// 1. 配置里应填写 Ark 服务根路径,而不是具体 embedding endpoint。
// 2. 这里兼容两类常见误配:`/embeddings` 和 `/embeddings/multimodal`。
// 3. 统一回退到 `/api/v3` 根路径后,再由对应 SDK 自己追加正确后缀,避免最终 URL 重复拼接。
if strings.HasSuffix(lowerBaseURL, "/embeddings/multimodal") {
return strings.TrimSuffix(baseURL, baseURL[len(baseURL)-len("/embeddings/multimodal"):])
}
if strings.HasSuffix(lowerBaseURL, "/embeddings") {
return strings.TrimSuffix(baseURL, baseURL[len(baseURL)-len("/embeddings"):])
}
return baseURL
}

View File

@@ -73,9 +73,12 @@ func buildEmbedder(ctx context.Context, cfg ragconfig.Config) (core.Embedder, er
case "", "mock":
return ragembed.NewMockEmbedder(cfg.EmbedDimension), nil
case "eino":
apiKey := strings.TrimSpace(os.Getenv(cfg.EmbedAPIKeyEnv))
// 1. RAG embedding 与普通 LLM 链路保持同一套密钥来源,统一直接读取 ARK_API_KEY
// 2. 这样可以避免再维护一层 “env 名称配置 -> 再读环境变量” 的间接映射,减少配置分叉;
// 3. 若后续真的需要多套 embedding 凭据,再显式设计独立字段,而不是继续隐式透传 env 名称。
apiKey := strings.TrimSpace(os.Getenv("ARK_API_KEY"))
if apiKey == "" {
return nil, fmt.Errorf("rag embed api key is empty: env=%s", cfg.EmbedAPIKeyEnv)
return nil, fmt.Errorf("rag embed api key is empty: env=%s", "ARK_API_KEY")
}
return ragembed.NewEinoEmbedder(ctx, ragembed.EinoConfig{
APIKey: apiKey,

View File

@@ -21,7 +21,15 @@ func NewVectorRetriever(embedder core.Embedder, store core.VectorStore) *VectorR
}
}
func (r *VectorRetriever) Retrieve(ctx context.Context, req core.RetrieveRequest) ([]core.ScoredChunk, error) {
func (r *VectorRetriever) Retrieve(ctx context.Context, req core.RetrieveRequest) (result []core.ScoredChunk, err error) {
defer func() {
recovered := recover()
if recovered == nil {
return
}
err = fmt.Errorf("vector retriever panic recovered: %v", recovered)
}()
if r == nil || r.embedder == nil || r.store == nil {
return nil, core.ErrNilDependency
}
@@ -55,7 +63,7 @@ func (r *VectorRetriever) Retrieve(ctx context.Context, req core.RetrieveRequest
return nil, err
}
result := make([]core.ScoredChunk, 0, len(rows))
result = make([]core.ScoredChunk, 0, len(rows))
for _, row := range rows {
if row.Score < req.Threshold {
continue

View File

@@ -3,6 +3,7 @@ package rag
import (
"context"
"fmt"
"runtime/debug"
"strings"
"time"
@@ -33,7 +34,9 @@ func newRuntime(cfg ragconfig.Config, pipeline *core.Pipeline, observer Observer
}
// IngestMemory 统一承接记忆语料入库。
func (r *runtime) IngestMemory(ctx context.Context, req MemoryIngestRequest) (*IngestResult, error) {
func (r *runtime) IngestMemory(ctx context.Context, req MemoryIngestRequest) (result *IngestResult, err error) {
defer r.recoverPublicPanic(ctx, req.TraceID, "memory", normalizeAction(req.Action, "add"), "ingest", &err)
items := make([]corpus.MemoryIngestItem, 0, len(req.Items))
for _, item := range req.Items {
items = append(items, corpus.MemoryIngestItem{
@@ -58,7 +61,9 @@ func (r *runtime) IngestMemory(ctx context.Context, req MemoryIngestRequest) (*I
}
// RetrieveMemory 统一承接记忆语料检索。
func (r *runtime) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (*RetrieveResult, error) {
func (r *runtime) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (result *RetrieveResult, err error) {
defer r.recoverPublicPanic(ctx, req.TraceID, "memory", normalizeAction(req.Action, "search"), "retrieve", &err)
corpusInput := corpus.MemoryRetrieveInput{
UserID: req.UserID,
ConversationID: req.ConversationID,
@@ -69,7 +74,7 @@ func (r *runtime) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest)
corpusInput.MemoryType = req.MemoryTypes[0]
}
result, err := r.retrieveWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, core.RetrieveRequest{
result, err = r.retrieveWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, core.RetrieveRequest{
Query: req.Query,
TopK: normalizeTopK(req.TopK, r.cfg.TopK),
Threshold: normalizeThreshold(req.Threshold, r.cfg.Threshold),
@@ -113,7 +118,9 @@ func (r *runtime) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest)
}
// IngestWeb 统一承接网页语料入库。
func (r *runtime) IngestWeb(ctx context.Context, req WebIngestRequest) (*IngestResult, error) {
func (r *runtime) IngestWeb(ctx context.Context, req WebIngestRequest) (result *IngestResult, err error) {
defer r.recoverPublicPanic(ctx, req.TraceID, "web", normalizeAction(req.Action, "add"), "ingest", &err)
items := make([]corpus.WebIngestItem, 0, len(req.Items))
for _, item := range req.Items {
items = append(items, corpus.WebIngestItem{
@@ -133,7 +140,9 @@ func (r *runtime) IngestWeb(ctx context.Context, req WebIngestRequest) (*IngestR
}
// RetrieveWeb 统一承接网页语料检索。
func (r *runtime) RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (*RetrieveResult, error) {
func (r *runtime) RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (result *RetrieveResult, err error) {
defer r.recoverPublicPanic(ctx, req.TraceID, "web", normalizeAction(req.Action, "search"), "retrieve", &err)
return r.retrieveWithCorpus(ctx, req.TraceID, "web", r.webCorpus, core.RetrieveRequest{
Query: req.Query,
TopK: normalizeTopK(req.TopK, r.cfg.TopK),
@@ -311,6 +320,41 @@ func (r *runtime) observe(ctx context.Context, event ObserveEvent) {
r.observer.Observe(ctx, event)
}
func (r *runtime) recoverPublicPanic(
ctx context.Context,
traceID string,
corpusName string,
action string,
operation string,
errPtr *error,
) {
recovered := recover()
if recovered == nil || errPtr == nil {
return
}
// 1. runtime 是 RAG Infra 对业务侧暴露的最终方法面,任何下层 panic 都不应再穿透到业务协程。
// 2. 这里统一把 panic 转成 error并补一条结构化观测方便继续排查是哪一层依赖失控。
// 3. 保留 stack 是为了在“进程不崩”的前提下仍能定位根因,避免只剩一句 recovered 无法复盘。
panicErr := fmt.Errorf("rag runtime panic recovered: corpus=%s operation=%s panic=%v", corpusName, operation, recovered)
*errPtr = panicErr
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelError,
Component: "runtime",
Operation: operation + "_panic_recovered",
Fields: map[string]any{
"status": "failed",
"panic": fmt.Sprintf("%v", recovered),
"panic_type": fmt.Sprintf("%T", recovered),
"error": panicErr,
"error_code": core.ClassifyErrorCode(panicErr),
"stack": string(debug.Stack()),
},
})
}
func newObserveContext(ctx context.Context, traceID string, corpusName string, action string) context.Context {
fields := map[string]any{
"corpus": corpusName,

View File

@@ -2,6 +2,7 @@ package store
import (
"context"
"errors"
"fmt"
"math"
"sort"
@@ -29,12 +30,18 @@ func NewInMemoryVectorStore() *InMemoryVectorStore {
}
func (s *InMemoryVectorStore) Upsert(_ context.Context, rows []core.VectorRow) error {
if s == nil {
return errors.New("inmemory vector store is nil")
}
if len(rows) == 0 {
return nil
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
if s.rows == nil {
s.rows = make(map[string]core.VectorRow)
}
for _, row := range rows {
current, exists := s.rows[row.ID]
if exists {
@@ -52,6 +59,9 @@ func (s *InMemoryVectorStore) Upsert(_ context.Context, rows []core.VectorRow) e
}
func (s *InMemoryVectorStore) Search(_ context.Context, req core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
if s == nil {
return nil, errors.New("inmemory vector store is nil")
}
topK := req.TopK
if topK <= 0 {
topK = 8
@@ -81,6 +91,9 @@ func (s *InMemoryVectorStore) Search(_ context.Context, req core.VectorSearchReq
}
func (s *InMemoryVectorStore) Delete(_ context.Context, ids []string) error {
if s == nil {
return errors.New("inmemory vector store is nil")
}
if len(ids) == 0 {
return nil
}
@@ -93,6 +106,9 @@ func (s *InMemoryVectorStore) Delete(_ context.Context, ids []string) error {
}
func (s *InMemoryVectorStore) Get(_ context.Context, ids []string) ([]core.VectorRow, error) {
if s == nil {
return nil, errors.New("inmemory vector store is nil")
}
if len(ids) == 0 {
return nil, nil
}

View File

@@ -108,6 +108,9 @@ func NewMilvusStore(cfg MilvusConfig) (*MilvusStore, error) {
}
func (s *MilvusStore) Upsert(ctx context.Context, rows []core.VectorRow) error {
if err := s.ensureReady(); err != nil {
return err
}
start := time.Now()
if len(rows) == 0 {
return nil
@@ -171,6 +174,9 @@ func (s *MilvusStore) Upsert(ctx context.Context, rows []core.VectorRow) error {
}
func (s *MilvusStore) Search(ctx context.Context, req core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
if err := s.ensureReady(); err != nil {
return nil, err
}
start := time.Now()
if len(req.QueryVector) == 0 {
return nil, nil
@@ -314,6 +320,9 @@ func (s *MilvusStore) Search(ctx context.Context, req core.VectorSearchRequest)
}
func (s *MilvusStore) Delete(ctx context.Context, ids []string) error {
if err := s.ensureReady(); err != nil {
return err
}
start := time.Now()
if len(ids) == 0 {
return nil
@@ -356,6 +365,9 @@ func (s *MilvusStore) Delete(ctx context.Context, ids []string) error {
}
func (s *MilvusStore) Get(ctx context.Context, ids []string) ([]core.VectorRow, error) {
if err := s.ensureReady(); err != nil {
return nil, err
}
start := time.Now()
if len(ids) == 0 {
return nil, nil
@@ -438,6 +450,9 @@ func (s *MilvusStore) Get(ctx context.Context, ids []string) ([]core.VectorRow,
}
func (s *MilvusStore) ensureCollection(ctx context.Context, dimension int) error {
if err := s.ensureReady(); err != nil {
return err
}
start := time.Now()
if dimension <= 0 {
dimension = s.cfg.Dimension
@@ -537,6 +552,9 @@ func (s *MilvusStore) ensureCollection(ctx context.Context, dimension int) error
}
func (s *MilvusStore) postJSON(ctx context.Context, path string, payload map[string]any) ([]byte, error) {
if err := s.ensureReady(); err != nil {
return nil, err
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
@@ -577,6 +595,19 @@ func (s *MilvusStore) postJSON(ctx context.Context, path string, payload map[str
return respBody, nil
}
func (s *MilvusStore) ensureReady() error {
if s == nil || s.client == nil {
return errors.New("milvus store is not initialized")
}
if strings.TrimSpace(s.cfg.Address) == "" {
return errors.New("milvus address is empty")
}
if strings.TrimSpace(s.cfg.CollectionName) == "" {
return errors.New("milvus collection name is empty")
}
return nil
}
func (s *MilvusStore) observe(ctx context.Context, event core.ObserveEvent) {
if s == nil || s.observer == nil {
return