Version: 0.9.65.dev.260503

后端:
1. 阶段 1.5/1.6
收口 llm-service / rag-service,统一模型出口与检索基础设施入口,清退 backend/infra/llm 与 backend/infra/rag 旧实现;
2. 同步更新相关调用链与微服务迁移计划文档
This commit is contained in:
Losita
2026-05-03 23:21:03 +08:00
parent a6c1e5d077
commit 9902ca3563
65 changed files with 550 additions and 376 deletions

View File

@@ -1,87 +0,0 @@
// 过渡期统一 Ark 调用封装。
//
// 这里保留 CallArkText / CallArkJSON方便暂时还直接持有 *ark.ChatModel 的调用点
// 逐步迁移到统一 Client。后续 memory 也可以直接复用这套中立层。
package llm
import (
"context"
"errors"
"strings"
"github.com/cloudwego/eino-ext/components/model/ark"
einoModel "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
)
// ArkCallOptions 是基于 ark.ChatModel 的通用调用选项。
//
// 设计目的:
// 1. 先把 Ark 调用样板抽成公共层;
// 2. 再由 WrapArkClient 提供统一 Client
// 3. 让上层尽量只关注业务 prompt 和结构化结果。
type ArkCallOptions struct {
Temperature float64
MaxTokens int
Thinking ThinkingMode
}
// CallArkText 调用 ark 模型并返回纯文本。
//
// 职责边界:
// 1. 负责拼 system + user 两段消息;
// 2. 负责统一配置 thinking / temperature / maxTokens
// 3. 负责拦截空响应;
// 4. 不负责 JSON 解析,不负责业务字段校验。
func CallArkText(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, options ArkCallOptions) (string, error) {
if chatModel == nil {
return "", errors.New("ark model is nil")
}
messages := []*schema.Message{
schema.SystemMessage(systemPrompt),
schema.UserMessage(userPrompt),
}
resp, err := chatModel.Generate(ctx, messages, buildArkOptions(options)...)
if err != nil {
return "", err
}
if resp == nil {
return "", errors.New("模型返回为空")
}
text := strings.TrimSpace(resp.Content)
if text == "" {
return "", errors.New("模型返回内容为空")
}
return text, nil
}
// CallArkJSON 调用 ark 模型并直接解析 JSON。
func CallArkJSON[T any](ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, options ArkCallOptions) (*T, string, error) {
raw, err := CallArkText(ctx, chatModel, systemPrompt, userPrompt, options)
if err != nil {
return nil, "", err
}
parsed, err := ParseJSONObject[T](raw)
if err != nil {
return nil, raw, err
}
return parsed, raw, nil
}
func buildArkOptions(options ArkCallOptions) []einoModel.Option {
thinkingType := arkModel.ThinkingTypeDisabled
if options.Thinking == ThinkingModeEnabled {
thinkingType = arkModel.ThinkingTypeEnabled
}
opts := []einoModel.Option{
ark.WithThinking(&arkModel.Thinking{Type: thinkingType}),
einoModel.WithTemperature(float32(options.Temperature)),
}
if options.MaxTokens > 0 {
opts = append(opts, einoModel.WithMaxTokens(options.MaxTokens))
}
return opts
}

View File

@@ -1,123 +0,0 @@
package llm
import (
"context"
"errors"
"io"
"github.com/cloudwego/eino-ext/components/model/ark"
einoModel "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
)
// WrapArkClient 将 ark.ChatModel 适配为统一 Client。
//
// 职责边界:
// 1. generateText调用 ark.ChatModel.Generate非流式供 GenerateJSON 使用;
// 2. streamText调用 ark.ChatModel.Stream流式供需要流式输出的场景使用
// 3. 两者共用同一套 options 转换。
func WrapArkClient(arkChatModel *ark.ChatModel) *Client {
if arkChatModel == nil {
return nil
}
// 非流式文本生成,供 GenerateJSON / GenerateText 调用路径使用。
generateFunc := func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error) {
arkOpts := buildArkStreamOptions(options)
msg, err := arkChatModel.Generate(ctx, messages, arkOpts...)
if err != nil {
return nil, err
}
if msg == nil {
return nil, errors.New("ark model returned nil message")
}
var usage *schema.TokenUsage
finishReason := ""
if msg.ResponseMeta != nil {
usage = CloneUsage(msg.ResponseMeta.Usage)
finishReason = msg.ResponseMeta.FinishReason
}
return &TextResult{
Text: msg.Content,
Usage: usage,
FinishReason: finishReason,
}, nil
}
// 流式文本生成。
streamFunc := func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error) {
arkOpts := buildArkStreamOptions(options)
reader, err := arkChatModel.Stream(ctx, messages, arkOpts...)
if err != nil {
return nil, err
}
return &arkStreamReaderAdapter{reader: reader}, nil
}
return NewClient(generateFunc, streamFunc)
}
// buildArkStreamOptions 将统一 GenerateOptions 转换为 ark 的流式调用选项。
func buildArkStreamOptions(options GenerateOptions) []einoModel.Option {
thinkingEnabled := options.Thinking == ThinkingModeEnabled
// Thinking
thinkingType := arkModel.ThinkingTypeDisabled
if thinkingEnabled {
thinkingType = arkModel.ThinkingTypeEnabled
}
opts := []einoModel.Option{
ark.WithThinking(&arkModel.Thinking{Type: thinkingType}),
}
// Temperaturethinking 模型强制要求 temperature=1否则 API 静默忽略 thinking。
if thinkingEnabled {
opts = append(opts, einoModel.WithTemperature(1.0))
} else if options.Temperature > 0 {
opts = append(opts, einoModel.WithTemperature(float32(options.Temperature)))
}
// MaxTokensthinking 模式下 thinking token 占用 max_tokens 预算,
// 调用方设定的值仅代表"期望输出长度",实际预算需留出思考空间。
// 最低保障 16000避免思考链被截断导致输出为空或非 JSON。
maxTokens := options.MaxTokens
if thinkingEnabled {
const minThinkingBudget = 16000
if maxTokens < minThinkingBudget {
maxTokens = minThinkingBudget
}
}
if maxTokens > 0 {
opts = append(opts, einoModel.WithMaxTokens(maxTokens))
}
return opts
}
// arkStreamReaderAdapter 适配 ark.ChatModel.Stream 返回的 reader。
// ark.Stream 返回 schema.StreamReader[*schema.Message],其 Close() 方法无返回值
// 而我们的 StreamReader 接口要求 Close() error
type arkStreamReaderAdapter struct {
reader *schema.StreamReader[*schema.Message]
}
// Recv 转发到 ark reader 的 Recv 方法。
func (r *arkStreamReaderAdapter) Recv() (*schema.Message, error) {
if r == nil || r.reader == nil {
return nil, io.EOF
}
return r.reader.Recv()
}
// Close 转发到 ark reader 的 Close 方法。
// ark 的 Close() 无返回值,我们适配为返回 nil
func (r *arkStreamReaderAdapter) Close() error {
if r == nil || r.reader == nil {
return nil
}
r.reader.Close()
return nil
}

View File

@@ -1,337 +0,0 @@
package llm
import (
"context"
"errors"
"fmt"
"strings"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses"
)
// ArkResponsesMessage 描述一次 Responses 输入消息。
//
// 职责边界:
// 1. 负责表达角色与多模态内容(文本/图片);
// 2. 不负责业务 prompt 生成;
// 3. 不负责输出 JSON 的字段校验。
type ArkResponsesMessage struct {
Role string
Text string
ImageURL string
ImageDetail string
}
// ArkResponsesOptions 描述 Responses 生成选项。
type ArkResponsesOptions struct {
Model string
Temperature float64
MaxOutputTokens int
Thinking ThinkingMode
TextFormat string
}
// ArkResponsesUsage 统一透传 token 使用量。
type ArkResponsesUsage struct {
InputTokens int64
OutputTokens int64
TotalTokens int64
}
// ArkResponsesResult 是 Ark Responses 的统一输出结构。
type ArkResponsesResult struct {
Text string
Status string
IncompleteReason string
ErrorCode string
ErrorMessage string
Usage *ArkResponsesUsage
}
// ArkResponsesClient 是 Ark SDK Responses 的统一模型出口。
type ArkResponsesClient struct {
model string
client *arkruntime.Client
}
// NewArkResponsesClient 创建 Ark SDK Responses 客户端。
//
// 说明:
// 1. model 为空时返回 nil表示当前能力未启用
// 2. baseURL 为空时使用 SDK 默认地址;
// 3. 仅负责客户端创建,不做连通性探测。
func NewArkResponsesClient(apiKey string, baseURL string, model string) *ArkResponsesClient {
model = strings.TrimSpace(model)
if model == "" {
return nil
}
options := make([]arkruntime.ConfigOption, 0, 1)
if strings.TrimSpace(baseURL) != "" {
options = append(options, arkruntime.WithBaseUrl(strings.TrimSpace(baseURL)))
}
return &ArkResponsesClient{
model: model,
client: arkruntime.NewClientWithApiKey(strings.TrimSpace(apiKey), options...),
}
}
// GenerateText 执行一次非流式 Responses 调用并提取文本。
func (c *ArkResponsesClient) GenerateText(ctx context.Context, messages []ArkResponsesMessage, options ArkResponsesOptions) (*ArkResponsesResult, error) {
req, err := c.buildRequest(messages, options)
if err != nil {
return nil, err
}
resp, err := c.client.CreateResponses(ctx, req)
if err != nil {
return nil, err
}
result := buildArkResponsesResult(resp)
if result.Status == "failed" {
if result.ErrorMessage != "" {
return result, fmt.Errorf("ark responses failed: %s", result.ErrorMessage)
}
return result, errors.New("ark responses failed")
}
if strings.TrimSpace(result.Text) == "" {
return result, FormatEmptyResponseError("ark_responses")
}
return result, nil
}
// GenerateArkResponsesJSON 先调用 Responses再解析为 JSON 结构体。
func GenerateArkResponsesJSON[T any](ctx context.Context, client *ArkResponsesClient, messages []ArkResponsesMessage, options ArkResponsesOptions) (*T, *ArkResponsesResult, error) {
if client == nil {
return nil, nil, errors.New("ark responses client is not ready")
}
result, err := client.GenerateText(ctx, messages, options)
if err != nil {
return nil, result, err
}
parsed, err := ParseJSONObject[T](result.Text)
if err != nil {
return nil, result, err
}
return parsed, result, nil
}
func (c *ArkResponsesClient) buildRequest(messages []ArkResponsesMessage, options ArkResponsesOptions) (*responses.ResponsesRequest, error) {
if c == nil || c.client == nil {
return nil, errors.New("ark responses client is not ready")
}
if len(messages) == 0 {
return nil, errors.New("ark responses messages is empty")
}
modelName := strings.TrimSpace(options.Model)
if modelName == "" {
modelName = c.model
}
if modelName == "" {
return nil, errors.New("ark responses model is empty")
}
inputItems := make([]*responses.InputItem, 0, len(messages))
for idx := range messages {
item, err := buildInputItem(messages[idx])
if err != nil {
return nil, fmt.Errorf("build ark responses message[%d] failed: %w", idx, err)
}
inputItems = append(inputItems, item)
}
request := &responses.ResponsesRequest{
Model: modelName,
Input: &responses.ResponsesInput{
Union: &responses.ResponsesInput_ListValue{
ListValue: &responses.InputItemList{ListValue: inputItems},
},
},
}
if options.Temperature > 0 {
request.Temperature = float64Ptr(options.Temperature)
}
if options.MaxOutputTokens > 0 {
request.MaxOutputTokens = int64Ptr(int64(options.MaxOutputTokens))
}
switch options.Thinking {
case ThinkingModeEnabled:
thinkingType := responses.ThinkingType_enabled
request.Thinking = &responses.ResponsesThinking{Type: &thinkingType}
case ThinkingModeDisabled:
thinkingType := responses.ThinkingType_disabled
request.Thinking = &responses.ResponsesThinking{Type: &thinkingType}
}
if textType, ok := parseTextType(options.TextFormat); ok {
request.Text = &responses.ResponsesText{
Format: &responses.TextFormat{
Type: textType,
},
}
}
return request, nil
}
func buildInputItem(message ArkResponsesMessage) (*responses.InputItem, error) {
role, ok := parseMessageRole(message.Role)
if !ok {
return nil, fmt.Errorf("unsupported message role: %s", strings.TrimSpace(message.Role))
}
content := make([]*responses.ContentItem, 0, 2)
if text := strings.TrimSpace(message.Text); text != "" {
content = append(content, &responses.ContentItem{
Union: &responses.ContentItem_Text{
Text: &responses.ContentItemText{
Type: responses.ContentItemType_input_text,
Text: text,
},
},
})
}
if imageURL := strings.TrimSpace(message.ImageURL); imageURL != "" {
image := &responses.ContentItemImage{
Type: responses.ContentItemType_input_image,
ImageUrl: stringPtr(imageURL),
}
if detail, ok := parseImageDetail(message.ImageDetail); ok {
image.Detail = &detail
}
content = append(content, &responses.ContentItem{
Union: &responses.ContentItem_Image{
Image: image,
},
})
}
if len(content) == 0 {
return nil, errors.New("message content is empty")
}
return &responses.InputItem{
Union: &responses.InputItem_InputMessage{
InputMessage: &responses.ItemInputMessage{
Role: role,
Content: content,
},
},
}, nil
}
func buildArkResponsesResult(resp *responses.ResponseObject) *ArkResponsesResult {
if resp == nil {
return &ArkResponsesResult{}
}
result := &ArkResponsesResult{
Text: extractArkResponsesText(resp),
Status: strings.TrimSpace(resp.GetStatus().String()),
}
if details := resp.GetIncompleteDetails(); details != nil {
result.IncompleteReason = strings.TrimSpace(details.GetReason())
}
if responseErr := resp.GetError(); responseErr != nil {
result.ErrorCode = strings.TrimSpace(responseErr.GetCode())
result.ErrorMessage = strings.TrimSpace(responseErr.GetMessage())
}
if usage := resp.GetUsage(); usage != nil {
result.Usage = &ArkResponsesUsage{
InputTokens: usage.GetInputTokens(),
OutputTokens: usage.GetOutputTokens(),
TotalTokens: usage.GetTotalTokens(),
}
}
return result
}
func extractArkResponsesText(resp *responses.ResponseObject) string {
if resp == nil {
return ""
}
textParts := make([]string, 0, 2)
for _, outputItem := range resp.GetOutput() {
outputMessage := outputItem.GetOutputMessage()
if outputMessage == nil {
continue
}
for _, contentItem := range outputMessage.GetContent() {
text := strings.TrimSpace(contentItem.GetText().GetText())
if text == "" {
continue
}
textParts = append(textParts, text)
}
}
return strings.TrimSpace(strings.Join(textParts, "\n"))
}
func parseMessageRole(raw string) (responses.MessageRole_Enum, bool) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "user":
return responses.MessageRole_user, true
case "system":
return responses.MessageRole_system, true
case "developer":
return responses.MessageRole_developer, true
case "assistant":
return responses.MessageRole_assistant, true
default:
return responses.MessageRole_unspecified, false
}
}
func parseImageDetail(raw string) (responses.ContentItemImageDetail_Enum, bool) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "high":
return responses.ContentItemImageDetail_high, true
case "low":
return responses.ContentItemImageDetail_low, true
case "auto":
return responses.ContentItemImageDetail_auto, true
default:
return responses.ContentItemImageDetail_auto, false
}
}
func parseTextType(raw string) (responses.TextType_Enum, bool) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "":
return responses.TextType_unspecified, false
case "text":
return responses.TextType_text, true
case "json_object":
return responses.TextType_json_object, true
default:
return responses.TextType_unspecified, false
}
}
func stringPtr(value string) *string {
return &value
}
func float64Ptr(value float64) *float64 {
return &value
}
func int64Ptr(value int64) *int64 {
return &value
}

