Version: 0.9.14.dev.260410
后端:
1. LLM 客户端从 newAgent/llm 提升为 infra/llm 基础设施层
- 删除 backend/newAgent/llm/(ark.go / ark_adapter.go / client.go / json.go)
- 等价迁移至 backend/infra/llm/,所有 newAgent node 与 service 统一改引用 infrallm
- 消除 newAgent 对模型客户端的私有依赖,为 memory / websearch 等多模块复用铺路
2. RAG 基础设施完成可运行态接入(factory / runtime / observer / service 四层成型)
- 新建 backend/infra/rag/factory.go / runtime.go / observe.go / observer.go /
service.go:工厂创建、运行时生命周期、轻量观测接口、检索服务门面
- 更新 infra/rag/config/config.go:补齐 Milvus / Embed / Reranker 全部配置项与默认值
- 更新 infra/rag/embed/eino_embedder.go:增强 Eino embedding 适配,支持 BaseURL / APIKey 环境变量 / 超时 /
维度等参数
- 更新 infra/rag/store/milvus_store.go:完整实现 Milvus 向量存储(建集合 / 建 Index / Upsert / Search /
Delete),支持 COSINE / L2 / IP 度量
- 更新 infra/rag/core/pipeline.go:适配 Runtime 接口,Pipeline 由 factory 注入而非手动拼装
- 更新 infra/rag/corpus/memory_corpus.go / vector_store.go:对接 Memory 模块数据源与 Store 接口扩展
3. Memory 模块从 Day1 骨架升级为 Day2 完整可运行态
- 新建 memory/module.go:统一门面 Module,对外封装 EnqueueExtract / ReadService / ManageService / WithTx /
StartWorker,启动层只依赖这一个入口
- 新建 memory/orchestrator/llm_write_orchestrator.go:LLM 驱动的记忆抽取编排器,替代原 mock 抽取
- 新建 memory/service/read_service.go:按用户开关过滤 + 轻量重排 + 访问时间刷新的读取链路
- 新建 memory/service/manage_service.go:记忆管理面能力(列出 / 软删除 / 开关读写),删除同步写审计日志
- 新建 memory/service/common.go:服务层公共工具
- 新建 memory/worker/loop.go:后台轮询循环 RunPollingLoop,定时抢占 pending 任务并推进
- 新建 memory/utils/audit.go / settings.go:审计日志构造、用户设置过滤等纯函数
- 更新 memory/model/item.go / job.go / settings.go / config.go / status.go:补齐 DTO 字段与状态常量
- 更新 memory/repo/item_repo.go / job_repo.go / audit_repo.go / settings_repo.go:补齐 CRUD 与查询能力
- 更新 memory/worker/runner.go:Runner 对接 Module 与 LLM 抽取器,任务状态机完整化
- 更新 memory/README.md:同步模块现状说明
4. newAgent 接入 Memory 读取注入与工具注册依赖预埋
- 新建 service/agentsvc/agent_memory.go:定义 MemoryReader 接口 + injectMemoryContext,在 graph
执行前统一补充记忆上下文
- 更新 service/agentsvc/agent.go:新增 memoryReader 字段与 SetMemoryReader 方法
- 更新 service/agentsvc/agent_newagent.go:调用 injectMemoryContext 注入 pinned block,检索失败仅降级不阻断主链路
- 更新 newAgent/tools/registry.go:新增 DefaultRegistryDeps(含 RAGRuntime),工具注册表支持依赖注入
5. 启动流程与事件处理器接线更新
- 更新 cmd/start.go:初始化 RAG Runtime → Memory Module → 注册事件处理器 → 启动 Worker 后台轮询
- 更新 service/events/memory_extract_requested.go:改用 memory.Module.WithTx(tx) 统一门面,事件处理器不再直接依赖
repo/service 内部包
6. 缓存插件与配置同步
- 更新 middleware/cache_deleter.go:静默忽略 MemoryJob / MemoryItem / MemoryAuditLog / MemoryUserSetting
等新模型,避免日志刷屏;清理冗余注释
- 更新 config.example.yaml:补齐 rag / memory / websearch 配置段及默认值
- 更新 go.mod / go.sum:新增 eino-ext/openai / json-patch / go-openai 依赖
前端:无 仓库:无
This commit is contained in:
87
backend/infra/llm/ark.go
Normal file
87
backend/infra/llm/ark.go
Normal file
@@ -0,0 +1,87 @@
|
||||
// 过渡期统一 Ark 调用封装。
|
||||
//
|
||||
// 这里保留 CallArkText / CallArkJSON,方便暂时还直接持有 *ark.ChatModel 的调用点
|
||||
// 逐步迁移到统一 Client。后续 memory 也可以直接复用这套中立层。
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
einoModel "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
// ArkCallOptions 是基于 ark.ChatModel 的通用调用选项。
|
||||
//
|
||||
// 设计目的:
|
||||
// 1. 先把 Ark 调用样板抽成公共层;
|
||||
// 2. 再由 WrapArkClient 提供统一 Client;
|
||||
// 3. 让上层尽量只关注业务 prompt 和结构化结果。
|
||||
type ArkCallOptions struct {
|
||||
Temperature float64
|
||||
MaxTokens int
|
||||
Thinking ThinkingMode
|
||||
}
|
||||
|
||||
// CallArkText 调用 ark 模型并返回纯文本。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责拼 system + user 两段消息;
|
||||
// 2. 负责统一配置 thinking / temperature / maxTokens;
|
||||
// 3. 负责拦截空响应;
|
||||
// 4. 不负责 JSON 解析,不负责业务字段校验。
|
||||
func CallArkText(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, options ArkCallOptions) (string, error) {
|
||||
if chatModel == nil {
|
||||
return "", errors.New("ark model is nil")
|
||||
}
|
||||
|
||||
messages := []*schema.Message{
|
||||
schema.SystemMessage(systemPrompt),
|
||||
schema.UserMessage(userPrompt),
|
||||
}
|
||||
resp, err := chatModel.Generate(ctx, messages, buildArkOptions(options)...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp == nil {
|
||||
return "", errors.New("模型返回为空")
|
||||
}
|
||||
|
||||
text := strings.TrimSpace(resp.Content)
|
||||
if text == "" {
|
||||
return "", errors.New("模型返回内容为空")
|
||||
}
|
||||
return text, nil
|
||||
}
|
||||
|
||||
// CallArkJSON 调用 ark 模型并直接解析 JSON。
|
||||
func CallArkJSON[T any](ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, options ArkCallOptions) (*T, string, error) {
|
||||
raw, err := CallArkText(ctx, chatModel, systemPrompt, userPrompt, options)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
parsed, err := ParseJSONObject[T](raw)
|
||||
if err != nil {
|
||||
return nil, raw, err
|
||||
}
|
||||
return parsed, raw, nil
|
||||
}
|
||||
|
||||
func buildArkOptions(options ArkCallOptions) []einoModel.Option {
|
||||
thinkingType := arkModel.ThinkingTypeDisabled
|
||||
if options.Thinking == ThinkingModeEnabled {
|
||||
thinkingType = arkModel.ThinkingTypeEnabled
|
||||
}
|
||||
opts := []einoModel.Option{
|
||||
ark.WithThinking(&arkModel.Thinking{Type: thinkingType}),
|
||||
einoModel.WithTemperature(float32(options.Temperature)),
|
||||
}
|
||||
if options.MaxTokens > 0 {
|
||||
opts = append(opts, einoModel.WithMaxTokens(options.MaxTokens))
|
||||
}
|
||||
return opts
|
||||
}
|
||||
111
backend/infra/llm/ark_adapter.go
Normal file
111
backend/infra/llm/ark_adapter.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
einoModel "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
// WrapArkClient 将 ark.ChatModel 适配为统一 Client。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. generateText:调用 ark.ChatModel.Generate(非流式),供 GenerateJSON 使用;
|
||||
// 2. streamText:调用 ark.ChatModel.Stream(流式),供需要流式输出的场景使用;
|
||||
// 3. 两者共用同一套 options 转换。
|
||||
func WrapArkClient(arkChatModel *ark.ChatModel) *Client {
|
||||
if arkChatModel == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 非流式文本生成,供 GenerateJSON / GenerateText 调用路径使用。
|
||||
generateFunc := func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error) {
|
||||
arkOpts := buildArkStreamOptions(options)
|
||||
msg, err := arkChatModel.Generate(ctx, messages, arkOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msg == nil {
|
||||
return nil, errors.New("ark model returned nil message")
|
||||
}
|
||||
return &TextResult{Text: msg.Content}, nil
|
||||
}
|
||||
|
||||
// 流式文本生成。
|
||||
streamFunc := func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error) {
|
||||
arkOpts := buildArkStreamOptions(options)
|
||||
reader, err := arkChatModel.Stream(ctx, messages, arkOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &arkStreamReaderAdapter{reader: reader}, nil
|
||||
}
|
||||
|
||||
return NewClient(generateFunc, streamFunc)
|
||||
}
|
||||
|
||||
// buildArkStreamOptions 将统一 GenerateOptions 转换为 ark 的流式调用选项。
|
||||
func buildArkStreamOptions(options GenerateOptions) []einoModel.Option {
|
||||
thinkingEnabled := options.Thinking == ThinkingModeEnabled
|
||||
|
||||
// Thinking
|
||||
thinkingType := arkModel.ThinkingTypeDisabled
|
||||
if thinkingEnabled {
|
||||
thinkingType = arkModel.ThinkingTypeEnabled
|
||||
}
|
||||
opts := []einoModel.Option{
|
||||
ark.WithThinking(&arkModel.Thinking{Type: thinkingType}),
|
||||
}
|
||||
|
||||
// Temperature:thinking 模型强制要求 temperature=1,否则 API 静默忽略 thinking。
|
||||
if thinkingEnabled {
|
||||
opts = append(opts, einoModel.WithTemperature(1.0))
|
||||
} else if options.Temperature > 0 {
|
||||
opts = append(opts, einoModel.WithTemperature(float32(options.Temperature)))
|
||||
}
|
||||
|
||||
// MaxTokens:thinking 模式下 thinking token 占用 max_tokens 预算,
|
||||
// 调用方设定的值仅代表"期望输出长度",实际预算需留出思考空间。
|
||||
// 最低保障 16000,避免思考链被截断导致输出为空或非 JSON。
|
||||
maxTokens := options.MaxTokens
|
||||
if thinkingEnabled {
|
||||
const minThinkingBudget = 16000
|
||||
if maxTokens < minThinkingBudget {
|
||||
maxTokens = minThinkingBudget
|
||||
}
|
||||
}
|
||||
if maxTokens > 0 {
|
||||
opts = append(opts, einoModel.WithMaxTokens(maxTokens))
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
// arkStreamReaderAdapter 适配 ark.ChatModel.Stream 返回的 reader。
|
||||
// ark.Stream 返回 schema.StreamReader[*schema.Message],其 Close() 方法无返回值
|
||||
// 而我们的 StreamReader 接口要求 Close() error
|
||||
type arkStreamReaderAdapter struct {
|
||||
reader *schema.StreamReader[*schema.Message]
|
||||
}
|
||||
|
||||
// Recv 转发到 ark reader 的 Recv 方法。
|
||||
func (r *arkStreamReaderAdapter) Recv() (*schema.Message, error) {
|
||||
if r == nil || r.reader == nil {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return r.reader.Recv()
|
||||
}
|
||||
|
||||
// Close 转发到 ark reader 的 Close 方法。
|
||||
// ark 的 Close() 无返回值,我们适配为返回 nil
|
||||
func (r *arkStreamReaderAdapter) Close() error {
|
||||
if r == nil || r.reader == nil {
|
||||
return nil
|
||||
}
|
||||
r.reader.Close()
|
||||
return nil
|
||||
}
|
||||
215
backend/infra/llm/client.go
Normal file
215
backend/infra/llm/client.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// ThinkingMode 描述本次模型调用对 thinking 的期望。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 这里只表达“调用方希望怎样配置推理模式”;
|
||||
// 2. 不直接绑定某个具体模型厂商的参数枚举;
|
||||
// 3. 真正如何把它翻译成 ark / OpenAI / 其他 provider 的 option,由后续适配层负责。
|
||||
type ThinkingMode string
|
||||
|
||||
const (
|
||||
ThinkingModeDefault ThinkingMode = "default"
|
||||
ThinkingModeEnabled ThinkingMode = "enabled"
|
||||
ThinkingModeDisabled ThinkingMode = "disabled"
|
||||
)
|
||||
|
||||
// GenerateOptions 是统一模型调用选项。
|
||||
//
|
||||
// 设计目的:
|
||||
// 1. 先把“每个 skill / worker 都会反复传的参数”收敛成一份结构;
|
||||
// 2. 让上层以后只表达“我要什么”,不再自己重复组织 option;
|
||||
// 3. 暂时不追求覆盖所有 provider 参数,先把最常用的几个公共位抽出来。
|
||||
type GenerateOptions struct {
|
||||
Temperature float64
|
||||
MaxTokens int
|
||||
Thinking ThinkingMode
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// TextResult 是统一文本生成结果。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. Text 保存模型最终返回的纯文本;
|
||||
// 2. Usage 保存本次调用的 token 使用量,供后续统一统计;
|
||||
// 3. 不负责 JSON 解析,不负责业务字段映射。
|
||||
type TextResult struct {
|
||||
Text string
|
||||
Usage *schema.TokenUsage
|
||||
}
|
||||
|
||||
// StreamReader 抽象了“可逐块 Recv 的流式返回器”。
|
||||
//
|
||||
// 之所以不直接依赖某个具体 SDK 的 reader 类型,是因为现在还处在骨架收敛阶段,
|
||||
// 后续接 ark、OpenAI 兼容层还是别的 provider,都可以往这个最小接口上适配。
|
||||
type StreamReader interface {
|
||||
Recv() (*schema.Message, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// TextGenerateFunc 是文本生成的统一适配函数签名。
|
||||
type TextGenerateFunc func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error)
|
||||
|
||||
// StreamGenerateFunc 是流式生成的统一适配函数签名。
|
||||
type StreamGenerateFunc func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error)
|
||||
|
||||
// Client 是统一模型客户端门面。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责把调用方的“模型调用意图”收敛到统一入口;
|
||||
// 2. 负责统一参数校验、空响应防御、GenerateJSON 复用;
|
||||
// 3. 不负责写 prompt,不负责业务 fallback,也不直接持有具体厂商 SDK 细节。
|
||||
type Client struct {
|
||||
generateText TextGenerateFunc
|
||||
streamText StreamGenerateFunc
|
||||
}
|
||||
|
||||
// NewClient 创建统一模型客户端。
|
||||
func NewClient(generateText TextGenerateFunc, streamText StreamGenerateFunc) *Client {
|
||||
return &Client{
|
||||
generateText: generateText,
|
||||
streamText: streamText,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateText 执行一次统一文本生成。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责做最小必要的入参校验;
|
||||
// 2. 负责统一拦截“模型空响应”这类公共问题;
|
||||
// 3. 不负责业务 prompt 拼接,也不负责把文本再映射成业务结构。
|
||||
func (c *Client) GenerateText(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error) {
|
||||
if c == nil || c.generateText == nil {
|
||||
return nil, errors.New("llm client is not ready")
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return nil, errors.New("llm messages is empty")
|
||||
}
|
||||
|
||||
result, err := c.generateText(ctx, messages, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result == nil {
|
||||
return nil, errors.New("llm result is nil")
|
||||
}
|
||||
if strings.TrimSpace(result.Text) == "" {
|
||||
return nil, errors.New("llm returned empty text")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateJSON 先走统一文本生成,再走统一 JSON 解析。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 把“Generate -> 提取 JSON -> 反序列化”这段公共链路收敛起来;
|
||||
// 2. 上层只关心业务结构,不需要重复实现解析样板;
|
||||
// 3. 返回 parsed + rawResult,方便打点与回退时保留原文。
|
||||
func GenerateJSON[T any](ctx context.Context, client *Client, messages []*schema.Message, options GenerateOptions) (*T, *TextResult, error) {
|
||||
result, err := client.GenerateText(ctx, messages, options)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
parsed, err := ParseJSONObject[T](result.Text)
|
||||
if err != nil {
|
||||
return nil, result, err
|
||||
}
|
||||
return parsed, result, nil
|
||||
}
|
||||
|
||||
// Stream 打开统一流式调用入口。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只负责把“流式生成能力”暴露给上层;
|
||||
// 2. 不负责 chunk 到 OpenAI 协议的转换,那部分应放在 stream/;
|
||||
// 3. 不负责累计全文,也不负责 token 统计落库。
|
||||
func (c *Client) Stream(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error) {
|
||||
if c == nil || c.streamText == nil {
|
||||
return nil, errors.New("llm stream client is not ready")
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return nil, errors.New("llm messages is empty")
|
||||
}
|
||||
return c.streamText(ctx, messages, options)
|
||||
}
|
||||
|
||||
// BuildSystemUserMessages 构造最常见的“system + history + user”消息列表。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 先把最稳定的消息编排方式沉淀下来,减少各业务域样板代码;
|
||||
// 2. 只做消息切片装配,不做 prompt 生成;
|
||||
// 3. 供 agent / memory 等多个能力域复用。
|
||||
func BuildSystemUserMessages(systemPrompt string, history []*schema.Message, userPrompt string) []*schema.Message {
|
||||
messages := make([]*schema.Message, 0, len(history)+2)
|
||||
if strings.TrimSpace(systemPrompt) != "" {
|
||||
messages = append(messages, schema.SystemMessage(systemPrompt))
|
||||
}
|
||||
if len(history) > 0 {
|
||||
messages = append(messages, history...)
|
||||
}
|
||||
if strings.TrimSpace(userPrompt) != "" {
|
||||
messages = append(messages, schema.UserMessage(userPrompt))
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
// CloneUsage 深拷贝 token usage,避免后续多处累加时共享同一指针。
|
||||
func CloneUsage(usage *schema.TokenUsage) *schema.TokenUsage {
|
||||
if usage == nil {
|
||||
return nil
|
||||
}
|
||||
copied := *usage
|
||||
return &copied
|
||||
}
|
||||
|
||||
// MergeUsage 合并两段 usage。
|
||||
//
|
||||
// 合并策略:
|
||||
// 1. 对“同一次调用不同流分片”的场景,取更大值作为最终值;
|
||||
// 2. 对“多次独立调用累计”的场景,应由上层显式做加法,而不是用这个函数;
|
||||
// 3. 该函数只适用于“同一次调用的分块 usage 收敛”。
|
||||
func MergeUsage(base *schema.TokenUsage, incoming *schema.TokenUsage) *schema.TokenUsage {
|
||||
if incoming == nil {
|
||||
return CloneUsage(base)
|
||||
}
|
||||
if base == nil {
|
||||
return CloneUsage(incoming)
|
||||
}
|
||||
|
||||
merged := *base
|
||||
if incoming.PromptTokens > merged.PromptTokens {
|
||||
merged.PromptTokens = incoming.PromptTokens
|
||||
}
|
||||
if incoming.CompletionTokens > merged.CompletionTokens {
|
||||
merged.CompletionTokens = incoming.CompletionTokens
|
||||
}
|
||||
if incoming.TotalTokens > merged.TotalTokens {
|
||||
merged.TotalTokens = incoming.TotalTokens
|
||||
}
|
||||
if incoming.PromptTokenDetails.CachedTokens > merged.PromptTokenDetails.CachedTokens {
|
||||
merged.PromptTokenDetails.CachedTokens = incoming.PromptTokenDetails.CachedTokens
|
||||
}
|
||||
if incoming.CompletionTokensDetails.ReasoningTokens > merged.CompletionTokensDetails.ReasoningTokens {
|
||||
merged.CompletionTokensDetails.ReasoningTokens = incoming.CompletionTokensDetails.ReasoningTokens
|
||||
}
|
||||
return &merged
|
||||
}
|
||||
|
||||
// FormatEmptyResponseError 统一生成“模型返回空结果”的错误文案。
|
||||
func FormatEmptyResponseError(scene string) error {
|
||||
scene = strings.TrimSpace(scene)
|
||||
if scene == "" {
|
||||
scene = "unknown"
|
||||
}
|
||||
return fmt.Errorf("模型在 %s 场景返回空结果", scene)
|
||||
}
|
||||
112
backend/infra/llm/json.go
Normal file
112
backend/infra/llm/json.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseJSONObject 解析模型返回中的 JSON 对象。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责处理“模型输出前后夹杂解释文字 / markdown 代码块”的常见情况;
|
||||
// 2. 负责提取最外层 JSON object 并反序列化为目标结构;
|
||||
// 3. 不负责业务字段合法性校验,应由上层调用方自行校验。
|
||||
func ParseJSONObject[T any](raw string) (*T, error) {
|
||||
clean := strings.TrimSpace(raw)
|
||||
if clean == "" {
|
||||
return nil, errors.New("模型返回为空,无法解析 JSON")
|
||||
}
|
||||
|
||||
objectText := ExtractJSONObject(clean)
|
||||
if objectText == "" {
|
||||
return nil, fmt.Errorf("模型返回中未找到 JSON 对象: %s", truncateForError(clean))
|
||||
}
|
||||
|
||||
var out T
|
||||
if err := json.Unmarshal([]byte(objectText), &out); err != nil {
|
||||
return nil, fmt.Errorf("JSON 解析失败: %w", err)
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
// ExtractJSONObject 从混合文本里提取第一个完整 JSON 对象。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. LLM 很容易输出“这里是结果:{...}”这种半结构化文本;
|
||||
// 2. 这里用括号计数而不是正则,避免嵌套对象一多就误截断;
|
||||
// 3. 目前只提取 object,不提取 array,因为当前契约基本都是对象。
|
||||
func ExtractJSONObject(text string) string {
|
||||
clean := trimMarkdownCodeFence(strings.TrimSpace(text))
|
||||
if clean == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
start := strings.Index(clean, "{")
|
||||
if start < 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
depth := 0
|
||||
inString := false
|
||||
escaped := false
|
||||
for idx := start; idx < len(clean); idx++ {
|
||||
ch := clean[idx]
|
||||
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' && inString {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
inString = !inString
|
||||
continue
|
||||
}
|
||||
if inString {
|
||||
continue
|
||||
}
|
||||
|
||||
switch ch {
|
||||
case '{':
|
||||
depth++
|
||||
case '}':
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return clean[start : idx+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func trimMarkdownCodeFence(text string) string {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if !strings.HasPrefix(trimmed, "```") {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
lines := strings.Split(trimmed, "\n")
|
||||
if len(lines) == 0 {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
// 1. 去掉首行 ```json / ```;
|
||||
// 2. 若末行是 ```,一并去掉;
|
||||
// 3. 中间正文保持原样,避免破坏 JSON 的换行结构。
|
||||
body := lines[1:]
|
||||
if len(body) > 0 && strings.TrimSpace(body[len(body)-1]) == "```" {
|
||||
body = body[:len(body)-1]
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(body, "\n"))
|
||||
}
|
||||
|
||||
func truncateForError(text string) string {
|
||||
if len(text) <= 160 {
|
||||
return text
|
||||
}
|
||||
return text[:160] + "..."
|
||||
}
|
||||
640
backend/infra/rag/HANDOFF_RAGInfra一步到位接入方案.md
Normal file
640
backend/infra/rag/HANDOFF_RAGInfra一步到位接入方案.md
Normal file
@@ -0,0 +1,640 @@
|
||||
# HANDOFF:RAG Infra 一步到位接入方案
|
||||
|
||||
## 1. 文档目的
|
||||
|
||||
本文用于把 `backend/infra/rag` 从“可运行骨架”推进到“可被业务正式接入的共享基础设施”。
|
||||
|
||||
本文重点回答 4 个问题:
|
||||
|
||||
1. 当前 `RAG Infra` 已经做到了什么,还缺什么。
|
||||
2. 什么样的状态,才算“合格、可接入、可灰度、可回滚”的 `RAG Infra`。
|
||||
3. 如何以“依赖注入 + 对外只暴露方法入口”的方式收口,避免业务侧直接依赖底层实现细节。
|
||||
4. 如何在不打断现有业务的前提下,把 `memory` 与 `websearch` 并行迁移到统一 `RAG Infra`。
|
||||
|
||||
---
|
||||
|
||||
## 2. 当前现状
|
||||
|
||||
## 2.1 已完成部分
|
||||
|
||||
当前 `backend/infra/rag` 已经具备共享骨架,主要包括:
|
||||
|
||||
1. 通用接口与类型:
|
||||
- `core/interfaces.go`
|
||||
- `core/types.go`
|
||||
- `core/errors.go`
|
||||
2. 通用编排器:
|
||||
- `core/pipeline.go`
|
||||
3. 默认切块器:
|
||||
- `chunk/text_chunker.go`
|
||||
4. 语料适配器:
|
||||
- `corpus/memory_corpus.go`
|
||||
- `corpus/web_corpus.go`
|
||||
5. 默认可运行实现:
|
||||
- `embed/mock_embedder.go`
|
||||
- `rerank/noop_reranker.go`
|
||||
- `store/inmemory_store.go`
|
||||
6. 配置骨架:
|
||||
- `config/config.go`
|
||||
|
||||
这说明项目已经完成了“共享 RAG Core 的第一阶段搭骨架”,不再是单纯的设计想法。
|
||||
|
||||
## 2.2 当前存在的问题
|
||||
|
||||
虽然骨架已经有了,但距离“可正式接入的 Infra”还差关键几步:
|
||||
|
||||
1. 运行时没有正式装配入口。
|
||||
- 当前仍主要依赖 `rag.NewDefaultPipeline()`。
|
||||
- 启动阶段没有统一按配置组装 `embedder / store / reranker / corpus runtime`。
|
||||
2. 真实底层实现还是占位。
|
||||
- `embed/eino_embedder.go` 未实现。
|
||||
- `rerank/eino_reranker.go` 未实现。
|
||||
- `store/milvus_store.go` 未实现。
|
||||
3. 配置虽有结构,但还未真正接入运行链路。
|
||||
- `rag/config/config.go` 定义了 `rag.*` 配置。
|
||||
- `backend/cmd/start.go` 尚未实例化并注入 `RAG Runtime`。
|
||||
4. 业务尚未真正切流。
|
||||
- `memory` 读取链路还没有正式走 `Pipeline.Retrieve`。
|
||||
- `websearch` 还没有通过 `WebCorpus + Pipeline` 形成正式 WebRAG 路径。
|
||||
5. 工程化能力不完整。
|
||||
- 缺统一 timeout。
|
||||
- 缺统一日志字段。
|
||||
- 缺基础指标。
|
||||
- 缺单元测试与集成测试。
|
||||
6. 还存在潜在重复实现风险。
|
||||
- `retrieve/vector_retriever.go` 与 `core/pipeline.go` 都承载部分检索逻辑。
|
||||
- 若后续两套逻辑并存,容易出现行为漂移与维护成本上升。
|
||||
|
||||
## 2.3 当前状态结论
|
||||
|
||||
当前 `RAG Infra` 的状态,更准确地说是:
|
||||
|
||||
1. 已经完成“共享骨架搭建”。
|
||||
2. 还没有完成“统一装配、真实实现、正式接入、工程化收口”。
|
||||
3. 目前适合继续扩展,但还不适合直接作为长期稳定的业务依赖面。
|
||||
|
||||
---
|
||||
|
||||
## 3. 目标定义:什么叫“合格的 RAG Infra”
|
||||
|
||||
本轮改造完成后,`backend/infra/rag` 应满足以下标准:
|
||||
|
||||
1. 启动时可统一构造并注入,不再靠业务模块自行拼装底层依赖。
|
||||
2. 对外只暴露稳定方法入口,不暴露底层 `Pipeline / Store / Embedder / Reranker` 的装配细节。
|
||||
3. 支持按配置切换实现:
|
||||
- `inmemory / milvus`
|
||||
- `mock / eino`
|
||||
- `noop / eino`
|
||||
4. 支持 `memory` 与 `websearch` 两类语料复用同一套 `chunk / embed / retrieve / rerank / fallback` 流程。
|
||||
5. 支持灰度开关与回滚,不要求业务“一次性硬切流”。
|
||||
6. 支持基础观测:
|
||||
- 延迟
|
||||
- 命中数
|
||||
- fallback 原因
|
||||
- 错误码
|
||||
7. 具备最小可依赖测试集,保证公共层改动不会悄悄破坏业务。
|
||||
|
||||
---
|
||||
|
||||
## 4. 核心改造原则
|
||||
|
||||
## 4.1 原则一:依赖注入统一由 Infra 自己负责
|
||||
|
||||
`RAG Infra` 必须自己承接“底层实现装配”,业务侧不应感知:
|
||||
|
||||
1. 当前用的是 `Milvus` 还是 `InMemoryStore`。
|
||||
2. 当前用的是 `MockEmbedder` 还是 `EinoEmbedder`。
|
||||
3. 当前是否开启 `Reranker`。
|
||||
4. 当前超时、阈值、切块参数是多少。
|
||||
|
||||
业务只拿到一个已经注入好的 `RAG Runtime` 或 `RAG Service`,直接调用方法。
|
||||
|
||||
## 4.2 原则二:对外只暴露方法,不暴露底层零件
|
||||
|
||||
业务层不应直接依赖这些细粒度对象:
|
||||
|
||||
1. `core.Pipeline`
|
||||
2. `core.VectorStore`
|
||||
3. `core.Embedder`
|
||||
4. `core.Reranker`
|
||||
5. `corpus.MemoryCorpus`
|
||||
6. `corpus.WebCorpus`
|
||||
|
||||
这些对象应被视为 `infra/rag` 内部拼装细节。
|
||||
|
||||
业务层只应调用诸如以下方法:
|
||||
|
||||
1. `IngestMemory`
|
||||
2. `RetrieveMemory`
|
||||
3. `IngestWeb`
|
||||
4. `RetrieveWeb`
|
||||
|
||||
这样做的好处是:
|
||||
|
||||
1. 业务依赖面更稳定。
|
||||
2. 后续替换底层实现时,不会把改动扩散到多个业务模块。
|
||||
3. 便于统一日志、监控、降级和权限边界。
|
||||
|
||||
## 4.3 原则三:业务语义留在业务层,通用 RAG 工序下沉到 Infra
|
||||
|
||||
下沉到 `infra/rag` 的内容:
|
||||
|
||||
1. 切块
|
||||
2. 向量化
|
||||
3. 向量存储
|
||||
4. 召回
|
||||
5. rerank
|
||||
6. threshold 过滤
|
||||
7. fallback 语义
|
||||
8. 统一日志与指标
|
||||
|
||||
留在业务层的内容:
|
||||
|
||||
1. `memory` 的注入优先级、门控规则、显式/隐式策略
|
||||
2. `websearch` 的 provider 搜索、query 改写、时间过滤、domain 白名单、抓取策略
|
||||
3. 最终给模型注入哪些证据、注入多少、如何组织引用
|
||||
|
||||
## 4.4 原则四:并行迁移,不一步删旧
|
||||
|
||||
本轮改造虽然目标是“一步到位把 Infra 做完整”,但切流必须保持并行迁移:
|
||||
|
||||
1. 新 Infra 建好后,先让 `memory` 接入并保留旧逻辑兜底。
|
||||
2. 再让 `websearch` 接入并保留 V1 路径兜底。
|
||||
3. 观察稳定后再删除旧分支。
|
||||
|
||||
---
|
||||
|
||||
## 5. 目标架构
|
||||
|
||||
## 5.1 推荐对外结构
|
||||
|
||||
建议在 `backend/infra/rag` 新增统一对外门面,例如:
|
||||
|
||||
1. `runtime.go`
|
||||
2. `factory.go`
|
||||
3. `service.go`
|
||||
|
||||
推荐把正式对外依赖面收敛为一个接口,例如:
|
||||
|
||||
```go
|
||||
type Runtime interface {
|
||||
IngestMemory(ctx context.Context, input MemoryIngestRequest) (*IngestResult, error)
|
||||
RetrieveMemory(ctx context.Context, input MemoryRetrieveRequest) (*RetrieveResult, error)
|
||||
|
||||
IngestWeb(ctx context.Context, input WebIngestRequest) (*IngestResult, error)
|
||||
RetrieveWeb(ctx context.Context, input WebRetrieveRequest) (*RetrieveResult, error)
|
||||
}
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
1. 业务侧只依赖 `Runtime`。
|
||||
2. `Runtime` 内部再去调用 `Pipeline + CorpusAdapter + Store + Embedder + Reranker`。
|
||||
3. 这样可以保证业务不会直接 import `core` 包下的底层细节。
|
||||
|
||||
## 5.2 推荐内部结构
|
||||
|
||||
建议内部形成以下分工:
|
||||
|
||||
1. `factory.go`
|
||||
- 负责按配置创建 `Embedder / Store / Reranker / Pipeline`
|
||||
2. `runtime.go`
|
||||
- 负责持有 `Pipeline + MemoryCorpus + WebCorpus + Logger + Metrics`
|
||||
3. `service.go`
|
||||
- 负责定义 `Runtime` 接口与对外方法
|
||||
4. `core/`
|
||||
- 保持底层通用编排逻辑
|
||||
5. `corpus/`
|
||||
- 只负责“语料 -> 标准文档”和“业务过滤 -> 标准 filter”
|
||||
|
||||
## 5.3 推荐依赖注入方式
|
||||
|
||||
在 `backend/cmd/start.go` 中,启动期统一创建 `RAG Runtime`,例如:
|
||||
|
||||
1. 读取 `rag.*` 配置
|
||||
2. 构造 `RAGFactory`
|
||||
3. 生成 `RAGRuntime`
|
||||
4. 注入给:
|
||||
- `memory service`
|
||||
- `newAgent web tools`
|
||||
|
||||
业务侧只拿运行好的对象,不再自己 new 任何底层实现。
|
||||
|
||||
---
|
||||
|
||||
## 6. 对外方法面设计
|
||||
|
||||
## 6.1 Memory 对外方法
|
||||
|
||||
推荐对外暴露以下方法:
|
||||
|
||||
1. `IngestMemory`
|
||||
- 输入:标准化后的记忆入库请求
|
||||
- 输出:文档数、chunk 数、同步结果
|
||||
2. `RetrieveMemory`
|
||||
- 输入:用户、会话、助手、run、query、topK、threshold
|
||||
- 输出:标准 `RetrieveResult`
|
||||
|
||||
注意:
|
||||
|
||||
1. `memory` 业务层不应直接调用 `MemoryCorpus`。
|
||||
2. `memory` 业务层不应自己拼向量过滤条件。
|
||||
3. 所有过滤条件由 `RetrieveMemory` 内部统一转换。
|
||||
|
||||
## 6.2 Web 对外方法
|
||||
|
||||
推荐对外暴露以下方法:
|
||||
|
||||
1. `IngestWeb`
|
||||
- 输入:抓取结果 `url/title/snippet/content/domain/query_id/session_id`
|
||||
- 输出:统一入库摘要
|
||||
2. `RetrieveWeb`
|
||||
- 输入:query、query_id/session_id、domain、topK、threshold
|
||||
- 输出:标准 `RetrieveResult`
|
||||
|
||||
注意:
|
||||
|
||||
1. `websearch` 业务层不应直接持有 `WebCorpus`。
|
||||
2. `websearch` 业务层只负责“拿到页面内容”与“决定是否需要调用 RAG”。
|
||||
3. 实际向量入库、检索、rerank 由 `infra/rag` 统一处理。
|
||||
|
||||
## 6.3 对外方法设计边界
|
||||
|
||||
方法层负责什么:
|
||||
|
||||
1. 参数合法性校验
|
||||
2. 内部 filter 组装
|
||||
3. 调 `Pipeline.Ingest / Retrieve`
|
||||
4. 统一日志、指标、fallback
|
||||
|
||||
方法层不负责什么:
|
||||
|
||||
1. 不负责 `websearch provider` 搜索
|
||||
2. 不负责 HTML 抓取
|
||||
3. 不负责 prompt 注入
|
||||
4. 不负责业务排序偏好
|
||||
|
||||
---
|
||||
|
||||
## 7. 具体改造计划
|
||||
|
||||
## 7.1 第一部分:把 RAG Infra 自身做完整
|
||||
|
||||
### 目标
|
||||
|
||||
让 `backend/infra/rag` 成为“正式可注入、正式可切换、正式可依赖”的共享基础设施。
|
||||
|
||||
### 实施项
|
||||
|
||||
1. 新增正式运行时与工厂:
|
||||
- `backend/infra/rag/runtime.go`
|
||||
- `backend/infra/rag/factory.go`
|
||||
- 如有需要,新增 `backend/infra/rag/service.go`
|
||||
2. 扩展配置:
|
||||
- `rag.enabled`
|
||||
- `rag.store`
|
||||
- `rag.embed.provider`
|
||||
- `rag.embed.model`
|
||||
- `rag.embed.timeoutMs`
|
||||
- `rag.embed.dimension`
|
||||
- `rag.reranker.provider`
|
||||
- `rag.reranker.timeoutMs`
|
||||
- `rag.retrieve.timeoutMs`
|
||||
- `rag.ingest.chunkSize`
|
||||
- `rag.ingest.chunkOverlap`
|
||||
3. 收口运行入口:
|
||||
- `rag.NewDefaultPipeline()` 保留为本地 fallback
|
||||
- 正式业务接入走 `NewRuntimeFromConfig(...)`
|
||||
4. 消除重复检索路径:
|
||||
- 明确 `Pipeline` 是官方检索入口
|
||||
- `retrieve/vector_retriever.go` 要么内聚为内部实现,要么后续删除,避免双轨
|
||||
|
||||
### 验收
|
||||
|
||||
1. 启动期可按配置成功构造 `RAG Runtime`。
|
||||
2. 业务侧不需要自己组装 `Pipeline / Store / Embedder / Reranker`。
|
||||
3. 对外暴露面稳定,底层实现可替换。
|
||||
|
||||
## 7.2 第二部分:补齐真实底层实现
|
||||
|
||||
### 目标
|
||||
|
||||
让 `RAG Infra` 具备真实可用的向量能力,而不是停留在 mock。
|
||||
|
||||
### 实施项
|
||||
|
||||
1. 实现 `embed/eino_embedder.go`
|
||||
- 负责 embedding 调用
|
||||
- 负责 embedding timeout
|
||||
- 负责错误包装与统一日志
|
||||
2. 实现 `rerank/eino_reranker.go`
|
||||
- 负责 rerank 调用
|
||||
- 负责 rerank timeout
|
||||
- 负责失败降级到原排序
|
||||
3. 实现 `store/milvus_store.go`
|
||||
- `Upsert`
|
||||
- `Search`
|
||||
- `Delete`
|
||||
- `Get`
|
||||
4. Milvus 元数据设计建议:
|
||||
- 高频过滤字段应做显式标量字段,不建议全部依赖大 JSON 过滤
|
||||
- 重点字段包括:
|
||||
- `corpus`
|
||||
- `user_id`
|
||||
- `assistant_id`
|
||||
- `conversation_id`
|
||||
- `run_id`
|
||||
- `memory_type`
|
||||
- `query_id`
|
||||
- `session_id`
|
||||
- `domain`
|
||||
|
||||
### 验收
|
||||
|
||||
1. `MilvusStore` 在已准备好的 Docker 环境中可稳定完成写入与检索。
|
||||
2. `EinoEmbedder` 和 `EinoReranker` 可按配置启用。
|
||||
3. provider 波动时,主链路仍能 fallback。
|
||||
|
||||
## 7.3 第三部分:补齐工程化能力
|
||||
|
||||
### 目标
|
||||
|
||||
让 `RAG Infra` 具备“可观测、可测试、可回滚”的基础设施属性。
|
||||
|
||||
### 实施项
|
||||
|
||||
1. timeout 接线:
|
||||
- embedding timeout
|
||||
- retrieve timeout
|
||||
- rerank timeout
|
||||
2. 统一日志字段:
|
||||
- `trace_id`
|
||||
- `corpus`
|
||||
- `action`
|
||||
- `provider`
|
||||
- `latency_ms`
|
||||
- `hit_count`
|
||||
- `fallback_reason`
|
||||
3. 指标补齐:
|
||||
- `rag_ingest_count`
|
||||
- `rag_retrieve_count`
|
||||
- `rag_hit_count`
|
||||
- `rag_fallback_rate`
|
||||
- `rag_latency_ms`
|
||||
4. 测试补齐:
|
||||
- `chunker` 单测
|
||||
- `corpus filter` 单测
|
||||
- `pipeline fallback` 单测
|
||||
- `MilvusStore` 集成测试
|
||||
- `memory/web` 过滤隔离测试
|
||||
|
||||
### 验收
|
||||
|
||||
1. 出现检索问题时,可从日志定位是:
|
||||
- 没命中
|
||||
- 超时
|
||||
- rerank 降级
|
||||
- filter 过滤过严
|
||||
2. 公共层测试可稳定覆盖关键路径。
|
||||
|
||||
## 7.4 第四部分:接入 Memory
|
||||
|
||||
### 目标
|
||||
|
||||
让 `memory` 成为第一个正式接入 `RAG Infra` 的业务域。
|
||||
|
||||
### 实施项
|
||||
|
||||
1. 写入链路接入:
|
||||
- 在 memory worker 成功写入 `memory_items` 后,调用 `RAGRuntime.IngestMemory`
|
||||
- 复用 `memory_items.vector_status/vector_id`
|
||||
2. 读取链路接入:
|
||||
- 在 `memory/service/read_service.go` 中新增 `RetrieveMemory` 路径
|
||||
- 强制过滤:
|
||||
- `user_id`
|
||||
- `assistant_id`
|
||||
- `conversation_id`
|
||||
- `run_id`
|
||||
3. 开关控制:
|
||||
- `memory.rag.enabled=false` 默认关闭
|
||||
- 打开后先灰度使用新路径
|
||||
4. 降级策略:
|
||||
- `RAG` 检索失败 -> 回退旧读取链路
|
||||
- `Reranker` 失败 -> 保留原始排序
|
||||
|
||||
### 验收
|
||||
|
||||
1. 开关关闭时行为与当前一致。
|
||||
2. 开关开启时,记忆召回可稳定工作。
|
||||
3. 失败时不会影响主链路回复。
|
||||
|
||||
## 7.5 第五部分:接入 WebSearch
|
||||
|
||||
### 目标
|
||||
|
||||
让 `websearch` 成为第二个正式接入 `RAG Infra` 的业务域,并复用 `WebCorpus`。
|
||||
|
||||
### 实施项
|
||||
|
||||
1. 保留 V1 路径:
|
||||
- `web_search` 做 provider 搜索
|
||||
- `web_fetch` 做正文抓取与清洗
|
||||
2. 新增 V2 路径:
|
||||
- 把抓取结果映射为 `WebIngestItem`
|
||||
- 调 `RAGRuntime.IngestWeb`
|
||||
- 再调 `RAGRuntime.RetrieveWeb`
|
||||
3. 强约束过滤:
|
||||
- `query_id` 或 `session_id` 至少有一个
|
||||
- 避免跨 query/session 串召回
|
||||
4. 开关控制:
|
||||
- `websearch.rag.enabled=false` 默认关闭
|
||||
5. 降级策略:
|
||||
- `web_rag_search` 失败 -> 回退到 `web_search + web_fetch`
|
||||
|
||||
### 验收
|
||||
|
||||
1. 新旧链路并存,互不影响。
|
||||
2. 新链路不会跨 query/session 串数据。
|
||||
3. 失败可立刻回退到 V1。
|
||||
|
||||
## 7.6 第六部分:启动接线与统一管理
|
||||
|
||||
### 目标
|
||||
|
||||
让 `RAG Runtime` 成为启动期统一装配、统一管理的依赖。
|
||||
|
||||
### 实施项
|
||||
|
||||
1. 在 `backend/cmd/start.go` 中:
|
||||
- 读取 `rag.*` 配置
|
||||
- 构造 `RAG Runtime`
|
||||
- 注入给 `memory` 与 `newAgent web tools`
|
||||
2. 统一由启动期管理依赖生命周期:
|
||||
- 初始化
|
||||
- 健康检查
|
||||
- 关闭清理
|
||||
3. 业务层禁止直接 new 底层实现:
|
||||
- 禁止业务自己构建 `MilvusStore`
|
||||
- 禁止业务自己构建 `EinoEmbedder`
|
||||
- 禁止业务自己拼 `Pipeline`
|
||||
|
||||
### 验收
|
||||
|
||||
1. 依赖管理集中在启动层。
|
||||
2. 业务代码只依赖方法入口,不接触底层实现。
|
||||
3. 后续替换实现时,无需大面积修改业务层代码。
|
||||
|
||||
---
|
||||
|
||||
## 8. 推荐目录改造方案
|
||||
|
||||
建议新增或调整如下文件:
|
||||
|
||||
1. `backend/infra/rag/runtime.go`
|
||||
2. `backend/infra/rag/factory.go`
|
||||
3. `backend/infra/rag/service.go`
|
||||
4. `backend/infra/rag/README.md` 或在本文件持续追加
|
||||
5. `backend/infra/rag/embed/eino_embedder.go`
|
||||
6. `backend/infra/rag/rerank/eino_reranker.go`
|
||||
7. `backend/infra/rag/store/milvus_store.go`
|
||||
8. `backend/infra/rag/core/pipeline_test.go`
|
||||
9. `backend/infra/rag/chunk/text_chunker_test.go`
|
||||
10. `backend/infra/rag/corpus/memory_corpus_test.go`
|
||||
11. `backend/infra/rag/corpus/web_corpus_test.go`
|
||||
12. `backend/infra/rag/store/milvus_store_integration_test.go`
|
||||
|
||||
配套改动文件:
|
||||
|
||||
1. `backend/cmd/start.go`
|
||||
2. `backend/config.example.yaml`
|
||||
3. `backend/memory/service/read_service.go`
|
||||
4. `backend/newAgent/tools/registry.go`
|
||||
5. `backend/agent/通用能力接入文档.md`
|
||||
|
||||
---
|
||||
|
||||
## 9. 配置建议
|
||||
|
||||
建议新增如下配置结构:
|
||||
|
||||
```yaml
|
||||
rag:
|
||||
enabled: true
|
||||
store: "milvus"
|
||||
topK: 8
|
||||
threshold: 0.55
|
||||
retrieve:
|
||||
timeoutMs: 1500
|
||||
ingest:
|
||||
chunkSize: 400
|
||||
chunkOverlap: 80
|
||||
embed:
|
||||
provider: "eino"
|
||||
model: ""
|
||||
timeoutMs: 1200
|
||||
dimension: 1024
|
||||
reranker:
|
||||
enabled: true
|
||||
provider: "eino"
|
||||
timeoutMs: 1200
|
||||
|
||||
memory:
|
||||
rag:
|
||||
enabled: false
|
||||
|
||||
websearch:
|
||||
rag:
|
||||
enabled: false
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
1. `rag.enabled` 控制公共层是否启用。
|
||||
2. `memory.rag.enabled` 与 `websearch.rag.enabled` 控制业务级切流。
|
||||
3. 即使 `rag.enabled=true`,也不代表所有业务立刻默认走新链路。
|
||||
|
||||
---
|
||||
|
||||
## 10. 回滚策略
|
||||
|
||||
推荐回滚顺序如下:
|
||||
|
||||
1. 先关业务级开关:
|
||||
- `memory.rag.enabled=false`
|
||||
- `websearch.rag.enabled=false`
|
||||
2. 再关重排:
|
||||
- `rag.reranker.enabled=false`
|
||||
3. 再切底层实现:
|
||||
- `rag.store=inmemory`
|
||||
- `rag.embed.provider=mock`
|
||||
- `rag.reranker.provider=noop`
|
||||
4. 若仍异常,再回退到业务旧链路
|
||||
|
||||
这样可以做到:
|
||||
|
||||
1. 不因单个 provider 波动打断主流程。
|
||||
2. 保留最小可用能力。
|
||||
3. 故障定位粒度更细。
|
||||
|
||||
---
|
||||
|
||||
## 11. 风险与应对
|
||||
|
||||
1. 风险:Milvus 过滤能力与现有 metadata 结构不匹配。
|
||||
- 应对:高频过滤字段单独建模,不依赖大 JSON 粗暴过滤。
|
||||
2. 风险:embedding/rerank provider 波动影响延迟。
|
||||
- 应对:超时控制 + fallback + 业务级开关。
|
||||
3. 风险:业务层绕过 Infra 直接依赖底层实现。
|
||||
- 应对:通过 `Runtime` 方法面统一收口,代码评审禁止横向绕过。
|
||||
4. 风险:新旧检索路径长期并存导致维护成本上升。
|
||||
- 应对:本轮先保留兜底,稳定后明确删除旧实现。
|
||||
5. 风险:跨 query/session 串召回。
|
||||
- 应对:`WebRetrieve` 强制校验 `query_id/session_id` 至少其一存在。
|
||||
|
||||
---
|
||||
|
||||
## 12. 最小落地顺序
|
||||
|
||||
如果按“尽快落成可接入 Infra”的优先级来排,本轮建议顺序如下:
|
||||
|
||||
1. 先做 `runtime/factory/service`,把依赖注入和方法面收口。
|
||||
2. 再实现 `MilvusStore + EinoEmbedder + EinoReranker`。
|
||||
3. 再补 timeout、日志、指标、测试。
|
||||
4. 然后优先接 `memory`。
|
||||
5. 最后接 `websearch`。
|
||||
|
||||
原因:
|
||||
|
||||
1. 若先接业务、不先收口方法面,后面会把底层细节泄露到业务层。
|
||||
2. 若先接 websearch、不先接 memory,会导致共享 Infra 价值不够集中,面试叙事也不完整。
|
||||
|
||||
---
|
||||
|
||||
## 13. 本轮完成后的预期收益
|
||||
|
||||
完成本方案后,项目会获得以下收益:
|
||||
|
||||
1. `memory` 与 `websearch` 共享一套真正可运行的 RAG 基础设施。
|
||||
2. 业务侧不再重复实现切块、召回、重排与降级逻辑。
|
||||
3. `infra/rag` 成为正式公共能力,具备统一依赖注入与统一管理能力。
|
||||
4. 后续新增新语料域时,只需新增 `CorpusAdapter + 方法面`,无需再复制一套 RAG 链路。
|
||||
5. 项目简历叙事会更完整:
|
||||
- “抽象并实现共享 RAG Infra”
|
||||
- “统一 Memory/WebSearch 的检索与重排能力”
|
||||
- “通过依赖注入与门面方法收口底层复杂度”
|
||||
|
||||
---
|
||||
|
||||
## 14. 当前建议结论
|
||||
|
||||
建议把本轮目标明确为:
|
||||
|
||||
1. **不是**“再给 RAG 补几个占位实现”。
|
||||
2. **而是**“把 `backend/infra/rag` 一次性做成正式可接入的公共基础设施”。
|
||||
|
||||
关键落点是两句话:
|
||||
|
||||
1. 依赖注入统一由 `infra/rag` 自己负责。
|
||||
2. 对外只暴露方法入口,业务侧不直接接触底层实现细节。
|
||||
|
||||
只要这两点收住,后续 `memory`、`websearch`、甚至更多语料域都会明显更好管理。
|
||||
@@ -5,30 +5,63 @@ import "github.com/spf13/viper"
|
||||
// Config 是 RAG Core 运行配置。
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
Store string
|
||||
TopK int
|
||||
|
||||
Threshold float64
|
||||
|
||||
EmbedProvider string
|
||||
EmbedModel string
|
||||
EmbedBaseURL string
|
||||
EmbedAPIKeyEnv string
|
||||
EmbedTimeoutMS int
|
||||
EmbedDimension int
|
||||
|
||||
RerankerEnabled bool
|
||||
RerankerProvider string
|
||||
RerankerTimeoutMS int
|
||||
|
||||
ChunkSize int
|
||||
ChunkOverlap int
|
||||
|
||||
RetrieveTimeoutMS int
|
||||
|
||||
MilvusAddress string
|
||||
MilvusToken string
|
||||
MilvusDBName string
|
||||
MilvusCollectionName string
|
||||
MilvusMetricType string
|
||||
MilvusRequestTimeoutMS int
|
||||
}
|
||||
|
||||
// LoadFromViper 读取 rag 配置并补默认值。
|
||||
func LoadFromViper() Config {
|
||||
cfg := Config{
|
||||
Enabled: viper.GetBool("rag.enabled"),
|
||||
TopK: viper.GetInt("rag.topK"),
|
||||
Threshold: viper.GetFloat64("rag.threshold"),
|
||||
RerankerEnabled: viper.GetBool("rag.reranker.enabled"),
|
||||
RerankerTimeoutMS: viper.GetInt("rag.reranker.timeoutMs"),
|
||||
ChunkSize: viper.GetInt("rag.ingest.chunkSize"),
|
||||
ChunkOverlap: viper.GetInt("rag.ingest.chunkOverlap"),
|
||||
RetrieveTimeoutMS: viper.GetInt("rag.retrieve.timeoutMs"),
|
||||
Enabled: viper.GetBool("rag.enabled"),
|
||||
Store: viper.GetString("rag.store"),
|
||||
TopK: viper.GetInt("rag.topK"),
|
||||
Threshold: viper.GetFloat64("rag.threshold"),
|
||||
EmbedProvider: viper.GetString("rag.embed.provider"),
|
||||
EmbedModel: viper.GetString("rag.embed.model"),
|
||||
EmbedBaseURL: viper.GetString("rag.embed.baseURL"),
|
||||
EmbedAPIKeyEnv: viper.GetString("rag.embed.apiKeyEnv"),
|
||||
EmbedTimeoutMS: viper.GetInt("rag.embed.timeoutMs"),
|
||||
EmbedDimension: viper.GetInt("rag.embed.dimension"),
|
||||
RerankerEnabled: viper.GetBool("rag.reranker.enabled"),
|
||||
RerankerProvider: viper.GetString("rag.reranker.provider"),
|
||||
RerankerTimeoutMS: viper.GetInt("rag.reranker.timeoutMs"),
|
||||
ChunkSize: viper.GetInt("rag.ingest.chunkSize"),
|
||||
ChunkOverlap: viper.GetInt("rag.ingest.chunkOverlap"),
|
||||
RetrieveTimeoutMS: viper.GetInt("rag.retrieve.timeoutMs"),
|
||||
MilvusAddress: viper.GetString("rag.milvus.address"),
|
||||
MilvusToken: viper.GetString("rag.milvus.token"),
|
||||
MilvusDBName: viper.GetString("rag.milvus.dbName"),
|
||||
MilvusCollectionName: viper.GetString("rag.milvus.collectionName"),
|
||||
MilvusMetricType: viper.GetString("rag.milvus.metricType"),
|
||||
MilvusRequestTimeoutMS: viper.GetInt("rag.milvus.requestTimeoutMs"),
|
||||
}
|
||||
if cfg.Store == "" {
|
||||
cfg.Store = "inmemory"
|
||||
}
|
||||
if cfg.TopK <= 0 {
|
||||
cfg.TopK = 8
|
||||
@@ -36,6 +69,24 @@ func LoadFromViper() Config {
|
||||
if cfg.Threshold < 0 {
|
||||
cfg.Threshold = 0
|
||||
}
|
||||
if cfg.EmbedProvider == "" {
|
||||
cfg.EmbedProvider = "mock"
|
||||
}
|
||||
if cfg.EmbedBaseURL == "" {
|
||||
cfg.EmbedBaseURL = viper.GetString("agent.baseURL")
|
||||
}
|
||||
if cfg.EmbedAPIKeyEnv == "" {
|
||||
cfg.EmbedAPIKeyEnv = "ARK_API_KEY"
|
||||
}
|
||||
if cfg.EmbedTimeoutMS <= 0 {
|
||||
cfg.EmbedTimeoutMS = 1200
|
||||
}
|
||||
if cfg.EmbedDimension <= 0 {
|
||||
cfg.EmbedDimension = 1024
|
||||
}
|
||||
if cfg.RerankerProvider == "" {
|
||||
cfg.RerankerProvider = "noop"
|
||||
}
|
||||
if cfg.RerankerTimeoutMS <= 0 {
|
||||
cfg.RerankerTimeoutMS = 1200
|
||||
}
|
||||
@@ -48,5 +99,20 @@ func LoadFromViper() Config {
|
||||
if cfg.RetrieveTimeoutMS <= 0 {
|
||||
cfg.RetrieveTimeoutMS = 1500
|
||||
}
|
||||
if cfg.MilvusAddress == "" {
|
||||
cfg.MilvusAddress = "http://localhost:19530"
|
||||
}
|
||||
if cfg.MilvusToken == "" {
|
||||
cfg.MilvusToken = "root:Milvus"
|
||||
}
|
||||
if cfg.MilvusCollectionName == "" {
|
||||
cfg.MilvusCollectionName = "smartflow_rag_chunks"
|
||||
}
|
||||
if cfg.MilvusMetricType == "" {
|
||||
cfg.MilvusMetricType = "COSINE"
|
||||
}
|
||||
if cfg.MilvusRequestTimeoutMS <= 0 {
|
||||
cfg.MilvusRequestTimeoutMS = 1500
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
190
backend/infra/rag/core/observer.go
Normal file
190
backend/infra/rag/core/observer.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ObserveLevel 表示观测事件等级。
|
||||
type ObserveLevel string
|
||||
|
||||
const (
|
||||
ObserveLevelInfo ObserveLevel = "info"
|
||||
ObserveLevelWarn ObserveLevel = "warn"
|
||||
ObserveLevelError ObserveLevel = "error"
|
||||
)
|
||||
|
||||
// ObserveEvent 描述一次统一观测事件。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只承载 RAG Infra 的结构化运行信息;
|
||||
// 2. 不绑定具体日志系统、指标系统或 tracing 实现;
|
||||
// 3. 字段内容应尽量稳定,便于后续统一接入全局观测平台。
|
||||
type ObserveEvent struct {
|
||||
Level ObserveLevel
|
||||
Component string
|
||||
Operation string
|
||||
Fields map[string]any
|
||||
}
|
||||
|
||||
// Observer 是 RAG Infra 的最小观测接口。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责消费结构化事件;
|
||||
// 2. 不负责决定业务逻辑是否继续执行;
|
||||
// 3. 任一实现都不应反向影响主链路稳定性。
|
||||
type Observer interface {
|
||||
Observe(ctx context.Context, event ObserveEvent)
|
||||
}
|
||||
|
||||
// ObserverFunc 允许用函数快速适配 Observer。
|
||||
type ObserverFunc func(ctx context.Context, event ObserveEvent)
|
||||
|
||||
func (f ObserverFunc) Observe(ctx context.Context, event ObserveEvent) {
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
f(ctx, event)
|
||||
}
|
||||
|
||||
// NewNopObserver 返回空实现,适合在未接入统一观测平台时兜底。
|
||||
func NewNopObserver() Observer {
|
||||
return ObserverFunc(func(context.Context, ObserveEvent) {})
|
||||
}
|
||||
|
||||
// NewLoggerObserver 返回标准日志适配器。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 当前项目尚未建立统一日志平台时,先把结构化字段稳定打印出来;
|
||||
// 2. 后续若项目引入统一 logger/metrics/tracing,只需替换该 Observer 注入实现;
|
||||
// 3. 该适配器默认保持单行输出,减少和现有日志风格的割裂感。
|
||||
func NewLoggerObserver(logger *log.Logger) Observer {
|
||||
if logger == nil {
|
||||
logger = log.Default()
|
||||
}
|
||||
return &loggerObserver{logger: logger}
|
||||
}
|
||||
|
||||
type loggerObserver struct {
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func (o *loggerObserver) Observe(ctx context.Context, event ObserveEvent) {
|
||||
if o == nil || o.logger == nil {
|
||||
return
|
||||
}
|
||||
|
||||
level := strings.TrimSpace(string(event.Level))
|
||||
if level == "" {
|
||||
level = string(ObserveLevelInfo)
|
||||
}
|
||||
component := strings.TrimSpace(event.Component)
|
||||
if component == "" {
|
||||
component = "unknown"
|
||||
}
|
||||
operation := strings.TrimSpace(event.Operation)
|
||||
if operation == "" {
|
||||
operation = "unknown"
|
||||
}
|
||||
|
||||
fields := ObserveFieldsFromContext(ctx)
|
||||
for key, value := range event.Fields {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" || !shouldKeepObserveField(value) {
|
||||
continue
|
||||
}
|
||||
fields[key] = value
|
||||
}
|
||||
|
||||
parts := []string{
|
||||
"rag",
|
||||
fmt.Sprintf("level=%s", level),
|
||||
fmt.Sprintf("component=%s", component),
|
||||
fmt.Sprintf("operation=%s", operation),
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(fields))
|
||||
for key := range fields {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, key := range keys {
|
||||
parts = append(parts, fmt.Sprintf("%s=%v", key, fields[key]))
|
||||
}
|
||||
|
||||
o.logger.Print(strings.Join(parts, " "))
|
||||
}
|
||||
|
||||
type observeFieldsContextKey struct{}
|
||||
|
||||
// WithObserveFields 把通用观测字段挂入上下文,便于下游组件复用。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先读取已有上下文字段,保证 Runtime / Pipeline / Store 能逐层补充信息;
|
||||
// 2. 后写字段覆盖同名旧值,确保下游拿到的是最新语义;
|
||||
// 3. 仅保存“有意义”的字段,避免日志长期堆积大量空值。
|
||||
func WithObserveFields(ctx context.Context, fields map[string]any) context.Context {
|
||||
if len(fields) == 0 {
|
||||
return ctx
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
merged := ObserveFieldsFromContext(ctx)
|
||||
for key, value := range fields {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" || !shouldKeepObserveField(value) {
|
||||
continue
|
||||
}
|
||||
merged[key] = value
|
||||
}
|
||||
if len(merged) == 0 {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, observeFieldsContextKey{}, merged)
|
||||
}
|
||||
|
||||
// ObserveFieldsFromContext 提取上下文中已经累积的观测字段。
|
||||
func ObserveFieldsFromContext(ctx context.Context) map[string]any {
|
||||
if ctx == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
raw, ok := ctx.Value(observeFieldsContextKey{}).(map[string]any)
|
||||
if !ok || len(raw) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
result := make(map[string]any, len(raw))
|
||||
for key, value := range raw {
|
||||
result[key] = value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ClassifyErrorCode 统一把常见错误压缩为稳定错误码,便于后续接入全局观测平台。
|
||||
func ClassifyErrorCode(err error) string {
|
||||
switch {
|
||||
case err == nil:
|
||||
return ""
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
return "DEADLINE_EXCEEDED"
|
||||
case errors.Is(err, context.Canceled):
|
||||
return "CANCELED"
|
||||
default:
|
||||
return "RAG_ERROR"
|
||||
}
|
||||
}
|
||||
|
||||
func shouldKeepObserveField(value any) bool {
|
||||
if value == nil {
|
||||
return false
|
||||
}
|
||||
if text, ok := value.(string); ok {
|
||||
return strings.TrimSpace(text) != ""
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -28,6 +28,7 @@ type Pipeline struct {
|
||||
store VectorStore
|
||||
reranker Reranker
|
||||
logger *log.Logger
|
||||
observer Observer
|
||||
}
|
||||
|
||||
func NewPipeline(chunker Chunker, embedder Embedder, store VectorStore, reranker Reranker) *Pipeline {
|
||||
@@ -37,9 +38,26 @@ func NewPipeline(chunker Chunker, embedder Embedder, store VectorStore, reranker
|
||||
store: store,
|
||||
reranker: reranker,
|
||||
logger: log.Default(),
|
||||
observer: NewNopObserver(),
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogger 设置 Pipeline 使用的日志器。
|
||||
func (p *Pipeline) SetLogger(logger *log.Logger) {
|
||||
if p == nil || logger == nil {
|
||||
return
|
||||
}
|
||||
p.logger = logger
|
||||
}
|
||||
|
||||
// SetObserver 设置 Pipeline 使用的统一观测器。
|
||||
func (p *Pipeline) SetObserver(observer Observer) {
|
||||
if p == nil || observer == nil {
|
||||
return
|
||||
}
|
||||
p.observer = observer
|
||||
}
|
||||
|
||||
// Ingest 执行统一入库流程。
|
||||
//
|
||||
// 步骤化说明:
|
||||
@@ -63,6 +81,24 @@ func (p *Pipeline) Ingest(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p.IngestDocuments(ctx, corpus.Name(), docs, opt)
|
||||
}
|
||||
|
||||
// IngestDocuments 执行“已标准化文档”的统一入库流程。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责处理已经完成 CorpusAdapter 映射的标准文档;
|
||||
// 2. 负责统一切块、向量化与 Upsert;
|
||||
// 3. 不负责再做业务输入解析,避免 Runtime 为拿到 document_id 重复 build 文档。
|
||||
func (p *Pipeline) IngestDocuments(
|
||||
ctx context.Context,
|
||||
corpusName string,
|
||||
docs []SourceDocument,
|
||||
opt IngestOption,
|
||||
) (*IngestResult, error) {
|
||||
if p == nil || p.chunker == nil || p.embedder == nil || p.store == nil {
|
||||
return nil, ErrNilDependency
|
||||
}
|
||||
if len(docs) == 0 {
|
||||
return &IngestResult{DocumentCount: 0, ChunkCount: 0}, nil
|
||||
}
|
||||
@@ -102,7 +138,7 @@ func (p *Pipeline) Ingest(
|
||||
now := time.Now()
|
||||
for i, chunk := range chunks {
|
||||
metadata := cloneMap(chunk.Metadata)
|
||||
metadata["corpus"] = corpus.Name()
|
||||
metadata["corpus"] = corpusName
|
||||
metadata["document_id"] = chunk.DocumentID
|
||||
metadata["chunk_order"] = chunk.Order
|
||||
rows = append(rows, VectorRow{
|
||||
@@ -214,7 +250,23 @@ func (p *Pipeline) Retrieve(
|
||||
// 2. rerank 异常不终止主流程,统一降级为原排序。
|
||||
result.FallbackUsed = true
|
||||
result.FallbackReason = FallbackReasonRerankFailed
|
||||
p.logger.Printf("rag rerank fallback: reason=%s err=%v", FallbackReasonRerankFailed, rerankErr)
|
||||
if p.observer != nil {
|
||||
p.observer.Observe(ctx, ObserveEvent{
|
||||
Level: ObserveLevelWarn,
|
||||
Component: "pipeline",
|
||||
Operation: "rerank_fallback",
|
||||
Fields: map[string]any{
|
||||
"status": "fallback",
|
||||
"fallback_reason": FallbackReasonRerankFailed,
|
||||
"candidate_count": len(candidates),
|
||||
"top_k": topK,
|
||||
"error": rerankErr,
|
||||
"error_code": ClassifyErrorCode(rerankErr),
|
||||
},
|
||||
})
|
||||
} else if p.logger != nil {
|
||||
p.logger.Printf("rag rerank fallback: reason=%s err=%v", FallbackReasonRerankFailed, rerankErr)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
result.Items = reranked
|
||||
|
||||
@@ -22,7 +22,11 @@ type MemoryIngestItem struct {
|
||||
MemoryType string
|
||||
Title string
|
||||
Content string
|
||||
Confidence float64
|
||||
Importance float64
|
||||
SensitivityLevel int
|
||||
IsExplicit bool
|
||||
Status string
|
||||
TTLAt *time.Time
|
||||
CreatedAt *time.Time
|
||||
}
|
||||
@@ -71,7 +75,12 @@ func (c *MemoryCorpus) BuildIngestDocuments(_ context.Context, input any) ([]cor
|
||||
"assistant_id": strings.TrimSpace(item.AssistantID),
|
||||
"run_id": strings.TrimSpace(item.RunID),
|
||||
"memory_type": strings.TrimSpace(strings.ToLower(item.MemoryType)),
|
||||
"title": strings.TrimSpace(item.Title),
|
||||
"confidence": item.Confidence,
|
||||
"importance": item.Importance,
|
||||
"sensitivity_level": item.SensitivityLevel,
|
||||
"is_explicit": item.IsExplicit,
|
||||
"status": strings.TrimSpace(item.Status),
|
||||
}
|
||||
if item.TTLAt != nil {
|
||||
metadata["ttl_at"] = item.TTLAt.Format(time.RFC3339)
|
||||
|
||||
@@ -3,19 +3,97 @@ package embed
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
openaiembedding "github.com/cloudwego/eino-ext/libs/acl/openai"
|
||||
einoembedding "github.com/cloudwego/eino/components/embedding"
|
||||
)
|
||||
|
||||
// EinoEmbedder 是 Eino embedding 的占位实现。
|
||||
// EinoConfig 描述 Eino embedding 运行参数。
|
||||
type EinoConfig struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
Model string
|
||||
TimeoutMS int
|
||||
Dimension int
|
||||
}
|
||||
|
||||
// EinoEmbedder 是基于 Eino 的 embedding 适配器。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 本轮先占位接口,避免过早耦合具体 Provider;
|
||||
// 2. 后续接入真实 embedding 时,只替换此文件内部实现。
|
||||
type EinoEmbedder struct{}
|
||||
|
||||
func NewEinoEmbedder() *EinoEmbedder {
|
||||
return &EinoEmbedder{}
|
||||
// 1. 对 infra/rag 暴露统一 []float32 结果,屏蔽 Eino/OpenAI 兼容实现细节;
|
||||
// 2. 超时由该适配器自身收口,避免业务侧每次调用都手写超时控制;
|
||||
// 3. 当前底层走 Eino Ext 的 OpenAI 兼容 embedding client,便于接 Ark/OpenAI 兼容接口。
|
||||
type EinoEmbedder struct {
|
||||
client einoembedding.Embedder
|
||||
model string
|
||||
timeout time.Duration
|
||||
dimension int
|
||||
}
|
||||
|
||||
func (e *EinoEmbedder) Embed(_ context.Context, _ []string, _ string) ([][]float32, error) {
|
||||
return nil, errors.New("eino embedder is not implemented yet")
|
||||
func NewEinoEmbedder(ctx context.Context, cfg EinoConfig) (*EinoEmbedder, error) {
|
||||
if strings.TrimSpace(cfg.APIKey) == "" {
|
||||
return nil, errors.New("eino embedder api key is empty")
|
||||
}
|
||||
if strings.TrimSpace(cfg.Model) == "" {
|
||||
return nil, errors.New("eino embedder model is empty")
|
||||
}
|
||||
|
||||
clientCfg := &openaiembedding.EmbeddingConfig{
|
||||
APIKey: strings.TrimSpace(cfg.APIKey),
|
||||
BaseURL: strings.TrimSpace(cfg.BaseURL),
|
||||
Model: strings.TrimSpace(cfg.Model),
|
||||
}
|
||||
if cfg.Dimension > 0 {
|
||||
clientCfg.Dimensions = &cfg.Dimension
|
||||
}
|
||||
|
||||
client, err := openaiembedding.NewEmbeddingClient(ctx, clientCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
timeout := 1200 * time.Millisecond
|
||||
if cfg.TimeoutMS > 0 {
|
||||
timeout = time.Duration(cfg.TimeoutMS) * time.Millisecond
|
||||
}
|
||||
|
||||
return &EinoEmbedder{
|
||||
client: client,
|
||||
model: strings.TrimSpace(cfg.Model),
|
||||
timeout: timeout,
|
||||
dimension: cfg.Dimension,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *EinoEmbedder) Embed(ctx context.Context, texts []string, _ string) ([][]float32, error) {
|
||||
if e == nil || e.client == nil {
|
||||
return nil, errors.New("eino embedder is not initialized")
|
||||
}
|
||||
if len(texts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
callCtx := ctx
|
||||
cancel := func() {}
|
||||
if e.timeout > 0 {
|
||||
callCtx, cancel = context.WithTimeout(ctx, e.timeout)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
vectors, err := e.client.EmbedStrings(callCtx, texts, einoembedding.WithModel(e.model))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([][]float32, 0, len(vectors))
|
||||
for _, vector := range vectors {
|
||||
converted := make([]float32, len(vector))
|
||||
for i, value := range vector {
|
||||
converted[i] = float32(value)
|
||||
}
|
||||
result = append(result, converted)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
139
backend/infra/rag/factory.go
Normal file
139
backend/infra/rag/factory.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
ragchunk "github.com/LoveLosita/smartflow/backend/infra/rag/chunk"
|
||||
ragconfig "github.com/LoveLosita/smartflow/backend/infra/rag/config"
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
|
||||
ragembed "github.com/LoveLosita/smartflow/backend/infra/rag/embed"
|
||||
ragrerank "github.com/LoveLosita/smartflow/backend/infra/rag/rerank"
|
||||
ragstore "github.com/LoveLosita/smartflow/backend/infra/rag/store"
|
||||
)
|
||||
|
||||
// FactoryDeps 描述 Runtime 工厂所需的可选依赖。
|
||||
//
|
||||
// 说明:
|
||||
// 1. Logger 仅作为“当前项目尚无统一日志系统”时的默认落点;
|
||||
// 2. Observer 是正式的统一观测插槽,后续可替换为项目级 logger/metrics/tracing 适配器;
|
||||
// 3. 业务侧仍然只拿 Runtime,不直接碰底层装配细节。
|
||||
type FactoryDeps struct {
|
||||
Logger *log.Logger
|
||||
Observer Observer
|
||||
}
|
||||
|
||||
// NewRuntimeFromConfig 按配置统一组装 RAG Runtime。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 所有底层实现选择都收口到这里,业务侧不再自行 new store/embedder/reranker;
|
||||
// 2. 即使后续引入更多 provider,也应优先扩展本工厂,而不是把选择逻辑扩散到业务模块;
|
||||
// 3. 观测能力也在此统一注入,避免 runtime/store/pipeline 各自偷偷打印日志。
|
||||
func NewRuntimeFromConfig(ctx context.Context, cfg ragconfig.Config, deps FactoryDeps) (Runtime, error) {
|
||||
logger, observer := normalizeFactoryDeps(deps)
|
||||
|
||||
embedder, err := buildEmbedder(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
store, err := buildStore(cfg, logger, observer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reranker, err := buildReranker(cfg, observer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pipeline := core.NewPipeline(ragchunk.NewTextChunker(), embedder, store, reranker)
|
||||
pipeline.SetLogger(logger)
|
||||
pipeline.SetObserver(observer)
|
||||
return newRuntime(cfg, pipeline, observer), nil
|
||||
}
|
||||
|
||||
func normalizeFactoryDeps(deps FactoryDeps) (*log.Logger, Observer) {
|
||||
logger := deps.Logger
|
||||
if logger == nil {
|
||||
logger = log.Default()
|
||||
}
|
||||
observer := deps.Observer
|
||||
if observer == nil {
|
||||
observer = NewLoggerObserver(logger)
|
||||
}
|
||||
return logger, observer
|
||||
}
|
||||
|
||||
func buildEmbedder(ctx context.Context, cfg ragconfig.Config) (core.Embedder, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.EmbedProvider)) {
|
||||
case "", "mock":
|
||||
return ragembed.NewMockEmbedder(cfg.EmbedDimension), nil
|
||||
case "eino":
|
||||
apiKey := strings.TrimSpace(os.Getenv(cfg.EmbedAPIKeyEnv))
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("rag embed api key is empty: env=%s", cfg.EmbedAPIKeyEnv)
|
||||
}
|
||||
return ragembed.NewEinoEmbedder(ctx, ragembed.EinoConfig{
|
||||
APIKey: apiKey,
|
||||
BaseURL: cfg.EmbedBaseURL,
|
||||
Model: cfg.EmbedModel,
|
||||
TimeoutMS: cfg.EmbedTimeoutMS,
|
||||
Dimension: cfg.EmbedDimension,
|
||||
})
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported rag embed provider: %s", cfg.EmbedProvider)
|
||||
}
|
||||
}
|
||||
|
||||
func buildStore(cfg ragconfig.Config, logger *log.Logger, observer Observer) (core.VectorStore, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.Store)) {
|
||||
case "", "inmemory":
|
||||
return ragstore.NewInMemoryVectorStore(), nil
|
||||
case "milvus":
|
||||
return ragstore.NewMilvusStore(ragstore.MilvusConfig{
|
||||
Address: cfg.MilvusAddress,
|
||||
Token: cfg.MilvusToken,
|
||||
DBName: cfg.MilvusDBName,
|
||||
CollectionName: cfg.MilvusCollectionName,
|
||||
RequestTimeoutMS: cfg.MilvusRequestTimeoutMS,
|
||||
Dimension: cfg.EmbedDimension,
|
||||
MetricType: cfg.MilvusMetricType,
|
||||
Logger: logger,
|
||||
Observer: observer,
|
||||
})
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported rag store: %s", cfg.Store)
|
||||
}
|
||||
}
|
||||
|
||||
func buildReranker(cfg ragconfig.Config, observer Observer) (core.Reranker, error) {
|
||||
if !cfg.RerankerEnabled {
|
||||
return ragrerank.NewNoopReranker(), nil
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.RerankerProvider)) {
|
||||
case "", "noop":
|
||||
return ragrerank.NewNoopReranker(), nil
|
||||
case "eino":
|
||||
if observer != nil {
|
||||
observer.Observe(context.Background(), ObserveEvent{
|
||||
Level: ObserveLevelWarn,
|
||||
Component: "factory",
|
||||
Operation: "reranker_fallback",
|
||||
Fields: map[string]any{
|
||||
"provider": "eino",
|
||||
"status": "fallback",
|
||||
"fallback_target": "noop",
|
||||
"reason": "reranker_not_implemented",
|
||||
},
|
||||
})
|
||||
}
|
||||
return ragrerank.NewNoopReranker(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported rag reranker provider: %s", cfg.RerankerProvider)
|
||||
}
|
||||
}
|
||||
32
backend/infra/rag/observe.go
Normal file
32
backend/infra/rag/observe.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
|
||||
)
|
||||
|
||||
// ObserveLevel 对外暴露统一观测等级别名,避免启动层直接依赖 core 细节。
|
||||
type ObserveLevel = core.ObserveLevel
|
||||
|
||||
const (
|
||||
ObserveLevelInfo = core.ObserveLevelInfo
|
||||
ObserveLevelWarn = core.ObserveLevelWarn
|
||||
ObserveLevelError = core.ObserveLevelError
|
||||
)
|
||||
|
||||
// ObserveEvent 对外暴露统一观测事件别名。
|
||||
type ObserveEvent = core.ObserveEvent
|
||||
|
||||
// Observer 对外暴露统一观测接口别名。
|
||||
type Observer = core.Observer
|
||||
|
||||
// NewNopObserver 返回空实现。
|
||||
func NewNopObserver() Observer {
|
||||
return core.NewNopObserver()
|
||||
}
|
||||
|
||||
// NewLoggerObserver 返回标准日志适配器。
|
||||
func NewLoggerObserver(logger *log.Logger) Observer {
|
||||
return core.NewLoggerObserver(logger)
|
||||
}
|
||||
380
backend/infra/rag/runtime.go
Normal file
380
backend/infra/rag/runtime.go
Normal file
@@ -0,0 +1,380 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
ragconfig "github.com/LoveLosita/smartflow/backend/infra/rag/config"
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/corpus"
|
||||
)
|
||||
|
||||
type runtime struct {
|
||||
cfg ragconfig.Config
|
||||
pipeline *core.Pipeline
|
||||
memoryCorpus *corpus.MemoryCorpus
|
||||
webCorpus *corpus.WebCorpus
|
||||
observer Observer
|
||||
}
|
||||
|
||||
func newRuntime(cfg ragconfig.Config, pipeline *core.Pipeline, observer Observer) Runtime {
|
||||
if observer == nil {
|
||||
observer = NewNopObserver()
|
||||
}
|
||||
return &runtime{
|
||||
cfg: cfg,
|
||||
pipeline: pipeline,
|
||||
memoryCorpus: corpus.NewMemoryCorpus(),
|
||||
webCorpus: corpus.NewWebCorpus(),
|
||||
observer: observer,
|
||||
}
|
||||
}
|
||||
|
||||
// IngestMemory 统一承接记忆语料入库。
|
||||
func (r *runtime) IngestMemory(ctx context.Context, req MemoryIngestRequest) (*IngestResult, error) {
|
||||
items := make([]corpus.MemoryIngestItem, 0, len(req.Items))
|
||||
for _, item := range req.Items {
|
||||
items = append(items, corpus.MemoryIngestItem{
|
||||
MemoryID: item.MemoryID,
|
||||
UserID: item.UserID,
|
||||
ConversationID: item.ConversationID,
|
||||
AssistantID: item.AssistantID,
|
||||
RunID: item.RunID,
|
||||
MemoryType: item.MemoryType,
|
||||
Title: item.Title,
|
||||
Content: item.Content,
|
||||
Confidence: item.Confidence,
|
||||
Importance: item.Importance,
|
||||
SensitivityLevel: item.SensitivityLevel,
|
||||
IsExplicit: item.IsExplicit,
|
||||
Status: item.Status,
|
||||
TTLAt: item.TTLAt,
|
||||
CreatedAt: item.CreatedAt,
|
||||
})
|
||||
}
|
||||
return r.ingestWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, items, req.Action)
|
||||
}
|
||||
|
||||
// RetrieveMemory 统一承接记忆语料检索。
|
||||
func (r *runtime) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (*RetrieveResult, error) {
|
||||
corpusInput := corpus.MemoryRetrieveInput{
|
||||
UserID: req.UserID,
|
||||
ConversationID: req.ConversationID,
|
||||
AssistantID: req.AssistantID,
|
||||
RunID: req.RunID,
|
||||
}
|
||||
if len(req.MemoryTypes) == 1 {
|
||||
corpusInput.MemoryType = req.MemoryTypes[0]
|
||||
}
|
||||
|
||||
result, err := r.retrieveWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, core.RetrieveRequest{
|
||||
Query: req.Query,
|
||||
TopK: normalizeTopK(req.TopK, r.cfg.TopK),
|
||||
Threshold: normalizeThreshold(req.Threshold, r.cfg.Threshold),
|
||||
Action: normalizeAction(req.Action, "search"),
|
||||
CorpusInput: corpusInput,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(req.MemoryTypes) <= 1 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 1. 当前底层过滤仍以等值条件为主,先保持 Runtime 做多类型二次筛选;
|
||||
// 2. 这样可以避免把 “memory_type in (...)” 的实现细节扩散到所有 Store;
|
||||
// 3. 等后续底层过滤能力统一后,再考虑把该逻辑继续下沉。
|
||||
allowed := make(map[string]struct{}, len(req.MemoryTypes))
|
||||
for _, item := range req.MemoryTypes {
|
||||
value := strings.TrimSpace(strings.ToLower(item))
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
allowed[value] = struct{}{}
|
||||
}
|
||||
|
||||
filtered := make([]RetrieveHit, 0, len(result.Items))
|
||||
for _, item := range result.Items {
|
||||
memoryType := strings.TrimSpace(strings.ToLower(asString(item.Metadata["memory_type"])))
|
||||
if len(allowed) > 0 {
|
||||
if _, ok := allowed[memoryType]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
result.Items = filtered
|
||||
if req.TopK > 0 && len(result.Items) > req.TopK {
|
||||
result.Items = result.Items[:req.TopK]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// IngestWeb 统一承接网页语料入库。
|
||||
func (r *runtime) IngestWeb(ctx context.Context, req WebIngestRequest) (*IngestResult, error) {
|
||||
items := make([]corpus.WebIngestItem, 0, len(req.Items))
|
||||
for _, item := range req.Items {
|
||||
items = append(items, corpus.WebIngestItem{
|
||||
URL: item.URL,
|
||||
Title: item.Title,
|
||||
Content: item.Content,
|
||||
Snippet: item.Snippet,
|
||||
Domain: item.Domain,
|
||||
QueryID: item.QueryID,
|
||||
SessionID: item.SessionID,
|
||||
PublishedAt: item.PublishedAt,
|
||||
FetchedAt: item.FetchedAt,
|
||||
SourceRank: item.SourceRank,
|
||||
})
|
||||
}
|
||||
return r.ingestWithCorpus(ctx, req.TraceID, "web", r.webCorpus, items, req.Action)
|
||||
}
|
||||
|
||||
// RetrieveWeb 统一承接网页语料检索。
|
||||
func (r *runtime) RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (*RetrieveResult, error) {
|
||||
return r.retrieveWithCorpus(ctx, req.TraceID, "web", r.webCorpus, core.RetrieveRequest{
|
||||
Query: req.Query,
|
||||
TopK: normalizeTopK(req.TopK, r.cfg.TopK),
|
||||
Threshold: normalizeThreshold(req.Threshold, r.cfg.Threshold),
|
||||
Action: normalizeAction(req.Action, "search"),
|
||||
CorpusInput: corpus.WebRetrieveInput{
|
||||
QueryID: req.QueryID,
|
||||
SessionID: req.SessionID,
|
||||
Domain: req.Domain,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (r *runtime) ingestWithCorpus(
|
||||
ctx context.Context,
|
||||
traceID string,
|
||||
corpusName string,
|
||||
adapter core.CorpusAdapter,
|
||||
input any,
|
||||
action string,
|
||||
) (*IngestResult, error) {
|
||||
start := time.Now()
|
||||
if r == nil || r.pipeline == nil || adapter == nil {
|
||||
return nil, core.ErrNilDependency
|
||||
}
|
||||
|
||||
action = normalizeAction(action, "add")
|
||||
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
|
||||
|
||||
docs, err := adapter.BuildIngestDocuments(observeCtx, input)
|
||||
if err != nil {
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelError,
|
||||
Component: "runtime",
|
||||
Operation: "ingest",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"phase": "build_documents",
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
"input_count": estimateInputCount(input),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
docIDs := make([]string, 0, len(docs))
|
||||
for _, doc := range docs {
|
||||
docIDs = append(docIDs, doc.ID)
|
||||
}
|
||||
|
||||
result, err := r.pipeline.IngestDocuments(observeCtx, adapter.Name(), docs, core.IngestOption{
|
||||
Chunk: core.ChunkOption{
|
||||
ChunkSize: r.cfg.ChunkSize,
|
||||
ChunkOverlap: r.cfg.ChunkOverlap,
|
||||
},
|
||||
Action: action,
|
||||
})
|
||||
if err != nil {
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelError,
|
||||
Component: "runtime",
|
||||
Operation: "ingest",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"document_count": len(docs),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelInfo,
|
||||
Component: "runtime",
|
||||
Operation: "ingest",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"document_count": result.DocumentCount,
|
||||
"chunk_count": result.ChunkCount,
|
||||
},
|
||||
})
|
||||
return &IngestResult{
|
||||
DocumentCount: result.DocumentCount,
|
||||
ChunkCount: result.ChunkCount,
|
||||
DocumentIDs: docIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *runtime) retrieveWithCorpus(
|
||||
ctx context.Context,
|
||||
traceID string,
|
||||
corpusName string,
|
||||
adapter core.CorpusAdapter,
|
||||
req core.RetrieveRequest,
|
||||
) (*RetrieveResult, error) {
|
||||
start := time.Now()
|
||||
if r == nil || r.pipeline == nil || adapter == nil {
|
||||
return nil, core.ErrNilDependency
|
||||
}
|
||||
|
||||
action := normalizeAction(req.Action, "search")
|
||||
req.Action = action
|
||||
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
|
||||
|
||||
timeoutCtx := observeCtx
|
||||
cancel := func() {}
|
||||
if r.cfg.RetrieveTimeoutMS > 0 {
|
||||
timeoutCtx, cancel = context.WithTimeout(observeCtx, time.Duration(r.cfg.RetrieveTimeoutMS)*time.Millisecond)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
result, err := r.pipeline.Retrieve(timeoutCtx, adapter, req)
|
||||
if err != nil {
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelError,
|
||||
Component: "runtime",
|
||||
Operation: "retrieve",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"query_len": len(strings.TrimSpace(req.Query)),
|
||||
"top_k": req.TopK,
|
||||
"threshold": req.Threshold,
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
items := make([]RetrieveHit, 0, len(result.Items))
|
||||
for _, item := range result.Items {
|
||||
items = append(items, RetrieveHit{
|
||||
ChunkID: item.ChunkID,
|
||||
DocumentID: item.DocumentID,
|
||||
Text: item.Text,
|
||||
Score: item.Score,
|
||||
Metadata: cloneMap(item.Metadata),
|
||||
})
|
||||
}
|
||||
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelInfo,
|
||||
Component: "runtime",
|
||||
Operation: "retrieve",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"query_len": len(strings.TrimSpace(req.Query)),
|
||||
"top_k": req.TopK,
|
||||
"threshold": req.Threshold,
|
||||
"raw_count": result.RawCount,
|
||||
"hit_count": len(result.Items),
|
||||
"fallback_used": result.FallbackUsed,
|
||||
"fallback_reason": result.FallbackReason,
|
||||
},
|
||||
})
|
||||
return &RetrieveResult{
|
||||
Items: items,
|
||||
RawCount: result.RawCount,
|
||||
FallbackUsed: result.FallbackUsed,
|
||||
FallbackReason: result.FallbackReason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *runtime) observe(ctx context.Context, event ObserveEvent) {
|
||||
if r == nil || r.observer == nil {
|
||||
return
|
||||
}
|
||||
r.observer.Observe(ctx, event)
|
||||
}
|
||||
|
||||
func newObserveContext(ctx context.Context, traceID string, corpusName string, action string) context.Context {
|
||||
fields := map[string]any{
|
||||
"corpus": corpusName,
|
||||
"action": action,
|
||||
}
|
||||
if traceID = strings.TrimSpace(traceID); traceID != "" {
|
||||
fields["trace_id"] = traceID
|
||||
}
|
||||
return core.WithObserveFields(ctx, fields)
|
||||
}
|
||||
|
||||
func estimateInputCount(input any) int {
|
||||
switch value := input.(type) {
|
||||
case []corpus.MemoryIngestItem:
|
||||
return len(value)
|
||||
case []corpus.WebIngestItem:
|
||||
return len(value)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeAction(action string, fallback string) string {
|
||||
action = strings.TrimSpace(action)
|
||||
if action == "" {
|
||||
return fallback
|
||||
}
|
||||
return action
|
||||
}
|
||||
|
||||
func normalizeTopK(topK int, fallback int) int {
|
||||
if topK > 0 {
|
||||
return topK
|
||||
}
|
||||
if fallback > 0 {
|
||||
return fallback
|
||||
}
|
||||
return 8
|
||||
}
|
||||
|
||||
func normalizeThreshold(threshold float64, fallback float64) float64 {
|
||||
if threshold >= 0 {
|
||||
return threshold
|
||||
}
|
||||
if fallback >= 0 {
|
||||
return fallback
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for key, value := range src {
|
||||
dst[key] = value
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
117
backend/infra/rag/service.go
Normal file
117
backend/infra/rag/service.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Runtime 是 RAG Infra 对业务侧暴露的唯一稳定方法面。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责承接 memory/web 两类语料的统一入库与检索入口;
|
||||
// 2. 负责屏蔽底层 Pipeline / Store / Embedder / Reranker 的装配细节;
|
||||
// 3. 不负责 provider 搜索、HTML 抓取、prompt 注入等业务语义。
|
||||
type Runtime interface {
|
||||
IngestMemory(ctx context.Context, req MemoryIngestRequest) (*IngestResult, error)
|
||||
RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (*RetrieveResult, error)
|
||||
|
||||
IngestWeb(ctx context.Context, req WebIngestRequest) (*IngestResult, error)
|
||||
RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (*RetrieveResult, error)
|
||||
}
|
||||
|
||||
// IngestResult 描述一次统一入库执行摘要。
|
||||
type IngestResult struct {
|
||||
DocumentCount int
|
||||
ChunkCount int
|
||||
DocumentIDs []string
|
||||
}
|
||||
|
||||
// RetrieveHit 是对业务侧暴露的统一命中项。
|
||||
type RetrieveHit struct {
|
||||
ChunkID string
|
||||
DocumentID string
|
||||
Text string
|
||||
Score float64
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// RetrieveResult 描述一次检索执行摘要。
|
||||
type RetrieveResult struct {
|
||||
Items []RetrieveHit
|
||||
RawCount int
|
||||
FallbackUsed bool
|
||||
FallbackReason string
|
||||
}
|
||||
|
||||
// MemoryIngestItem 是 memory 语料入库项。
|
||||
type MemoryIngestItem struct {
|
||||
MemoryID int64
|
||||
UserID int
|
||||
ConversationID string
|
||||
AssistantID string
|
||||
RunID string
|
||||
MemoryType string
|
||||
Title string
|
||||
Content string
|
||||
Confidence float64
|
||||
Importance float64
|
||||
SensitivityLevel int
|
||||
IsExplicit bool
|
||||
Status string
|
||||
TTLAt *time.Time
|
||||
CreatedAt *time.Time
|
||||
}
|
||||
|
||||
// MemoryIngestRequest 描述一次记忆向量入库请求。
|
||||
type MemoryIngestRequest struct {
|
||||
TraceID string
|
||||
Action string
|
||||
Items []MemoryIngestItem
|
||||
}
|
||||
|
||||
// MemoryRetrieveRequest 描述一次记忆检索请求。
|
||||
type MemoryRetrieveRequest struct {
|
||||
TraceID string
|
||||
Query string
|
||||
TopK int
|
||||
Threshold float64
|
||||
Action string
|
||||
UserID int
|
||||
ConversationID string
|
||||
AssistantID string
|
||||
RunID string
|
||||
MemoryTypes []string
|
||||
}
|
||||
|
||||
// WebIngestItem 是网页语料入库项。
|
||||
type WebIngestItem struct {
|
||||
URL string
|
||||
Title string
|
||||
Content string
|
||||
Snippet string
|
||||
Domain string
|
||||
QueryID string
|
||||
SessionID string
|
||||
PublishedAt *time.Time
|
||||
FetchedAt *time.Time
|
||||
SourceRank int
|
||||
}
|
||||
|
||||
// WebIngestRequest 描述一次网页语料入库请求。
|
||||
type WebIngestRequest struct {
|
||||
TraceID string
|
||||
Action string
|
||||
Items []WebIngestItem
|
||||
}
|
||||
|
||||
// WebRetrieveRequest 描述一次网页检索请求。
|
||||
type WebRetrieveRequest struct {
|
||||
TraceID string
|
||||
Query string
|
||||
TopK int
|
||||
Threshold float64
|
||||
Action string
|
||||
QueryID string
|
||||
SessionID string
|
||||
Domain string
|
||||
}
|
||||
@@ -1,35 +1,894 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
|
||||
)
|
||||
|
||||
// MilvusStore 是 Milvus 连接器占位实现。
|
||||
// MilvusConfig 描述 Milvus REST 存储配置。
|
||||
type MilvusConfig struct {
|
||||
// Address 应指向 Milvus REST 入口。
|
||||
// 当前项目联调验证使用 19530;9091 仅用于 health/metrics,不承载本文实现所走的 REST API。
|
||||
Address string
|
||||
Token string
|
||||
DBName string
|
||||
CollectionName string
|
||||
RequestTimeoutMS int
|
||||
Dimension int
|
||||
MetricType string
|
||||
Logger *log.Logger
|
||||
Observer core.Observer
|
||||
}
|
||||
|
||||
// MilvusStore 是基于 Milvus REST API 的向量存储实现。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 本轮先保留接口结构,便于后续平滑替换 InMemoryStore;
|
||||
// 2. 真实接入时需补充连接池、集合初始化、元数据过滤与错误转换。
|
||||
type MilvusStore struct{}
|
||||
|
||||
func NewMilvusStore() *MilvusStore {
|
||||
return &MilvusStore{}
|
||||
// 设计说明:
|
||||
// 1. 本实现优先保证“项目内可接入、可管理、可灰度”,不强依赖额外 SDK;
|
||||
// 2. 通过固定字段 + metadata JSON 的方式兼顾过滤能力与元数据完整性;
|
||||
// 3. collection 在首次写入时自动创建,避免启动期额外初始化脚本。
|
||||
type MilvusStore struct {
|
||||
cfg MilvusConfig
|
||||
client *http.Client
|
||||
observer core.Observer
|
||||
mu sync.Mutex
|
||||
ensured bool
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Upsert(_ context.Context, _ []core.VectorRow) error {
|
||||
return errors.New("milvus store is not implemented yet")
|
||||
const (
|
||||
milvusPrimaryField = "id"
|
||||
milvusVectorField = "vector"
|
||||
milvusTextField = "text"
|
||||
milvusMetadataField = "metadata"
|
||||
milvusCorpusField = "corpus"
|
||||
milvusDocumentField = "document_id"
|
||||
milvusUserIDField = "user_id"
|
||||
milvusAssistantField = "assistant_id"
|
||||
milvusConvField = "conversation_id"
|
||||
milvusRunField = "run_id"
|
||||
milvusMemoryType = "memory_type"
|
||||
milvusQueryIDField = "query_id"
|
||||
milvusSessionField = "session_id"
|
||||
milvusDomainField = "domain"
|
||||
milvusChunkOrder = "chunk_order"
|
||||
milvusUpdatedAtField = "updated_at"
|
||||
)
|
||||
|
||||
var milvusFilterFieldMap = map[string]string{
|
||||
"corpus": milvusCorpusField,
|
||||
"document_id": milvusDocumentField,
|
||||
"user_id": milvusUserIDField,
|
||||
"assistant_id": milvusAssistantField,
|
||||
"conversation_id": milvusConvField,
|
||||
"run_id": milvusRunField,
|
||||
"memory_type": milvusMemoryType,
|
||||
"query_id": milvusQueryIDField,
|
||||
"session_id": milvusSessionField,
|
||||
"domain": milvusDomainField,
|
||||
"chunk_order": milvusChunkOrder,
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Search(_ context.Context, _ core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
|
||||
return nil, errors.New("milvus store is not implemented yet")
|
||||
func NewMilvusStore(cfg MilvusConfig) (*MilvusStore, error) {
|
||||
cfg.Address = strings.TrimRight(strings.TrimSpace(cfg.Address), "/")
|
||||
if cfg.Address == "" {
|
||||
return nil, errors.New("milvus address is empty")
|
||||
}
|
||||
if cfg.CollectionName == "" {
|
||||
cfg.CollectionName = "smartflow_rag_chunks"
|
||||
}
|
||||
if cfg.MetricType == "" {
|
||||
cfg.MetricType = "COSINE"
|
||||
}
|
||||
if cfg.RequestTimeoutMS <= 0 {
|
||||
cfg.RequestTimeoutMS = 1500
|
||||
}
|
||||
if cfg.Logger == nil {
|
||||
cfg.Logger = log.Default()
|
||||
}
|
||||
if cfg.Observer == nil {
|
||||
cfg.Observer = core.NewLoggerObserver(cfg.Logger)
|
||||
}
|
||||
|
||||
return &MilvusStore{
|
||||
cfg: cfg,
|
||||
client: &http.Client{Timeout: time.Duration(cfg.RequestTimeoutMS) * time.Millisecond},
|
||||
observer: cfg.Observer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Delete(_ context.Context, _ []string) error {
|
||||
return errors.New("milvus store is not implemented yet")
|
||||
func (s *MilvusStore) Upsert(ctx context.Context, rows []core.VectorRow) error {
|
||||
start := time.Now()
|
||||
if len(rows) == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := s.ensureCollection(ctx, len(rows[0].Vector)); err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "upsert",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"row_count": len(rows),
|
||||
"vector_dim": len(rows[0].Vector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
data := make([]map[string]any, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
item := mapRowToMilvusEntity(row)
|
||||
data = append(data, item)
|
||||
}
|
||||
|
||||
_, err := s.postJSON(ctx, "/v2/vectordb/entities/upsert", map[string]any{
|
||||
"collectionName": s.cfg.CollectionName,
|
||||
"data": data,
|
||||
"dbName": blankToNil(s.cfg.DBName),
|
||||
})
|
||||
if err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "upsert",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"row_count": len(rows),
|
||||
"vector_dim": len(rows[0].Vector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "upsert",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"row_count": len(rows),
|
||||
"vector_dim": len(rows[0].Vector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Get(_ context.Context, _ []string) ([]core.VectorRow, error) {
|
||||
return nil, errors.New("milvus store is not implemented yet")
|
||||
func (s *MilvusStore) Search(ctx context.Context, req core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
|
||||
start := time.Now()
|
||||
if len(req.QueryVector) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if err := s.ensureCollection(ctx, len(req.QueryVector)); err != nil {
|
||||
if isMilvusCollectionMissing(err) {
|
||||
return nil, nil
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filterExpr, err := buildMilvusFilter(req.Filter)
|
||||
if err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body := map[string]any{
|
||||
"collectionName": s.cfg.CollectionName,
|
||||
"data": [][]float32{req.QueryVector},
|
||||
"annsField": milvusVectorField,
|
||||
"limit": normalizeMilvusTopK(req.TopK),
|
||||
"outputFields": milvusOutputFields(false),
|
||||
}
|
||||
if filterExpr != "" {
|
||||
body["filter"] = filterExpr
|
||||
}
|
||||
if s.cfg.DBName != "" {
|
||||
body["dbName"] = s.cfg.DBName
|
||||
}
|
||||
|
||||
respBody, err := s.postJSON(ctx, "/v2/vectordb/entities/search", body)
|
||||
if err != nil {
|
||||
if isMilvusCollectionMissing(err) {
|
||||
return nil, nil
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp milvusSearchResponse
|
||||
if err = json.Unmarshal(respBody, &resp); err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
if resp.Code != 0 && resp.Code != 200 {
|
||||
err = fmt.Errorf("milvus search failed: code=%d message=%s", resp.Code, resp.Message)
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]core.ScoredVectorRow, 0, len(resp.Data))
|
||||
for _, item := range resp.Data {
|
||||
row, score := item.toVectorRow()
|
||||
result = append(result, core.ScoredVectorRow{
|
||||
Row: row,
|
||||
Score: score,
|
||||
})
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"result_count": len(result),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Delete(ctx context.Context, ids []string) error {
|
||||
start := time.Now()
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
filter := fmt.Sprintf(`%s in [%s]`, milvusPrimaryField, joinQuotedStrings(ids))
|
||||
_, err := s.postJSON(ctx, "/v2/vectordb/entities/delete", map[string]any{
|
||||
"collectionName": s.cfg.CollectionName,
|
||||
"filter": filter,
|
||||
"dbName": blankToNil(s.cfg.DBName),
|
||||
})
|
||||
if isMilvusCollectionMissing(err) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "delete",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"id_count": len(ids),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "delete",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"id_count": len(ids),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Get(ctx context.Context, ids []string) ([]core.VectorRow, error) {
|
||||
start := time.Now()
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
respBody, err := s.postJSON(ctx, "/v2/vectordb/entities/get", map[string]any{
|
||||
"collectionName": s.cfg.CollectionName,
|
||||
"id": ids,
|
||||
"outputFields": milvusOutputFields(true),
|
||||
"dbName": blankToNil(s.cfg.DBName),
|
||||
})
|
||||
if err != nil {
|
||||
if isMilvusCollectionMissing(err) {
|
||||
return nil, nil
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "get",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"id_count": len(ids),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp milvusGetResponse
|
||||
if err = json.Unmarshal(respBody, &resp); err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "get",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"id_count": len(ids),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
if resp.Code != 0 && resp.Code != 200 {
|
||||
err = fmt.Errorf("milvus get failed: code=%d message=%s", resp.Code, resp.Message)
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "get",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"id_count": len(ids),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows := make([]core.VectorRow, 0, len(resp.Data))
|
||||
for _, item := range resp.Data {
|
||||
rows = append(rows, mapMilvusRow(item, true))
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "get",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"id_count": len(ids),
|
||||
"row_count": len(rows),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) ensureCollection(ctx context.Context, dimension int) error {
|
||||
start := time.Now()
|
||||
if dimension <= 0 {
|
||||
dimension = s.cfg.Dimension
|
||||
}
|
||||
if dimension <= 0 {
|
||||
return errors.New("milvus vector dimension is invalid")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.ensured {
|
||||
return nil
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"collectionName": s.cfg.CollectionName,
|
||||
"schema": map[string]any{
|
||||
"autoId": false,
|
||||
"enabledDynamicField": false,
|
||||
"fields": []map[string]any{
|
||||
buildVarcharField(milvusPrimaryField, true, 256),
|
||||
buildVectorField(milvusVectorField, dimension),
|
||||
buildVarcharField(milvusTextField, false, 65535),
|
||||
{"fieldName": milvusMetadataField, "dataType": "JSON"},
|
||||
buildVarcharField(milvusCorpusField, false, 64),
|
||||
buildVarcharField(milvusDocumentField, false, 256),
|
||||
{"fieldName": milvusUserIDField, "dataType": "Int64"},
|
||||
buildVarcharField(milvusAssistantField, false, 128),
|
||||
buildVarcharField(milvusConvField, false, 128),
|
||||
buildVarcharField(milvusRunField, false, 128),
|
||||
buildVarcharField(milvusMemoryType, false, 64),
|
||||
buildVarcharField(milvusQueryIDField, false, 128),
|
||||
buildVarcharField(milvusSessionField, false, 128),
|
||||
buildVarcharField(milvusDomainField, false, 128),
|
||||
{"fieldName": milvusChunkOrder, "dataType": "Int64"},
|
||||
{"fieldName": milvusUpdatedAtField, "dataType": "Int64"},
|
||||
},
|
||||
},
|
||||
"indexParams": []map[string]any{
|
||||
{
|
||||
"fieldName": milvusVectorField,
|
||||
"indexName": milvusVectorField,
|
||||
"metricType": s.cfg.MetricType,
|
||||
"indexType": "AUTOINDEX",
|
||||
},
|
||||
},
|
||||
}
|
||||
if s.cfg.DBName != "" {
|
||||
payload["dbName"] = s.cfg.DBName
|
||||
}
|
||||
|
||||
_, err := s.postJSON(ctx, "/v2/vectordb/collections/create", payload)
|
||||
if err != nil {
|
||||
if isMilvusAlreadyExists(err) {
|
||||
s.ensured = true
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "ensure_collection",
|
||||
Fields: map[string]any{
|
||||
"status": "already_exists",
|
||||
"vector_dim": dimension,
|
||||
"metric_type": s.cfg.MetricType,
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "ensure_collection",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"vector_dim": dimension,
|
||||
"metric_type": s.cfg.MetricType,
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
s.ensured = true
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "ensure_collection",
|
||||
Fields: map[string]any{
|
||||
"status": "created",
|
||||
"vector_dim": dimension,
|
||||
"metric_type": s.cfg.MetricType,
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) postJSON(ctx context.Context, path string, payload map[string]any) ([]byte, error) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.cfg.Address+path, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if token := strings.TrimSpace(s.cfg.Token); token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, readErr
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("milvus http failed: status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var basic milvusBasicResponse
|
||||
if jsonErr := json.Unmarshal(respBody, &basic); jsonErr == nil {
|
||||
if basic.Code != 0 && basic.Code != 200 {
|
||||
return nil, fmt.Errorf("milvus api failed: code=%d message=%s", basic.Code, basic.Message)
|
||||
}
|
||||
}
|
||||
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) observe(ctx context.Context, event core.ObserveEvent) {
|
||||
if s == nil || s.observer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
fields := cloneMap(event.Fields)
|
||||
fields["store"] = "milvus"
|
||||
fields["collection"] = s.cfg.CollectionName
|
||||
if strings.TrimSpace(s.cfg.DBName) != "" {
|
||||
fields["db_name"] = s.cfg.DBName
|
||||
}
|
||||
|
||||
s.observer.Observe(ctx, core.ObserveEvent{
|
||||
Level: event.Level,
|
||||
Component: event.Component,
|
||||
Operation: event.Operation,
|
||||
Fields: fields,
|
||||
})
|
||||
}
|
||||
|
||||
func mapRowToMilvusEntity(row core.VectorRow) map[string]any {
|
||||
metadata := cloneMap(row.Metadata)
|
||||
entity := map[string]any{
|
||||
milvusPrimaryField: row.ID,
|
||||
milvusVectorField: row.Vector,
|
||||
milvusTextField: row.Text,
|
||||
milvusMetadataField: metadata,
|
||||
milvusCorpusField: asString(metadata["corpus"]),
|
||||
milvusDocumentField: asString(metadata["document_id"]),
|
||||
milvusUpdatedAtField: func() int64 {
|
||||
if row.UpdatedAt.IsZero() {
|
||||
return time.Now().UnixMilli()
|
||||
}
|
||||
return row.UpdatedAt.UnixMilli()
|
||||
}(),
|
||||
}
|
||||
assignMilvusScalar(entity, milvusUserIDField, metadata["user_id"])
|
||||
assignMilvusScalar(entity, milvusAssistantField, metadata["assistant_id"])
|
||||
assignMilvusScalar(entity, milvusConvField, metadata["conversation_id"])
|
||||
assignMilvusScalar(entity, milvusRunField, metadata["run_id"])
|
||||
assignMilvusScalar(entity, milvusMemoryType, metadata["memory_type"])
|
||||
assignMilvusScalar(entity, milvusQueryIDField, metadata["query_id"])
|
||||
assignMilvusScalar(entity, milvusSessionField, metadata["session_id"])
|
||||
assignMilvusScalar(entity, milvusDomainField, metadata["domain"])
|
||||
assignMilvusScalar(entity, milvusChunkOrder, metadata["chunk_order"])
|
||||
return entity
|
||||
}
|
||||
|
||||
func assignMilvusScalar(target map[string]any, field string, value any) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
switch field {
|
||||
case milvusUserIDField, milvusChunkOrder:
|
||||
if parsed, ok := toInt64(value); ok {
|
||||
target[field] = parsed
|
||||
}
|
||||
default:
|
||||
if text := asString(value); text != "" {
|
||||
target[field] = text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildMilvusFilter(filter map[string]any) (string, error) {
|
||||
if len(filter) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(filter))
|
||||
for key, value := range filter {
|
||||
field, ok := milvusFilterFieldMap[key]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unsupported milvus filter key: %s", key)
|
||||
}
|
||||
switch field {
|
||||
case milvusUserIDField, milvusChunkOrder:
|
||||
parsed, parseOK := toInt64(value)
|
||||
if !parseOK {
|
||||
return "", fmt.Errorf("milvus filter key=%s expects integer", key)
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("%s == %d", field, parsed))
|
||||
default:
|
||||
text := escapeMilvusString(asString(value))
|
||||
parts = append(parts, fmt.Sprintf(`%s == "%s"`, field, text))
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " and "), nil
|
||||
}
|
||||
|
||||
func buildVarcharField(name string, isPrimary bool, maxLength int) map[string]any {
|
||||
field := map[string]any{
|
||||
"fieldName": name,
|
||||
"dataType": "VarChar",
|
||||
"elementTypeParams": map[string]any{"max_length": maxLength},
|
||||
}
|
||||
if isPrimary {
|
||||
field["isPrimary"] = true
|
||||
}
|
||||
return field
|
||||
}
|
||||
|
||||
func buildVectorField(name string, dimension int) map[string]any {
|
||||
return map[string]any{
|
||||
"fieldName": name,
|
||||
"dataType": "FloatVector",
|
||||
"elementTypeParams": map[string]any{"dim": dimension},
|
||||
}
|
||||
}
|
||||
|
||||
func milvusOutputFields(includeVector bool) []string {
|
||||
fields := []string{
|
||||
milvusTextField,
|
||||
milvusMetadataField,
|
||||
milvusCorpusField,
|
||||
milvusDocumentField,
|
||||
milvusUserIDField,
|
||||
milvusAssistantField,
|
||||
milvusConvField,
|
||||
milvusRunField,
|
||||
milvusMemoryType,
|
||||
milvusQueryIDField,
|
||||
milvusSessionField,
|
||||
milvusDomainField,
|
||||
milvusChunkOrder,
|
||||
milvusUpdatedAtField,
|
||||
}
|
||||
if includeVector {
|
||||
fields = append(fields, milvusVectorField)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func normalizeMilvusTopK(topK int) int {
|
||||
if topK <= 0 {
|
||||
return 8
|
||||
}
|
||||
return topK
|
||||
}
|
||||
|
||||
func blankToNil(v string) any {
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func escapeMilvusString(v string) string {
|
||||
v = strings.ReplaceAll(v, `\`, `\\`)
|
||||
return strings.ReplaceAll(v, `"`, `\"`)
|
||||
}
|
||||
|
||||
func joinQuotedStrings(values []string) string {
|
||||
parts := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
parts = append(parts, fmt.Sprintf(`"%s"`, escapeMilvusString(value)))
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for key, value := range src {
|
||||
dst[key] = value
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
|
||||
func toInt64(v any) (int64, bool) {
|
||||
switch value := v.(type) {
|
||||
case int:
|
||||
return int64(value), true
|
||||
case int32:
|
||||
return int64(value), true
|
||||
case int64:
|
||||
return value, true
|
||||
case float64:
|
||||
return int64(value), true
|
||||
case json.Number:
|
||||
parsed, err := value.Int64()
|
||||
return parsed, err == nil
|
||||
case string:
|
||||
parsed, err := strconv.ParseInt(strings.TrimSpace(value), 10, 64)
|
||||
return parsed, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func isMilvusAlreadyExists(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
text := strings.ToLower(err.Error())
|
||||
return strings.Contains(text, "already exist") || strings.Contains(text, "already exists")
|
||||
}
|
||||
|
||||
func isMilvusCollectionMissing(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
text := strings.ToLower(err.Error())
|
||||
return strings.Contains(text, "can't find collection") || strings.Contains(text, "collection not found")
|
||||
}
|
||||
|
||||
type milvusBasicResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type milvusSearchResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data []milvusSearchItem `json:"data"`
|
||||
}
|
||||
|
||||
type milvusSearchItem map[string]any
|
||||
|
||||
func (m milvusSearchItem) toVectorRow() (core.VectorRow, float64) {
|
||||
row := mapMilvusRow(map[string]any(m), false)
|
||||
score := 0.0
|
||||
if value, ok := m["distance"].(float64); ok {
|
||||
score = value
|
||||
}
|
||||
return row, score
|
||||
}
|
||||
|
||||
type milvusGetResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data []map[string]any `json:"data"`
|
||||
}
|
||||
|
||||
func mapMilvusRow(raw map[string]any, includeVector bool) core.VectorRow {
|
||||
metadata := cloneMap(readMetadataMap(raw[milvusMetadataField]))
|
||||
assignMetadataIfPresent(metadata, "corpus", raw[milvusCorpusField])
|
||||
assignMetadataIfPresent(metadata, "document_id", raw[milvusDocumentField])
|
||||
assignMetadataIfPresent(metadata, "user_id", raw[milvusUserIDField])
|
||||
assignMetadataIfPresent(metadata, "assistant_id", raw[milvusAssistantField])
|
||||
assignMetadataIfPresent(metadata, "conversation_id", raw[milvusConvField])
|
||||
assignMetadataIfPresent(metadata, "run_id", raw[milvusRunField])
|
||||
assignMetadataIfPresent(metadata, "memory_type", raw[milvusMemoryType])
|
||||
assignMetadataIfPresent(metadata, "query_id", raw[milvusQueryIDField])
|
||||
assignMetadataIfPresent(metadata, "session_id", raw[milvusSessionField])
|
||||
assignMetadataIfPresent(metadata, "domain", raw[milvusDomainField])
|
||||
assignMetadataIfPresent(metadata, "chunk_order", raw[milvusChunkOrder])
|
||||
|
||||
row := core.VectorRow{
|
||||
ID: asString(raw[milvusPrimaryField]),
|
||||
Text: asString(raw[milvusTextField]),
|
||||
Metadata: metadata,
|
||||
}
|
||||
if row.ID == "" {
|
||||
row.ID = asString(raw["id"])
|
||||
}
|
||||
if includeVector {
|
||||
row.Vector = readFloat32Vector(raw[milvusVectorField])
|
||||
}
|
||||
return row
|
||||
}
|
||||
|
||||
func readMetadataMap(value any) map[string]any {
|
||||
switch data := value.(type) {
|
||||
case map[string]any:
|
||||
return data
|
||||
default:
|
||||
return map[string]any{}
|
||||
}
|
||||
}
|
||||
|
||||
func readFloat32Vector(value any) []float32 {
|
||||
switch vector := value.(type) {
|
||||
case []float32:
|
||||
return vector
|
||||
case []any:
|
||||
result := make([]float32, 0, len(vector))
|
||||
for _, item := range vector {
|
||||
switch number := item.(type) {
|
||||
case float64:
|
||||
result = append(result, float32(number))
|
||||
case float32:
|
||||
result = append(result, number)
|
||||
}
|
||||
}
|
||||
return result
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func assignMetadataIfPresent(target map[string]any, key string, value any) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(typed) == "" {
|
||||
return
|
||||
}
|
||||
target[key] = strings.TrimSpace(typed)
|
||||
default:
|
||||
target[key] = typed
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,4 +5,5 @@ import "github.com/LoveLosita/smartflow/backend/infra/rag/core"
|
||||
// EnsureCompile 用于静态校验实现是否满足接口。
|
||||
func EnsureCompile() {
|
||||
var _ core.VectorStore = (*InMemoryVectorStore)(nil)
|
||||
var _ core.VectorStore = (*MilvusStore)(nil)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user