Version: 0.9.14.dev.260410

后端:
  1. LLM 客户端从 newAgent/llm 提升为 infra/llm 基础设施层
     - 删除 backend/newAgent/llm/(ark.go / ark_adapter.go / client.go / json.go)
     - 等价迁移至 backend/infra/llm/,所有 newAgent node 与 service 统一改引用 infrallm
     - 消除 newAgent 对模型客户端的私有依赖,为 memory / websearch 等多模块复用铺路
  2. RAG 基础设施完成可运行态接入(factory / runtime / observer / service 四层成型)
     - 新建 backend/infra/rag/factory.go / runtime.go / observe.go / observer.go /
  service.go:工厂创建、运行时生命周期、轻量观测接口、检索服务门面
     - 更新 infra/rag/config/config.go:补齐 Milvus / Embed / Reranker 全部配置项与默认值
     - 更新 infra/rag/embed/eino_embedder.go:增强 Eino embedding 适配,支持 BaseURL / APIKey 环境变量 / 超时 /
  维度等参数
     - 更新 infra/rag/store/milvus_store.go:完整实现 Milvus 向量存储(建集合 / 建 Index / Upsert / Search /
  Delete),支持 COSINE / L2 / IP 度量
     - 更新 infra/rag/core/pipeline.go:适配 Runtime 接口,Pipeline 由 factory 注入而非手动拼装
     - 更新 infra/rag/corpus/memory_corpus.go / vector_store.go:对接 Memory 模块数据源与 Store 接口扩展
  3. Memory 模块从 Day1 骨架升级为 Day2 完整可运行态
     - 新建 memory/module.go:统一门面 Module,对外封装 EnqueueExtract / ReadService / ManageService / WithTx /
  StartWorker,启动层只依赖这一个入口
     - 新建 memory/orchestrator/llm_write_orchestrator.go:LLM 驱动的记忆抽取编排器,替代原 mock 抽取
     - 新建 memory/service/read_service.go:按用户开关过滤 + 轻量重排 + 访问时间刷新的读取链路
     - 新建 memory/service/manage_service.go:记忆管理面能力(列出 / 软删除 / 开关读写),删除同步写审计日志
     - 新建 memory/service/common.go:服务层公共工具
     - 新建 memory/worker/loop.go:后台轮询循环 RunPollingLoop,定时抢占 pending 任务并推进
     - 新建 memory/utils/audit.go / settings.go:审计日志构造、用户设置过滤等纯函数
     - 更新 memory/model/item.go / job.go / settings.go / config.go / status.go:补齐 DTO 字段与状态常量
     - 更新 memory/repo/item_repo.go / job_repo.go / audit_repo.go / settings_repo.go:补齐 CRUD 与查询能力
     - 更新 memory/worker/runner.go:Runner 对接 Module 与 LLM 抽取器,任务状态机完整化
     - 更新 memory/README.md:同步模块现状说明
  4. newAgent 接入 Memory 读取注入与工具注册依赖预埋
     - 新建 service/agentsvc/agent_memory.go:定义 MemoryReader 接口 + injectMemoryContext,在 graph
  执行前统一补充记忆上下文
     - 更新 service/agentsvc/agent.go:新增 memoryReader 字段与 SetMemoryReader 方法
     - 更新 service/agentsvc/agent_newagent.go:调用 injectMemoryContext 注入 pinned block,检索失败仅降级不阻断主链路
     - 更新 newAgent/tools/registry.go:新增 DefaultRegistryDeps(含 RAGRuntime),工具注册表支持依赖注入
  5. 启动流程与事件处理器接线更新
     - 更新 cmd/start.go:初始化 RAG Runtime → Memory Module → 注册事件处理器 → 启动 Worker 后台轮询
     - 更新 service/events/memory_extract_requested.go:改用 memory.Module.WithTx(tx) 统一门面,事件处理器不再直接依赖
  repo/service 内部包
  6. 缓存插件与配置同步
     - 更新 middleware/cache_deleter.go:静默忽略 MemoryJob / MemoryItem / MemoryAuditLog / MemoryUserSetting
  等新模型,避免日志刷屏;清理冗余注释
     - 更新 config.example.yaml:补齐 rag / memory / websearch 配置段及默认值
     - 更新 go.mod / go.sum:新增 eino-ext/openai / json-patch / go-openai 依赖
  前端:无 仓库:无
This commit is contained in:
Losita
2026-04-10 23:17:38 +08:00
parent fae162162a
commit bf1f1defa5
53 changed files with 5875 additions and 231 deletions

87
backend/infra/llm/ark.go Normal file
View File

@@ -0,0 +1,87 @@
// 过渡期统一 Ark 调用封装。
//
// 这里保留 CallArkText / CallArkJSON方便暂时还直接持有 *ark.ChatModel 的调用点
// 逐步迁移到统一 Client。后续 memory 也可以直接复用这套中立层。
package llm
import (
"context"
"errors"
"strings"
"github.com/cloudwego/eino-ext/components/model/ark"
einoModel "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
)
// ArkCallOptions 是基于 ark.ChatModel 的通用调用选项。
//
// 设计目的:
// 1. 先把 Ark 调用样板抽成公共层;
// 2. 再由 WrapArkClient 提供统一 Client
// 3. 让上层尽量只关注业务 prompt 和结构化结果。
type ArkCallOptions struct {
Temperature float64
MaxTokens int
Thinking ThinkingMode
}
// CallArkText 调用 ark 模型并返回纯文本。
//
// 职责边界:
// 1. 负责拼 system + user 两段消息;
// 2. 负责统一配置 thinking / temperature / maxTokens
// 3. 负责拦截空响应;
// 4. 不负责 JSON 解析,不负责业务字段校验。
func CallArkText(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, options ArkCallOptions) (string, error) {
if chatModel == nil {
return "", errors.New("ark model is nil")
}
messages := []*schema.Message{
schema.SystemMessage(systemPrompt),
schema.UserMessage(userPrompt),
}
resp, err := chatModel.Generate(ctx, messages, buildArkOptions(options)...)
if err != nil {
return "", err
}
if resp == nil {
return "", errors.New("模型返回为空")
}
text := strings.TrimSpace(resp.Content)
if text == "" {
return "", errors.New("模型返回内容为空")
}
return text, nil
}
// CallArkJSON 调用 ark 模型并直接解析 JSON。
func CallArkJSON[T any](ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, options ArkCallOptions) (*T, string, error) {
raw, err := CallArkText(ctx, chatModel, systemPrompt, userPrompt, options)
if err != nil {
return nil, "", err
}
parsed, err := ParseJSONObject[T](raw)
if err != nil {
return nil, raw, err
}
return parsed, raw, nil
}
func buildArkOptions(options ArkCallOptions) []einoModel.Option {
thinkingType := arkModel.ThinkingTypeDisabled
if options.Thinking == ThinkingModeEnabled {
thinkingType = arkModel.ThinkingTypeEnabled
}
opts := []einoModel.Option{
ark.WithThinking(&arkModel.Thinking{Type: thinkingType}),
einoModel.WithTemperature(float32(options.Temperature)),
}
if options.MaxTokens > 0 {
opts = append(opts, einoModel.WithMaxTokens(options.MaxTokens))
}
return opts
}

View File

@@ -0,0 +1,111 @@
package llm
import (
"context"
"errors"
"io"
"github.com/cloudwego/eino-ext/components/model/ark"
einoModel "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
)
// WrapArkClient 将 ark.ChatModel 适配为统一 Client。
//
// 职责边界:
// 1. generateText调用 ark.ChatModel.Generate非流式供 GenerateJSON 使用;
// 2. streamText调用 ark.ChatModel.Stream流式供需要流式输出的场景使用
// 3. 两者共用同一套 options 转换。
func WrapArkClient(arkChatModel *ark.ChatModel) *Client {
if arkChatModel == nil {
return nil
}
// 非流式文本生成,供 GenerateJSON / GenerateText 调用路径使用。
generateFunc := func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error) {
arkOpts := buildArkStreamOptions(options)
msg, err := arkChatModel.Generate(ctx, messages, arkOpts...)
if err != nil {
return nil, err
}
if msg == nil {
return nil, errors.New("ark model returned nil message")
}
return &TextResult{Text: msg.Content}, nil
}
// 流式文本生成。
streamFunc := func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error) {
arkOpts := buildArkStreamOptions(options)
reader, err := arkChatModel.Stream(ctx, messages, arkOpts...)
if err != nil {
return nil, err
}
return &arkStreamReaderAdapter{reader: reader}, nil
}
return NewClient(generateFunc, streamFunc)
}
// buildArkStreamOptions 将统一 GenerateOptions 转换为 ark 的流式调用选项。
func buildArkStreamOptions(options GenerateOptions) []einoModel.Option {
thinkingEnabled := options.Thinking == ThinkingModeEnabled
// Thinking
thinkingType := arkModel.ThinkingTypeDisabled
if thinkingEnabled {
thinkingType = arkModel.ThinkingTypeEnabled
}
opts := []einoModel.Option{
ark.WithThinking(&arkModel.Thinking{Type: thinkingType}),
}
// Temperaturethinking 模型强制要求 temperature=1否则 API 静默忽略 thinking。
if thinkingEnabled {
opts = append(opts, einoModel.WithTemperature(1.0))
} else if options.Temperature > 0 {
opts = append(opts, einoModel.WithTemperature(float32(options.Temperature)))
}
// MaxTokensthinking 模式下 thinking token 占用 max_tokens 预算,
// 调用方设定的值仅代表"期望输出长度",实际预算需留出思考空间。
// 最低保障 16000避免思考链被截断导致输出为空或非 JSON。
maxTokens := options.MaxTokens
if thinkingEnabled {
const minThinkingBudget = 16000
if maxTokens < minThinkingBudget {
maxTokens = minThinkingBudget
}
}
if maxTokens > 0 {
opts = append(opts, einoModel.WithMaxTokens(maxTokens))
}
return opts
}
// arkStreamReaderAdapter 适配 ark.ChatModel.Stream 返回的 reader。
// ark.Stream 返回 schema.StreamReader[*schema.Message],其 Close() 方法无返回值
// 而我们的 StreamReader 接口要求 Close() error
type arkStreamReaderAdapter struct {
reader *schema.StreamReader[*schema.Message]
}
// Recv 转发到 ark reader 的 Recv 方法。
func (r *arkStreamReaderAdapter) Recv() (*schema.Message, error) {
if r == nil || r.reader == nil {
return nil, io.EOF
}
return r.reader.Recv()
}
// Close 转发到 ark reader 的 Close 方法。
// ark 的 Close() 无返回值,我们适配为返回 nil
func (r *arkStreamReaderAdapter) Close() error {
if r == nil || r.reader == nil {
return nil
}
r.reader.Close()
return nil
}

215
backend/infra/llm/client.go Normal file
View File