View File

@@ -1,217 +0,0 @@
package llm
import (
"context"
"errors"
"fmt"
"strings"
"github.com/cloudwego/eino/schema"
)
// ThinkingMode 描述本次模型调用对 thinking 的期望。
//
// 职责边界:
// 1. 这里只表达“调用方希望怎样配置推理模式”;
// 2. 不直接绑定某个具体模型厂商的参数枚举;
// 3. 真正如何把它翻译成 ark / OpenAI / 其他 provider 的 option由后续适配层负责。
type ThinkingMode string
const (
ThinkingModeDefault ThinkingMode = "default"
ThinkingModeEnabled ThinkingMode = "enabled"
ThinkingModeDisabled ThinkingMode = "disabled"
)
// GenerateOptions 是统一模型调用选项。
//
// 设计目的:
// 1. 先把“每个 skill / worker 都会反复传的参数”收敛成一份结构;
// 2. 让上层以后只表达“我要什么”,不再自己重复组织 option
// 3. 暂时不追求覆盖所有 provider 参数,先把最常用的几个公共位抽出来。
type GenerateOptions struct {
Temperature float64
MaxTokens int
Thinking ThinkingMode
Metadata map[string]any
}
// TextResult 是统一文本生成结果。
//
// 职责边界:
// 1. Text 保存模型最终返回的纯文本;
// 2. Usage 保存本次调用的 token 使用量,供后续统一统计;
// 3. 不负责 JSON 解析,不负责业务字段映射。
type TextResult struct {
Text string
Usage *schema.TokenUsage
// FinishReason 透传 provider 的停止原因,便于上层判断是否因 length 等原因被截断。
FinishReason string
}
// StreamReader 抽象了“可逐块 Recv 的流式返回器”。
//
// 之所以不直接依赖某个具体 SDK 的 reader 类型,是因为现在还处在骨架收敛阶段,
// 后续接 ark、OpenAI 兼容层还是别的 provider都可以往这个最小接口上适配。
type StreamReader interface {
Recv() (*schema.Message, error)
Close() error
}
// TextGenerateFunc 是文本生成的统一适配函数签名。
type TextGenerateFunc func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error)
// StreamGenerateFunc 是流式生成的统一适配函数签名。
type StreamGenerateFunc func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error)
// Client 是统一模型客户端门面。
//
// 职责边界:
// 1. 负责把调用方的“模型调用意图”收敛到统一入口;
// 2. 负责统一参数校验、空响应防御、GenerateJSON 复用;
// 3. 不负责写 prompt不负责业务 fallback也不直接持有具体厂商 SDK 细节。
type Client struct {
generateText TextGenerateFunc
streamText StreamGenerateFunc
}
// NewClient 创建统一模型客户端。
func NewClient(generateText TextGenerateFunc, streamText StreamGenerateFunc) *Client {
return &Client{
generateText: generateText,
streamText: streamText,
}
}
// GenerateText 执行一次统一文本生成。
//
// 职责边界:
// 1. 负责做最小必要的入参校验;
// 2. 负责统一拦截“模型空响应”这类公共问题;
// 3. 不负责业务 prompt 拼接,也不负责把文本再映射成业务结构。
func (c *Client) GenerateText(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error) {
if c == nil || c.generateText == nil {
return nil, errors.New("llm client is not ready")
}
if len(messages) == 0 {
return nil, errors.New("llm messages is empty")
}
result, err := c.generateText(ctx, messages, options)
if err != nil {
return nil, err
}
if result == nil {
return nil, errors.New("llm result is nil")
}
if strings.TrimSpace(result.Text) == "" {
return nil, errors.New("llm returned empty text")
}
return result, nil
}
// GenerateJSON 先走统一文本生成,再走统一 JSON 解析。
//
// 设计说明:
// 1. 把“Generate -> 提取 JSON -> 反序列化”这段公共链路收敛起来;
// 2. 上层只关心业务结构,不需要重复实现解析样板;
// 3. 返回 parsed + rawResult方便打点与回退时保留原文。
func GenerateJSON[T any](ctx context.Context, client *Client, messages []*schema.Message, options GenerateOptions) (*T, *TextResult, error) {
result, err := client.GenerateText(ctx, messages, options)
if err != nil {
return nil, nil, err
}
parsed, err := ParseJSONObject[T](result.Text)
if err != nil {
return nil, result, err
}
return parsed, result, nil
}
// Stream 打开统一流式调用入口。
//
// 职责边界:
// 1. 只负责把“流式生成能力”暴露给上层;
// 2. 不负责 chunk 到 OpenAI 协议的转换,那部分应放在 stream/
// 3. 不负责累计全文,也不负责 token 统计落库。
func (c *Client) Stream(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error) {
if c == nil || c.streamText == nil {
return nil, errors.New("llm stream client is not ready")
}
if len(messages) == 0 {
return nil, errors.New("llm messages is empty")
}
return c.streamText(ctx, messages, options)
}
// BuildSystemUserMessages 构造最常见的“system + history + user”消息列表。
//
// 设计说明:
// 1. 先把最稳定的消息编排方式沉淀下来,减少各业务域样板代码;
// 2. 只做消息切片装配,不做 prompt 生成;
// 3. 供 agent / memory 等多个能力域复用。
func BuildSystemUserMessages(systemPrompt string, history []*schema.Message, userPrompt string) []*schema.Message {
messages := make([]*schema.Message, 0, len(history)+2)
if strings.TrimSpace(systemPrompt) != "" {
messages = append(messages, schema.SystemMessage(systemPrompt))
}
if len(history) > 0 {
messages = append(messages, history...)
}
if strings.TrimSpace(userPrompt) != "" {
messages = append(messages, schema.UserMessage(userPrompt))
}
return messages
}
// CloneUsage 深拷贝 token usage避免后续多处累加时共享同一指针。
func CloneUsage(usage *schema.TokenUsage) *schema.TokenUsage {
if usage == nil {
return nil
}
copied := *usage
return &copied
}
// MergeUsage 合并两段 usage。
//
// 合并策略:
// 1. 对“同一次调用不同流分片”的场景,取更大值作为最终值;
// 2. 对“多次独立调用累计”的场景,应由上层显式做加法,而不是用这个函数;
// 3. 该函数只适用于“同一次调用的分块 usage 收敛”。
func MergeUsage(base *schema.TokenUsage, incoming *schema.TokenUsage) *schema.TokenUsage {
if incoming == nil {
return CloneUsage(base)
}
if base == nil {
return CloneUsage(incoming)
}
merged := *base
if incoming.PromptTokens > merged.PromptTokens {
merged.PromptTokens = incoming.PromptTokens
}
if incoming.CompletionTokens > merged.CompletionTokens {
merged.CompletionTokens = incoming.CompletionTokens
}
if incoming.TotalTokens > merged.TotalTokens {
merged.TotalTokens = incoming.TotalTokens
}
if incoming.PromptTokenDetails.CachedTokens > merged.PromptTokenDetails.CachedTokens {
merged.PromptTokenDetails.CachedTokens = incoming.PromptTokenDetails.CachedTokens
}
if incoming.CompletionTokensDetails.ReasoningTokens > merged.CompletionTokensDetails.ReasoningTokens {
merged.CompletionTokensDetails.ReasoningTokens = incoming.CompletionTokensDetails.ReasoningTokens
}
return &merged
}
// FormatEmptyResponseError 统一生成“模型返回空结果”的错误文案。
func FormatEmptyResponseError(scene string) error {
scene = strings.TrimSpace(scene)
if scene == "" {
scene = "unknown"
}
return fmt.Errorf("模型在 %s 场景返回空结果", scene)
}

View File

@@ -1,112 +0,0 @@
package llm
import (
"encoding/json"
"errors"
"fmt"
"strings"
)
// ParseJSONObject 解析模型返回中的 JSON 对象。
//
// 职责边界:
// 1. 负责处理“模型输出前后夹杂解释文字 / markdown 代码块”的常见情况;
// 2. 负责提取最外层 JSON object 并反序列化为目标结构;
// 3. 不负责业务字段合法性校验,应由上层调用方自行校验。
func ParseJSONObject[T any](raw string) (*T, error) {
clean := strings.TrimSpace(raw)
if clean == "" {
return nil, errors.New("模型返回为空,无法解析 JSON")
}
objectText := ExtractJSONObject(clean)
if objectText == "" {
return nil, fmt.Errorf("模型返回中未找到 JSON 对象: %s", truncateForError(clean))
}
var out T
if err := json.Unmarshal([]byte(objectText), &out); err != nil {
return nil, fmt.Errorf("JSON 解析失败: %w", err)
}
return &out, nil
}
// ExtractJSONObject 从混合文本里提取第一个完整 JSON 对象。
//
// 设计说明:
// 1. LLM 很容易输出“这里是结果:{...}”这种半结构化文本;
// 2. 这里用括号计数而不是正则,避免嵌套对象一多就误截断;
// 3. 目前只提取 object不提取 array因为当前契约基本都是对象。
func ExtractJSONObject(text string) string {
clean := trimMarkdownCodeFence(strings.TrimSpace(text))
if clean == "" {
return ""
}
start := strings.Index(clean, "{")
if start < 0 {
return ""
}
depth := 0
inString := false
escaped := false
for idx := start; idx < len(clean); idx++ {
ch := clean[idx]
if escaped {
escaped = false
continue
}
if ch == '\\' && inString {
escaped = true
continue
}
if ch == '"' {
inString = !inString
continue
}
if inString {
continue
}
switch ch {
case '{':
depth++
case '}':
depth--
if depth == 0 {
return clean[start : idx+1]
}
}
}
return ""
}
func trimMarkdownCodeFence(text string) string {
trimmed := strings.TrimSpace(text)
if !strings.HasPrefix(trimmed, "```") {
return trimmed
}
lines := strings.Split(trimmed, "\n")
if len(lines) == 0 {
return trimmed
}
// 1. 去掉首行 ```json / ```
// 2. 若末行是 ```,一并去掉;
// 3. 中间正文保持原样,避免破坏 JSON 的换行结构。
body := lines[1:]
if len(body) > 0 && strings.TrimSpace(body[len(body)-1]) == "```" {
body = body[:len(body)-1]
}
return strings.TrimSpace(strings.Join(body, "\n"))
}
func truncateForError(text string) string {
if len(text) <= 160 {
return text
}
return text[:160] + "..."
}

View File

