diff --git a/backend/config.example.yaml b/backend/config.example.yaml index 88b2e59..aaa7c83 100644 --- a/backend/config.example.yaml +++ b/backend/config.example.yaml @@ -49,15 +49,15 @@ time: semesterEndDate: "2026-07-19" #学期结束日期,一定要设定为周日,确保最后一周完整 agent: - workerModel: "doubao-seed-1-6-lite-251015" # 智能体使用的Worker模型,需根据实际情况调整 - strategistModel: "deepseek-v3-2-251201" # 策略师使用的Worker模型,需根据实际情况调整 + workerModel: "doubao-seed-2-0-code-preview-260215" # 智能体使用的Worker模型,需根据实际情况调整 + strategistModel: "doubao-seed-2-0-code-preview-260215" # 策略师使用的Worker模型,需根据实际情况调整 baseURL: "https://ark.cn-beijing.volces.com/api/v3" # Worker服务的基础URL,需根据实际情况调整 - dailyRefineConcurrency: 3 # 日内并发优化并发度,建议按模型配额调整 + dailyRefineConcurrency: 7 # 日内并发优化并发度,建议按模型配额调整 weeklyAdjustBudget: 5 # 周级跨天配平额度上限,防止过度调整 rag: - enabled: false - store: "inmemory" # 可选:inmemory / milvus + enabled: true + store: "milvus" # 可选:inmemory / milvus topK: 8 threshold: 0.55 retrieve: @@ -66,16 +66,14 @@ rag: chunkSize: 400 chunkOverlap: 80 embed: - provider: "mock" # 可选:mock / eino - model: "" # 例如 Ark/OpenAI 兼容 embedding 模型名 - baseURL: "https://ark.cn-beijing.volces.com/api/v3" - apiKeyEnv: "ARK_API_KEY" + provider: "eino" # 可选:mock / eino + model: "doubao-embedding-vision-251215" # 例如 Ark/OpenAI 兼容 embedding 模型名 + baseURL: "https://ark.cn-beijing.volces.com/api/v3" # 这里填服务根路径,SDK 会自动拼接 /embeddings;API Key 统一从环境变量 ARK_API_KEY 读取 timeoutMs: 1200 dimension: 1024 reranker: enabled: false provider: "noop" # 当前默认 noop,后续可扩展 - timeoutMs: 1200 milvus: address: "http://localhost:19530" # Milvus REST 入口,当前联调确认不要填 9091 健康检查口 token: "root:Milvus" @@ -87,7 +85,7 @@ rag: memory: enabled: true rag: - enabled: false + enabled: true prompt: extract: "" decision: "" @@ -103,7 +101,7 @@ memory: claimBatch: 1 websearch: - provider: mock # 可选:mock | bocha(mock 为空实现,跑通链路用) + provider: bocha # 可选:mock | bocha(mock 为空实现,跑通链路用) apiKey: "" # 搜索供应商 API Key(bocha 模式必填,否则降级为 mock) timeout: 10s # 单次搜索请求超时 fetchTimeout: 15s # 单次 URL 抓取超时 diff --git a/backend/infra/rag/config/config.go b/backend/infra/rag/config/config.go index 378a532..9277ec1 100644 --- a/backend/infra/rag/config/config.go +++ b/backend/infra/rag/config/config.go @@ -13,7 +13,6 @@ type Config struct { EmbedProvider string EmbedModel string EmbedBaseURL string - EmbedAPIKeyEnv string EmbedTimeoutMS int EmbedDimension int @@ -44,7 +43,6 @@ func LoadFromViper() Config { EmbedProvider: viper.GetString("rag.embed.provider"), EmbedModel: viper.GetString("rag.embed.model"), EmbedBaseURL: viper.GetString("rag.embed.baseURL"), - EmbedAPIKeyEnv: viper.GetString("rag.embed.apiKeyEnv"), EmbedTimeoutMS: viper.GetInt("rag.embed.timeoutMs"), EmbedDimension: viper.GetInt("rag.embed.dimension"), RerankerEnabled: viper.GetBool("rag.reranker.enabled"), @@ -75,9 +73,6 @@ func LoadFromViper() Config { if cfg.EmbedBaseURL == "" { cfg.EmbedBaseURL = viper.GetString("agent.baseURL") } - if cfg.EmbedAPIKeyEnv == "" { - cfg.EmbedAPIKeyEnv = "ARK_API_KEY" - } if cfg.EmbedTimeoutMS <= 0 { cfg.EmbedTimeoutMS = 1200 } diff --git a/backend/infra/rag/core/pipeline.go b/backend/infra/rag/core/pipeline.go index 68bb21e..52e5745 100644 --- a/backend/infra/rag/core/pipeline.go +++ b/backend/infra/rag/core/pipeline.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log" + "runtime/debug" "strings" "time" ) @@ -69,7 +70,9 @@ func (p *Pipeline) Ingest( corpus CorpusAdapter, input any, opt IngestOption, -) (*IngestResult, error) { +) (result *IngestResult, err error) { + defer p.recoverExecutionPanic(ctx, "ingest", &err) + if p == nil || p.chunker == nil || p.embedder == nil || p.store == nil { return nil, ErrNilDependency } @@ -95,7 +98,9 @@ func (p *Pipeline) IngestDocuments( corpusName string, docs []SourceDocument, opt IngestOption, -) (*IngestResult, error) { +) (result *IngestResult, err error) { + defer p.recoverExecutionPanic(ctx, "ingest_documents", &err) + if p == nil || p.chunker == nil || p.embedder == nil || p.store == nil { return nil, ErrNilDependency } @@ -170,7 +175,9 @@ func (p *Pipeline) Retrieve( ctx context.Context, corpus CorpusAdapter, req RetrieveRequest, -) (*RetrieveResult, error) { +) (result *RetrieveResult, err error) { + defer p.recoverExecutionPanic(ctx, "retrieve", &err) + if p == nil || p.embedder == nil || p.store == nil { return nil, ErrNilDependency } @@ -236,7 +243,7 @@ func (p *Pipeline) Retrieve( }) } - result := &RetrieveResult{ + result = &RetrieveResult{ Items: candidates, RawCount: rawCount, FallbackUsed: false, @@ -273,6 +280,39 @@ func (p *Pipeline) Retrieve( return result, nil } +func (p *Pipeline) recoverExecutionPanic(ctx context.Context, operation string, errPtr *error) { + recovered := recover() + if recovered == nil || errPtr == nil { + return + } + + panicErr := fmt.Errorf("rag pipeline panic recovered: operation=%s panic=%v", operation, recovered) + *errPtr = panicErr + + // 1. Pipeline 是 chunk/embed/store/rerank 的统一编排边界,第三方依赖异常不应直接杀掉上层请求。 + // 2. 这里统一 recover 后继续走 error 语义,让 runtime/service 决定降级、回退或记日志。 + // 3. stack 只写观测层,不塞进返回值,避免把超长堆栈直接暴露给上层业务错误文案。 + if p != nil && p.observer != nil { + p.observer.Observe(ctx, ObserveEvent{ + Level: ObserveLevelError, + Component: "pipeline", + Operation: operation + "_panic_recovered", + Fields: map[string]any{ + "status": "failed", + "panic": fmt.Sprintf("%v", recovered), + "panic_type": fmt.Sprintf("%T", recovered), + "error": panicErr, + "error_code": ClassifyErrorCode(panicErr), + "stack": string(debug.Stack()), + }, + }) + return + } + if p != nil && p.logger != nil { + p.logger.Printf("rag pipeline panic recovered: operation=%s panic=%v stack=%s", operation, recovered, string(debug.Stack())) + } +} + func normalizeChunkOption(opt ChunkOption) ChunkOption { if opt.ChunkSize <= 0 { opt.ChunkSize = defaultChunkSize diff --git a/backend/infra/rag/embed/eino_embedder.go b/backend/infra/rag/embed/eino_embedder.go index 7505673..4587b07 100644 --- a/backend/infra/rag/embed/eino_embedder.go +++ b/backend/infra/rag/embed/eino_embedder.go @@ -3,11 +3,15 @@ package embed import ( "context" "errors" + "fmt" + "net/http" "strings" "time" openaiembedding "github.com/cloudwego/eino-ext/libs/acl/openai" einoembedding "github.com/cloudwego/eino/components/embedding" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime" + arkmodel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" ) // EinoConfig 描述 Eino embedding 运行参数。 @@ -22,14 +26,15 @@ type EinoConfig struct { // EinoEmbedder 是基于 Eino 的 embedding 适配器。 // // 说明: -// 1. 对 infra/rag 暴露统一 []float32 结果,屏蔽 Eino/OpenAI 兼容实现细节; -// 2. 超时由该适配器自身收口,避免业务侧每次调用都手写超时控制; -// 3. 当前底层走 Eino Ext 的 OpenAI 兼容 embedding client,便于接 Ark/OpenAI 兼容接口。 +// 1. 对 infra/rag 暴露统一 []float32 结果,屏蔽底层 SDK 的实现差异。 +// 2. 文本 embedding 继续走当前稳定的 OpenAI 兼容链路,避免无关模型受影响。 +// 3. 多模态 embedding 模型单独走 Ark 原生 `/embeddings/multimodal`,解决 vision 模型与标准 `/embeddings` 不兼容的问题。 type EinoEmbedder struct { - client einoembedding.Embedder - model string - timeout time.Duration - dimension int + textClient einoembedding.Embedder + multimodalClient *arkruntime.Client + model string + timeout time.Duration + dimension int } func NewEinoEmbedder(ctx context.Context, cfg EinoConfig) (*EinoEmbedder, error) { @@ -40,10 +45,42 @@ func NewEinoEmbedder(ctx context.Context, cfg EinoConfig) (*EinoEmbedder, error) return nil, errors.New("eino embedder model is empty") } + timeout := 1200 * time.Millisecond + if cfg.TimeoutMS > 0 { + timeout = time.Duration(cfg.TimeoutMS) * time.Millisecond + } + + baseURL := normalizeEmbeddingBaseURL(cfg.BaseURL) + model := strings.TrimSpace(cfg.Model) + httpClient := &http.Client{Timeout: timeout} + + // 1. `doubao-embedding-vision-*` 这类模型不支持标准 `/embeddings`。 + // 2. 这里直接切到 Ark 原生多模态 embedding API,避免再依赖错误 endpoint 拼接。 + // 3. 之所以仍保留文本链路,是为了不影响普通 text embedding 模型的既有行为。 + if isMultimodalEmbeddingModel(model) { + arkOptions := []arkruntime.ConfigOption{ + arkruntime.WithHTTPClient(httpClient), + } + if baseURL != "" { + arkOptions = append(arkOptions, arkruntime.WithBaseUrl(baseURL)) + } + + return &EinoEmbedder{ + multimodalClient: arkruntime.NewClientWithApiKey( + strings.TrimSpace(cfg.APIKey), + arkOptions..., + ), + model: model, + timeout: timeout, + dimension: cfg.Dimension, + }, nil + } + clientCfg := &openaiembedding.EmbeddingConfig{ - APIKey: strings.TrimSpace(cfg.APIKey), - BaseURL: strings.TrimSpace(cfg.BaseURL), - Model: strings.TrimSpace(cfg.Model), + APIKey: strings.TrimSpace(cfg.APIKey), + BaseURL: baseURL, + Model: model, + HTTPClient: httpClient, } if cfg.Dimension > 0 { clientCfg.Dimensions = &cfg.Dimension @@ -54,21 +91,16 @@ func NewEinoEmbedder(ctx context.Context, cfg EinoConfig) (*EinoEmbedder, error) return nil, err } - timeout := 1200 * time.Millisecond - if cfg.TimeoutMS > 0 { - timeout = time.Duration(cfg.TimeoutMS) * time.Millisecond - } - return &EinoEmbedder{ - client: client, - model: strings.TrimSpace(cfg.Model), - timeout: timeout, - dimension: cfg.Dimension, + textClient: client, + model: model, + timeout: timeout, + dimension: cfg.Dimension, }, nil } -func (e *EinoEmbedder) Embed(ctx context.Context, texts []string, _ string) ([][]float32, error) { - if e == nil || e.client == nil { +func (e *EinoEmbedder) Embed(ctx context.Context, texts []string, _ string) (result [][]float32, err error) { + if e == nil { return nil, errors.New("eino embedder is not initialized") } if len(texts) == 0 { @@ -82,12 +114,29 @@ func (e *EinoEmbedder) Embed(ctx context.Context, texts []string, _ string) ([][ } defer cancel() - vectors, err := e.client.EmbedStrings(callCtx, texts, einoembedding.WithModel(e.model)) + // 1. 第三方 SDK 一旦 panic,不应该穿透到 RAG 主链路。 + // 2. 这里统一在模型调用边界 recover,并转成 error 交给上层做降级。 + // 3. 这样 memory 主写链路和 agent 主回复链路都不会因为向量同步失败被直接打崩。 + defer func() { + if recovered := recover(); recovered != nil { + err = fmt.Errorf("eino embedder panic recovered: %v", recovered) + result = nil + } + }() + + if e.multimodalClient != nil { + return e.embedTextsWithMultimodalAPI(callCtx, texts) + } + if e.textClient == nil { + return nil, errors.New("eino embedder client is not initialized") + } + + vectors, err := e.textClient.EmbedStrings(callCtx, texts, einoembedding.WithModel(e.model)) if err != nil { return nil, err } - result := make([][]float32, 0, len(vectors)) + result = make([][]float32, 0, len(vectors)) for _, vector := range vectors { converted := make([]float32, len(vector)) for i, value := range vector { @@ -97,3 +146,63 @@ func (e *EinoEmbedder) Embed(ctx context.Context, texts []string, _ string) ([][ } return result, nil } + +func (e *EinoEmbedder) embedTextsWithMultimodalAPI(ctx context.Context, texts []string) ([][]float32, error) { + if e.multimodalClient == nil { + return nil, errors.New("eino multimodal embedder client is not initialized") + } + + vectors := make([][]float32, 0, len(texts)) + for _, text := range texts { + text := text + req := arkmodel.MultiModalEmbeddingRequest{ + Model: e.model, + Input: []arkmodel.MultimodalEmbeddingInput{ + { + Type: arkmodel.MultiModalEmbeddingInputTypeText, + Text: &text, + }, + }, + } + if e.dimension > 0 { + req.Dimensions = &e.dimension + } + + // 1. Ark 的多模态 embedding 请求体是“单条内容由多个 part 组成”。 + // 2. 当前 RAG 这里只传文本,因此每段文本单独发一次,避免把多段文本错误拼成同一个 multimodal sample。 + // 3. 一旦后续真的要做批量多模态 embedding,再单独扩展 batch 接口,而不是在这里偷改语义。 + resp, err := e.multimodalClient.CreateMultiModalEmbeddings(ctx, req) + if err != nil { + return nil, err + } + + converted := make([]float32, len(resp.Data.Embedding)) + copy(converted, resp.Data.Embedding) + vectors = append(vectors, converted) + } + return vectors, nil +} + +func isMultimodalEmbeddingModel(model string) bool { + return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "doubao-embedding-vision-") +} + +func normalizeEmbeddingBaseURL(raw string) string { + baseURL := strings.TrimRight(strings.TrimSpace(raw), "/") + if baseURL == "" { + return "" + } + + lowerBaseURL := strings.ToLower(baseURL) + + // 1. 配置里应填写 Ark 服务根路径,而不是具体 embedding endpoint。 + // 2. 这里兼容两类常见误配:`/embeddings` 和 `/embeddings/multimodal`。 + // 3. 统一回退到 `/api/v3` 根路径后,再由对应 SDK 自己追加正确后缀,避免最终 URL 重复拼接。 + if strings.HasSuffix(lowerBaseURL, "/embeddings/multimodal") { + return strings.TrimSuffix(baseURL, baseURL[len(baseURL)-len("/embeddings/multimodal"):]) + } + if strings.HasSuffix(lowerBaseURL, "/embeddings") { + return strings.TrimSuffix(baseURL, baseURL[len(baseURL)-len("/embeddings"):]) + } + return baseURL +} diff --git a/backend/infra/rag/factory.go b/backend/infra/rag/factory.go index 2a64573..c1e8eb0 100644 --- a/backend/infra/rag/factory.go +++ b/backend/infra/rag/factory.go @@ -73,9 +73,12 @@ func buildEmbedder(ctx context.Context, cfg ragconfig.Config) (core.Embedder, er case "", "mock": return ragembed.NewMockEmbedder(cfg.EmbedDimension), nil case "eino": - apiKey := strings.TrimSpace(os.Getenv(cfg.EmbedAPIKeyEnv)) + // 1. RAG embedding 与普通 LLM 链路保持同一套密钥来源,统一直接读取 ARK_API_KEY; + // 2. 这样可以避免再维护一层 “env 名称配置 -> 再读环境变量” 的间接映射,减少配置分叉; + // 3. 若后续真的需要多套 embedding 凭据,再显式设计独立字段,而不是继续隐式透传 env 名称。 + apiKey := strings.TrimSpace(os.Getenv("ARK_API_KEY")) if apiKey == "" { - return nil, fmt.Errorf("rag embed api key is empty: env=%s", cfg.EmbedAPIKeyEnv) + return nil, fmt.Errorf("rag embed api key is empty: env=%s", "ARK_API_KEY") } return ragembed.NewEinoEmbedder(ctx, ragembed.EinoConfig{ APIKey: apiKey, diff --git a/backend/infra/rag/retrieve/vector_retriever.go b/backend/infra/rag/retrieve/vector_retriever.go index c04d6c1..0d00950 100644 --- a/backend/infra/rag/retrieve/vector_retriever.go +++ b/backend/infra/rag/retrieve/vector_retriever.go @@ -21,7 +21,15 @@ func NewVectorRetriever(embedder core.Embedder, store core.VectorStore) *VectorR } } -func (r *VectorRetriever) Retrieve(ctx context.Context, req core.RetrieveRequest) ([]core.ScoredChunk, error) { +func (r *VectorRetriever) Retrieve(ctx context.Context, req core.RetrieveRequest) (result []core.ScoredChunk, err error) { + defer func() { + recovered := recover() + if recovered == nil { + return + } + err = fmt.Errorf("vector retriever panic recovered: %v", recovered) + }() + if r == nil || r.embedder == nil || r.store == nil { return nil, core.ErrNilDependency } @@ -55,7 +63,7 @@ func (r *VectorRetriever) Retrieve(ctx context.Context, req core.RetrieveRequest return nil, err } - result := make([]core.ScoredChunk, 0, len(rows)) + result = make([]core.ScoredChunk, 0, len(rows)) for _, row := range rows { if row.Score < req.Threshold { continue diff --git a/backend/infra/rag/runtime.go b/backend/infra/rag/runtime.go index 830c90c..b995ecc 100644 --- a/backend/infra/rag/runtime.go +++ b/backend/infra/rag/runtime.go @@ -3,6 +3,7 @@ package rag import ( "context" "fmt" + "runtime/debug" "strings" "time" @@ -33,7 +34,9 @@ func newRuntime(cfg ragconfig.Config, pipeline *core.Pipeline, observer Observer } // IngestMemory 统一承接记忆语料入库。 -func (r *runtime) IngestMemory(ctx context.Context, req MemoryIngestRequest) (*IngestResult, error) { +func (r *runtime) IngestMemory(ctx context.Context, req MemoryIngestRequest) (result *IngestResult, err error) { + defer r.recoverPublicPanic(ctx, req.TraceID, "memory", normalizeAction(req.Action, "add"), "ingest", &err) + items := make([]corpus.MemoryIngestItem, 0, len(req.Items)) for _, item := range req.Items { items = append(items, corpus.MemoryIngestItem{ @@ -58,7 +61,9 @@ func (r *runtime) IngestMemory(ctx context.Context, req MemoryIngestRequest) (*I } // RetrieveMemory 统一承接记忆语料检索。 -func (r *runtime) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (*RetrieveResult, error) { +func (r *runtime) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (result *RetrieveResult, err error) { + defer r.recoverPublicPanic(ctx, req.TraceID, "memory", normalizeAction(req.Action, "search"), "retrieve", &err) + corpusInput := corpus.MemoryRetrieveInput{ UserID: req.UserID, ConversationID: req.ConversationID, @@ -69,7 +74,7 @@ func (r *runtime) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) corpusInput.MemoryType = req.MemoryTypes[0] } - result, err := r.retrieveWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, core.RetrieveRequest{ + result, err = r.retrieveWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, core.RetrieveRequest{ Query: req.Query, TopK: normalizeTopK(req.TopK, r.cfg.TopK), Threshold: normalizeThreshold(req.Threshold, r.cfg.Threshold), @@ -113,7 +118,9 @@ func (r *runtime) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) } // IngestWeb 统一承接网页语料入库。 -func (r *runtime) IngestWeb(ctx context.Context, req WebIngestRequest) (*IngestResult, error) { +func (r *runtime) IngestWeb(ctx context.Context, req WebIngestRequest) (result *IngestResult, err error) { + defer r.recoverPublicPanic(ctx, req.TraceID, "web", normalizeAction(req.Action, "add"), "ingest", &err) + items := make([]corpus.WebIngestItem, 0, len(req.Items)) for _, item := range req.Items { items = append(items, corpus.WebIngestItem{ @@ -133,7 +140,9 @@ func (r *runtime) IngestWeb(ctx context.Context, req WebIngestRequest) (*IngestR } // RetrieveWeb 统一承接网页语料检索。 -func (r *runtime) RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (*RetrieveResult, error) { +func (r *runtime) RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (result *RetrieveResult, err error) { + defer r.recoverPublicPanic(ctx, req.TraceID, "web", normalizeAction(req.Action, "search"), "retrieve", &err) + return r.retrieveWithCorpus(ctx, req.TraceID, "web", r.webCorpus, core.RetrieveRequest{ Query: req.Query, TopK: normalizeTopK(req.TopK, r.cfg.TopK), @@ -311,6 +320,41 @@ func (r *runtime) observe(ctx context.Context, event ObserveEvent) { r.observer.Observe(ctx, event) } +func (r *runtime) recoverPublicPanic( + ctx context.Context, + traceID string, + corpusName string, + action string, + operation string, + errPtr *error, +) { + recovered := recover() + if recovered == nil || errPtr == nil { + return + } + + // 1. runtime 是 RAG Infra 对业务侧暴露的最终方法面,任何下层 panic 都不应再穿透到业务协程。 + // 2. 这里统一把 panic 转成 error,并补一条结构化观测,方便继续排查是哪一层依赖失控。 + // 3. 保留 stack 是为了在“进程不崩”的前提下仍能定位根因,避免只剩一句 recovered 无法复盘。 + panicErr := fmt.Errorf("rag runtime panic recovered: corpus=%s operation=%s panic=%v", corpusName, operation, recovered) + *errPtr = panicErr + + observeCtx := newObserveContext(ctx, traceID, corpusName, action) + r.observe(observeCtx, ObserveEvent{ + Level: ObserveLevelError, + Component: "runtime", + Operation: operation + "_panic_recovered", + Fields: map[string]any{ + "status": "failed", + "panic": fmt.Sprintf("%v", recovered), + "panic_type": fmt.Sprintf("%T", recovered), + "error": panicErr, + "error_code": core.ClassifyErrorCode(panicErr), + "stack": string(debug.Stack()), + }, + }) +} + func newObserveContext(ctx context.Context, traceID string, corpusName string, action string) context.Context { fields := map[string]any{ "corpus": corpusName, diff --git a/backend/infra/rag/store/inmemory_store.go b/backend/infra/rag/store/inmemory_store.go index 25c5ac8..f15ba28 100644 --- a/backend/infra/rag/store/inmemory_store.go +++ b/backend/infra/rag/store/inmemory_store.go @@ -2,6 +2,7 @@ package store import ( "context" + "errors" "fmt" "math" "sort" @@ -29,12 +30,18 @@ func NewInMemoryVectorStore() *InMemoryVectorStore { } func (s *InMemoryVectorStore) Upsert(_ context.Context, rows []core.VectorRow) error { + if s == nil { + return errors.New("inmemory vector store is nil") + } if len(rows) == 0 { return nil } now := time.Now() s.mu.Lock() defer s.mu.Unlock() + if s.rows == nil { + s.rows = make(map[string]core.VectorRow) + } for _, row := range rows { current, exists := s.rows[row.ID] if exists { @@ -52,6 +59,9 @@ func (s *InMemoryVectorStore) Upsert(_ context.Context, rows []core.VectorRow) e } func (s *InMemoryVectorStore) Search(_ context.Context, req core.VectorSearchRequest) ([]core.ScoredVectorRow, error) { + if s == nil { + return nil, errors.New("inmemory vector store is nil") + } topK := req.TopK if topK <= 0 { topK = 8 @@ -81,6 +91,9 @@ func (s *InMemoryVectorStore) Search(_ context.Context, req core.VectorSearchReq } func (s *InMemoryVectorStore) Delete(_ context.Context, ids []string) error { + if s == nil { + return errors.New("inmemory vector store is nil") + } if len(ids) == 0 { return nil } @@ -93,6 +106,9 @@ func (s *InMemoryVectorStore) Delete(_ context.Context, ids []string) error { } func (s *InMemoryVectorStore) Get(_ context.Context, ids []string) ([]core.VectorRow, error) { + if s == nil { + return nil, errors.New("inmemory vector store is nil") + } if len(ids) == 0 { return nil, nil } diff --git a/backend/infra/rag/store/milvus_store.go b/backend/infra/rag/store/milvus_store.go index fa914a4..61f9782 100644 --- a/backend/infra/rag/store/milvus_store.go +++ b/backend/infra/rag/store/milvus_store.go @@ -108,6 +108,9 @@ func NewMilvusStore(cfg MilvusConfig) (*MilvusStore, error) { } func (s *MilvusStore) Upsert(ctx context.Context, rows []core.VectorRow) error { + if err := s.ensureReady(); err != nil { + return err + } start := time.Now() if len(rows) == 0 { return nil @@ -171,6 +174,9 @@ func (s *MilvusStore) Upsert(ctx context.Context, rows []core.VectorRow) error { } func (s *MilvusStore) Search(ctx context.Context, req core.VectorSearchRequest) ([]core.ScoredVectorRow, error) { + if err := s.ensureReady(); err != nil { + return nil, err + } start := time.Now() if len(req.QueryVector) == 0 { return nil, nil @@ -314,6 +320,9 @@ func (s *MilvusStore) Search(ctx context.Context, req core.VectorSearchRequest) } func (s *MilvusStore) Delete(ctx context.Context, ids []string) error { + if err := s.ensureReady(); err != nil { + return err + } start := time.Now() if len(ids) == 0 { return nil @@ -356,6 +365,9 @@ func (s *MilvusStore) Delete(ctx context.Context, ids []string) error { } func (s *MilvusStore) Get(ctx context.Context, ids []string) ([]core.VectorRow, error) { + if err := s.ensureReady(); err != nil { + return nil, err + } start := time.Now() if len(ids) == 0 { return nil, nil @@ -438,6 +450,9 @@ func (s *MilvusStore) Get(ctx context.Context, ids []string) ([]core.VectorRow, } func (s *MilvusStore) ensureCollection(ctx context.Context, dimension int) error { + if err := s.ensureReady(); err != nil { + return err + } start := time.Now() if dimension <= 0 { dimension = s.cfg.Dimension @@ -537,6 +552,9 @@ func (s *MilvusStore) ensureCollection(ctx context.Context, dimension int) error } func (s *MilvusStore) postJSON(ctx context.Context, path string, payload map[string]any) ([]byte, error) { + if err := s.ensureReady(); err != nil { + return nil, err + } body, err := json.Marshal(payload) if err != nil { return nil, err @@ -577,6 +595,19 @@ func (s *MilvusStore) postJSON(ctx context.Context, path string, payload map[str return respBody, nil } +func (s *MilvusStore) ensureReady() error { + if s == nil || s.client == nil { + return errors.New("milvus store is not initialized") + } + if strings.TrimSpace(s.cfg.Address) == "" { + return errors.New("milvus address is empty") + } + if strings.TrimSpace(s.cfg.CollectionName) == "" { + return errors.New("milvus collection name is empty") + } + return nil +} + func (s *MilvusStore) observe(ctx context.Context, event core.ObserveEvent) { if s == nil || s.observer == nil { return