@@ -0,0 +1,215 @@
package llm
import (
"context"
"errors"
"fmt"
"strings"
"github.com/cloudwego/eino/schema"
)
// ThinkingMode 描述本次模型调用对 thinking 的期望。
//
// 职责边界:
// 1. 这里只表达“调用方希望怎样配置推理模式”;
// 2. 不直接绑定某个具体模型厂商的参数枚举;
// 3. 真正如何把它翻译成 ark / OpenAI / 其他 provider 的 option由后续适配层负责。
type ThinkingMode string
const (
ThinkingModeDefault ThinkingMode = "default"
ThinkingModeEnabled ThinkingMode = "enabled"
ThinkingModeDisabled ThinkingMode = "disabled"
)
// GenerateOptions 是统一模型调用选项。
//
// 设计目的:
// 1. 先把“每个 skill / worker 都会反复传的参数”收敛成一份结构;
// 2. 让上层以后只表达“我要什么”,不再自己重复组织 option
// 3. 暂时不追求覆盖所有 provider 参数,先把最常用的几个公共位抽出来。
type GenerateOptions struct {
Temperature float64
MaxTokens int
Thinking ThinkingMode
Metadata map[string]any
}
// TextResult 是统一文本生成结果。
//
// 职责边界:
// 1. Text 保存模型最终返回的纯文本;
// 2. Usage 保存本次调用的 token 使用量,供后续统一统计;
// 3. 不负责 JSON 解析,不负责业务字段映射。
type TextResult struct {
Text string
Usage *schema.TokenUsage
}
// StreamReader 抽象了“可逐块 Recv 的流式返回器”。
//
// 之所以不直接依赖某个具体 SDK 的 reader 类型,是因为现在还处在骨架收敛阶段,
// 后续接 ark、OpenAI 兼容层还是别的 provider都可以往这个最小接口上适配。
type StreamReader interface {
Recv() (*schema.Message, error)
Close() error
}
// TextGenerateFunc 是文本生成的统一适配函数签名。
type TextGenerateFunc func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error)
// StreamGenerateFunc 是流式生成的统一适配函数签名。
type StreamGenerateFunc func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error)
// Client 是统一模型客户端门面。
//
// 职责边界:
// 1. 负责把调用方的“模型调用意图”收敛到统一入口;
// 2. 负责统一参数校验、空响应防御、GenerateJSON 复用;
// 3. 不负责写 prompt不负责业务 fallback也不直接持有具体厂商 SDK 细节。
type Client struct {
generateText TextGenerateFunc
streamText StreamGenerateFunc
}
// NewClient 创建统一模型客户端。
func NewClient(generateText TextGenerateFunc, streamText StreamGenerateFunc) *Client {
return &Client{
generateText: generateText,
streamText: streamText,
}
}
// GenerateText 执行一次统一文本生成。
//
// 职责边界:
// 1. 负责做最小必要的入参校验;
// 2. 负责统一拦截“模型空响应”这类公共问题;
// 3. 不负责业务 prompt 拼接,也不负责把文本再映射成业务结构。
func (c *Client) GenerateText(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error) {
if c == nil || c.generateText == nil {
return nil, errors.New("llm client is not ready")
}
if len(messages) == 0 {
return nil, errors.New("llm messages is empty")
}
result, err := c.generateText(ctx, messages, options)
if err != nil {
return nil, err
}
if result == nil {
return nil, errors.New("llm result is nil")
}
if strings.TrimSpace(result.Text) == "" {
return nil, errors.New("llm returned empty text")
}
return result, nil
}
// GenerateJSON 先走统一文本生成,再走统一 JSON 解析。
//
// 设计说明:
// 1. 把“Generate -> 提取 JSON -> 反序列化”这段公共链路收敛起来;
// 2. 上层只关心业务结构,不需要重复实现解析样板;
// 3. 返回 parsed + rawResult方便打点与回退时保留原文。
func GenerateJSON[T any](ctx context.Context, client *Client, messages []*schema.Message, options GenerateOptions) (*T, *TextResult, error) {
result, err := client.GenerateText(ctx, messages, options)
if err != nil {
return nil, nil, err
}
parsed, err := ParseJSONObject[T](result.Text)
if err != nil {
return nil, result, err
}
return parsed, result, nil
}
// Stream 打开统一流式调用入口。
//
// 职责边界:
// 1. 只负责把“流式生成能力”暴露给上层;
// 2. 不负责 chunk 到 OpenAI 协议的转换,那部分应放在 stream/
// 3. 不负责累计全文,也不负责 token 统计落库。
func (c *Client) Stream(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error) {
if c == nil || c.streamText == nil {
return nil, errors.New("llm stream client is not ready")
}
if len(messages) == 0 {
return nil, errors.New("llm messages is empty")
}
return c.streamText(ctx, messages, options)
}
// BuildSystemUserMessages 构造最常见的“system + history + user”消息列表。
//
// 设计说明:
// 1. 先把最稳定的消息编排方式沉淀下来,减少各业务域样板代码;
// 2. 只做消息切片装配,不做 prompt 生成;
// 3. 供 agent / memory 等多个能力域复用。
func BuildSystemUserMessages(systemPrompt string, history []*schema.Message, userPrompt string) []*schema.Message {
messages := make([]*schema.Message, 0, len(history)+2)
if strings.TrimSpace(systemPrompt) != "" {
messages = append(messages, schema.SystemMessage(systemPrompt))
}
if len(history) > 0 {
messages = append(messages, history...)
}
if strings.TrimSpace(userPrompt) != "" {
messages = append(messages, schema.UserMessage(userPrompt))
}
return messages
}
// CloneUsage 深拷贝 token usage避免后续多处累加时共享同一指针。
func CloneUsage(usage *schema.TokenUsage) *schema.TokenUsage {
if usage == nil {
return nil
}
copied := *usage
return &copied
}
// MergeUsage 合并两段 usage。
//
// 合并策略:
// 1. 对“同一次调用不同流分片”的场景,取更大值作为最终值;
// 2. 对“多次独立调用累计”的场景,应由上层显式做加法,而不是用这个函数;
// 3. 该函数只适用于“同一次调用的分块 usage 收敛”。
func MergeUsage(base *schema.TokenUsage, incoming *schema.TokenUsage) *schema.TokenUsage {
if incoming == nil {
return CloneUsage(base)
}
if base == nil {
return CloneUsage(incoming)
}
merged := *base
if incoming.PromptTokens > merged.PromptTokens {
merged.PromptTokens = incoming.PromptTokens
}
if incoming.CompletionTokens > merged.CompletionTokens {
merged.CompletionTokens = incoming.CompletionTokens
}
if incoming.TotalTokens > merged.TotalTokens {
merged.TotalTokens = incoming.TotalTokens
}
if incoming.PromptTokenDetails.CachedTokens > merged.PromptTokenDetails.CachedTokens {
merged.PromptTokenDetails.CachedTokens = incoming.PromptTokenDetails.CachedTokens
}
if incoming.CompletionTokensDetails.ReasoningTokens > merged.CompletionTokensDetails.ReasoningTokens {
merged.CompletionTokensDetails.ReasoningTokens = incoming.CompletionTokensDetails.ReasoningTokens
}
return &merged
}
// FormatEmptyResponseError 统一生成“模型返回空结果”的错误文案。
func FormatEmptyResponseError(scene string) error {
scene = strings.TrimSpace(scene)
if scene == "" {
scene = "unknown"
}
return fmt.Errorf("模型在 %s 场景返回空结果", scene)
}

112
backend/infra/llm/json.go Normal file
View File

@@ -0,0 +1,112 @@
package llm
import (
"encoding/json"
"errors"
"fmt"
"strings"
)
// ParseJSONObject 解析模型返回中的 JSON 对象。
//
// 职责边界:
// 1. 负责处理“模型输出前后夹杂解释文字 / markdown 代码块”的常见情况;
// 2. 负责提取最外层 JSON object 并反序列化为目标结构;
// 3. 不负责业务字段合法性校验,应由上层调用方自行校验。
func ParseJSONObject[T any](raw string) (*T, error) {
clean := strings.TrimSpace(raw)
if clean == "" {
return nil, errors.New("模型返回为空,无法解析 JSON")
}
objectText := ExtractJSONObject(clean)
if objectText == "" {
return nil, fmt.Errorf("模型返回中未找到 JSON 对象: %s", truncateForError(clean))
}
var out T
if err := json.Unmarshal([]byte(objectText), &out); err != nil {
return nil, fmt.Errorf("JSON 解析失败: %w", err)
}
return &out, nil
}
// ExtractJSONObject 从混合文本里提取第一个完整 JSON 对象。
//
// 设计说明:
// 1. LLM 很容易输出“这里是结果:{...}”这种半结构化文本;
// 2. 这里用括号计数而不是正则,避免嵌套对象一多就误截断;
// 3. 目前只提取 object不提取 array因为当前契约基本都是对象。
func ExtractJSONObject(text string) string {
clean := trimMarkdownCodeFence(strings.TrimSpace(text))
if clean == "" {
return ""
}
start := strings.Index(clean, "{")
if start < 0 {
return ""
}
depth := 0
inString := false
escaped := false
for idx := start; idx < len(clean); idx++ {
ch := clean[idx]
if escaped {
escaped = false
continue
}
if ch == '\\' && inString {
escaped = true
continue
}
if ch == '"' {
inString = !inString
continue
}
if inString {
continue
}
switch ch {
case '{':
depth++
case '}':
depth--
if depth == 0 {
return clean[start : idx+1]
}
}
}
return ""
}
func trimMarkdownCodeFence(text string) string {
trimmed := strings.TrimSpace(text)
if !strings.HasPrefix(trimmed, "```") {
return trimmed
}
lines := strings.Split(trimmed, "\n")
if len(lines) == 0 {
return trimmed
}
// 1. 去掉首行 ```json / ```
// 2. 若末行是 ```,一并去掉;
// 3. 中间正文保持原样,避免破坏 JSON 的换行结构。
body := lines[1:]
if len(body) > 0 && strings.TrimSpace(body[len(body)-1]) == "```" {
body = body[:len(body)-1]
}
return strings.TrimSpace(strings.Join(body, "\n"))
}
func truncateForError(text string) string {
if len(text) <= 160 {
return text
}
return text[:160] + "..."
}

View File