@@ -1,85 +0,0 @@
package chunk
import (
"context"
"fmt"
"strings"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
)
// TextChunker 是默认文本切块器。
type TextChunker struct{}
func NewTextChunker() *TextChunker {
return &TextChunker{}
}
// Chunk 对文本执行固定窗口切块。
//
// 步骤化说明:
// 1. 先做空白归一,避免无效块进入向量库;
// 2. 再按 chunk_size/overlap 滑窗切割;
// 3. 每块继承原文 metadata并补充 chunk 序号。
func (c *TextChunker) Chunk(_ context.Context, doc core.SourceDocument, opt core.ChunkOption) ([]core.Chunk, error) {
if strings.TrimSpace(doc.ID) == "" {
return nil, fmt.Errorf("empty document id")
}
text := strings.TrimSpace(doc.Text)
if text == "" {
return nil, nil
}
if opt.ChunkSize <= 0 {
opt.ChunkSize = 400
}
if opt.ChunkOverlap < 0 {
opt.ChunkOverlap = 0
}
if opt.ChunkOverlap >= opt.ChunkSize {
opt.ChunkOverlap = opt.ChunkSize / 5
}
runes := []rune(text)
step := opt.ChunkSize - opt.ChunkOverlap
if step <= 0 {
step = opt.ChunkSize
}
result := make([]core.Chunk, 0, len(runes)/step+1)
order := 0
for start := 0; start < len(runes); start += step {
end := start + opt.ChunkSize
if end > len(runes) {
end = len(runes)
}
chunkText := strings.TrimSpace(string(runes[start:end]))
if chunkText == "" {
continue
}
metadata := cloneMap(doc.Metadata)
metadata["chunk_order"] = order
result = append(result, core.Chunk{
ID: fmt.Sprintf("%s#%d", doc.ID, order),
DocumentID: doc.ID,
Text: chunkText,
Order: order,
Metadata: metadata,
})
order++
if end == len(runes) {
break
}
}
return result, nil
}
func cloneMap(src map[string]any) map[string]any {
if len(src) == 0 {
return map[string]any{}
}
dst := make(map[string]any, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}

View File

@@ -1,113 +0,0 @@
package config
import "github.com/spf13/viper"
// Config 是 RAG Core 运行配置。
type Config struct {
Enabled bool
Store string
TopK int
Threshold float64
EmbedProvider string
EmbedModel string
EmbedBaseURL string
EmbedTimeoutMS int
EmbedDimension int
RerankerEnabled bool
RerankerProvider string
RerankerTimeoutMS int
ChunkSize int
ChunkOverlap int
RetrieveTimeoutMS int
MilvusAddress string
MilvusToken string
MilvusDBName string
MilvusCollectionName string
MilvusMetricType string
MilvusRequestTimeoutMS int
}
// LoadFromViper 读取 rag 配置并补默认值。
func LoadFromViper() Config {
cfg := Config{
Enabled: viper.GetBool("rag.enabled"),
Store: viper.GetString("rag.store"),
TopK: viper.GetInt("rag.topK"),
Threshold: viper.GetFloat64("rag.threshold"),
EmbedProvider: viper.GetString("rag.embed.provider"),
EmbedModel: viper.GetString("rag.embed.model"),
EmbedBaseURL: viper.GetString("rag.embed.baseURL"),
EmbedTimeoutMS: viper.GetInt("rag.embed.timeoutMs"),
EmbedDimension: viper.GetInt("rag.embed.dimension"),
RerankerEnabled: viper.GetBool("rag.reranker.enabled"),
RerankerProvider: viper.GetString("rag.reranker.provider"),
RerankerTimeoutMS: viper.GetInt("rag.reranker.timeoutMs"),
ChunkSize: viper.GetInt("rag.ingest.chunkSize"),
ChunkOverlap: viper.GetInt("rag.ingest.chunkOverlap"),
RetrieveTimeoutMS: viper.GetInt("rag.retrieve.timeoutMs"),
MilvusAddress: viper.GetString("rag.milvus.address"),
MilvusToken: viper.GetString("rag.milvus.token"),
MilvusDBName: viper.GetString("rag.milvus.dbName"),
MilvusCollectionName: viper.GetString("rag.milvus.collectionName"),
MilvusMetricType: viper.GetString("rag.milvus.metricType"),
MilvusRequestTimeoutMS: viper.GetInt("rag.milvus.requestTimeoutMs"),
}
if cfg.Store == "" {
cfg.Store = "inmemory"
}
if cfg.TopK <= 0 {
cfg.TopK = 8
}
if cfg.Threshold < 0 {
cfg.Threshold = 0
}
if cfg.EmbedProvider == "" {
cfg.EmbedProvider = "mock"
}
if cfg.EmbedBaseURL == "" {
cfg.EmbedBaseURL = viper.GetString("agent.baseURL")
}
if cfg.EmbedTimeoutMS <= 0 {
cfg.EmbedTimeoutMS = 1200
}
if cfg.EmbedDimension <= 0 {
cfg.EmbedDimension = 1024
}
if cfg.RerankerProvider == "" {
cfg.RerankerProvider = "noop"
}
if cfg.RerankerTimeoutMS <= 0 {
cfg.RerankerTimeoutMS = 1200
}
if cfg.ChunkSize <= 0 {
cfg.ChunkSize = 400
}
if cfg.ChunkOverlap < 0 {
cfg.ChunkOverlap = 80
}
if cfg.RetrieveTimeoutMS <= 0 {
cfg.RetrieveTimeoutMS = 1500
}
if cfg.MilvusAddress == "" {
cfg.MilvusAddress = "http://localhost:19530"
}
if cfg.MilvusToken == "" {
cfg.MilvusToken = "root:Milvus"
}
if cfg.MilvusCollectionName == "" {
cfg.MilvusCollectionName = "smartflow_rag_chunks"
}
if cfg.MilvusMetricType == "" {
cfg.MilvusMetricType = "COSINE"
}
if cfg.MilvusRequestTimeoutMS <= 0 {
cfg.MilvusRequestTimeoutMS = 1500
}
return cfg
}

View File

@@ -1,17 +0,0 @@
package core
import "errors"
var (
// ErrInvalidQuery 表示检索请求缺少有效 query。
ErrInvalidQuery = errors.New("invalid query")
// ErrInvalidTopK 表示 topK 非法。
ErrInvalidTopK = errors.New("invalid top_k")
// ErrNilDependency 表示 pipeline 关键依赖未注入。
ErrNilDependency = errors.New("nil dependency")
)
const (
// FallbackReasonRerankFailed 表示 rerank 失败后降级。
FallbackReasonRerankFailed = "RERANK_FAILED"
)

View File

@@ -1,38 +0,0 @@
package core
import "context"
// Chunker 负责文本切块。
type Chunker interface {
Chunk(ctx context.Context, doc SourceDocument, opt ChunkOption) ([]Chunk, error)
}
// Embedder 负责向量化。
type Embedder interface {
Embed(ctx context.Context, texts []string, action string) ([][]float32, error)
}
// Retriever 负责召回候选。
type Retriever interface {
Retrieve(ctx context.Context, req RetrieveRequest) ([]ScoredChunk, error)
}
// Reranker 负责重排候选。
type Reranker interface {
Rerank(ctx context.Context, query string, candidates []ScoredChunk, topK int) ([]ScoredChunk, error)
}
// VectorStore 负责向量库读写。
type VectorStore interface {
Upsert(ctx context.Context, rows []VectorRow) error
Search(ctx context.Context, req VectorSearchRequest) ([]ScoredVectorRow, error)
Delete(ctx context.Context, ids []string) error
Get(ctx context.Context, ids []string) ([]VectorRow, error)
}
// CorpusAdapter 负责把业务语料映射成统一文档/过滤条件。
type CorpusAdapter interface {
Name() string
BuildIngestDocuments(ctx context.Context, input any) ([]SourceDocument, error)
BuildRetrieveFilter(ctx context.Context, req any) (map[string]any, error)
}

View File

@@ -1,190 +0,0 @@
package core
import (
"context"
"errors"
"fmt"
"log"
"sort"
"strings"
)
// ObserveLevel 表示观测事件等级。
type ObserveLevel string
const (
ObserveLevelInfo ObserveLevel = "info"
ObserveLevelWarn ObserveLevel = "warn"
ObserveLevelError ObserveLevel = "error"
)
// ObserveEvent 描述一次统一观测事件。
//
// 职责边界:
// 1. 只承载 RAG Infra 的结构化运行信息;
// 2. 不绑定具体日志系统、指标系统或 tracing 实现;
// 3. 字段内容应尽量稳定,便于后续统一接入全局观测平台。
type ObserveEvent struct {
Level ObserveLevel
Component string
Operation string
Fields map[string]any
}
// Observer 是 RAG Infra 的最小观测接口。
//
// 职责边界:
// 1. 负责消费结构化事件;
// 2. 不负责决定业务逻辑是否继续执行;
// 3. 任一实现都不应反向影响主链路稳定性。
type Observer interface {
Observe(ctx context.Context, event ObserveEvent)
}
// ObserverFunc 允许用函数快速适配 Observer。
type ObserverFunc func(ctx context.Context, event ObserveEvent)
func (f ObserverFunc) Observe(ctx context.Context, event ObserveEvent) {
if f == nil {
return
}
f(ctx, event)
}
// NewNopObserver 返回空实现,适合在未接入统一观测平台时兜底。
func NewNopObserver() Observer {
return ObserverFunc(func(context.Context, ObserveEvent) {})
}
// NewLoggerObserver 返回标准日志适配器。
//
// 说明:
// 1. 当前项目尚未建立统一日志平台时,先把结构化字段稳定打印出来;
// 2. 后续若项目引入统一 logger/metrics/tracing只需替换该 Observer 注入实现;
// 3. 该适配器默认保持单行输出,减少和现有日志风格的割裂感。
func NewLoggerObserver(logger *log.Logger) Observer {
if logger == nil {
logger = log.Default()
}
return &loggerObserver{logger: logger}
}
type loggerObserver struct {
logger *log.Logger
}
func (o *loggerObserver) Observe(ctx context.Context, event ObserveEvent) {
if o == nil || o.logger == nil {
return
}
level := strings.TrimSpace(string(event.Level))
if level == "" {
level = string(ObserveLevelInfo)
}
component := strings.TrimSpace(event.Component)
if component == "" {
component = "unknown"
}
operation := strings.TrimSpace(event.Operation)
if operation == "" {
operation = "unknown"
}
fields := ObserveFieldsFromContext(ctx)
for key, value := range event.Fields {
key = strings.TrimSpace(key)
if key == "" || !shouldKeepObserveField(value) {
continue
}
fields[key] = value
}
parts := []string{
"rag",
fmt.Sprintf("level=%s", level),
fmt.Sprintf("component=%s", component),
fmt.Sprintf("operation=%s", operation),
}
keys := make([]string, 0, len(fields))
for key := range fields {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
parts = append(parts, fmt.Sprintf("%s=%v", key, fields[key]))
}
o.logger.Print(strings.Join(parts, " "))
}
type observeFieldsContextKey struct{}
// WithObserveFields 把通用观测字段挂入上下文,便于下游组件复用。
//
// 步骤化说明:
// 1. 先读取已有上下文字段,保证 Runtime / Pipeline / Store 能逐层补充信息;
// 2. 后写字段覆盖同名旧值,确保下游拿到的是最新语义;
// 3. 仅保存“有意义”的字段,避免日志长期堆积大量空值。
func WithObserveFields(ctx context.Context, fields map[string]any) context.Context {
if len(fields) == 0 {
return ctx
}
if ctx == nil {
ctx = context.Background()
}
merged := ObserveFieldsFromContext(ctx)
for key, value := range fields {
key = strings.TrimSpace(key)
if key == "" || !shouldKeepObserveField(value) {
continue
}
merged[key] = value
}
if len(merged) == 0 {
return ctx
}
return context.WithValue(ctx, observeFieldsContextKey{}, merged)
}
// ObserveFieldsFromContext 提取上下文中已经累积的观测字段。
func ObserveFieldsFromContext(ctx context.Context) map[string]any {
if ctx == nil {
return map[string]any{}
}
raw, ok := ctx.Value(observeFieldsContextKey{}).(map[string]any)
if !ok || len(raw) == 0 {
return map[string]any{}
}
result := make(map[string]any, len(raw))
for key, value := range raw {
result[key] = value
}
return result
}
// ClassifyErrorCode 统一把常见错误压缩为稳定错误码,便于后续接入全局观测平台。
func ClassifyErrorCode(err error) string {
switch {
case err == nil:
return ""
case errors.Is(err, context.DeadlineExceeded):
return "DEADLINE_EXCEEDED"
case errors.Is(err, context.Canceled):
return "CANCELED"
default:
return "RAG_ERROR"
}
}
func shouldKeepObserveField(value any) bool {
if value == nil {
return false
}
if text, ok := value.(string); ok {
return strings.TrimSpace(text) != ""
}
return true
}

View File

@@ -1,366 +0,0 @@
package core
import (
"context"
"errors"
"fmt"
"log"
"runtime/debug"
"strings"
"time"
)
const (
defaultTopK = 8
defaultThreshold = 0
defaultChunkSize = 400
defaultChunkOvLap = 80
)
// Pipeline 是 RAG Core 编排器。
//
// 职责边界:
// 1. 负责统一 chunk/embed/retrieve/rerank 流程;
// 2. 负责失败降级语义;
// 3. 不承载任何具体业务语义(由 CorpusAdapter 提供)。
type Pipeline struct {
chunker Chunker
embedder Embedder
store VectorStore
reranker Reranker
logger *log.Logger
observer Observer
}
func NewPipeline(chunker Chunker, embedder Embedder, store VectorStore, reranker Reranker) *Pipeline {
return &Pipeline{
chunker: chunker,
embedder: embedder,
store: store,
reranker: reranker,
logger: log.Default(),
observer: NewNopObserver(),
}
}
// SetLogger 设置 Pipeline 使用的日志器。
func (p *Pipeline) SetLogger(logger *log.Logger) {
if p == nil || logger == nil {
return
}
p.logger = logger
}
// SetObserver 设置 Pipeline 使用的统一观测器。
func (p *Pipeline) SetObserver(observer Observer) {
if p == nil || observer == nil {
return
}
p.observer = observer
}
// Ingest 执行统一入库流程。
//
// 步骤化说明:
// 1. 先由 CorpusAdapter 生成统一文档,确保不同语料入口一致;
// 2. 再统一切块与向量化,避免业务侧重复实现;
// 3. 最后一次性 Upsert失败直接返回交由上层决定是否重试。
func (p *Pipeline) Ingest(
ctx context.Context,
corpus CorpusAdapter,
input any,
opt IngestOption,
) (result *IngestResult, err error) {
defer p.recoverExecutionPanic(ctx, "ingest", &err)
if p == nil || p.chunker == nil || p.embedder == nil || p.store == nil {
return nil, ErrNilDependency
}
if corpus == nil {
return nil, errors.New("nil corpus adapter")
}
docs, err := corpus.BuildIngestDocuments(ctx, input)
if err != nil {
return nil, err
}
return p.IngestDocuments(ctx, corpus.Name(), docs, opt)
}
// IngestDocuments 执行“已标准化文档”的统一入库流程。
//
// 职责边界:
// 1. 负责处理已经完成 CorpusAdapter 映射的标准文档;
// 2. 负责统一切块、向量化与 Upsert
// 3. 不负责再做业务输入解析,避免 Runtime 为拿到 document_id 重复 build 文档。
func (p *Pipeline) IngestDocuments(
ctx context.Context,
corpusName string,
docs []SourceDocument,
opt IngestOption,
) (result *IngestResult, err error) {
defer p.recoverExecutionPanic(ctx, "ingest_documents", &err)
if p == nil || p.chunker == nil || p.embedder == nil || p.store == nil {
return nil, ErrNilDependency
}
if len(docs) == 0 {
return &IngestResult{DocumentCount: 0, ChunkCount: 0}, nil
}
chunkOpt := normalizeChunkOption(opt.Chunk)
chunks := make([]Chunk, 0, len(docs)*2)
for _, doc := range docs {
// 1. 对每个文档独立切块,失败直接中断,避免写入半成品。
docChunks, chunkErr := p.chunker.Chunk(ctx, doc, chunkOpt)
if chunkErr != nil {
return nil, chunkErr
}
chunks = append(chunks, docChunks...)
}
if len(chunks) == 0 {
return &IngestResult{DocumentCount: len(docs), ChunkCount: 0}, nil
}
texts := make([]string, 0, len(chunks))
for _, chunk := range chunks {
texts = append(texts, chunk.Text)
}
action := strings.TrimSpace(opt.Action)
if action == "" {
action = "add"
}
vectors, err := p.embedder.Embed(ctx, texts, action)
if err != nil {
return nil, err
}
if len(vectors) != len(chunks) {
return nil, fmt.Errorf("embedding result length mismatch: chunks=%d vectors=%d", len(chunks), len(vectors))
}
rows := make([]VectorRow, 0, len(chunks))
now := time.Now()
for i, chunk := range chunks {
metadata := cloneMap(chunk.Metadata)
metadata["corpus"] = corpusName
metadata["document_id"] = chunk.DocumentID
metadata["chunk_order"] = chunk.Order
rows = append(rows, VectorRow{
ID: chunk.ID,
Vector: vectors[i],
Text: chunk.Text,
Metadata: metadata,
CreatedAt: now,
UpdatedAt: now,
})
}
if err = p.store.Upsert(ctx, rows); err != nil {
return nil, err
}
return &IngestResult{
DocumentCount: len(docs),
ChunkCount: len(chunks),
}, nil
}
// Retrieve 执行统一检索流程。
//
// 步骤化说明:
// 1. 先做 query 向量化与向量检索;
// 2. 再执行阈值过滤,减少低质量候选;
// 3. 最后可选 rerank若失败则降级回原排序并打日志。
func (p *Pipeline) Retrieve(
ctx context.Context,
corpus CorpusAdapter,
req RetrieveRequest,
) (result *RetrieveResult, err error) {
defer p.recoverExecutionPanic(ctx, "retrieve", &err)
if p == nil || p.embedder == nil || p.store == nil {
return nil, ErrNilDependency
}
query := strings.TrimSpace(req.Query)
if query == "" {
return nil, ErrInvalidQuery
}
topK := req.TopK
if topK <= 0 {
topK = defaultTopK
}
threshold := req.Threshold
if threshold < 0 {
threshold = defaultThreshold
}
filter := cloneMap(req.Filter)
if corpus != nil {
// 1. 先拼接 corpus 过滤条件,避免跨语料串召回。
corpusFilter, err := corpus.BuildRetrieveFilter(ctx, req.CorpusInput)
if err != nil {
return nil, err
}
filter = mergeMap(filter, corpusFilter)
filter["corpus"] = corpus.Name()
}
action := strings.TrimSpace(req.Action)
if action == "" {
action = "search"
}
vectors, err := p.embedder.Embed(ctx, []string{query}, action)
if err != nil {
return nil, err
}
if len(vectors) != 1 {
return nil, fmt.Errorf("embedding query length mismatch: %d", len(vectors))
}
scoredRows, err := p.store.Search(ctx, VectorSearchRequest{
QueryVector: vectors[0],
TopK: topK,
Filter: filter,
})
if err != nil {
return nil, err
}
rawCount := len(scoredRows)
candidates := make([]ScoredChunk, 0, len(scoredRows))
for _, row := range scoredRows {
if row.Score < threshold {
continue
}
candidates = append(candidates, ScoredChunk{
ChunkID: row.Row.ID,
DocumentID: asString(row.Row.Metadata["document_id"]),
Text: row.Row.Text,
Score: row.Score,
Metadata: cloneMap(row.Row.Metadata),
})
}
result = &RetrieveResult{
Items: candidates,
RawCount: rawCount,
FallbackUsed: false,
}
if len(candidates) == 0 || p.reranker == nil {
return result, nil
}
reranked, rerankErr := p.reranker.Rerank(ctx, query, candidates, topK)
if rerankErr != nil {
// 2. rerank 异常不终止主流程,统一降级为原排序。
result.FallbackUsed = true
result.FallbackReason = FallbackReasonRerankFailed
if p.observer != nil {
p.observer.Observe(ctx, ObserveEvent{
Level: ObserveLevelWarn,
Component: "pipeline",
Operation: "rerank_fallback",
Fields: map[string]any{
"status": "fallback",
"fallback_reason": FallbackReasonRerankFailed,
"candidate_count": len(candidates),
"top_k": topK,
"error": rerankErr,
"error_code": ClassifyErrorCode(rerankErr),
},
})
} else if p.logger != nil {
p.logger.Printf("rag rerank fallback: reason=%s err=%v", FallbackReasonRerankFailed, rerankErr)
}
return result, nil
}
result.Items = reranked
return result, nil
}
// Delete 删除指定 ID 的向量。
func (p *Pipeline) Delete(ctx context.Context, ids []string) error {
if p == nil || p.store == nil {
return nil
}
return p.store.Delete(ctx, ids)
}
func (p *Pipeline) recoverExecutionPanic(ctx context.Context, operation string, errPtr *error) {
recovered := recover()
if recovered == nil || errPtr == nil {
return
}
panicErr := fmt.Errorf("rag pipeline panic recovered: operation=%s panic=%v", operation, recovered)
*errPtr = panicErr
// 1. Pipeline 是 chunk/embed/store/rerank 的统一编排边界,第三方依赖异常不应直接杀掉上层请求。
// 2. 这里统一 recover 后继续走 error 语义,让 runtime/service 决定降级、回退或记日志。
// 3. stack 只写观测层,不塞进返回值,避免把超长堆栈直接暴露给上层业务错误文案。
if p != nil && p.observer != nil {
p.observer.Observe(ctx, ObserveEvent{
Level: ObserveLevelError,
Component: "pipeline",
Operation: operation + "_panic_recovered",
Fields: map[string]any{
"status": "failed",
"panic": fmt.Sprintf("%v", recovered),
"panic_type": fmt.Sprintf("%T", recovered),
"error": panicErr,
"error_code": ClassifyErrorCode(panicErr),
"stack": string(debug.Stack()),
},
})
return
}
if p != nil && p.logger != nil {
p.logger.Printf("rag pipeline panic recovered: operation=%s panic=%v stack=%s", operation, recovered, string(debug.Stack()))
}
}
func normalizeChunkOption(opt ChunkOption) ChunkOption {
if opt.ChunkSize <= 0 {
opt.ChunkSize = defaultChunkSize
}
if opt.ChunkOverlap < 0 {
opt.ChunkOverlap = 0
}
if opt.ChunkOverlap >= opt.ChunkSize {
opt.ChunkOverlap = defaultChunkOvLap
if opt.ChunkOverlap >= opt.ChunkSize {
opt.ChunkOverlap = opt.ChunkSize / 5
}
}
return opt
}
func cloneMap(src map[string]any) map[string]any {
if len(src) == 0 {
return map[string]any{}
}
dst := make(map[string]any, len(src))
for key, value := range src {
dst[key] = value
}
return dst
}
func mergeMap(base map[string]any, ext map[string]any) map[string]any {
if base == nil {
base = map[string]any{}
}
for key, value := range ext {
base[key] = value
}
return base
}
func asString(v any) string {
if v == nil {
return ""
}
return fmt.Sprintf("%v", v)
}

View File

@@ -1,94 +0,0 @@
package core
import "time"
// SourceDocument 是统一语料文档模型。
//
// 职责边界:
// 1. 只描述“可被切块与索引”的原始文档;
// 2. 不承载业务流程状态。
type SourceDocument struct {
ID string
Text string
Title string
Metadata map[string]any
CreatedAt time.Time
}
// Chunk 是标准切块结果。
type Chunk struct {
ID string
DocumentID string
Text string
Order int
Metadata map[string]any
}
// ChunkOption 控制切块参数。
type ChunkOption struct {
ChunkSize int
ChunkOverlap int
}
// IngestOption 控制入库参数。
type IngestOption struct {
Chunk ChunkOption
// Action 用于 embedding 分型add/update/search
Action string
}
// IngestResult 描述一次入库执行摘要。
type IngestResult struct {
DocumentCount int
ChunkCount int
}
// RetrieveRequest 是统一检索请求。
type RetrieveRequest struct {
Query string
TopK int
Threshold float64
Action string
Filter map[string]any
CorpusInput any
}
// ScoredChunk 是统一召回结果。
type ScoredChunk struct {
ChunkID string
DocumentID string
Text string
Score float64
Metadata map[string]any
}
// RetrieveResult 是检索链路执行摘要。
type RetrieveResult struct {
Items []ScoredChunk
RawCount int
FallbackUsed bool
FallbackReason string
}
// VectorRow 是向量存储标准行。
type VectorRow struct {
ID string
Vector []float32
Text string
Metadata map[string]any
CreatedAt time.Time
UpdatedAt time.Time
}
// VectorSearchRequest 是向量检索请求。
type VectorSearchRequest struct {
QueryVector []float32
TopK int
Filter map[string]any
}
// ScoredVectorRow 是向量检索结果。
type ScoredVectorRow struct {
Row VectorRow
Score float64
}

View File

@@ -1,13 +0,0 @@
package corpus
import (
"crypto/sha256"
"encoding/hex"
"strings"
)
func hashLikeText(text string) string {
normalized := strings.TrimSpace(strings.ToLower(text))
sum := sha256.Sum256([]byte(normalized))
return hex.EncodeToString(sum[:8])
}

View File

@@ -1,158 +0,0 @@
package corpus
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
)
const memoryCorpusName = "memory"
// MemoryIngestItem 是记忆语料入库项。
type MemoryIngestItem struct {
MemoryID int64
UserID int
ConversationID string
AssistantID string
RunID string
MemoryType string
Title string
Content string
Confidence float64
Importance float64
SensitivityLevel int
IsExplicit bool
Status string
TTLAt *time.Time
CreatedAt *time.Time
}
// MemoryRetrieveInput 是记忆检索过滤输入。
type MemoryRetrieveInput struct {
UserID int
ConversationID string
AssistantID string
RunID string
MemoryType string
}
// MemoryCorpus 是记忆语料适配器。
type MemoryCorpus struct{}
func NewMemoryCorpus() *MemoryCorpus {
return &MemoryCorpus{}
}
func (c *MemoryCorpus) Name() string {
return memoryCorpusName
}
func (c *MemoryCorpus) BuildIngestDocuments(_ context.Context, input any) ([]core.SourceDocument, error) {
items, err := toMemoryItems(input)
if err != nil {
return nil, err
}
result := make([]core.SourceDocument, 0, len(items))
for _, item := range items {
if item.UserID <= 0 {
return nil, errors.New("memory ingest item user_id is invalid")
}
text := strings.TrimSpace(item.Content)
if text == "" {
continue
}
docID := fmt.Sprintf("memory:%d", item.MemoryID)
if item.MemoryID <= 0 {
docID = fmt.Sprintf("memory:uid:%d:%s", item.UserID, hashLikeText(text))
}
metadata := map[string]any{
"user_id": item.UserID,
"conversation_id": strings.TrimSpace(item.ConversationID),
"assistant_id": strings.TrimSpace(item.AssistantID),
"run_id": strings.TrimSpace(item.RunID),
"memory_type": strings.TrimSpace(strings.ToLower(item.MemoryType)),
"title": strings.TrimSpace(item.Title),
"confidence": item.Confidence,
"importance": item.Importance,
"sensitivity_level": item.SensitivityLevel,
"is_explicit": item.IsExplicit,
"status": strings.TrimSpace(item.Status),
}
if item.TTLAt != nil {
metadata["ttl_at"] = item.TTLAt.Format(time.RFC3339)
}
createdAt := time.Now()
if item.CreatedAt != nil {
createdAt = *item.CreatedAt
}
result = append(result, core.SourceDocument{
ID: docID,
Text: text,
Title: strings.TrimSpace(item.Title),
Metadata: metadata,
CreatedAt: createdAt,
})
}
return result, nil
}
func (c *MemoryCorpus) BuildRetrieveFilter(_ context.Context, req any) (map[string]any, error) {
input, ok := req.(MemoryRetrieveInput)
if !ok {
if ptr, isPtr := req.(*MemoryRetrieveInput); isPtr && ptr != nil {
input = *ptr
} else if req == nil {
return nil, errors.New("memory retrieve input is nil")
} else {
return nil, errors.New("invalid memory retrieve input")
}
}
if input.UserID <= 0 {
return nil, errors.New("memory retrieve user_id is invalid")
}
filter := map[string]any{
"user_id": input.UserID,
}
if v := strings.TrimSpace(input.ConversationID); v != "" {
filter["conversation_id"] = v
}
if v := strings.TrimSpace(input.AssistantID); v != "" {
filter["assistant_id"] = v
}
if v := strings.TrimSpace(input.RunID); v != "" {
filter["run_id"] = v
}
if v := strings.TrimSpace(strings.ToLower(input.MemoryType)); v != "" {
filter["memory_type"] = v
}
return filter, nil
}
func toMemoryItems(input any) ([]MemoryIngestItem, error) {
switch value := input.(type) {
case MemoryIngestItem:
return []MemoryIngestItem{value}, nil
case *MemoryIngestItem:
if value == nil {
return nil, errors.New("memory ingest item is nil")
}
return []MemoryIngestItem{*value}, nil
case []MemoryIngestItem:
return value, nil
case []*MemoryIngestItem:
items := make([]MemoryIngestItem, 0, len(value))
for _, ptr := range value {
if ptr == nil {
continue
}
items = append(items, *ptr)
}
return items, nil
default:
return nil, errors.New("invalid memory ingest input")
}
}

View File

@@ -1,163 +0,0 @@
package corpus
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
)
const webCorpusName = "web"
// WebIngestItem 是网页语料入库项。
type WebIngestItem struct {
URL string
Title string
Content string
Snippet string
Domain string
QueryID string
SessionID string
PublishedAt *time.Time
FetchedAt *time.Time
SourceRank int
}
// WebRetrieveInput 是网页检索过滤输入。
type WebRetrieveInput struct {
QueryID string
SessionID string
Domain string
}
// WebCorpus 是网页语料适配器。
type WebCorpus struct{}
func NewWebCorpus() *WebCorpus {
return &WebCorpus{}
}
func (c *WebCorpus) Name() string {
return webCorpusName
}
func (c *WebCorpus) BuildIngestDocuments(_ context.Context, input any) ([]core.SourceDocument, error) {
items, err := toWebItems(input)
if err != nil {
return nil, err
}
result := make([]core.SourceDocument, 0, len(items))
for _, item := range items {
url := strings.TrimSpace(item.URL)
if url == "" {
return nil, errors.New("web ingest item url is empty")
}
mainText := buildWebText(item)
if strings.TrimSpace(mainText) == "" {
continue
}
docID := fmt.Sprintf("web:%s", hashLikeText(url+"|"+mainText))
metadata := map[string]any{
"url": url,
"domain": strings.TrimSpace(item.Domain),
"query_id": strings.TrimSpace(item.QueryID),
"session_id": strings.TrimSpace(item.SessionID),
"source_rank": item.SourceRank,
}
if item.PublishedAt != nil {
metadata["published_at"] = item.PublishedAt.Format(time.RFC3339)
}
if item.FetchedAt != nil {
metadata["fetched_at"] = item.FetchedAt.Format(time.RFC3339)
}
createdAt := time.Now()
if item.FetchedAt != nil {
createdAt = *item.FetchedAt
}
result = append(result, core.SourceDocument{
ID: docID,
Text: mainText,
Title: strings.TrimSpace(item.Title),
Metadata: metadata,
CreatedAt: createdAt,
})
}
return result, nil
}
func (c *WebCorpus) BuildRetrieveFilter(_ context.Context, req any) (map[string]any, error) {
input, ok := req.(WebRetrieveInput)
if !ok {
if ptr, isPtr := req.(*WebRetrieveInput); isPtr && ptr != nil {
input = *ptr
} else if req == nil {
return nil, errors.New("web retrieve input is nil")
} else {
return nil, errors.New("invalid web retrieve input")
}
}
// 1. query_id/session_id 至少要有一个,避免跨问题串数据。
queryID := strings.TrimSpace(input.QueryID)
sessionID := strings.TrimSpace(input.SessionID)
if queryID == "" && sessionID == "" {
return nil, errors.New("web retrieve filter requires query_id or session_id")
}
filter := map[string]any{}
if queryID != "" {
filter["query_id"] = queryID
}
if sessionID != "" {
filter["session_id"] = sessionID
}
if domain := strings.TrimSpace(input.Domain); domain != "" {
filter["domain"] = domain
}
return filter, nil
}
func toWebItems(input any) ([]WebIngestItem, error) {
switch value := input.(type) {
case WebIngestItem:
return []WebIngestItem{value}, nil
case *WebIngestItem:
if value == nil {
return nil, errors.New("web ingest item is nil")
}
return []WebIngestItem{*value}, nil
case []WebIngestItem:
return value, nil
case []*WebIngestItem:
items := make([]WebIngestItem, 0, len(value))
for _, ptr := range value {
if ptr == nil {
continue
}
items = append(items, *ptr)
}
return items, nil
default:
return nil, errors.New("invalid web ingest input")
}
}
func buildWebText(item WebIngestItem) string {
parts := make([]string, 0, 3)
if title := strings.TrimSpace(item.Title); title != "" {
parts = append(parts, title)
}
if snippet := strings.TrimSpace(item.Snippet); snippet != "" {
parts = append(parts, snippet)
}
if content := strings.TrimSpace(item.Content); content != "" {
parts = append(parts, content)
}
return strings.Join(parts, "\n\n")
}