@@ -0,0 +1,640 @@
# HANDOFFRAG Infra 一步到位接入方案
## 1. 文档目的
本文用于把 `backend/infra/rag` 从“可运行骨架”推进到“可被业务正式接入的共享基础设施”。
本文重点回答 4 个问题:
1. 当前 `RAG Infra` 已经做到了什么,还缺什么。
2. 什么样的状态,才算“合格、可接入、可灰度、可回滚”的 `RAG Infra`
3. 如何以“依赖注入 + 对外只暴露方法入口”的方式收口,避免业务侧直接依赖底层实现细节。
4. 如何在不打断现有业务的前提下,把 `memory``websearch` 并行迁移到统一 `RAG Infra`
---
## 2. 当前现状
## 2.1 已完成部分
当前 `backend/infra/rag` 已经具备共享骨架,主要包括:
1. 通用接口与类型:
- `core/interfaces.go`
- `core/types.go`
- `core/errors.go`
2. 通用编排器:
- `core/pipeline.go`
3. 默认切块器:
- `chunk/text_chunker.go`
4. 语料适配器:
- `corpus/memory_corpus.go`
- `corpus/web_corpus.go`
5. 默认可运行实现:
- `embed/mock_embedder.go`
- `rerank/noop_reranker.go`
- `store/inmemory_store.go`
6. 配置骨架:
- `config/config.go`
这说明项目已经完成了“共享 RAG Core 的第一阶段搭骨架”,不再是单纯的设计想法。
## 2.2 当前存在的问题
虽然骨架已经有了,但距离“可正式接入的 Infra”还差关键几步
1. 运行时没有正式装配入口。
- 当前仍主要依赖 `rag.NewDefaultPipeline()`
- 启动阶段没有统一按配置组装 `embedder / store / reranker / corpus runtime`
2. 真实底层实现还是占位。
- `embed/eino_embedder.go` 未实现。
- `rerank/eino_reranker.go` 未实现。
- `store/milvus_store.go` 未实现。
3. 配置虽有结构,但还未真正接入运行链路。
- `rag/config/config.go` 定义了 `rag.*` 配置。
- `backend/cmd/start.go` 尚未实例化并注入 `RAG Runtime`
4. 业务尚未真正切流。
- `memory` 读取链路还没有正式走 `Pipeline.Retrieve`
- `websearch` 还没有通过 `WebCorpus + Pipeline` 形成正式 WebRAG 路径。
5. 工程化能力不完整。
- 缺统一 timeout。
- 缺统一日志字段。
- 缺基础指标。
- 缺单元测试与集成测试。
6. 还存在潜在重复实现风险。
- `retrieve/vector_retriever.go``core/pipeline.go` 都承载部分检索逻辑。
- 若后续两套逻辑并存,容易出现行为漂移与维护成本上升。
## 2.3 当前状态结论
当前 `RAG Infra` 的状态,更准确地说是:
1. 已经完成“共享骨架搭建”。
2. 还没有完成“统一装配、真实实现、正式接入、工程化收口”。
3. 目前适合继续扩展,但还不适合直接作为长期稳定的业务依赖面。
---
## 3. 目标定义:什么叫“合格的 RAG Infra”
本轮改造完成后,`backend/infra/rag` 应满足以下标准:
1. 启动时可统一构造并注入,不再靠业务模块自行拼装底层依赖。
2. 对外只暴露稳定方法入口,不暴露底层 `Pipeline / Store / Embedder / Reranker` 的装配细节。
3. 支持按配置切换实现:
- `inmemory / milvus`
- `mock / eino`
- `noop / eino`
4. 支持 `memory``websearch` 两类语料复用同一套 `chunk / embed / retrieve / rerank / fallback` 流程。
5. 支持灰度开关与回滚,不要求业务“一次性硬切流”。
6. 支持基础观测:
- 延迟
- 命中数
- fallback 原因
- 错误码
7. 具备最小可依赖测试集,保证公共层改动不会悄悄破坏业务。
---
## 4. 核心改造原则
## 4.1 原则一:依赖注入统一由 Infra 自己负责
`RAG Infra` 必须自己承接“底层实现装配”,业务侧不应感知:
1. 当前用的是 `Milvus` 还是 `InMemoryStore`
2. 当前用的是 `MockEmbedder` 还是 `EinoEmbedder`
3. 当前是否开启 `Reranker`
4. 当前超时、阈值、切块参数是多少。
业务只拿到一个已经注入好的 `RAG Runtime``RAG Service`,直接调用方法。
## 4.2 原则二:对外只暴露方法,不暴露底层零件
业务层不应直接依赖这些细粒度对象:
1. `core.Pipeline`
2. `core.VectorStore`
3. `core.Embedder`
4. `core.Reranker`
5. `corpus.MemoryCorpus`
6. `corpus.WebCorpus`
这些对象应被视为 `infra/rag` 内部拼装细节。
业务层只应调用诸如以下方法:
1. `IngestMemory`
2. `RetrieveMemory`
3. `IngestWeb`
4. `RetrieveWeb`
这样做的好处是:
1. 业务依赖面更稳定。
2. 后续替换底层实现时,不会把改动扩散到多个业务模块。
3. 便于统一日志、监控、降级和权限边界。
## 4.3 原则三:业务语义留在业务层,通用 RAG 工序下沉到 Infra
下沉到 `infra/rag` 的内容:
1. 切块
2. 向量化
3. 向量存储
4. 召回
5. rerank
6. threshold 过滤
7. fallback 语义
8. 统一日志与指标
留在业务层的内容:
1. `memory` 的注入优先级、门控规则、显式/隐式策略
2. `websearch` 的 provider 搜索、query 改写、时间过滤、domain 白名单、抓取策略
3. 最终给模型注入哪些证据、注入多少、如何组织引用
## 4.4 原则四:并行迁移,不一步删旧
本轮改造虽然目标是“一步到位把 Infra 做完整”,但切流必须保持并行迁移:
1. 新 Infra 建好后,先让 `memory` 接入并保留旧逻辑兜底。
2. 再让 `websearch` 接入并保留 V1 路径兜底。
3. 观察稳定后再删除旧分支。
---
## 5. 目标架构
## 5.1 推荐对外结构
建议在 `backend/infra/rag` 新增统一对外门面,例如:
1. `runtime.go`
2. `factory.go`
3. `service.go`
推荐把正式对外依赖面收敛为一个接口,例如:
```go
type Runtime interface {
IngestMemory(ctx context.Context, input MemoryIngestRequest) (*IngestResult, error)
RetrieveMemory(ctx context.Context, input MemoryRetrieveRequest) (*RetrieveResult, error)
IngestWeb(ctx context.Context, input WebIngestRequest) (*IngestResult, error)
RetrieveWeb(ctx context.Context, input WebRetrieveRequest) (*RetrieveResult, error)
}
```
说明:
1. 业务侧只依赖 `Runtime`
2. `Runtime` 内部再去调用 `Pipeline + CorpusAdapter + Store + Embedder + Reranker`
3. 这样可以保证业务不会直接 import `core` 包下的底层细节。
## 5.2 推荐内部结构
建议内部形成以下分工:
1. `factory.go`
- 负责按配置创建 `Embedder / Store / Reranker / Pipeline`
2. `runtime.go`
- 负责持有 `Pipeline + MemoryCorpus + WebCorpus + Logger + Metrics`
3. `service.go`
- 负责定义 `Runtime` 接口与对外方法
4. `core/`
- 保持底层通用编排逻辑
5. `corpus/`
- 只负责“语料 -> 标准文档”和“业务过滤 -> 标准 filter”
## 5.3 推荐依赖注入方式
`backend/cmd/start.go` 中,启动期统一创建 `RAG Runtime`,例如:
1. 读取 `rag.*` 配置
2. 构造 `RAGFactory`
3. 生成 `RAGRuntime`
4. 注入给:
- `memory service`
- `newAgent web tools`
业务侧只拿运行好的对象,不再自己 new 任何底层实现。
---
## 6. 对外方法面设计
## 6.1 Memory 对外方法
推荐对外暴露以下方法:
1. `IngestMemory`
- 输入:标准化后的记忆入库请求
- 输出文档数、chunk 数、同步结果
2. `RetrieveMemory`
- 输入用户、会话、助手、run、query、topK、threshold
- 输出:标准 `RetrieveResult`
注意:
1. `memory` 业务层不应直接调用 `MemoryCorpus`
2. `memory` 业务层不应自己拼向量过滤条件。
3. 所有过滤条件由 `RetrieveMemory` 内部统一转换。
## 6.2 Web 对外方法
推荐对外暴露以下方法:
1. `IngestWeb`
- 输入:抓取结果 `url/title/snippet/content/domain/query_id/session_id`
- 输出:统一入库摘要
2. `RetrieveWeb`
- 输入query、query_id/session_id、domain、topK、threshold
- 输出:标准 `RetrieveResult`
注意:
1. `websearch` 业务层不应直接持有 `WebCorpus`
2. `websearch` 业务层只负责“拿到页面内容”与“决定是否需要调用 RAG”。
3. 实际向量入库、检索、rerank 由 `infra/rag` 统一处理。
## 6.3 对外方法设计边界
方法层负责什么:
1. 参数合法性校验
2. 内部 filter 组装
3.`Pipeline.Ingest / Retrieve`
4. 统一日志、指标、fallback
方法层不负责什么:
1. 不负责 `websearch provider` 搜索
2. 不负责 HTML 抓取
3. 不负责 prompt 注入
4. 不负责业务排序偏好
---
## 7. 具体改造计划
## 7.1 第一部分:把 RAG Infra 自身做完整
### 目标
`backend/infra/rag` 成为“正式可注入、正式可切换、正式可依赖”的共享基础设施。
### 实施项
1. 新增正式运行时与工厂:
- `backend/infra/rag/runtime.go`
- `backend/infra/rag/factory.go`
- 如有需要,新增 `backend/infra/rag/service.go`
2. 扩展配置:
- `rag.enabled`
- `rag.store`
- `rag.embed.provider`
- `rag.embed.model`
- `rag.embed.timeoutMs`
- `rag.embed.dimension`
- `rag.reranker.provider`
- `rag.reranker.timeoutMs`
- `rag.retrieve.timeoutMs`
- `rag.ingest.chunkSize`
- `rag.ingest.chunkOverlap`
3. 收口运行入口:
- `rag.NewDefaultPipeline()` 保留为本地 fallback
- 正式业务接入走 `NewRuntimeFromConfig(...)`
4. 消除重复检索路径:
- 明确 `Pipeline` 是官方检索入口
- `retrieve/vector_retriever.go` 要么内聚为内部实现,要么后续删除,避免双轨
### 验收
1. 启动期可按配置成功构造 `RAG Runtime`
2. 业务侧不需要自己组装 `Pipeline / Store / Embedder / Reranker`
3. 对外暴露面稳定,底层实现可替换。
## 7.2 第二部分:补齐真实底层实现
### 目标
`RAG Infra` 具备真实可用的向量能力,而不是停留在 mock。
### 实施项
1. 实现 `embed/eino_embedder.go`
- 负责 embedding 调用
- 负责 embedding timeout
- 负责错误包装与统一日志
2. 实现 `rerank/eino_reranker.go`
- 负责 rerank 调用
- 负责 rerank timeout
- 负责失败降级到原排序
3. 实现 `store/milvus_store.go`
- `Upsert`
- `Search`
- `Delete`
- `Get`
4. Milvus 元数据设计建议:
- 高频过滤字段应做显式标量字段,不建议全部依赖大 JSON 过滤
- 重点字段包括:
- `corpus`
- `user_id`
- `assistant_id`
- `conversation_id`
- `run_id`
- `memory_type`
- `query_id`
- `session_id`
- `domain`
### 验收
1. `MilvusStore` 在已准备好的 Docker 环境中可稳定完成写入与检索。
2. `EinoEmbedder``EinoReranker` 可按配置启用。
3. provider 波动时,主链路仍能 fallback。
## 7.3 第三部分:补齐工程化能力
### 目标
`RAG Infra` 具备“可观测、可测试、可回滚”的基础设施属性。
### 实施项
1. timeout 接线:
- embedding timeout
- retrieve timeout
- rerank timeout
2. 统一日志字段:
- `trace_id`
- `corpus`
- `action`
- `provider`
- `latency_ms`
- `hit_count`
- `fallback_reason`
3. 指标补齐:
- `rag_ingest_count`
- `rag_retrieve_count`
- `rag_hit_count`
- `rag_fallback_rate`
- `rag_latency_ms`
4. 测试补齐:
- `chunker` 单测
- `corpus filter` 单测
- `pipeline fallback` 单测
- `MilvusStore` 集成测试
- `memory/web` 过滤隔离测试
### 验收
1. 出现检索问题时,可从日志定位是:
- 没命中
- 超时
- rerank 降级
- filter 过滤过严
2. 公共层测试可稳定覆盖关键路径。
## 7.4 第四部分:接入 Memory
### 目标
`memory` 成为第一个正式接入 `RAG Infra` 的业务域。
### 实施项
1. 写入链路接入:
- 在 memory worker 成功写入 `memory_items` 后,调用 `RAGRuntime.IngestMemory`
- 复用 `memory_items.vector_status/vector_id`
2. 读取链路接入:
-`memory/service/read_service.go` 中新增 `RetrieveMemory` 路径
- 强制过滤:
- `user_id`
- `assistant_id`
- `conversation_id`
- `run_id`
3. 开关控制:
- `memory.rag.enabled=false` 默认关闭
- 打开后先灰度使用新路径
4. 降级策略:
- `RAG` 检索失败 -> 回退旧读取链路
- `Reranker` 失败 -> 保留原始排序
### 验收
1. 开关关闭时行为与当前一致。
2. 开关开启时,记忆召回可稳定工作。
3. 失败时不会影响主链路回复。
## 7.5 第五部分:接入 WebSearch
### 目标
`websearch` 成为第二个正式接入 `RAG Infra` 的业务域,并复用 `WebCorpus`
### 实施项
1. 保留 V1 路径:
- `web_search` 做 provider 搜索
- `web_fetch` 做正文抓取与清洗
2. 新增 V2 路径:
- 把抓取结果映射为 `WebIngestItem`
-`RAGRuntime.IngestWeb`
- 再调 `RAGRuntime.RetrieveWeb`
3. 强约束过滤:
- `query_id``session_id` 至少有一个
- 避免跨 query/session 串召回
4. 开关控制:
- `websearch.rag.enabled=false` 默认关闭
5. 降级策略:
- `web_rag_search` 失败 -> 回退到 `web_search + web_fetch`
### 验收
1. 新旧链路并存,互不影响。
2. 新链路不会跨 query/session 串数据。
3. 失败可立刻回退到 V1。
## 7.6 第六部分:启动接线与统一管理
### 目标
`RAG Runtime` 成为启动期统一装配、统一管理的依赖。
### 实施项
1.`backend/cmd/start.go` 中:
- 读取 `rag.*` 配置
- 构造 `RAG Runtime`
- 注入给 `memory``newAgent web tools`
2. 统一由启动期管理依赖生命周期:
- 初始化
- 健康检查
- 关闭清理
3. 业务层禁止直接 new 底层实现:
- 禁止业务自己构建 `MilvusStore`
- 禁止业务自己构建 `EinoEmbedder`
- 禁止业务自己拼 `Pipeline`
### 验收
1. 依赖管理集中在启动层。
2. 业务代码只依赖方法入口,不接触底层实现。
3. 后续替换实现时,无需大面积修改业务层代码。
---
## 8. 推荐目录改造方案
建议新增或调整如下文件:
1. `backend/infra/rag/runtime.go`
2. `backend/infra/rag/factory.go`
3. `backend/infra/rag/service.go`
4. `backend/infra/rag/README.md` 或在本文件持续追加
5. `backend/infra/rag/embed/eino_embedder.go`
6. `backend/infra/rag/rerank/eino_reranker.go`
7. `backend/infra/rag/store/milvus_store.go`
8. `backend/infra/rag/core/pipeline_test.go`
9. `backend/infra/rag/chunk/text_chunker_test.go`
10. `backend/infra/rag/corpus/memory_corpus_test.go`
11. `backend/infra/rag/corpus/web_corpus_test.go`
12. `backend/infra/rag/store/milvus_store_integration_test.go`
配套改动文件:
1. `backend/cmd/start.go`
2. `backend/config.example.yaml`
3. `backend/memory/service/read_service.go`
4. `backend/newAgent/tools/registry.go`
5. `backend/agent/通用能力接入文档.md`
---
## 9. 配置建议
建议新增如下配置结构:
```yaml
rag:
enabled: true
store: "milvus"
topK: 8
threshold: 0.55
retrieve:
timeoutMs: 1500
ingest:
chunkSize: 400
chunkOverlap: 80
embed:
provider: "eino"
model: ""
timeoutMs: 1200
dimension: 1024
reranker:
enabled: true
provider: "eino"
timeoutMs: 1200
memory:
rag:
enabled: false
websearch:
rag:
enabled: false
```
说明:
1. `rag.enabled` 控制公共层是否启用。
2. `memory.rag.enabled``websearch.rag.enabled` 控制业务级切流。
3. 即使 `rag.enabled=true`,也不代表所有业务立刻默认走新链路。
---
## 10. 回滚策略
推荐回滚顺序如下:
1. 先关业务级开关:
- `memory.rag.enabled=false`
- `websearch.rag.enabled=false`
2. 再关重排:
- `rag.reranker.enabled=false`
3. 再切底层实现:
- `rag.store=inmemory`
- `rag.embed.provider=mock`
- `rag.reranker.provider=noop`
4. 若仍异常,再回退到业务旧链路
这样可以做到:
1. 不因单个 provider 波动打断主流程。
2. 保留最小可用能力。
3. 故障定位粒度更细。
---
## 11. 风险与应对
1. 风险Milvus 过滤能力与现有 metadata 结构不匹配。
- 应对:高频过滤字段单独建模,不依赖大 JSON 粗暴过滤。
2. 风险embedding/rerank provider 波动影响延迟。
- 应对:超时控制 + fallback + 业务级开关。
3. 风险:业务层绕过 Infra 直接依赖底层实现。
- 应对:通过 `Runtime` 方法面统一收口,代码评审禁止横向绕过。
4. 风险:新旧检索路径长期并存导致维护成本上升。
- 应对:本轮先保留兜底,稳定后明确删除旧实现。
5. 风险:跨 query/session 串召回。
- 应对:`WebRetrieve` 强制校验 `query_id/session_id` 至少其一存在。
---
## 12. 最小落地顺序
如果按“尽快落成可接入 Infra”的优先级来排本轮建议顺序如下
1. 先做 `runtime/factory/service`,把依赖注入和方法面收口。
2. 再实现 `MilvusStore + EinoEmbedder + EinoReranker`
3. 再补 timeout、日志、指标、测试。
4. 然后优先接 `memory`
5. 最后接 `websearch`
原因:
1. 若先接业务、不先收口方法面,后面会把底层细节泄露到业务层。
2. 若先接 websearch、不先接 memory会导致共享 Infra 价值不够集中,面试叙事也不完整。
---
## 13. 本轮完成后的预期收益
完成本方案后,项目会获得以下收益:
1. `memory``websearch` 共享一套真正可运行的 RAG 基础设施。
2. 业务侧不再重复实现切块、召回、重排与降级逻辑。
3. `infra/rag` 成为正式公共能力,具备统一依赖注入与统一管理能力。
4. 后续新增新语料域时,只需新增 `CorpusAdapter + 方法面`,无需再复制一套 RAG 链路。
5. 项目简历叙事会更完整:
- “抽象并实现共享 RAG Infra”
- “统一 Memory/WebSearch 的检索与重排能力”
- “通过依赖注入与门面方法收口底层复杂度”
---
## 14. 当前建议结论
建议把本轮目标明确为:
1. **不是**“再给 RAG 补几个占位实现”。
2. **而是**“把 `backend/infra/rag` 一次性做成正式可接入的公共基础设施”。
关键落点是两句话:
1. 依赖注入统一由 `infra/rag` 自己负责。
2. 对外只暴露方法入口,业务侧不直接接触底层实现细节。
只要这两点收住,后续 `memory``websearch`、甚至更多语料域都会明显更好管理。