View File

@@ -1,208 +0,0 @@
package embed
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
openaiembedding "github.com/cloudwego/eino-ext/libs/acl/openai"
einoembedding "github.com/cloudwego/eino/components/embedding"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
arkmodel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
)
// EinoConfig 描述 Eino embedding 运行参数。
type EinoConfig struct {
APIKey string
BaseURL string
Model string
TimeoutMS int
Dimension int
}
// EinoEmbedder 是基于 Eino 的 embedding 适配器。
//
// 说明:
// 1. 对 infra/rag 暴露统一 []float32 结果,屏蔽底层 SDK 的实现差异。
// 2. 文本 embedding 继续走当前稳定的 OpenAI 兼容链路,避免无关模型受影响。
// 3. 多模态 embedding 模型单独走 Ark 原生 `/embeddings/multimodal`,解决 vision 模型与标准 `/embeddings` 不兼容的问题。
type EinoEmbedder struct {
textClient einoembedding.Embedder
multimodalClient *arkruntime.Client
model string
timeout time.Duration
dimension int
}
func NewEinoEmbedder(ctx context.Context, cfg EinoConfig) (*EinoEmbedder, error) {
if strings.TrimSpace(cfg.APIKey) == "" {
return nil, errors.New("eino embedder api key is empty")
}
if strings.TrimSpace(cfg.Model) == "" {
return nil, errors.New("eino embedder model is empty")
}
timeout := 1200 * time.Millisecond
if cfg.TimeoutMS > 0 {
timeout = time.Duration(cfg.TimeoutMS) * time.Millisecond
}
baseURL := normalizeEmbeddingBaseURL(cfg.BaseURL)
model := strings.TrimSpace(cfg.Model)
httpClient := &http.Client{Timeout: timeout}
// 1. `doubao-embedding-vision-*` 这类模型不支持标准 `/embeddings`。
// 2. 这里直接切到 Ark 原生多模态 embedding API避免再依赖错误 endpoint 拼接。
// 3. 之所以仍保留文本链路,是为了不影响普通 text embedding 模型的既有行为。
if isMultimodalEmbeddingModel(model) {
arkOptions := []arkruntime.ConfigOption{
arkruntime.WithHTTPClient(httpClient),
}
if baseURL != "" {
arkOptions = append(arkOptions, arkruntime.WithBaseUrl(baseURL))
}
return &EinoEmbedder{
multimodalClient: arkruntime.NewClientWithApiKey(
strings.TrimSpace(cfg.APIKey),
arkOptions...,
),
model: model,
timeout: timeout,
dimension: cfg.Dimension,
}, nil
}
clientCfg := &openaiembedding.EmbeddingConfig{
APIKey: strings.TrimSpace(cfg.APIKey),
BaseURL: baseURL,
Model: model,
HTTPClient: httpClient,
}
if cfg.Dimension > 0 {
clientCfg.Dimensions = &cfg.Dimension
}
client, err := openaiembedding.NewEmbeddingClient(ctx, clientCfg)
if err != nil {
return nil, err
}
return &EinoEmbedder{
textClient: client,
model: model,
timeout: timeout,
dimension: cfg.Dimension,
}, nil
}
func (e *EinoEmbedder) Embed(ctx context.Context, texts []string, _ string) (result [][]float32, err error) {
if e == nil {
return nil, errors.New("eino embedder is not initialized")
}
if len(texts) == 0 {
return nil, nil
}
callCtx := ctx
cancel := func() {}
if e.timeout > 0 {
callCtx, cancel = context.WithTimeout(ctx, e.timeout)
}
defer cancel()
// 1. 第三方 SDK 一旦 panic不应该穿透到 RAG 主链路。
// 2. 这里统一在模型调用边界 recover并转成 error 交给上层做降级。
// 3. 这样 memory 主写链路和 agent 主回复链路都不会因为向量同步失败被直接打崩。
defer func() {
if recovered := recover(); recovered != nil {
err = fmt.Errorf("eino embedder panic recovered: %v", recovered)
result = nil
}
}()
if e.multimodalClient != nil {
return e.embedTextsWithMultimodalAPI(callCtx, texts)
}
if e.textClient == nil {
return nil, errors.New("eino embedder client is not initialized")
}
vectors, err := e.textClient.EmbedStrings(callCtx, texts, einoembedding.WithModel(e.model))
if err != nil {
return nil, err
}
result = make([][]float32, 0, len(vectors))
for _, vector := range vectors {
converted := make([]float32, len(vector))
for i, value := range vector {
converted[i] = float32(value)
}
result = append(result, converted)
}
return result, nil
}
func (e *EinoEmbedder) embedTextsWithMultimodalAPI(ctx context.Context, texts []string) ([][]float32, error) {
if e.multimodalClient == nil {
return nil, errors.New("eino multimodal embedder client is not initialized")
}
vectors := make([][]float32, 0, len(texts))
for _, text := range texts {
text := text
req := arkmodel.MultiModalEmbeddingRequest{
Model: e.model,
Input: []arkmodel.MultimodalEmbeddingInput{
{
Type: arkmodel.MultiModalEmbeddingInputTypeText,
Text: &text,
},
},
}
if e.dimension > 0 {
req.Dimensions = &e.dimension
}
// 1. Ark 的多模态 embedding 请求体是“单条内容由多个 part 组成”。
// 2. 当前 RAG 这里只传文本,因此每段文本单独发一次,避免把多段文本错误拼成同一个 multimodal sample。
// 3. 一旦后续真的要做批量多模态 embedding再单独扩展 batch 接口,而不是在这里偷改语义。
resp, err := e.multimodalClient.CreateMultiModalEmbeddings(ctx, req)
if err != nil {
return nil, err
}
converted := make([]float32, len(resp.Data.Embedding))
copy(converted, resp.Data.Embedding)
vectors = append(vectors, converted)
}
return vectors, nil
}
func isMultimodalEmbeddingModel(model string) bool {
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "doubao-embedding-vision-")
}
func normalizeEmbeddingBaseURL(raw string) string {
baseURL := strings.TrimRight(strings.TrimSpace(raw), "/")
if baseURL == "" {
return ""
}
lowerBaseURL := strings.ToLower(baseURL)
// 1. 配置里应填写 Ark 服务根路径,而不是具体 embedding endpoint。
// 2. 这里兼容两类常见误配:`/embeddings` 和 `/embeddings/multimodal`。
// 3. 统一回退到 `/api/v3` 根路径后,再由对应 SDK 自己追加正确后缀,避免最终 URL 重复拼接。
if strings.HasSuffix(lowerBaseURL, "/embeddings/multimodal") {
return strings.TrimSuffix(baseURL, baseURL[len(baseURL)-len("/embeddings/multimodal"):])
}
if strings.HasSuffix(lowerBaseURL, "/embeddings") {
return strings.TrimSuffix(baseURL, baseURL[len(baseURL)-len("/embeddings"):])
}
return baseURL
}

View File

@@ -1,46 +0,0 @@
package embed
import (
"context"
"crypto/sha256"
"encoding/binary"
"strings"
)
const defaultDim = 16
// MockEmbedder 是本地可运行的占位向量化实现。
//
// 说明:
// 1. 该实现用于开发阶段打通链路,不代表真实语义向量质量;
// 2. 后续可替换为 Eino embedding 实现,接口保持不变。
type MockEmbedder struct {
dim int
}
func NewMockEmbedder(dim int) *MockEmbedder {
if dim <= 0 {
dim = defaultDim
}
return &MockEmbedder{dim: dim}
}
func (e *MockEmbedder) Embed(_ context.Context, texts []string, _ string) ([][]float32, error) {
result := make([][]float32, 0, len(texts))
for _, text := range texts {
result = append(result, e.embedOne(text))
}
return result, nil
}
func (e *MockEmbedder) embedOne(text string) []float32 {
normalized := strings.TrimSpace(strings.ToLower(text))
sum := sha256.Sum256([]byte(normalized))
vec := make([]float32, e.dim)
for i := 0; i < e.dim; i++ {
offset := (i * 4) % len(sum)
v := binary.BigEndian.Uint32(sum[offset : offset+4])
vec[i] = float32(v%1000) / 1000
}
return vec
}

View File

@@ -1,142 +0,0 @@
package rag
import (
"context"
"fmt"
"log"
"os"
"strings"
ragchunk "github.com/LoveLosita/smartflow/backend/infra/rag/chunk"
ragconfig "github.com/LoveLosita/smartflow/backend/infra/rag/config"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
ragembed "github.com/LoveLosita/smartflow/backend/infra/rag/embed"
ragrerank "github.com/LoveLosita/smartflow/backend/infra/rag/rerank"
ragstore "github.com/LoveLosita/smartflow/backend/infra/rag/store"
)
// FactoryDeps 描述 Runtime 工厂所需的可选依赖。
//
// 说明:
// 1. Logger 仅作为“当前项目尚无统一日志系统”时的默认落点;
// 2. Observer 是正式的统一观测插槽,后续可替换为项目级 logger/metrics/tracing 适配器;
// 3. 业务侧仍然只拿 Runtime不直接碰底层装配细节。
type FactoryDeps struct {
Logger *log.Logger
Observer Observer
}
// NewRuntimeFromConfig 按配置统一组装 RAG Runtime。
//
// 设计说明:
// 1. 所有底层实现选择都收口到这里,业务侧不再自行 new store/embedder/reranker
// 2. 即使后续引入更多 provider也应优先扩展本工厂而不是把选择逻辑扩散到业务模块
// 3. 观测能力也在此统一注入,避免 runtime/store/pipeline 各自偷偷打印日志。
func NewRuntimeFromConfig(ctx context.Context, cfg ragconfig.Config, deps FactoryDeps) (Runtime, error) {
logger, observer := normalizeFactoryDeps(deps)
embedder, err := buildEmbedder(ctx, cfg)
if err != nil {
return nil, err
}
store, err := buildStore(cfg, logger, observer)
if err != nil {
return nil, err
}
reranker, err := buildReranker(cfg, observer)
if err != nil {
return nil, err
}
pipeline := core.NewPipeline(ragchunk.NewTextChunker(), embedder, store, reranker)
pipeline.SetLogger(logger)
pipeline.SetObserver(observer)
return newRuntime(cfg, pipeline, observer), nil
}
func normalizeFactoryDeps(deps FactoryDeps) (*log.Logger, Observer) {
logger := deps.Logger
if logger == nil {
logger = log.Default()
}
observer := deps.Observer
if observer == nil {
observer = NewLoggerObserver(logger)
}
return logger, observer
}
func buildEmbedder(ctx context.Context, cfg ragconfig.Config) (core.Embedder, error) {
switch strings.ToLower(strings.TrimSpace(cfg.EmbedProvider)) {
case "", "mock":
return ragembed.NewMockEmbedder(cfg.EmbedDimension), nil
case "eino":
// 1. RAG embedding 与普通 LLM 链路保持同一套密钥来源,统一直接读取 ARK_API_KEY
// 2. 这样可以避免再维护一层 “env 名称配置 -> 再读环境变量” 的间接映射,减少配置分叉;
// 3. 若后续真的需要多套 embedding 凭据,再显式设计独立字段,而不是继续隐式透传 env 名称。
apiKey := strings.TrimSpace(os.Getenv("ARK_API_KEY"))
if apiKey == "" {
return nil, fmt.Errorf("rag embed api key is empty: env=%s", "ARK_API_KEY")
}
return ragembed.NewEinoEmbedder(ctx, ragembed.EinoConfig{
APIKey: apiKey,
BaseURL: cfg.EmbedBaseURL,
Model: cfg.EmbedModel,
TimeoutMS: cfg.EmbedTimeoutMS,
Dimension: cfg.EmbedDimension,
})
default:
return nil, fmt.Errorf("unsupported rag embed provider: %s", cfg.EmbedProvider)
}
}
func buildStore(cfg ragconfig.Config, logger *log.Logger, observer Observer) (core.VectorStore, error) {
switch strings.ToLower(strings.TrimSpace(cfg.Store)) {
case "", "inmemory":
return ragstore.NewInMemoryVectorStore(), nil
case "milvus":
return ragstore.NewMilvusStore(ragstore.MilvusConfig{
Address: cfg.MilvusAddress,
Token: cfg.MilvusToken,
DBName: cfg.MilvusDBName,
CollectionName: cfg.MilvusCollectionName,
RequestTimeoutMS: cfg.MilvusRequestTimeoutMS,
Dimension: cfg.EmbedDimension,
MetricType: cfg.MilvusMetricType,
Logger: logger,
Observer: observer,
})
default:
return nil, fmt.Errorf("unsupported rag store: %s", cfg.Store)
}
}
func buildReranker(cfg ragconfig.Config, observer Observer) (core.Reranker, error) {
if !cfg.RerankerEnabled {
return ragrerank.NewNoopReranker(), nil
}
switch strings.ToLower(strings.TrimSpace(cfg.RerankerProvider)) {
case "", "noop":
return ragrerank.NewNoopReranker(), nil
case "eino":
if observer != nil {
observer.Observe(context.Background(), ObserveEvent{
Level: ObserveLevelWarn,
Component: "factory",
Operation: "reranker_fallback",
Fields: map[string]any{
"provider": "eino",
"status": "fallback",
"fallback_target": "noop",
"reason": "reranker_not_implemented",
},
})
}
return ragrerank.NewNoopReranker(), nil
default:
return nil, fmt.Errorf("unsupported rag reranker provider: %s", cfg.RerankerProvider)
}
}

View File

@@ -1,32 +0,0 @@
package rag
import (
"log"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
)
// ObserveLevel 对外暴露统一观测等级别名,避免启动层直接依赖 core 细节。
type ObserveLevel = core.ObserveLevel
const (
ObserveLevelInfo = core.ObserveLevelInfo
ObserveLevelWarn = core.ObserveLevelWarn
ObserveLevelError = core.ObserveLevelError
)
// ObserveEvent 对外暴露统一观测事件别名。
type ObserveEvent = core.ObserveEvent
// Observer 对外暴露统一观测接口别名。
type Observer = core.Observer
// NewNopObserver 返回空实现。
func NewNopObserver() Observer {
return core.NewNopObserver()
}
// NewLoggerObserver 返回标准日志适配器。
func NewLoggerObserver(logger *log.Logger) Observer {
return core.NewLoggerObserver(logger)
}

View File

@@ -1,23 +0,0 @@
package rag
import (
"github.com/LoveLosita/smartflow/backend/infra/rag/chunk"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
"github.com/LoveLosita/smartflow/backend/infra/rag/embed"
"github.com/LoveLosita/smartflow/backend/infra/rag/rerank"
"github.com/LoveLosita/smartflow/backend/infra/rag/store"
)
// NewDefaultPipeline 构造默认可运行的 RAG Pipeline。
//
// 当前策略:
// 1. 默认使用本地 MockEmbedder + InMemoryStore保证零外部依赖可运行
// 2. 后续切 Milvus / Eino 时仅替换依赖,不改业务调用方式。
func NewDefaultPipeline() *core.Pipeline {
return core.NewPipeline(
chunk.NewTextChunker(),
embed.NewMockEmbedder(16),
store.NewInMemoryVectorStore(),
rerank.NewNoopReranker(),
)
}

View File

@@ -1,19 +0,0 @@
package rerank
import (
"context"
"errors"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
)
// EinoReranker 是 Eino 重排器占位实现。
type EinoReranker struct{}
func NewEinoReranker() *EinoReranker {
return &EinoReranker{}
}
func (r *EinoReranker) Rerank(_ context.Context, _ string, _ []core.ScoredChunk, _ int) ([]core.ScoredChunk, error) {
return nil, errors.New("eino reranker is not implemented yet")
}

View File

@@ -1,30 +0,0 @@
package rerank
import (
"context"
"sort"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
)
// NoopReranker 是默认重排器(仅按原 score 排序)。
type NoopReranker struct{}
func NewNoopReranker() *NoopReranker {
return &NoopReranker{}
}
func (r *NoopReranker) Rerank(_ context.Context, _ string, candidates []core.ScoredChunk, topK int) ([]core.ScoredChunk, error) {
if len(candidates) == 0 {
return nil, nil
}
sorted := make([]core.ScoredChunk, len(candidates))
copy(sorted, candidates)
sort.SliceStable(sorted, func(i, j int) bool {
return sorted[i].Score > sorted[j].Score
})
if topK <= 0 || topK >= len(sorted) {
return sorted, nil
}
return sorted[:topK], nil
}

View File

@@ -1,98 +0,0 @@
package retrieve
import (
"context"
"fmt"
"strings"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
)
// VectorRetriever 是通用检索器embed + vector search
type VectorRetriever struct {
embedder core.Embedder
store core.VectorStore
}
func NewVectorRetriever(embedder core.Embedder, store core.VectorStore) *VectorRetriever {
return &VectorRetriever{
embedder: embedder,
store: store,
}
}
func (r *VectorRetriever) Retrieve(ctx context.Context, req core.RetrieveRequest) (result []core.ScoredChunk, err error) {
defer func() {
recovered := recover()
if recovered == nil {
return
}
err = fmt.Errorf("vector retriever panic recovered: %v", recovered)
}()
if r == nil || r.embedder == nil || r.store == nil {
return nil, core.ErrNilDependency
}
query := strings.TrimSpace(req.Query)
if query == "" {
return nil, core.ErrInvalidQuery
}
topK := req.TopK
if topK <= 0 {
topK = 8
}
action := strings.TrimSpace(req.Action)
if action == "" {
action = "search"
}
vectors, err := r.embedder.Embed(ctx, []string{query}, action)
if err != nil {
return nil, err
}
if len(vectors) != 1 {
return nil, fmt.Errorf("embedding query length mismatch: %d", len(vectors))
}
rows, err := r.store.Search(ctx, core.VectorSearchRequest{
QueryVector: vectors[0],
TopK: topK,
Filter: req.Filter,
})
if err != nil {
return nil, err
}
result = make([]core.ScoredChunk, 0, len(rows))
for _, row := range rows {
if row.Score < req.Threshold {
continue
}
result = append(result, core.ScoredChunk{
ChunkID: row.Row.ID,
DocumentID: asString(row.Row.Metadata["document_id"]),
Text: row.Row.Text,
Score: row.Score,
Metadata: cloneMap(row.Row.Metadata),
})
}
return result, nil
}
func cloneMap(src map[string]any) map[string]any {
if len(src) == 0 {
return map[string]any{}
}
dst := make(map[string]any, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}
func asString(v any) string {
if v == nil {
return ""
}
return fmt.Sprintf("%v", v)
}

View File

@@ -1,434 +0,0 @@
package rag
import (
"context"
"fmt"
"runtime/debug"
"strings"
"time"
ragconfig "github.com/LoveLosita/smartflow/backend/infra/rag/config"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
"github.com/LoveLosita/smartflow/backend/infra/rag/corpus"
)
type runtime struct {
cfg ragconfig.Config
pipeline *core.Pipeline
memoryCorpus *corpus.MemoryCorpus
webCorpus *corpus.WebCorpus
observer Observer
}
func newRuntime(cfg ragconfig.Config, pipeline *core.Pipeline, observer Observer) Runtime {
if observer == nil {
observer = NewNopObserver()
}
return &runtime{
cfg: cfg,
pipeline: pipeline,
memoryCorpus: corpus.NewMemoryCorpus(),
webCorpus: corpus.NewWebCorpus(),
observer: observer,
}
}
// IngestMemory 统一承接记忆语料入库。
func (r *runtime) IngestMemory(ctx context.Context, req MemoryIngestRequest) (result *IngestResult, err error) {
defer r.recoverPublicPanic(ctx, req.TraceID, "memory", normalizeAction(req.Action, "add"), "ingest", &err)
items := make([]corpus.MemoryIngestItem, 0, len(req.Items))
for _, item := range req.Items {
items = append(items, corpus.MemoryIngestItem{
MemoryID: item.MemoryID,
UserID: item.UserID,
ConversationID: item.ConversationID,
AssistantID: item.AssistantID,
RunID: item.RunID,
MemoryType: item.MemoryType,
Title: item.Title,
Content: item.Content,
Confidence: item.Confidence,
Importance: item.Importance,
SensitivityLevel: item.SensitivityLevel,
IsExplicit: item.IsExplicit,
Status: item.Status,
TTLAt: item.TTLAt,
CreatedAt: item.CreatedAt,
})
}
return r.ingestWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, items, req.Action)
}
// RetrieveMemory 统一承接记忆语料检索。
func (r *runtime) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (result *RetrieveResult, err error) {
defer r.recoverPublicPanic(ctx, req.TraceID, "memory", normalizeAction(req.Action, "search"), "retrieve", &err)
corpusInput := corpus.MemoryRetrieveInput{
UserID: req.UserID,
ConversationID: req.ConversationID,
AssistantID: req.AssistantID,
RunID: req.RunID,
}
if len(req.MemoryTypes) == 1 {
corpusInput.MemoryType = req.MemoryTypes[0]
}
result, err = r.retrieveWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, core.RetrieveRequest{
Query: req.Query,
TopK: normalizeTopK(req.TopK, r.cfg.TopK),
Threshold: normalizeThreshold(req.Threshold, r.cfg.Threshold),
Action: normalizeAction(req.Action, "search"),
CorpusInput: corpusInput,
})
if err != nil {
return nil, err
}
if len(req.MemoryTypes) <= 1 {
return result, nil
}
// 1. 当前底层过滤仍以等值条件为主,先保持 Runtime 做多类型二次筛选;
// 2. 这样可以避免把 “memory_type in (...)” 的实现细节扩散到所有 Store
// 3. 等后续底层过滤能力统一后,再考虑把该逻辑继续下沉。
allowed := make(map[string]struct{}, len(req.MemoryTypes))
for _, item := range req.MemoryTypes {
value := strings.TrimSpace(strings.ToLower(item))
if value == "" {
continue
}
allowed[value] = struct{}{}
}
filtered := make([]RetrieveHit, 0, len(result.Items))
for _, item := range result.Items {
memoryType := strings.TrimSpace(strings.ToLower(asString(item.Metadata["memory_type"])))
if len(allowed) > 0 {
if _, ok := allowed[memoryType]; !ok {
continue
}
}
filtered = append(filtered, item)
}
result.Items = filtered
if req.TopK > 0 && len(result.Items) > req.TopK {
result.Items = result.Items[:req.TopK]
}
return result, nil
}
// DeleteMemory 删除记忆语料中的指定向量。
func (r *runtime) DeleteMemory(ctx context.Context, documentIDs []string) (err error) {
defer r.recoverPublicPanic(ctx, "", "memory", "delete", "delete", &err)
if r == nil || r.pipeline == nil || len(documentIDs) == 0 {
return nil
}
return r.pipeline.Delete(ctx, documentIDs)
}
// IngestWeb 统一承接网页语料入库。
func (r *runtime) IngestWeb(ctx context.Context, req WebIngestRequest) (result *IngestResult, err error) {
defer r.recoverPublicPanic(ctx, req.TraceID, "web", normalizeAction(req.Action, "add"), "ingest", &err)
items := make([]corpus.WebIngestItem, 0, len(req.Items))
for _, item := range req.Items {
items = append(items, corpus.WebIngestItem{
URL: item.URL,
Title: item.Title,
Content: item.Content,
Snippet: item.Snippet,
Domain: item.Domain,
QueryID: item.QueryID,
SessionID: item.SessionID,
PublishedAt: item.PublishedAt,
FetchedAt: item.FetchedAt,
SourceRank: item.SourceRank,
})
}
return r.ingestWithCorpus(ctx, req.TraceID, "web", r.webCorpus, items, req.Action)
}
// RetrieveWeb 统一承接网页语料检索。
func (r *runtime) RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (result *RetrieveResult, err error) {
defer r.recoverPublicPanic(ctx, req.TraceID, "web", normalizeAction(req.Action, "search"), "retrieve", &err)
return r.retrieveWithCorpus(ctx, req.TraceID, "web", r.webCorpus, core.RetrieveRequest{
Query: req.Query,
TopK: normalizeTopK(req.TopK, r.cfg.TopK),
Threshold: normalizeThreshold(req.Threshold, r.cfg.Threshold),
Action: normalizeAction(req.Action, "search"),
CorpusInput: corpus.WebRetrieveInput{
QueryID: req.QueryID,
SessionID: req.SessionID,
Domain: req.Domain,
},
})
}
func (r *runtime) ingestWithCorpus(
ctx context.Context,
traceID string,
corpusName string,
adapter core.CorpusAdapter,
input any,
action string,
) (*IngestResult, error) {
start := time.Now()
if r == nil || r.pipeline == nil || adapter == nil {
return nil, core.ErrNilDependency
}
action = normalizeAction(action, "add")
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
docs, err := adapter.BuildIngestDocuments(observeCtx, input)
if err != nil {
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelError,
Component: "runtime",
Operation: "ingest",
Fields: map[string]any{
"status": "failed",
"latency_ms": time.Since(start).Milliseconds(),
"phase": "build_documents",
"error": err,
"error_code": core.ClassifyErrorCode(err),
"input_count": estimateInputCount(input),
},
})
return nil, err
}
docIDs := make([]string, 0, len(docs))
for _, doc := range docs {
docIDs = append(docIDs, doc.ID)
}
result, err := r.pipeline.IngestDocuments(observeCtx, adapter.Name(), docs, core.IngestOption{
Chunk: core.ChunkOption{
ChunkSize: r.cfg.ChunkSize,
ChunkOverlap: r.cfg.ChunkOverlap,
},
Action: action,
})
if err != nil {
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelError,
Component: "runtime",
Operation: "ingest",
Fields: map[string]any{
"status": "failed",
"latency_ms": time.Since(start).Milliseconds(),
"document_count": len(docs),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelInfo,
Component: "runtime",
Operation: "ingest",
Fields: map[string]any{
"status": "success",
"latency_ms": time.Since(start).Milliseconds(),
"document_count": result.DocumentCount,
"chunk_count": result.ChunkCount,
},
})
return &IngestResult{
DocumentCount: result.DocumentCount,
ChunkCount: result.ChunkCount,
DocumentIDs: docIDs,
}, nil
}
func (r *runtime) retrieveWithCorpus(
ctx context.Context,
traceID string,
corpusName string,
adapter core.CorpusAdapter,
req core.RetrieveRequest,
) (*RetrieveResult, error) {
start := time.Now()
if r == nil || r.pipeline == nil || adapter == nil {
return nil, core.ErrNilDependency
}
action := normalizeAction(req.Action, "search")
req.Action = action
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
timeoutCtx := observeCtx
cancel := func() {}
if r.cfg.RetrieveTimeoutMS > 0 {
timeoutCtx, cancel = context.WithTimeout(observeCtx, time.Duration(r.cfg.RetrieveTimeoutMS)*time.Millisecond)
}
defer cancel()
result, err := r.pipeline.Retrieve(timeoutCtx, adapter, req)
if err != nil {
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelError,
Component: "runtime",
Operation: "retrieve",
Fields: map[string]any{
"status": "failed",
"latency_ms": time.Since(start).Milliseconds(),
"query_len": len(strings.TrimSpace(req.Query)),
"top_k": req.TopK,
"threshold": req.Threshold,
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
items := make([]RetrieveHit, 0, len(result.Items))
for _, item := range result.Items {
items = append(items, RetrieveHit{
ChunkID: item.ChunkID,
DocumentID: item.DocumentID,
Text: item.Text,
Score: item.Score,
Metadata: cloneMap(item.Metadata),
})
}
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelInfo,
Component: "runtime",
Operation: "retrieve",
Fields: map[string]any{
"status": "success",
"latency_ms": time.Since(start).Milliseconds(),
"query_len": len(strings.TrimSpace(req.Query)),
"top_k": req.TopK,
"threshold": req.Threshold,
"raw_count": result.RawCount,
"hit_count": len(result.Items),
"fallback_used": result.FallbackUsed,
"fallback_reason": result.FallbackReason,
},
})
return &RetrieveResult{
Items: items,
RawCount: result.RawCount,
FallbackUsed: result.FallbackUsed,
FallbackReason: result.FallbackReason,
}, nil
}
func (r *runtime) observe(ctx context.Context, event ObserveEvent) {
if r == nil || r.observer == nil {
return
}
r.observer.Observe(ctx, event)
}
func (r *runtime) recoverPublicPanic(
ctx context.Context,
traceID string,
corpusName string,
action string,
operation string,
errPtr *error,
) {
recovered := recover()
if recovered == nil || errPtr == nil {
return
}
// 1. runtime 是 RAG Infra 对业务侧暴露的最终方法面,任何下层 panic 都不应再穿透到业务协程。
// 2. 这里统一把 panic 转成 error并补一条结构化观测方便继续排查是哪一层依赖失控。
// 3. 保留 stack 是为了在“进程不崩”的前提下仍能定位根因,避免只剩一句 recovered 无法复盘。
panicErr := fmt.Errorf("rag runtime panic recovered: corpus=%s operation=%s panic=%v", corpusName, operation, recovered)
*errPtr = panicErr
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelError,
Component: "runtime",
Operation: operation + "_panic_recovered",
Fields: map[string]any{
"status": "failed",
"panic": fmt.Sprintf("%v", recovered),
"panic_type": fmt.Sprintf("%T", recovered),
"error": panicErr,
"error_code": core.ClassifyErrorCode(panicErr),
"stack": string(debug.Stack()),
},
})
}
func newObserveContext(ctx context.Context, traceID string, corpusName string, action string) context.Context {
fields := map[string]any{
"corpus": corpusName,
"action": action,
}
if traceID = strings.TrimSpace(traceID); traceID != "" {
fields["trace_id"] = traceID
}
return core.WithObserveFields(ctx, fields)
}
func estimateInputCount(input any) int {
switch value := input.(type) {
case []corpus.MemoryIngestItem:
return len(value)
case []corpus.WebIngestItem:
return len(value)
default:
return 0
}
}
func normalizeAction(action string, fallback string) string {
action = strings.TrimSpace(action)
if action == "" {
return fallback
}
return action
}
func normalizeTopK(topK int, fallback int) int {
if topK > 0 {
return topK
}
if fallback > 0 {
return fallback
}
return 8
}
func normalizeThreshold(threshold float64, fallback float64) float64 {
if threshold >= 0 {
return threshold
}
if fallback >= 0 {
return fallback
}
return 0
}
func cloneMap(src map[string]any) map[string]any {
if len(src) == 0 {
return map[string]any{}
}
dst := make(map[string]any, len(src))
for key, value := range src {
dst[key] = value
}
return dst
}
func asString(v any) string {
if v == nil {
return ""
}
return strings.TrimSpace(fmt.Sprintf("%v", v))
}