View File

@@ -5,30 +5,63 @@ import "github.com/spf13/viper"
// Config 是 RAG Core 运行配置。
type Config struct {
Enabled bool
Store string
TopK int
Threshold float64
EmbedProvider string
EmbedModel string
EmbedBaseURL string
EmbedAPIKeyEnv string
EmbedTimeoutMS int
EmbedDimension int
RerankerEnabled bool
RerankerProvider string
RerankerTimeoutMS int
ChunkSize int
ChunkOverlap int
RetrieveTimeoutMS int
MilvusAddress string
MilvusToken string
MilvusDBName string
MilvusCollectionName string
MilvusMetricType string
MilvusRequestTimeoutMS int
}
// LoadFromViper 读取 rag 配置并补默认值。
func LoadFromViper() Config {
cfg := Config{
Enabled: viper.GetBool("rag.enabled"),
TopK: viper.GetInt("rag.topK"),
Threshold: viper.GetFloat64("rag.threshold"),
RerankerEnabled: viper.GetBool("rag.reranker.enabled"),
RerankerTimeoutMS: viper.GetInt("rag.reranker.timeoutMs"),
ChunkSize: viper.GetInt("rag.ingest.chunkSize"),
ChunkOverlap: viper.GetInt("rag.ingest.chunkOverlap"),
RetrieveTimeoutMS: viper.GetInt("rag.retrieve.timeoutMs"),
Enabled: viper.GetBool("rag.enabled"),
Store: viper.GetString("rag.store"),
TopK: viper.GetInt("rag.topK"),
Threshold: viper.GetFloat64("rag.threshold"),
EmbedProvider: viper.GetString("rag.embed.provider"),
EmbedModel: viper.GetString("rag.embed.model"),
EmbedBaseURL: viper.GetString("rag.embed.baseURL"),
EmbedAPIKeyEnv: viper.GetString("rag.embed.apiKeyEnv"),
EmbedTimeoutMS: viper.GetInt("rag.embed.timeoutMs"),
EmbedDimension: viper.GetInt("rag.embed.dimension"),
RerankerEnabled: viper.GetBool("rag.reranker.enabled"),
RerankerProvider: viper.GetString("rag.reranker.provider"),
RerankerTimeoutMS: viper.GetInt("rag.reranker.timeoutMs"),
ChunkSize: viper.GetInt("rag.ingest.chunkSize"),
ChunkOverlap: viper.GetInt("rag.ingest.chunkOverlap"),
RetrieveTimeoutMS: viper.GetInt("rag.retrieve.timeoutMs"),
MilvusAddress: viper.GetString("rag.milvus.address"),
MilvusToken: viper.GetString("rag.milvus.token"),
MilvusDBName: viper.GetString("rag.milvus.dbName"),
MilvusCollectionName: viper.GetString("rag.milvus.collectionName"),
MilvusMetricType: viper.GetString("rag.milvus.metricType"),
MilvusRequestTimeoutMS: viper.GetInt("rag.milvus.requestTimeoutMs"),
}
if cfg.Store == "" {
cfg.Store = "inmemory"
}
if cfg.TopK <= 0 {
cfg.TopK = 8
@@ -36,6 +69,24 @@ func LoadFromViper() Config {
if cfg.Threshold < 0 {
cfg.Threshold = 0
}
if cfg.EmbedProvider == "" {
cfg.EmbedProvider = "mock"
}
if cfg.EmbedBaseURL == "" {
cfg.EmbedBaseURL = viper.GetString("agent.baseURL")
}
if cfg.EmbedAPIKeyEnv == "" {
cfg.EmbedAPIKeyEnv = "ARK_API_KEY"
}
if cfg.EmbedTimeoutMS <= 0 {
cfg.EmbedTimeoutMS = 1200
}
if cfg.EmbedDimension <= 0 {
cfg.EmbedDimension = 1024
}
if cfg.RerankerProvider == "" {
cfg.RerankerProvider = "noop"
}
if cfg.RerankerTimeoutMS <= 0 {
cfg.RerankerTimeoutMS = 1200
}
@@ -48,5 +99,20 @@ func LoadFromViper() Config {
if cfg.RetrieveTimeoutMS <= 0 {
cfg.RetrieveTimeoutMS = 1500
}
if cfg.MilvusAddress == "" {
cfg.MilvusAddress = "http://localhost:19530"
}
if cfg.MilvusToken == "" {
cfg.MilvusToken = "root:Milvus"
}
if cfg.MilvusCollectionName == "" {
cfg.MilvusCollectionName = "smartflow_rag_chunks"
}
if cfg.MilvusMetricType == "" {
cfg.MilvusMetricType = "COSINE"
}
if cfg.MilvusRequestTimeoutMS <= 0 {
cfg.MilvusRequestTimeoutMS = 1500
}
return cfg
}

View File

@@ -0,0 +1,190 @@
package core
import (
"context"
"errors"
"fmt"
"log"
"sort"
"strings"
)
// ObserveLevel 表示观测事件等级。
type ObserveLevel string
const (
ObserveLevelInfo ObserveLevel = "info"
ObserveLevelWarn ObserveLevel = "warn"
ObserveLevelError ObserveLevel = "error"
)
// ObserveEvent 描述一次统一观测事件。
//
// 职责边界:
// 1. 只承载 RAG Infra 的结构化运行信息;
// 2. 不绑定具体日志系统、指标系统或 tracing 实现;
// 3. 字段内容应尽量稳定,便于后续统一接入全局观测平台。
type ObserveEvent struct {
Level ObserveLevel
Component string
Operation string
Fields map[string]any
}
// Observer 是 RAG Infra 的最小观测接口。
//
// 职责边界:
// 1. 负责消费结构化事件;
// 2. 不负责决定业务逻辑是否继续执行;
// 3. 任一实现都不应反向影响主链路稳定性。
type Observer interface {
Observe(ctx context.Context, event ObserveEvent)
}
// ObserverFunc 允许用函数快速适配 Observer。
type ObserverFunc func(ctx context.Context, event ObserveEvent)
func (f ObserverFunc) Observe(ctx context.Context, event ObserveEvent) {
if f == nil {
return
}
f(ctx, event)
}
// NewNopObserver 返回空实现,适合在未接入统一观测平台时兜底。
func NewNopObserver() Observer {
return ObserverFunc(func(context.Context, ObserveEvent) {})
}
// NewLoggerObserver 返回标准日志适配器。
//
// 说明:
// 1. 当前项目尚未建立统一日志平台时,先把结构化字段稳定打印出来;
// 2. 后续若项目引入统一 logger/metrics/tracing只需替换该 Observer 注入实现;
// 3. 该适配器默认保持单行输出,减少和现有日志风格的割裂感。
func NewLoggerObserver(logger *log.Logger) Observer {
if logger == nil {
logger = log.Default()
}
return &loggerObserver{logger: logger}
}
type loggerObserver struct {
logger *log.Logger
}
func (o *loggerObserver) Observe(ctx context.Context, event ObserveEvent) {
if o == nil || o.logger == nil {
return
}
level := strings.TrimSpace(string(event.Level))
if level == "" {
level = string(ObserveLevelInfo)
}
component := strings.TrimSpace(event.Component)
if component == "" {
component = "unknown"
}
operation := strings.TrimSpace(event.Operation)
if operation == "" {
operation = "unknown"
}
fields := ObserveFieldsFromContext(ctx)
for key, value := range event.Fields {
key = strings.TrimSpace(key)
if key == "" || !shouldKeepObserveField(value) {
continue
}
fields[key] = value
}
parts := []string{
"rag",
fmt.Sprintf("level=%s", level),
fmt.Sprintf("component=%s", component),
fmt.Sprintf("operation=%s", operation),
}
keys := make([]string, 0, len(fields))
for key := range fields {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
parts = append(parts, fmt.Sprintf("%s=%v", key, fields[key]))
}
o.logger.Print(strings.Join(parts, " "))
}
type observeFieldsContextKey struct{}
// WithObserveFields 把通用观测字段挂入上下文,便于下游组件复用。
//
// 步骤化说明:
// 1. 先读取已有上下文字段,保证 Runtime / Pipeline / Store 能逐层补充信息;
// 2. 后写字段覆盖同名旧值,确保下游拿到的是最新语义;
// 3. 仅保存“有意义”的字段,避免日志长期堆积大量空值。
func WithObserveFields(ctx context.Context, fields map[string]any) context.Context {
if len(fields) == 0 {
return ctx
}
if ctx == nil {
ctx = context.Background()
}
merged := ObserveFieldsFromContext(ctx)
for key, value := range fields {
key = strings.TrimSpace(key)
if key == "" || !shouldKeepObserveField(value) {
continue
}
merged[key] = value
}
if len(merged) == 0 {
return ctx
}
return context.WithValue(ctx, observeFieldsContextKey{}, merged)
}
// ObserveFieldsFromContext 提取上下文中已经累积的观测字段。
func ObserveFieldsFromContext(ctx context.Context) map[string]any {
if ctx == nil {
return map[string]any{}
}
raw, ok := ctx.Value(observeFieldsContextKey{}).(map[string]any)
if !ok || len(raw) == 0 {
return map[string]any{}
}
result := make(map[string]any, len(raw))
for key, value := range raw {
result[key] = value
}
return result
}
// ClassifyErrorCode 统一把常见错误压缩为稳定错误码,便于后续接入全局观测平台。
func ClassifyErrorCode(err error) string {
switch {
case err == nil:
return ""
case errors.Is(err, context.DeadlineExceeded):
return "DEADLINE_EXCEEDED"
case errors.Is(err, context.Canceled):
return "CANCELED"
default:
return "RAG_ERROR"
}
}
func shouldKeepObserveField(value any) bool {
if value == nil {
return false
}
if text, ok := value.(string); ok {
return strings.TrimSpace(text) != ""
}
return true
}

View File

@@ -28,6 +28,7 @@ type Pipeline struct {
store VectorStore
reranker Reranker
logger *log.Logger
observer Observer
}
func NewPipeline(chunker Chunker, embedder Embedder, store VectorStore, reranker Reranker) *Pipeline {
@@ -37,9 +38,26 @@ func NewPipeline(chunker Chunker, embedder Embedder, store VectorStore, reranker
store: store,
reranker: reranker,
logger: log.Default(),
observer: NewNopObserver(),
}
}
// SetLogger 设置 Pipeline 使用的日志器。
func (p *Pipeline) SetLogger(logger *log.Logger) {
if p == nil || logger == nil {
return
}
p.logger = logger
}
// SetObserver 设置 Pipeline 使用的统一观测器。
func (p *Pipeline) SetObserver(observer Observer) {
if p == nil || observer == nil {
return
}
p.observer = observer
}
// Ingest 执行统一入库流程。
//
// 步骤化说明:
@@ -63,6 +81,24 @@ func (p *Pipeline) Ingest(
if err != nil {
return nil, err
}
return p.IngestDocuments(ctx, corpus.Name(), docs, opt)
}
// IngestDocuments 执行“已标准化文档”的统一入库流程。
//
// 职责边界:
// 1. 负责处理已经完成 CorpusAdapter 映射的标准文档;
// 2. 负责统一切块、向量化与 Upsert
// 3. 不负责再做业务输入解析,避免 Runtime 为拿到 document_id 重复 build 文档。
func (p *Pipeline) IngestDocuments(
ctx context.Context,
corpusName string,
docs []SourceDocument,
opt IngestOption,
) (*IngestResult, error) {
if p == nil || p.chunker == nil || p.embedder == nil || p.store == nil {
return nil, ErrNilDependency
}
if len(docs) == 0 {
return &IngestResult{DocumentCount: 0, ChunkCount: 0}, nil
}
@@ -102,7 +138,7 @@ func (p *Pipeline) Ingest(
now := time.Now()
for i, chunk := range chunks {
metadata := cloneMap(chunk.Metadata)
metadata["corpus"] = corpus.Name()
metadata["corpus"] = corpusName
metadata["document_id"] = chunk.DocumentID
metadata["chunk_order"] = chunk.Order
rows = append(rows, VectorRow{
@@ -214,7 +250,23 @@ func (p *Pipeline) Retrieve(
// 2. rerank 异常不终止主流程,统一降级为原排序。
result.FallbackUsed = true
result.FallbackReason = FallbackReasonRerankFailed
p.logger.Printf("rag rerank fallback: reason=%s err=%v", FallbackReasonRerankFailed, rerankErr)
if p.observer != nil {
p.observer.Observe(ctx, ObserveEvent{
Level: ObserveLevelWarn,
Component: "pipeline",
Operation: "rerank_fallback",
Fields: map[string]any{
"status": "fallback",
"fallback_reason": FallbackReasonRerankFailed,
"candidate_count": len(candidates),
"top_k": topK,
"error": rerankErr,
"error_code": ClassifyErrorCode(rerankErr),
},
})
} else if p.logger != nil {
p.logger.Printf("rag rerank fallback: reason=%s err=%v", FallbackReasonRerankFailed, rerankErr)
}
return result, nil
}
result.Items = reranked

View File

@@ -22,7 +22,11 @@ type MemoryIngestItem struct {
MemoryType string
Title string
Content string
Confidence float64
Importance float64
SensitivityLevel int
IsExplicit bool
Status string
TTLAt *time.Time
CreatedAt *time.Time
}
@@ -71,7 +75,12 @@ func (c *MemoryCorpus) BuildIngestDocuments(_ context.Context, input any) ([]cor
"assistant_id": strings.TrimSpace(item.AssistantID),
"run_id": strings.TrimSpace(item.RunID),
"memory_type": strings.TrimSpace(strings.ToLower(item.MemoryType)),
"title": strings.TrimSpace(item.Title),
"confidence": item.Confidence,
"importance": item.Importance,
"sensitivity_level": item.SensitivityLevel,
"is_explicit": item.IsExplicit,
"status": strings.TrimSpace(item.Status),
}
if item.TTLAt != nil {
metadata["ttl_at"] = item.TTLAt.Format(time.RFC3339)

View File

@@ -3,19 +3,97 @@ package embed
import (
"context"
"errors"
"strings"
"time"
openaiembedding "github.com/cloudwego/eino-ext/libs/acl/openai"
einoembedding "github.com/cloudwego/eino/components/embedding"
)
// EinoEmbedder 是 Eino embedding 的占位实现
// EinoConfig 描述 Eino embedding 运行参数
type EinoConfig struct {
APIKey string
BaseURL string
Model string
TimeoutMS int
Dimension int
}
// EinoEmbedder 是基于 Eino 的 embedding 适配器。
//
// 说明:
// 1. 本轮先占位接口,避免过早耦合具体 Provider
// 2. 后续接入真实 embedding 时,只替换此文件内部实现。
type EinoEmbedder struct{}
func NewEinoEmbedder() *EinoEmbedder {
return &EinoEmbedder{}
// 1. 对 infra/rag 暴露统一 []float32 结果,屏蔽 Eino/OpenAI 兼容实现细节
// 2. 超时由该适配器自身收口,避免业务侧每次调用都手写超时控制;
// 3. 当前底层走 Eino Ext 的 OpenAI 兼容 embedding client便于接 Ark/OpenAI 兼容接口。
type EinoEmbedder struct {
client einoembedding.Embedder
model string
timeout time.Duration
dimension int
}
func (e *EinoEmbedder) Embed(_ context.Context, _ []string, _ string) ([][]float32, error) {
return nil, errors.New("eino embedder is not implemented yet")
func NewEinoEmbedder(ctx context.Context, cfg EinoConfig) (*EinoEmbedder, error) {
if strings.TrimSpace(cfg.APIKey) == "" {
return nil, errors.New("eino embedder api key is empty")
}
if strings.TrimSpace(cfg.Model) == "" {
return nil, errors.New("eino embedder model is empty")
}
clientCfg := &openaiembedding.EmbeddingConfig{
APIKey: strings.TrimSpace(cfg.APIKey),
BaseURL: strings.TrimSpace(cfg.BaseURL),
Model: strings.TrimSpace(cfg.Model),
}
if cfg.Dimension > 0 {
clientCfg.Dimensions = &cfg.Dimension
}
client, err := openaiembedding.NewEmbeddingClient(ctx, clientCfg)
if err != nil {
return nil, err
}
timeout := 1200 * time.Millisecond
if cfg.TimeoutMS > 0 {
timeout = time.Duration(cfg.TimeoutMS) * time.Millisecond
}
return &EinoEmbedder{
client: client,
model: strings.TrimSpace(cfg.Model),
timeout: timeout,
dimension: cfg.Dimension,
}, nil
}
func (e *EinoEmbedder) Embed(ctx context.Context, texts []string, _ string) ([][]float32, error) {
if e == nil || e.client == nil {
return nil, errors.New("eino embedder is not initialized")
}
if len(texts) == 0 {
return nil, nil
}
callCtx := ctx
cancel := func() {}
if e.timeout > 0 {
callCtx, cancel = context.WithTimeout(ctx, e.timeout)
}
defer cancel()
vectors, err := e.client.EmbedStrings(callCtx, texts, einoembedding.WithModel(e.model))
if err != nil {
return nil, err
}
result := make([][]float32, 0, len(vectors))
for _, vector := range vectors {
converted := make([]float32, len(vector))
for i, value := range vector {
converted[i] = float32(value)
}
result = append(result, converted)
}
return result, nil
}

View File

@@ -0,0 +1,139 @@
package rag
import (
"context"
"fmt"
"log"
"os"
"strings"
ragchunk "github.com/LoveLosita/smartflow/backend/infra/rag/chunk"
ragconfig "github.com/LoveLosita/smartflow/backend/infra/rag/config"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
ragembed "github.com/LoveLosita/smartflow/backend/infra/rag/embed"
ragrerank "github.com/LoveLosita/smartflow/backend/infra/rag/rerank"
ragstore "github.com/LoveLosita/smartflow/backend/infra/rag/store"
)
// FactoryDeps 描述 Runtime 工厂所需的可选依赖。
//
// 说明:
// 1. Logger 仅作为“当前项目尚无统一日志系统”时的默认落点;
// 2. Observer 是正式的统一观测插槽,后续可替换为项目级 logger/metrics/tracing 适配器;
// 3. 业务侧仍然只拿 Runtime不直接碰底层装配细节。
type FactoryDeps struct {
Logger *log.Logger
Observer Observer
}
// NewRuntimeFromConfig 按配置统一组装 RAG Runtime。
//
// 设计说明:
// 1. 所有底层实现选择都收口到这里,业务侧不再自行 new store/embedder/reranker
// 2. 即使后续引入更多 provider也应优先扩展本工厂而不是把选择逻辑扩散到业务模块
// 3. 观测能力也在此统一注入,避免 runtime/store/pipeline 各自偷偷打印日志。
func NewRuntimeFromConfig(ctx context.Context, cfg ragconfig.Config, deps FactoryDeps) (Runtime, error) {
logger, observer := normalizeFactoryDeps(deps)
embedder, err := buildEmbedder(ctx, cfg)
if err != nil {
return nil, err
}
store, err := buildStore(cfg, logger, observer)
if err != nil {
return nil, err
}
reranker, err := buildReranker(cfg, observer)
if err != nil {
return nil, err
}
pipeline := core.NewPipeline(ragchunk.NewTextChunker(), embedder, store, reranker)
pipeline.SetLogger(logger)
pipeline.SetObserver(observer)
return newRuntime(cfg, pipeline, observer), nil
}
func normalizeFactoryDeps(deps FactoryDeps) (*log.Logger, Observer) {
logger := deps.Logger
if logger == nil {
logger = log.Default()
}
observer := deps.Observer
if observer == nil {
observer = NewLoggerObserver(logger)
}
return logger, observer
}
func buildEmbedder(ctx context.Context, cfg ragconfig.Config) (core.Embedder, error) {
switch strings.ToLower(strings.TrimSpace(cfg.EmbedProvider)) {
case "", "mock":
return ragembed.NewMockEmbedder(cfg.EmbedDimension), nil
case "eino":
apiKey := strings.TrimSpace(os.Getenv(cfg.EmbedAPIKeyEnv))
if apiKey == "" {
return nil, fmt.Errorf("rag embed api key is empty: env=%s", cfg.EmbedAPIKeyEnv)
}
return ragembed.NewEinoEmbedder(ctx, ragembed.EinoConfig{
APIKey: apiKey,
BaseURL: cfg.EmbedBaseURL,
Model: cfg.EmbedModel,
TimeoutMS: cfg.EmbedTimeoutMS,
Dimension: cfg.EmbedDimension,
})
default:
return nil, fmt.Errorf("unsupported rag embed provider: %s", cfg.EmbedProvider)
}
}
func buildStore(cfg ragconfig.Config, logger *log.Logger, observer Observer) (core.VectorStore, error) {
switch strings.ToLower(strings.TrimSpace(cfg.Store)) {
case "", "inmemory":
return ragstore.NewInMemoryVectorStore(), nil
case "milvus":
return ragstore.NewMilvusStore(ragstore.MilvusConfig{
Address: cfg.MilvusAddress,
Token: cfg.MilvusToken,
DBName: cfg.MilvusDBName,
CollectionName: cfg.MilvusCollectionName,
RequestTimeoutMS: cfg.MilvusRequestTimeoutMS,
Dimension: cfg.EmbedDimension,
MetricType: cfg.MilvusMetricType,
Logger: logger,
Observer: observer,
})
default:
return nil, fmt.Errorf("unsupported rag store: %s", cfg.Store)
}
}
func buildReranker(cfg ragconfig.Config, observer Observer) (core.Reranker, error) {
if !cfg.RerankerEnabled {
return ragrerank.NewNoopReranker(), nil
}
switch strings.ToLower(strings.TrimSpace(cfg.RerankerProvider)) {
case "", "noop":
return ragrerank.NewNoopReranker(), nil
case "eino":
if observer != nil {
observer.Observe(context.Background(), ObserveEvent{
Level: ObserveLevelWarn,
Component: "factory",
Operation: "reranker_fallback",
Fields: map[string]any{
"provider": "eino",
"status": "fallback",
"fallback_target": "noop",
"reason": "reranker_not_implemented",
},
})
}
return ragrerank.NewNoopReranker(), nil
default:
return nil, fmt.Errorf("unsupported rag reranker provider: %s", cfg.RerankerProvider)
}
}

View File

@@ -0,0 +1,32 @@
package rag
import (
"log"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
)
// ObserveLevel 对外暴露统一观测等级别名,避免启动层直接依赖 core 细节。
type ObserveLevel = core.ObserveLevel
const (
ObserveLevelInfo = core.ObserveLevelInfo
ObserveLevelWarn = core.ObserveLevelWarn
ObserveLevelError = core.ObserveLevelError
)
// ObserveEvent 对外暴露统一观测事件别名。
type ObserveEvent = core.ObserveEvent
// Observer 对外暴露统一观测接口别名。
type Observer = core.Observer
// NewNopObserver 返回空实现。
func NewNopObserver() Observer {
return core.NewNopObserver()
}
// NewLoggerObserver 返回标准日志适配器。
func NewLoggerObserver(logger *log.Logger) Observer {
return core.NewLoggerObserver(logger)
}

View File

@@ -0,0 +1,380 @@
package rag
import (
"context"
"fmt"
"strings"
"time"
ragconfig "github.com/LoveLosita/smartflow/backend/infra/rag/config"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
"github.com/LoveLosita/smartflow/backend/infra/rag/corpus"
)
type runtime struct {
cfg ragconfig.Config
pipeline *core.Pipeline
memoryCorpus *corpus.MemoryCorpus
webCorpus *corpus.WebCorpus
observer Observer
}
func newRuntime(cfg ragconfig.Config, pipeline *core.Pipeline, observer Observer) Runtime {
if observer == nil {
observer = NewNopObserver()
}
return &runtime{
cfg: cfg,
pipeline: pipeline,
memoryCorpus: corpus.NewMemoryCorpus(),
webCorpus: corpus.NewWebCorpus(),
observer: observer,
}
}
// IngestMemory 统一承接记忆语料入库。
func (r *runtime) IngestMemory(ctx context.Context, req MemoryIngestRequest) (*IngestResult, error) {
items := make([]corpus.MemoryIngestItem, 0, len(req.Items))
for _, item := range req.Items {
items = append(items, corpus.MemoryIngestItem{
MemoryID: item.MemoryID,
UserID: item.UserID,
ConversationID: item.ConversationID,
AssistantID: item.AssistantID,
RunID: item.RunID,
MemoryType: item.MemoryType,
Title: item.Title,
Content: item.Content,
Confidence: item.Confidence,
Importance: item.Importance,
SensitivityLevel: item.SensitivityLevel,
IsExplicit: item.IsExplicit,
Status: item.Status,
TTLAt: item.TTLAt,
CreatedAt: item.CreatedAt,
})
}
return r.ingestWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, items, req.Action)
}
// RetrieveMemory 统一承接记忆语料检索。
func (r *runtime) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (*RetrieveResult, error) {
corpusInput := corpus.MemoryRetrieveInput{
UserID: req.UserID,
ConversationID: req.ConversationID,
AssistantID: req.AssistantID,
RunID: req.RunID,
}
if len(req.MemoryTypes) == 1 {
corpusInput.MemoryType = req.MemoryTypes[0]
}
result, err := r.retrieveWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, core.RetrieveRequest{
Query: req.Query,
TopK: normalizeTopK(req.TopK, r.cfg.TopK),
Threshold: normalizeThreshold(req.Threshold, r.cfg.Threshold),
Action: normalizeAction(req.Action, "search"),
CorpusInput: corpusInput,
})
if err != nil {
return nil, err
}
if len(req.MemoryTypes) <= 1 {
return result, nil
}
// 1. 当前底层过滤仍以等值条件为主,先保持 Runtime 做多类型二次筛选;
// 2. 这样可以避免把 “memory_type in (...)” 的实现细节扩散到所有 Store
// 3. 等后续底层过滤能力统一后,再考虑把该逻辑继续下沉。
allowed := make(map[string]struct{}, len(req.MemoryTypes))
for _, item := range req.MemoryTypes {
value := strings.TrimSpace(strings.ToLower(item))
if value == "" {
continue
}
allowed[value] = struct{}{}
}
filtered := make([]RetrieveHit, 0, len(result.Items))
for _, item := range result.Items {
memoryType := strings.TrimSpace(strings.ToLower(asString(item.Metadata["memory_type"])))
if len(allowed) > 0 {
if _, ok := allowed[memoryType]; !ok {
continue
}
}
filtered = append(filtered, item)
}
result.Items = filtered
if req.TopK > 0 && len(result.Items) > req.TopK {
result.Items = result.Items[:req.TopK]
}
return result, nil
}
// IngestWeb 统一承接网页语料入库。
func (r *runtime) IngestWeb(ctx context.Context, req WebIngestRequest) (*IngestResult, error) {
items := make([]corpus.WebIngestItem, 0, len(req.Items))
for _, item := range req.Items {
items = append(items, corpus.WebIngestItem{
URL: item.URL,
Title: item.Title,
Content: item.Content,
Snippet: item.Snippet,
Domain: item.Domain,
QueryID: item.QueryID,
SessionID: item.SessionID,
PublishedAt: item.PublishedAt,
FetchedAt: item.FetchedAt,
SourceRank: item.SourceRank,
})
}
return r.ingestWithCorpus(ctx, req.TraceID, "web", r.webCorpus, items, req.Action)
}
// RetrieveWeb 统一承接网页语料检索。
func (r *runtime) RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (*RetrieveResult, error) {
return r.retrieveWithCorpus(ctx, req.TraceID, "web", r.webCorpus, core.RetrieveRequest{
Query: req.Query,
TopK: normalizeTopK(req.TopK, r.cfg.TopK),
Threshold: normalizeThreshold(req.Threshold, r.cfg.Threshold),
Action: normalizeAction(req.Action, "search"),
CorpusInput: corpus.WebRetrieveInput{
QueryID: req.QueryID,
SessionID: req.SessionID,
Domain: req.Domain,
},
})
}
func (r *runtime) ingestWithCorpus(
ctx context.Context,
traceID string,
corpusName string,
adapter core.CorpusAdapter,
input any,
action string,
) (*IngestResult, error) {
start := time.Now()
if r == nil || r.pipeline == nil || adapter == nil {
return nil, core.ErrNilDependency
}
action = normalizeAction(action, "add")
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
docs, err := adapter.BuildIngestDocuments(observeCtx, input)
if err != nil {
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelError,
Component: "runtime",
Operation: "ingest",
Fields: map[string]any{
"status": "failed",
"latency_ms": time.Since(start).Milliseconds(),
"phase": "build_documents",
"error": err,
"error_code": core.ClassifyErrorCode(err),
"input_count": estimateInputCount(input),
},
})
return nil, err
}
docIDs := make([]string, 0, len(docs))
for _, doc := range docs {
docIDs = append(docIDs, doc.ID)
}
result, err := r.pipeline.IngestDocuments(observeCtx, adapter.Name(), docs, core.IngestOption{
Chunk: core.ChunkOption{
ChunkSize: r.cfg.ChunkSize,
ChunkOverlap: r.cfg.ChunkOverlap,
},
Action: action,
})
if err != nil {
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelError,
Component: "runtime",
Operation: "ingest",
Fields: map[string]any{
"status": "failed",
"latency_ms": time.Since(start).Milliseconds(),
"document_count": len(docs),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelInfo,
Component: "runtime",
Operation: "ingest",
Fields: map[string]any{
"status": "success",
"latency_ms": time.Since(start).Milliseconds(),
"document_count": result.DocumentCount,
"chunk_count": result.ChunkCount,
},
})
return &IngestResult{
DocumentCount: result.DocumentCount,
ChunkCount: result.ChunkCount,
DocumentIDs: docIDs,
}, nil
}
func (r *runtime) retrieveWithCorpus(
ctx context.Context,
traceID string,
corpusName string,
adapter core.CorpusAdapter,
req core.RetrieveRequest,
) (*RetrieveResult, error) {
start := time.Now()
if r == nil || r.pipeline == nil || adapter == nil {
return nil, core.ErrNilDependency
}
action := normalizeAction(req.Action, "search")
req.Action = action
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
timeoutCtx := observeCtx
cancel := func() {}
if r.cfg.RetrieveTimeoutMS > 0 {
timeoutCtx, cancel = context.WithTimeout(observeCtx, time.Duration(r.cfg.RetrieveTimeoutMS)*time.Millisecond)
}
defer cancel()
result, err := r.pipeline.Retrieve(timeoutCtx, adapter, req)
if err != nil {
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelError,
Component: "runtime",
Operation: "retrieve",
Fields: map[string]any{
"status": "failed",
"latency_ms": time.Since(start).Milliseconds(),
"query_len": len(strings.TrimSpace(req.Query)),
"top_k": req.TopK,
"threshold": req.Threshold,
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
items := make([]RetrieveHit, 0, len(result.Items))
for _, item := range result.Items {
items = append(items, RetrieveHit{
ChunkID: item.ChunkID,
DocumentID: item.DocumentID,
Text: item.Text,
Score: item.Score,
Metadata: cloneMap(item.Metadata),
})
}
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelInfo,
Component: "runtime",
Operation: "retrieve",
Fields: map[string]any{
"status": "success",
"latency_ms": time.Since(start).Milliseconds(),
"query_len": len(strings.TrimSpace(req.Query)),
"top_k": req.TopK,
"threshold": req.Threshold,
"raw_count": result.RawCount,
"hit_count": len(result.Items),
"fallback_used": result.FallbackUsed,
"fallback_reason": result.FallbackReason,
},
})
return &RetrieveResult{
Items: items,
RawCount: result.RawCount,
FallbackUsed: result.FallbackUsed,
FallbackReason: result.FallbackReason,
}, nil
}
func (r *runtime) observe(ctx context.Context, event ObserveEvent) {
if r == nil || r.observer == nil {
return
}
r.observer.Observe(ctx, event)
}
func newObserveContext(ctx context.Context, traceID string, corpusName string, action string) context.Context {
fields := map[string]any{
"corpus": corpusName,
"action": action,
}
if traceID = strings.TrimSpace(traceID); traceID != "" {
fields["trace_id"] = traceID
}
return core.WithObserveFields(ctx, fields)
}
func estimateInputCount(input any) int {
switch value := input.(type) {
case []corpus.MemoryIngestItem:
return len(value)
case []corpus.WebIngestItem:
return len(value)
default:
return 0
}
}
func normalizeAction(action string, fallback string) string {
action = strings.TrimSpace(action)
if action == "" {
return fallback
}
return action
}
func normalizeTopK(topK int, fallback int) int {
if topK > 0 {
return topK
}
if fallback > 0 {
return fallback
}
return 8
}
func normalizeThreshold(threshold float64, fallback float64) float64 {
if threshold >= 0 {
return threshold
}
if fallback >= 0 {
return fallback
}
return 0
}
func cloneMap(src map[string]any) map[string]any {
if len(src) == 0 {
return map[string]any{}
}
dst := make(map[string]any, len(src))
for key, value := range src {
dst[key] = value
}
return dst
}
func asString(v any) string {
if v == nil {
return ""
}
return strings.TrimSpace(fmt.Sprintf("%v", v))
}

View File

@@ -0,0 +1,117 @@
package rag
import (
"context"
"time"
)
// Runtime 是 RAG Infra 对业务侧暴露的唯一稳定方法面。
//
// 职责边界:
// 1. 负责承接 memory/web 两类语料的统一入库与检索入口;
// 2. 负责屏蔽底层 Pipeline / Store / Embedder / Reranker 的装配细节;
// 3. 不负责 provider 搜索、HTML 抓取、prompt 注入等业务语义。
type Runtime interface {
IngestMemory(ctx context.Context, req MemoryIngestRequest) (*IngestResult, error)
RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (*RetrieveResult, error)
IngestWeb(ctx context.Context, req WebIngestRequest) (*IngestResult, error)
RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (*RetrieveResult, error)
}
// IngestResult 描述一次统一入库执行摘要。
type IngestResult struct {
DocumentCount int
ChunkCount int
DocumentIDs []string
}
// RetrieveHit 是对业务侧暴露的统一命中项。
type RetrieveHit struct {
ChunkID string
DocumentID string
Text string
Score float64
Metadata map[string]any
}
// RetrieveResult 描述一次检索执行摘要。
type RetrieveResult struct {
Items []RetrieveHit
RawCount int
FallbackUsed bool
FallbackReason string
}
// MemoryIngestItem 是 memory 语料入库项。
type MemoryIngestItem struct {
MemoryID int64
UserID int
ConversationID string
AssistantID string
RunID string
MemoryType string
Title string
Content string
Confidence float64
Importance float64
SensitivityLevel int
IsExplicit bool
Status string
TTLAt *time.Time
CreatedAt *time.Time
}
// MemoryIngestRequest 描述一次记忆向量入库请求。
type MemoryIngestRequest struct {
TraceID string
Action string
Items []MemoryIngestItem
}
// MemoryRetrieveRequest 描述一次记忆检索请求。
type MemoryRetrieveRequest struct {
TraceID string
Query string
TopK int
Threshold float64
Action string
UserID int
ConversationID string
AssistantID string
RunID string
MemoryTypes []string
}
// WebIngestItem 是网页语料入库项。
type WebIngestItem struct {
URL string
Title string
Content string
Snippet string
Domain string
QueryID string
SessionID string
PublishedAt *time.Time
FetchedAt *time.Time
SourceRank int
}
// WebIngestRequest 描述一次网页语料入库请求。
type WebIngestRequest struct {
TraceID string
Action string
Items []WebIngestItem
}
// WebRetrieveRequest 描述一次网页检索请求。
type WebRetrieveRequest struct {
TraceID string
Query string
TopK int
Threshold float64
Action string
QueryID string
SessionID string
Domain string
}

View File

@@ -1,35 +1,894 @@
package store
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
)
// MilvusStore 是 Milvus 连接器占位实现
// MilvusConfig 描述 Milvus REST 存储配置
type MilvusConfig struct {
// Address 应指向 Milvus REST 入口。
// 当前项目联调验证使用 195309091 仅用于 health/metrics不承载本文实现所走的 REST API。
Address string
Token string
DBName string
CollectionName string
RequestTimeoutMS int
Dimension int
MetricType string
Logger *log.Logger
Observer core.Observer
}
// MilvusStore 是基于 Milvus REST API 的向量存储实现。
//
// 说明:
// 1. 本轮先保留接口结构,便于后续平滑替换 InMemoryStore
// 2. 真实接入时需补充连接池、集合初始化、元数据过滤与错误转换。
type MilvusStore struct{}
func NewMilvusStore() *MilvusStore {
return &MilvusStore{}
// 设计说明:
// 1. 本实现优先保证“项目内可接入、可管理、可灰度”,不强依赖额外 SDK
// 2. 通过固定字段 + metadata JSON 的方式兼顾过滤能力与元数据完整性;
// 3. collection 在首次写入时自动创建,避免启动期额外初始化脚本。
type MilvusStore struct {
cfg MilvusConfig
client *http.Client
observer core.Observer
mu sync.Mutex
ensured bool
}
func (s *MilvusStore) Upsert(_ context.Context, _ []core.VectorRow) error {
return errors.New("milvus store is not implemented yet")
const (
milvusPrimaryField = "id"
milvusVectorField = "vector"
milvusTextField = "text"
milvusMetadataField = "metadata"
milvusCorpusField = "corpus"
milvusDocumentField = "document_id"
milvusUserIDField = "user_id"
milvusAssistantField = "assistant_id"
milvusConvField = "conversation_id"
milvusRunField = "run_id"
milvusMemoryType = "memory_type"
milvusQueryIDField = "query_id"
milvusSessionField = "session_id"
milvusDomainField = "domain"
milvusChunkOrder = "chunk_order"
milvusUpdatedAtField = "updated_at"
)
var milvusFilterFieldMap = map[string]string{
"corpus": milvusCorpusField,
"document_id": milvusDocumentField,
"user_id": milvusUserIDField,
"assistant_id": milvusAssistantField,
"conversation_id": milvusConvField,
"run_id": milvusRunField,
"memory_type": milvusMemoryType,
"query_id": milvusQueryIDField,
"session_id": milvusSessionField,
"domain": milvusDomainField,
"chunk_order": milvusChunkOrder,
}
func (s *MilvusStore) Search(_ context.Context, _ core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
return nil, errors.New("milvus store is not implemented yet")
func NewMilvusStore(cfg MilvusConfig) (*MilvusStore, error) {
cfg.Address = strings.TrimRight(strings.TrimSpace(cfg.Address), "/")
if cfg.Address == "" {
return nil, errors.New("milvus address is empty")
}
if cfg.CollectionName == "" {
cfg.CollectionName = "smartflow_rag_chunks"
}
if cfg.MetricType == "" {
cfg.MetricType = "COSINE"
}
if cfg.RequestTimeoutMS <= 0 {
cfg.RequestTimeoutMS = 1500
}
if cfg.Logger == nil {
cfg.Logger = log.Default()
}
if cfg.Observer == nil {
cfg.Observer = core.NewLoggerObserver(cfg.Logger)
}
return &MilvusStore{
cfg: cfg,
client: &http.Client{Timeout: time.Duration(cfg.RequestTimeoutMS) * time.Millisecond},
observer: cfg.Observer,
}, nil
}
func (s *MilvusStore) Delete(_ context.Context, _ []string) error {
return errors.New("milvus store is not implemented yet")
func (s *MilvusStore) Upsert(ctx context.Context, rows []core.VectorRow) error {
start := time.Now()
if len(rows) == 0 {
return nil
}
if err := s.ensureCollection(ctx, len(rows[0].Vector)); err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "upsert",
Fields: map[string]any{
"status": "failed",
"row_count": len(rows),
"vector_dim": len(rows[0].Vector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return err
}
data := make([]map[string]any, 0, len(rows))
for _, row := range rows {
item := mapRowToMilvusEntity(row)
data = append(data, item)
}
_, err := s.postJSON(ctx, "/v2/vectordb/entities/upsert", map[string]any{
"collectionName": s.cfg.CollectionName,
"data": data,
"dbName": blankToNil(s.cfg.DBName),
})
if err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "upsert",
Fields: map[string]any{
"status": "failed",
"row_count": len(rows),
"vector_dim": len(rows[0].Vector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return err
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "upsert",
Fields: map[string]any{
"status": "success",
"row_count": len(rows),
"vector_dim": len(rows[0].Vector),
"latency_ms": time.Since(start).Milliseconds(),
},
})
return err
}
func (s *MilvusStore) Get(_ context.Context, _ []string) ([]core.VectorRow, error) {
return nil, errors.New("milvus store is not implemented yet")
func (s *MilvusStore) Search(ctx context.Context, req core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
start := time.Now()
if len(req.QueryVector) == 0 {
return nil, nil
}
if err := s.ensureCollection(ctx, len(req.QueryVector)); err != nil {
if isMilvusCollectionMissing(err) {
return nil, nil
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "failed",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
filterExpr, err := buildMilvusFilter(req.Filter)
if err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "failed",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
body := map[string]any{
"collectionName": s.cfg.CollectionName,
"data": [][]float32{req.QueryVector},
"annsField": milvusVectorField,
"limit": normalizeMilvusTopK(req.TopK),
"outputFields": milvusOutputFields(false),
}
if filterExpr != "" {
body["filter"] = filterExpr
}
if s.cfg.DBName != "" {
body["dbName"] = s.cfg.DBName
}
respBody, err := s.postJSON(ctx, "/v2/vectordb/entities/search", body)
if err != nil {
if isMilvusCollectionMissing(err) {
return nil, nil
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "failed",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
var resp milvusSearchResponse
if err = json.Unmarshal(respBody, &resp); err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "failed",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
if resp.Code != 0 && resp.Code != 200 {
err = fmt.Errorf("milvus search failed: code=%d message=%s", resp.Code, resp.Message)
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "failed",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
result := make([]core.ScoredVectorRow, 0, len(resp.Data))
for _, item := range resp.Data {
row, score := item.toVectorRow()
result = append(result, core.ScoredVectorRow{
Row: row,
Score: score,
})
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "success",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"result_count": len(result),
"latency_ms": time.Since(start).Milliseconds(),
},
})
return result, nil
}
func (s *MilvusStore) Delete(ctx context.Context, ids []string) error {
start := time.Now()
if len(ids) == 0 {
return nil
}
filter := fmt.Sprintf(`%s in [%s]`, milvusPrimaryField, joinQuotedStrings(ids))
_, err := s.postJSON(ctx, "/v2/vectordb/entities/delete", map[string]any{
"collectionName": s.cfg.CollectionName,
"filter": filter,
"dbName": blankToNil(s.cfg.DBName),
})
if isMilvusCollectionMissing(err) {
return nil
}
if err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "delete",
Fields: map[string]any{
"status": "failed",
"id_count": len(ids),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return err
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "delete",
Fields: map[string]any{
"status": "success",
"id_count": len(ids),
"latency_ms": time.Since(start).Milliseconds(),
},
})
return err
}
func (s *MilvusStore) Get(ctx context.Context, ids []string) ([]core.VectorRow, error) {
start := time.Now()
if len(ids) == 0 {
return nil, nil
}
respBody, err := s.postJSON(ctx, "/v2/vectordb/entities/get", map[string]any{
"collectionName": s.cfg.CollectionName,
"id": ids,
"outputFields": milvusOutputFields(true),
"dbName": blankToNil(s.cfg.DBName),
})
if err != nil {
if isMilvusCollectionMissing(err) {
return nil, nil
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "get",
Fields: map[string]any{
"status": "failed",
"id_count": len(ids),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
var resp milvusGetResponse
if err = json.Unmarshal(respBody, &resp); err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "get",
Fields: map[string]any{
"status": "failed",
"id_count": len(ids),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
if resp.Code != 0 && resp.Code != 200 {
err = fmt.Errorf("milvus get failed: code=%d message=%s", resp.Code, resp.Message)
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "get",
Fields: map[string]any{
"status": "failed",
"id_count": len(ids),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
rows := make([]core.VectorRow, 0, len(resp.Data))
for _, item := range resp.Data {
rows = append(rows, mapMilvusRow(item, true))
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "get",
Fields: map[string]any{
"status": "success",
"id_count": len(ids),
"row_count": len(rows),
"latency_ms": time.Since(start).Milliseconds(),
},
})
return rows, nil
}
func (s *MilvusStore) ensureCollection(ctx context.Context, dimension int) error {
start := time.Now()
if dimension <= 0 {
dimension = s.cfg.Dimension
}
if dimension <= 0 {
return errors.New("milvus vector dimension is invalid")
}
s.mu.Lock()
defer s.mu.Unlock()
if s.ensured {
return nil
}
payload := map[string]any{
"collectionName": s.cfg.CollectionName,
"schema": map[string]any{
"autoId": false,
"enabledDynamicField": false,
"fields": []map[string]any{
buildVarcharField(milvusPrimaryField, true, 256),
buildVectorField(milvusVectorField, dimension),
buildVarcharField(milvusTextField, false, 65535),
{"fieldName": milvusMetadataField, "dataType": "JSON"},
buildVarcharField(milvusCorpusField, false, 64),
buildVarcharField(milvusDocumentField, false, 256),
{"fieldName": milvusUserIDField, "dataType": "Int64"},
buildVarcharField(milvusAssistantField, false, 128),
buildVarcharField(milvusConvField, false, 128),
buildVarcharField(milvusRunField, false, 128),
buildVarcharField(milvusMemoryType, false, 64),
buildVarcharField(milvusQueryIDField, false, 128),
buildVarcharField(milvusSessionField, false, 128),
buildVarcharField(milvusDomainField, false, 128),
{"fieldName": milvusChunkOrder, "dataType": "Int64"},
{"fieldName": milvusUpdatedAtField, "dataType": "Int64"},
},
},
"indexParams": []map[string]any{
{
"fieldName": milvusVectorField,
"indexName": milvusVectorField,
"metricType": s.cfg.MetricType,
"indexType": "AUTOINDEX",
},
},
}
if s.cfg.DBName != "" {
payload["dbName"] = s.cfg.DBName
}
_, err := s.postJSON(ctx, "/v2/vectordb/collections/create", payload)
if err != nil {
if isMilvusAlreadyExists(err) {
s.ensured = true
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "ensure_collection",
Fields: map[string]any{
"status": "already_exists",
"vector_dim": dimension,
"metric_type": s.cfg.MetricType,
"latency_ms": time.Since(start).Milliseconds(),
},
})
return nil
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "ensure_collection",
Fields: map[string]any{
"status": "failed",
"vector_dim": dimension,
"metric_type": s.cfg.MetricType,
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return err
}
s.ensured = true
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "ensure_collection",
Fields: map[string]any{
"status": "created",
"vector_dim": dimension,
"metric_type": s.cfg.MetricType,
"latency_ms": time.Since(start).Milliseconds(),
},
})
return nil
}
func (s *MilvusStore) postJSON(ctx context.Context, path string, payload map[string]any) ([]byte, error) {
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.cfg.Address+path, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if token := strings.TrimSpace(s.cfg.Token); token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
resp, err := s.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, readErr
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("milvus http failed: status=%d body=%s", resp.StatusCode, string(respBody))
}
var basic milvusBasicResponse
if jsonErr := json.Unmarshal(respBody, &basic); jsonErr == nil {
if basic.Code != 0 && basic.Code != 200 {
return nil, fmt.Errorf("milvus api failed: code=%d message=%s", basic.Code, basic.Message)
}
}
return respBody, nil
}
func (s *MilvusStore) observe(ctx context.Context, event core.ObserveEvent) {
if s == nil || s.observer == nil {
return
}
fields := cloneMap(event.Fields)
fields["store"] = "milvus"
fields["collection"] = s.cfg.CollectionName
if strings.TrimSpace(s.cfg.DBName) != "" {
fields["db_name"] = s.cfg.DBName
}
s.observer.Observe(ctx, core.ObserveEvent{
Level: event.Level,
Component: event.Component,
Operation: event.Operation,
Fields: fields,
})
}
func mapRowToMilvusEntity(row core.VectorRow) map[string]any {
metadata := cloneMap(row.Metadata)
entity := map[string]any{
milvusPrimaryField: row.ID,
milvusVectorField: row.Vector,
milvusTextField: row.Text,
milvusMetadataField: metadata,
milvusCorpusField: asString(metadata["corpus"]),
milvusDocumentField: asString(metadata["document_id"]),
milvusUpdatedAtField: func() int64 {
if row.UpdatedAt.IsZero() {
return time.Now().UnixMilli()
}
return row.UpdatedAt.UnixMilli()
}(),
}
assignMilvusScalar(entity, milvusUserIDField, metadata["user_id"])
assignMilvusScalar(entity, milvusAssistantField, metadata["assistant_id"])
assignMilvusScalar(entity, milvusConvField, metadata["conversation_id"])
assignMilvusScalar(entity, milvusRunField, metadata["run_id"])
assignMilvusScalar(entity, milvusMemoryType, metadata["memory_type"])
assignMilvusScalar(entity, milvusQueryIDField, metadata["query_id"])
assignMilvusScalar(entity, milvusSessionField, metadata["session_id"])
assignMilvusScalar(entity, milvusDomainField, metadata["domain"])
assignMilvusScalar(entity, milvusChunkOrder, metadata["chunk_order"])
return entity
}
func assignMilvusScalar(target map[string]any, field string, value any) {
if value == nil {
return
}
switch field {
case milvusUserIDField, milvusChunkOrder:
if parsed, ok := toInt64(value); ok {
target[field] = parsed
}
default:
if text := asString(value); text != "" {
target[field] = text
}
}
}
func buildMilvusFilter(filter map[string]any) (string, error) {
if len(filter) == 0 {
return "", nil
}
parts := make([]string, 0, len(filter))
for key, value := range filter {
field, ok := milvusFilterFieldMap[key]
if !ok {
return "", fmt.Errorf("unsupported milvus filter key: %s", key)
}
switch field {
case milvusUserIDField, milvusChunkOrder:
parsed, parseOK := toInt64(value)
if !parseOK {
return "", fmt.Errorf("milvus filter key=%s expects integer", key)
}
parts = append(parts, fmt.Sprintf("%s == %d", field, parsed))
default:
text := escapeMilvusString(asString(value))
parts = append(parts, fmt.Sprintf(`%s == "%s"`, field, text))
}
}
return strings.Join(parts, " and "), nil
}
func buildVarcharField(name string, isPrimary bool, maxLength int) map[string]any {
field := map[string]any{
"fieldName": name,
"dataType": "VarChar",
"elementTypeParams": map[string]any{"max_length": maxLength},
}
if isPrimary {
field["isPrimary"] = true
}
return field
}
func buildVectorField(name string, dimension int) map[string]any {
return map[string]any{
"fieldName": name,
"dataType": "FloatVector",
"elementTypeParams": map[string]any{"dim": dimension},
}
}
func milvusOutputFields(includeVector bool) []string {
fields := []string{
milvusTextField,
milvusMetadataField,
milvusCorpusField,
milvusDocumentField,
milvusUserIDField,
milvusAssistantField,
milvusConvField,
milvusRunField,
milvusMemoryType,
milvusQueryIDField,
milvusSessionField,
milvusDomainField,
milvusChunkOrder,
milvusUpdatedAtField,
}
if includeVector {
fields = append(fields, milvusVectorField)
}
return fields
}
func normalizeMilvusTopK(topK int) int {
if topK <= 0 {
return 8
}
return topK
}
func blankToNil(v string) any {
v = strings.TrimSpace(v)
if v == "" {
return nil
}
return v
}
func escapeMilvusString(v string) string {
v = strings.ReplaceAll(v, `\`, `\\`)
return strings.ReplaceAll(v, `"`, `\"`)
}
func joinQuotedStrings(values []string) string {
parts := make([]string, 0, len(values))
for _, value := range values {
parts = append(parts, fmt.Sprintf(`"%s"`, escapeMilvusString(value)))
}
return strings.Join(parts, ",")
}
func cloneMap(src map[string]any) map[string]any {
if len(src) == 0 {
return map[string]any{}
}
dst := make(map[string]any, len(src))
for key, value := range src {
dst[key] = value
}
return dst
}
func asString(v any) string {
if v == nil {
return ""
}
return strings.TrimSpace(fmt.Sprintf("%v", v))
}
func toInt64(v any) (int64, bool) {
switch value := v.(type) {
case int:
return int64(value), true
case int32:
return int64(value), true
case int64:
return value, true
case float64:
return int64(value), true
case json.Number:
parsed, err := value.Int64()
return parsed, err == nil
case string:
parsed, err := strconv.ParseInt(strings.TrimSpace(value), 10, 64)
return parsed, err == nil
default:
return 0, false
}
}
func isMilvusAlreadyExists(err error) bool {
if err == nil {
return false
}
text := strings.ToLower(err.Error())
return strings.Contains(text, "already exist") || strings.Contains(text, "already exists")
}
func isMilvusCollectionMissing(err error) bool {
if err == nil {
return false
}
text := strings.ToLower(err.Error())
return strings.Contains(text, "can't find collection") || strings.Contains(text, "collection not found")
}
type milvusBasicResponse struct {
Code int `json:"code"`
Message string `json:"message"`
}
type milvusSearchResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data []milvusSearchItem `json:"data"`
}
type milvusSearchItem map[string]any
func (m milvusSearchItem) toVectorRow() (core.VectorRow, float64) {
row := mapMilvusRow(map[string]any(m), false)
score := 0.0
if value, ok := m["distance"].(float64); ok {
score = value
}
return row, score
}
type milvusGetResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data []map[string]any `json:"data"`
}
func mapMilvusRow(raw map[string]any, includeVector bool) core.VectorRow {
metadata := cloneMap(readMetadataMap(raw[milvusMetadataField]))
assignMetadataIfPresent(metadata, "corpus", raw[milvusCorpusField])
assignMetadataIfPresent(metadata, "document_id", raw[milvusDocumentField])
assignMetadataIfPresent(metadata, "user_id", raw[milvusUserIDField])
assignMetadataIfPresent(metadata, "assistant_id", raw[milvusAssistantField])
assignMetadataIfPresent(metadata, "conversation_id", raw[milvusConvField])
assignMetadataIfPresent(metadata, "run_id", raw[milvusRunField])
assignMetadataIfPresent(metadata, "memory_type", raw[milvusMemoryType])
assignMetadataIfPresent(metadata, "query_id", raw[milvusQueryIDField])
assignMetadataIfPresent(metadata, "session_id", raw[milvusSessionField])
assignMetadataIfPresent(metadata, "domain", raw[milvusDomainField])
assignMetadataIfPresent(metadata, "chunk_order", raw[milvusChunkOrder])
row := core.VectorRow{
ID: asString(raw[milvusPrimaryField]),
Text: asString(raw[milvusTextField]),
Metadata: metadata,
}
if row.ID == "" {
row.ID = asString(raw["id"])
}
if includeVector {
row.Vector = readFloat32Vector(raw[milvusVectorField])
}
return row
}
func readMetadataMap(value any) map[string]any {
switch data := value.(type) {
case map[string]any:
return data
default:
return map[string]any{}
}
}
func readFloat32Vector(value any) []float32 {
switch vector := value.(type) {
case []float32:
return vector
case []any:
result := make([]float32, 0, len(vector))
for _, item := range vector {
switch number := item.(type) {
case float64:
result = append(result, float32(number))
case float32:
result = append(result, number)
}
}
return result
default:
return nil
}
}
func assignMetadataIfPresent(target map[string]any, key string, value any) {
if value == nil {
return
}
switch typed := value.(type) {
case string:
if strings.TrimSpace(typed) == "" {
return
}
target[key] = strings.TrimSpace(typed)
default:
target[key] = typed
}
}

View File

@@ -5,4 +5,5 @@ import "github.com/LoveLosita/smartflow/backend/infra/rag/core"
// EnsureCompile 用于静态校验实现是否满足接口。
func EnsureCompile() {
var _ core.VectorStore = (*InMemoryVectorStore)(nil)
var _ core.VectorStore = (*MilvusStore)(nil)
}