View File

@@ -1,118 +0,0 @@
package rag
import (
"context"
"time"
)
// Runtime 是 RAG Infra 对业务侧暴露的唯一稳定方法面。
//
// 职责边界:
// 1. 负责承接 memory/web 两类语料的统一入库与检索入口;
// 2. 负责屏蔽底层 Pipeline / Store / Embedder / Reranker 的装配细节;
// 3. 不负责 provider 搜索、HTML 抓取、prompt 注入等业务语义。
type Runtime interface {
IngestMemory(ctx context.Context, req MemoryIngestRequest) (*IngestResult, error)
RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (*RetrieveResult, error)
DeleteMemory(ctx context.Context, documentIDs []string) error
IngestWeb(ctx context.Context, req WebIngestRequest) (*IngestResult, error)
RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (*RetrieveResult, error)
}
// IngestResult 描述一次统一入库执行摘要。
type IngestResult struct {
DocumentCount int
ChunkCount int
DocumentIDs []string
}
// RetrieveHit 是对业务侧暴露的统一命中项。
type RetrieveHit struct {
ChunkID string
DocumentID string
Text string
Score float64
Metadata map[string]any
}
// RetrieveResult 描述一次检索执行摘要。
type RetrieveResult struct {
Items []RetrieveHit
RawCount int
FallbackUsed bool
FallbackReason string
}
// MemoryIngestItem 是 memory 语料入库项。
type MemoryIngestItem struct {
MemoryID int64
UserID int
ConversationID string
AssistantID string
RunID string
MemoryType string
Title string
Content string
Confidence float64
Importance float64
SensitivityLevel int
IsExplicit bool
Status string
TTLAt *time.Time
CreatedAt *time.Time
}
// MemoryIngestRequest 描述一次记忆向量入库请求。
type MemoryIngestRequest struct {
TraceID string
Action string
Items []MemoryIngestItem
}
// MemoryRetrieveRequest 描述一次记忆检索请求。
type MemoryRetrieveRequest struct {
TraceID string
Query string
TopK int
Threshold float64
Action string
UserID int
ConversationID string
AssistantID string
RunID string
MemoryTypes []string
}
// WebIngestItem 是网页语料入库项。
type WebIngestItem struct {
URL string
Title string
Content string
Snippet string
Domain string
QueryID string
SessionID string
PublishedAt *time.Time
FetchedAt *time.Time
SourceRank int
}
// WebIngestRequest 描述一次网页语料入库请求。
type WebIngestRequest struct {
TraceID string
Action string
Items []WebIngestItem
}
// WebRetrieveRequest 描述一次网页检索请求。
type WebRetrieveRequest struct {
TraceID string
Query string
TopK int
Threshold float64
Action string
QueryID string
SessionID string
Domain string
}

View File

@@ -1,182 +0,0 @@
package store
import (
"context"
"errors"
"fmt"
"math"
"sort"
"strings"
"sync"
"time"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
)
// InMemoryVectorStore 是本地开发用向量存储实现。
//
// 注意:
// 1. 仅用于开发调试,不建议生产使用;
// 2. 真实环境可替换为 MilvusStore接口保持一致。
type InMemoryVectorStore struct {
mu sync.RWMutex
rows map[string]core.VectorRow
}
func NewInMemoryVectorStore() *InMemoryVectorStore {
return &InMemoryVectorStore{
rows: make(map[string]core.VectorRow),
}
}
func (s *InMemoryVectorStore) Upsert(_ context.Context, rows []core.VectorRow) error {
if s == nil {
return errors.New("inmemory vector store is nil")
}
if len(rows) == 0 {
return nil
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
if s.rows == nil {
s.rows = make(map[string]core.VectorRow)
}
for _, row := range rows {
current, exists := s.rows[row.ID]
if exists {
row.CreatedAt = current.CreatedAt
row.UpdatedAt = now
} else {
if row.CreatedAt.IsZero() {
row.CreatedAt = now
}
row.UpdatedAt = now
}
s.rows[row.ID] = row
}
return nil
}
func (s *InMemoryVectorStore) Search(_ context.Context, req core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
if s == nil {
return nil, errors.New("inmemory vector store is nil")
}
topK := req.TopK
if topK <= 0 {
topK = 8
}
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]core.ScoredVectorRow, 0, len(s.rows))
for _, row := range s.rows {
if !matchMetadataFilter(row.Metadata, req.Filter) {
continue
}
score := cosineSimilarity(req.QueryVector, row.Vector)
result = append(result, core.ScoredVectorRow{
Row: row,
Score: score,
})
}
sort.SliceStable(result, func(i, j int) bool {
return result[i].Score > result[j].Score
})
if len(result) <= topK {
return result, nil
}
return result[:topK], nil
}
func (s *InMemoryVectorStore) Delete(_ context.Context, ids []string) error {
if s == nil {
return errors.New("inmemory vector store is nil")
}
if len(ids) == 0 {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
for _, id := range ids {
delete(s.rows, id)
}
return nil
}
func (s *InMemoryVectorStore) Get(_ context.Context, ids []string) ([]core.VectorRow, error) {
if s == nil {
return nil, errors.New("inmemory vector store is nil")
}
if len(ids) == 0 {
return nil, nil
}
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]core.VectorRow, 0, len(ids))
for _, id := range ids {
row, exists := s.rows[id]
if !exists {
continue
}
result = append(result, row)
}
return result, nil
}
func cosineSimilarity(a, b []float32) float64 {
if len(a) == 0 || len(b) == 0 {
return 0
}
n := len(a)
if len(b) < n {
n = len(b)
}
if n == 0 {
return 0
}
var dot, normA, normB float64
for i := 0; i < n; i++ {
av := float64(a[i])
bv := float64(b[i])
dot += av * bv
normA += av * av
normB += bv * bv
}
if normA == 0 || normB == 0 {
return 0
}
return dot / (math.Sqrt(normA) * math.Sqrt(normB))
}
func matchMetadataFilter(metadata map[string]any, filter map[string]any) bool {
if len(filter) == 0 {
return true
}
for key, wanted := range filter {
got, exists := metadata[key]
if !exists {
return false
}
if !equalAny(got, wanted) {
return false
}
}
return true
}
func equalAny(left any, right any) bool {
return toString(left) == toString(right)
}
func toString(v any) string {
if v == nil {
return ""
}
return fmtAny(v)
}
func fmtAny(v any) string {
return strings.TrimSpace(fmt.Sprintf("%v", v))
}

View File

@@ -1,927 +0,0 @@
package store
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/LoveLosita/smartflow/backend/infra/rag/core"
)
// MilvusConfig 描述 Milvus REST 存储配置。
type MilvusConfig struct {
// Address 应指向 Milvus REST 入口。
// 当前项目联调验证使用 195309091 仅用于 health/metrics不承载本文实现所走的 REST API。
Address string
Token string
DBName string
CollectionName string
RequestTimeoutMS int
Dimension int
MetricType string
Logger *log.Logger
Observer core.Observer
}
// MilvusStore 是基于 Milvus REST API 的向量存储实现。
//
// 设计说明:
// 1. 本实现优先保证“项目内可接入、可管理、可灰度”,不强依赖额外 SDK
// 2. 通过固定字段 + metadata JSON 的方式兼顾过滤能力与元数据完整性;
// 3. collection 在首次写入时自动创建,避免启动期额外初始化脚本。
type MilvusStore struct {
cfg MilvusConfig
client *http.Client
observer core.Observer
mu sync.Mutex
ensured bool
}
const (
milvusPrimaryField = "id"
milvusVectorField = "vector"
milvusTextField = "text"
milvusMetadataField = "metadata"
milvusCorpusField = "corpus"
milvusDocumentField = "document_id"
milvusUserIDField = "user_id"
milvusAssistantField = "assistant_id"
milvusConvField = "conversation_id"
milvusRunField = "run_id"
milvusMemoryType = "memory_type"
milvusQueryIDField = "query_id"
milvusSessionField = "session_id"
milvusDomainField = "domain"
milvusChunkOrder = "chunk_order"
milvusUpdatedAtField = "updated_at"
)
var milvusFilterFieldMap = map[string]string{
"corpus": milvusCorpusField,
"document_id": milvusDocumentField,
"user_id": milvusUserIDField,
"assistant_id": milvusAssistantField,
"conversation_id": milvusConvField,
"run_id": milvusRunField,
"memory_type": milvusMemoryType,
"query_id": milvusQueryIDField,
"session_id": milvusSessionField,
"domain": milvusDomainField,
"chunk_order": milvusChunkOrder,
}
func NewMilvusStore(cfg MilvusConfig) (*MilvusStore, error) {
cfg.Address = strings.TrimRight(strings.TrimSpace(cfg.Address), "/")
if cfg.Address == "" {
return nil, errors.New("milvus address is empty")
}
if cfg.CollectionName == "" {
cfg.CollectionName = "smartflow_rag_chunks"
}
if cfg.MetricType == "" {
cfg.MetricType = "COSINE"
}
if cfg.RequestTimeoutMS <= 0 {
cfg.RequestTimeoutMS = 1500
}
if cfg.Logger == nil {
cfg.Logger = log.Default()
}
if cfg.Observer == nil {
cfg.Observer = core.NewLoggerObserver(cfg.Logger)
}
return &MilvusStore{
cfg: cfg,
client: &http.Client{Timeout: time.Duration(cfg.RequestTimeoutMS) * time.Millisecond},
observer: cfg.Observer,
}, nil
}
func (s *MilvusStore) Upsert(ctx context.Context, rows []core.VectorRow) error {
if err := s.ensureReady(); err != nil {
return err
}
start := time.Now()
if len(rows) == 0 {
return nil
}
if err := s.ensureCollection(ctx, len(rows[0].Vector)); err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "upsert",
Fields: map[string]any{
"status": "failed",
"row_count": len(rows),
"vector_dim": len(rows[0].Vector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return err
}
data := make([]map[string]any, 0, len(rows))
for _, row := range rows {
item := mapRowToMilvusEntity(row)
data = append(data, item)
}
_, err := s.postJSON(ctx, "/v2/vectordb/entities/upsert", map[string]any{
"collectionName": s.cfg.CollectionName,
"data": data,
"dbName": blankToNil(s.cfg.DBName),
})
if err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "upsert",
Fields: map[string]any{
"status": "failed",
"row_count": len(rows),
"vector_dim": len(rows[0].Vector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return err
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "upsert",
Fields: map[string]any{
"status": "success",
"row_count": len(rows),
"vector_dim": len(rows[0].Vector),
"latency_ms": time.Since(start).Milliseconds(),
},
})
return err
}
func (s *MilvusStore) Search(ctx context.Context, req core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
if err := s.ensureReady(); err != nil {
return nil, err
}
start := time.Now()
if len(req.QueryVector) == 0 {
return nil, nil
}
if err := s.ensureCollection(ctx, len(req.QueryVector)); err != nil {
if isMilvusCollectionMissing(err) {
return nil, nil
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "failed",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
filterExpr, err := buildMilvusFilter(req.Filter)
if err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "failed",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
body := map[string]any{
"collectionName": s.cfg.CollectionName,
"data": [][]float32{req.QueryVector},
"annsField": milvusVectorField,
"limit": normalizeMilvusTopK(req.TopK),
"outputFields": milvusOutputFields(false),
}
if filterExpr != "" {
body["filter"] = filterExpr
}
if s.cfg.DBName != "" {
body["dbName"] = s.cfg.DBName
}
respBody, err := s.postJSON(ctx, "/v2/vectordb/entities/search", body)
if err != nil {
if isMilvusCollectionMissing(err) {
return nil, nil
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "failed",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
var resp milvusSearchResponse
if err = json.Unmarshal(respBody, &resp); err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "failed",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
if resp.Code != 0 && resp.Code != 200 {
err = fmt.Errorf("milvus search failed: code=%d message=%s", resp.Code, resp.Message)
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "failed",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
result := make([]core.ScoredVectorRow, 0, len(resp.Data))
for _, item := range resp.Data {
row, score := item.toVectorRow()
result = append(result, core.ScoredVectorRow{
Row: row,
Score: score,
})
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "success",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"result_count": len(result),
"latency_ms": time.Since(start).Milliseconds(),
},
})
return result, nil
}
func (s *MilvusStore) Delete(ctx context.Context, ids []string) error {
if err := s.ensureReady(); err != nil {
return err
}
start := time.Now()
if len(ids) == 0 {
return nil
}
filter := fmt.Sprintf(`%s in [%s]`, milvusPrimaryField, joinQuotedStrings(ids))
_, err := s.postJSON(ctx, "/v2/vectordb/entities/delete", map[string]any{
"collectionName": s.cfg.CollectionName,
"filter": filter,
"dbName": blankToNil(s.cfg.DBName),
})
if isMilvusCollectionMissing(err) {
return nil
}
if err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "delete",
Fields: map[string]any{
"status": "failed",
"id_count": len(ids),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return err
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "delete",
Fields: map[string]any{
"status": "success",
"id_count": len(ids),
"latency_ms": time.Since(start).Milliseconds(),
},
})
return err
}
func (s *MilvusStore) Get(ctx context.Context, ids []string) ([]core.VectorRow, error) {
if err := s.ensureReady(); err != nil {
return nil, err
}
start := time.Now()
if len(ids) == 0 {
return nil, nil
}
respBody, err := s.postJSON(ctx, "/v2/vectordb/entities/get", map[string]any{
"collectionName": s.cfg.CollectionName,
"id": ids,
"outputFields": milvusOutputFields(true),
"dbName": blankToNil(s.cfg.DBName),
})
if err != nil {
if isMilvusCollectionMissing(err) {
return nil, nil
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "get",
Fields: map[string]any{
"status": "failed",
"id_count": len(ids),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
var resp milvusGetResponse
if err = json.Unmarshal(respBody, &resp); err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "get",
Fields: map[string]any{
"status": "failed",
"id_count": len(ids),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
if resp.Code != 0 && resp.Code != 200 {
err = fmt.Errorf("milvus get failed: code=%d message=%s", resp.Code, resp.Message)
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "get",
Fields: map[string]any{
"status": "failed",
"id_count": len(ids),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
rows := make([]core.VectorRow, 0, len(resp.Data))
for _, item := range resp.Data {
rows = append(rows, mapMilvusRow(item, true))
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "get",
Fields: map[string]any{
"status": "success",
"id_count": len(ids),
"row_count": len(rows),
"latency_ms": time.Since(start).Milliseconds(),
},
})
return rows, nil
}
func (s *MilvusStore) ensureCollection(ctx context.Context, dimension int) error {
if err := s.ensureReady(); err != nil {
return err
}
start := time.Now()
if dimension <= 0 {
dimension = s.cfg.Dimension
}
if dimension <= 0 {
return errors.New("milvus vector dimension is invalid")
}
s.mu.Lock()
defer s.mu.Unlock()
if s.ensured {
return nil
}
payload := map[string]any{
"collectionName": s.cfg.CollectionName,
"schema": map[string]any{
"autoId": false,
"enabledDynamicField": false,
"fields": []map[string]any{
buildVarcharField(milvusPrimaryField, true, 256),
buildVectorField(milvusVectorField, dimension),
buildVarcharField(milvusTextField, false, 65535),
{"fieldName": milvusMetadataField, "dataType": "JSON"},
buildVarcharField(milvusCorpusField, false, 64),
buildVarcharField(milvusDocumentField, false, 256),
{"fieldName": milvusUserIDField, "dataType": "Int64"},
buildVarcharField(milvusAssistantField, false, 128),
buildVarcharField(milvusConvField, false, 128),
buildVarcharField(milvusRunField, false, 128),
buildVarcharField(milvusMemoryType, false, 64),
buildVarcharField(milvusQueryIDField, false, 128),
buildVarcharField(milvusSessionField, false, 128),
buildVarcharField(milvusDomainField, false, 128),
{"fieldName": milvusChunkOrder, "dataType": "Int64"},
{"fieldName": milvusUpdatedAtField, "dataType": "Int64"},
},
},
"indexParams": []map[string]any{
{
"fieldName": milvusVectorField,
"indexName": milvusVectorField,
"metricType": s.cfg.MetricType,
"indexType": "AUTOINDEX",
},
},
}
if s.cfg.DBName != "" {
payload["dbName"] = s.cfg.DBName
}
_, err := s.postJSON(ctx, "/v2/vectordb/collections/create", payload)
if err != nil {
if isMilvusAlreadyExists(err) {
s.ensured = true
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "ensure_collection",
Fields: map[string]any{
"status": "already_exists",
"vector_dim": dimension,
"metric_type": s.cfg.MetricType,
"latency_ms": time.Since(start).Milliseconds(),
},
})
return nil
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "ensure_collection",
Fields: map[string]any{
"status": "failed",
"vector_dim": dimension,
"metric_type": s.cfg.MetricType,
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return err
}
s.ensured = true
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "ensure_collection",
Fields: map[string]any{
"status": "created",
"vector_dim": dimension,
"metric_type": s.cfg.MetricType,
"latency_ms": time.Since(start).Milliseconds(),
},
})
return nil
}
func (s *MilvusStore) postJSON(ctx context.Context, path string, payload map[string]any) ([]byte, error) {
if err := s.ensureReady(); err != nil {
return nil, err
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.cfg.Address+path, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if token := strings.TrimSpace(s.cfg.Token); token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
resp, err := s.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, readErr
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("milvus http failed: status=%d body=%s", resp.StatusCode, string(respBody))
}
var basic milvusBasicResponse
if jsonErr := json.Unmarshal(respBody, &basic); jsonErr == nil {
if basic.Code != 0 && basic.Code != 200 {
return nil, fmt.Errorf("milvus api failed: code=%d message=%s", basic.Code, basic.Message)
}
}
return respBody, nil
}
func (s *MilvusStore) ensureReady() error {
if s == nil || s.client == nil {
return errors.New("milvus store is not initialized")
}
if strings.TrimSpace(s.cfg.Address) == "" {
return errors.New("milvus address is empty")
}
if strings.TrimSpace(s.cfg.CollectionName) == "" {
return errors.New("milvus collection name is empty")
}
return nil
}
func (s *MilvusStore) observe(ctx context.Context, event core.ObserveEvent) {
if s == nil || s.observer == nil {
return
}
fields := cloneMap(event.Fields)
fields["store"] = "milvus"
fields["collection"] = s.cfg.CollectionName
if strings.TrimSpace(s.cfg.DBName) != "" {
fields["db_name"] = s.cfg.DBName
}
s.observer.Observe(ctx, core.ObserveEvent{
Level: event.Level,
Component: event.Component,
Operation: event.Operation,
Fields: fields,
})
}
func mapRowToMilvusEntity(row core.VectorRow) map[string]any {
metadata := cloneMap(row.Metadata)
entity := map[string]any{
milvusPrimaryField: row.ID,
milvusVectorField: row.Vector,
milvusTextField: row.Text,
milvusMetadataField: metadata,
milvusCorpusField: asString(metadata["corpus"]),
milvusDocumentField: asString(metadata["document_id"]),
milvusUpdatedAtField: func() int64 {
if row.UpdatedAt.IsZero() {
return time.Now().UnixMilli()
}
return row.UpdatedAt.UnixMilli()
}(),
}
assignMilvusScalar(entity, milvusUserIDField, metadata["user_id"])
assignMilvusScalar(entity, milvusAssistantField, metadata["assistant_id"])
assignMilvusScalar(entity, milvusConvField, metadata["conversation_id"])
assignMilvusScalar(entity, milvusRunField, metadata["run_id"])
assignMilvusScalar(entity, milvusMemoryType, metadata["memory_type"])
assignMilvusScalar(entity, milvusQueryIDField, metadata["query_id"])
assignMilvusScalar(entity, milvusSessionField, metadata["session_id"])
assignMilvusScalar(entity, milvusDomainField, metadata["domain"])
assignMilvusScalar(entity, milvusChunkOrder, metadata["chunk_order"])
return entity
}
func assignMilvusScalar(target map[string]any, field string, value any) {
if value == nil {
return
}
switch field {
case milvusUserIDField, milvusChunkOrder:
if parsed, ok := toInt64(value); ok {
target[field] = parsed
}
default:
if text := asString(value); text != "" {
target[field] = text
}
}
}
func buildMilvusFilter(filter map[string]any) (string, error) {
if len(filter) == 0 {
return "", nil
}
parts := make([]string, 0, len(filter))
for key, value := range filter {
field, ok := milvusFilterFieldMap[key]
if !ok {
return "", fmt.Errorf("unsupported milvus filter key: %s", key)
}
switch field {
case milvusUserIDField, milvusChunkOrder:
parsed, parseOK := toInt64(value)
if !parseOK {
return "", fmt.Errorf("milvus filter key=%s expects integer", key)
}
parts = append(parts, fmt.Sprintf("%s == %d", field, parsed))
default:
text := escapeMilvusString(asString(value))
parts = append(parts, fmt.Sprintf(`%s == "%s"`, field, text))
}
}
return strings.Join(parts, " and "), nil
}
func buildVarcharField(name string, isPrimary bool, maxLength int) map[string]any {
field := map[string]any{
"fieldName": name,
"dataType": "VarChar",
"elementTypeParams": map[string]any{"max_length": maxLength},
}
if isPrimary {
field["isPrimary"] = true
}
return field
}
func buildVectorField(name string, dimension int) map[string]any {
return map[string]any{
"fieldName": name,
"dataType": "FloatVector",
"elementTypeParams": map[string]any{"dim": dimension},
}
}
func milvusOutputFields(includeVector bool) []string {
fields := []string{
milvusTextField,
milvusMetadataField,
milvusCorpusField,
milvusDocumentField,
milvusUserIDField,
milvusAssistantField,
milvusConvField,
milvusRunField,
milvusMemoryType,
milvusQueryIDField,
milvusSessionField,
milvusDomainField,
milvusChunkOrder,
milvusUpdatedAtField,
}
if includeVector {
fields = append(fields, milvusVectorField)
}
return fields
}
func normalizeMilvusTopK(topK int) int {
if topK <= 0 {
return 8
}
return topK
}
func blankToNil(v string) any {
v = strings.TrimSpace(v)
if v == "" {
return nil
}
return v
}
func escapeMilvusString(v string) string {
v = strings.ReplaceAll(v, `\`, `\\`)
return strings.ReplaceAll(v, `"`, `\"`)
}
func joinQuotedStrings(values []string) string {
parts := make([]string, 0, len(values))
for _, value := range values {
parts = append(parts, fmt.Sprintf(`"%s"`, escapeMilvusString(value)))
}
return strings.Join(parts, ",")
}
func cloneMap(src map[string]any) map[string]any {
if len(src) == 0 {
return map[string]any{}
}
dst := make(map[string]any, len(src))
for key, value := range src {
dst[key] = value
}
return dst
}
func asString(v any) string {
if v == nil {
return ""
}
return strings.TrimSpace(fmt.Sprintf("%v", v))
}
func toInt64(v any) (int64, bool) {
switch value := v.(type) {
case int:
return int64(value), true
case int32:
return int64(value), true
case int64:
return value, true
case float64:
return int64(value), true
case json.Number:
parsed, err := value.Int64()
return parsed, err == nil
case string:
parsed, err := strconv.ParseInt(strings.TrimSpace(value), 10, 64)
return parsed, err == nil
default:
return 0, false
}
}
func isMilvusAlreadyExists(err error) bool {
if err == nil {
return false
}
text := strings.ToLower(err.Error())
return strings.Contains(text, "already exist") ||
strings.Contains(text, "already exists") ||
strings.Contains(text, "duplicate collection")
}
func isMilvusCollectionMissing(err error) bool {
if err == nil {
return false
}
text := strings.ToLower(err.Error())
return strings.Contains(text, "can't find collection") || strings.Contains(text, "collection not found")
}
type milvusBasicResponse struct {
Code int `json:"code"`
Message string `json:"message"`
}
type milvusSearchResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data []milvusSearchItem `json:"data"`
}
type milvusSearchItem map[string]any
func (m milvusSearchItem) toVectorRow() (core.VectorRow, float64) {
row := mapMilvusRow(map[string]any(m), false)
score := 0.0
if value, ok := m["distance"].(float64); ok {
score = value
}
return row, score
}
type milvusGetResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data []map[string]any `json:"data"`
}
func mapMilvusRow(raw map[string]any, includeVector bool) core.VectorRow {
metadata := cloneMap(readMetadataMap(raw[milvusMetadataField]))
assignMetadataIfPresent(metadata, "corpus", raw[milvusCorpusField])
assignMetadataIfPresent(metadata, "document_id", raw[milvusDocumentField])
assignMetadataIfPresent(metadata, "user_id", raw[milvusUserIDField])
assignMetadataIfPresent(metadata, "assistant_id", raw[milvusAssistantField])
assignMetadataIfPresent(metadata, "conversation_id", raw[milvusConvField])
assignMetadataIfPresent(metadata, "run_id", raw[milvusRunField])
assignMetadataIfPresent(metadata, "memory_type", raw[milvusMemoryType])
assignMetadataIfPresent(metadata, "query_id", raw[milvusQueryIDField])
assignMetadataIfPresent(metadata, "session_id", raw[milvusSessionField])
assignMetadataIfPresent(metadata, "domain", raw[milvusDomainField])
assignMetadataIfPresent(metadata, "chunk_order", raw[milvusChunkOrder])
row := core.VectorRow{
ID: asString(raw[milvusPrimaryField]),
Text: asString(raw[milvusTextField]),
Metadata: metadata,
}
if row.ID == "" {
row.ID = asString(raw["id"])
}
if includeVector {
row.Vector = readFloat32Vector(raw[milvusVectorField])
}
return row
}
func readMetadataMap(value any) map[string]any {
switch data := value.(type) {
case map[string]any:
return data
default:
return map[string]any{}
}
}
func readFloat32Vector(value any) []float32 {
switch vector := value.(type) {
case []float32:
return vector
case []any:
result := make([]float32, 0, len(vector))
for _, item := range vector {
switch number := item.(type) {
case float64:
result = append(result, float32(number))
case float32:
result = append(result, number)
}
}
return result
default:
return nil
}
}
func assignMetadataIfPresent(target map[string]any, key string, value any) {
if value == nil {
return
}
switch typed := value.(type) {
case string:
if strings.TrimSpace(typed) == "" {
return
}
target[key] = strings.TrimSpace(typed)
default:
target[key] = typed
}
}

View File

@@ -1,9 +0,0 @@
package store
import "github.com/LoveLosita/smartflow/backend/infra/rag/core"
// EnsureCompile 用于静态校验实现是否满足接口。
func EnsureCompile() {
var _ core.VectorStore = (*InMemoryVectorStore)(nil)
var _ core.VectorStore = (*MilvusStore)(nil)
}