Version: 0.9.65.dev.260503

后端:
1. 阶段 1.5/1.6
收口 llm-service / rag-service,统一模型出口与检索基础设施入口,清退 backend/infra/llm 与 backend/infra/rag 旧实现;
2. 同步更新相关调用链与微服务迁移计划文档
This commit is contained in:
Losita
2026-05-03 23:21:03 +08:00
parent a6c1e5d077
commit 9902ca3563
65 changed files with 550 additions and 376 deletions

View File

@@ -0,0 +1,73 @@
package llm
import (
"context"
"errors"
"strings"
"github.com/cloudwego/eino-ext/components/model/ark"
einoModel "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
)
// ArkCallOptions 是直接调用 ark.ChatModel 时使用的通用入参。
type ArkCallOptions struct {
Temperature float64
MaxTokens int
Thinking ThinkingMode
}
// CallArkText 调用 ark 模型并返回纯文本。
func CallArkText(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, options ArkCallOptions) (string, error) {
if chatModel == nil {
return "", errors.New("ark model is nil")
}
messages := []*schema.Message{
schema.SystemMessage(systemPrompt),
schema.UserMessage(userPrompt),
}
resp, err := chatModel.Generate(ctx, messages, buildArkOptions(options)...)
if err != nil {
return "", err
}
if resp == nil {
return "", errors.New("模型返回为空")
}
text := strings.TrimSpace(resp.Content)
if text == "" {
return "", errors.New("模型返回内容为空")
}
return text, nil
}
// CallArkJSON 调用 ark 模型并直接解析 JSON。
func CallArkJSON[T any](ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, options ArkCallOptions) (*T, string, error) {
raw, err := CallArkText(ctx, chatModel, systemPrompt, userPrompt, options)
if err != nil {
return nil, "", err
}
parsed, err := ParseJSONObject[T](raw)
if err != nil {
return nil, raw, err
}
return parsed, raw, nil
}
func buildArkOptions(options ArkCallOptions) []einoModel.Option {
thinkingType := arkModel.ThinkingTypeDisabled
if options.Thinking == ThinkingModeEnabled {
thinkingType = arkModel.ThinkingTypeEnabled
}
opts := []einoModel.Option{
ark.WithThinking(&arkModel.Thinking{Type: thinkingType}),
einoModel.WithTemperature(float32(options.Temperature)),
}
if options.MaxTokens > 0 {
opts = append(opts, einoModel.WithMaxTokens(options.MaxTokens))
}
return opts
}

View File

@@ -0,0 +1,111 @@
package llm
import (
"context"
"errors"
"io"
"github.com/cloudwego/eino-ext/components/model/ark"
einoModel "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
)
// WrapArkClient 将 ark.ChatModel 适配为统一 Client。
// 1. generateText 走 Generate供 GenerateJSON/GenerateText 使用。
// 2. streamText 走 Stream供需要流式输出的场景使用。
// 3. 两条路径共用同一套参数转换逻辑。
func WrapArkClient(arkChatModel *ark.ChatModel) *Client {
if arkChatModel == nil {
return nil
}
generateFunc := func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error) {
arkOpts := buildArkStreamOptions(options)
msg, err := arkChatModel.Generate(ctx, messages, arkOpts...)
if err != nil {
return nil, err
}
if msg == nil {
return nil, errors.New("ark model returned nil message")
}
var usage *schema.TokenUsage
finishReason := ""
if msg.ResponseMeta != nil {
usage = CloneUsage(msg.ResponseMeta.Usage)
finishReason = msg.ResponseMeta.FinishReason
}
return &TextResult{
Text: msg.Content,
Usage: usage,
FinishReason: finishReason,
}, nil
}
streamFunc := func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error) {
arkOpts := buildArkStreamOptions(options)
reader, err := arkChatModel.Stream(ctx, messages, arkOpts...)
if err != nil {
return nil, err
}
return &arkStreamReaderAdapter{reader: reader}, nil
}
return NewClient(generateFunc, streamFunc)
}
// buildArkStreamOptions 将统一的 GenerateOptions 转换为 ark 的流式调用参数。
func buildArkStreamOptions(options GenerateOptions) []einoModel.Option {
thinkingEnabled := options.Thinking == ThinkingModeEnabled
thinkingType := arkModel.ThinkingTypeDisabled
if thinkingEnabled {
thinkingType = arkModel.ThinkingTypeEnabled
}
opts := []einoModel.Option{
ark.WithThinking(&arkModel.Thinking{Type: thinkingType}),
}
if thinkingEnabled {
opts = append(opts, einoModel.WithTemperature(1.0))
} else if options.Temperature > 0 {
opts = append(opts, einoModel.WithTemperature(float32(options.Temperature)))
}
maxTokens := options.MaxTokens
if thinkingEnabled {
const minThinkingBudget = 16000
if maxTokens < minThinkingBudget {
maxTokens = minThinkingBudget
}
}
if maxTokens > 0 {
opts = append(opts, einoModel.WithMaxTokens(maxTokens))
}
return opts
}
// arkStreamReaderAdapter 把 ark 的流式 reader 转成统一的 StreamReader 接口。
type arkStreamReaderAdapter struct {
reader *schema.StreamReader[*schema.Message]
}
// Recv 转发到底层 reader。
func (r *arkStreamReaderAdapter) Recv() (*schema.Message, error) {
if r == nil || r.reader == nil {
return nil, io.EOF
}
return r.reader.Recv()
}
// Close 适配 ark reader 的 Close 行为。
func (r *arkStreamReaderAdapter) Close() error {
if r == nil || r.reader == nil {
return nil
}
r.reader.Close()
return nil
}

View File

@@ -0,0 +1,330 @@
package llm
import (
"context"
"errors"
"fmt"
"strings"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses"
)
// ArkResponsesMessage 描述一次 Responses 输入消息。
type ArkResponsesMessage struct {
Role string
Text string
ImageURL string
ImageDetail string
}
// ArkResponsesOptions 描述 Responses 调用参数。
type ArkResponsesOptions struct {
Model string
Temperature float64
MaxOutputTokens int
Thinking ThinkingMode
TextFormat string
}
// ArkResponsesUsage 统一转写 token usage。
type ArkResponsesUsage struct {
InputTokens int64
OutputTokens int64
TotalTokens int64
}
// ArkResponsesResult 是 Responses 调用的统一输出结构。
type ArkResponsesResult struct {
Text string
Status string
IncompleteReason string
ErrorCode string
ErrorMessage string
Usage *ArkResponsesUsage
}
// ArkResponsesClient 是 Ark SDK Responses 的统一模型出口。
type ArkResponsesClient struct {
model string
client *arkruntime.Client
}
// NewArkResponsesClient 创建 Ark SDK Responses 客户端。
// 1. model 为空时直接返回 nil表示这条能力没有启用。
// 2. baseURL 为空时使用 SDK 默认地址。
// 3. 这里只负责本地构造,不做连通性探测。
func NewArkResponsesClient(apiKey string, baseURL string, model string) *ArkResponsesClient {
model = strings.TrimSpace(model)
if model == "" {
return nil
}
options := make([]arkruntime.ConfigOption, 0, 1)
if strings.TrimSpace(baseURL) != "" {
options = append(options, arkruntime.WithBaseUrl(strings.TrimSpace(baseURL)))
}
return &ArkResponsesClient{
model: model,
client: arkruntime.NewClientWithApiKey(strings.TrimSpace(apiKey), options...),
}
}
// GenerateText 执行一次非流式 Responses 调用并提取文本。
func (c *ArkResponsesClient) GenerateText(ctx context.Context, messages []ArkResponsesMessage, options ArkResponsesOptions) (*ArkResponsesResult, error) {
req, err := c.buildRequest(messages, options)
if err != nil {
return nil, err
}
resp, err := c.client.CreateResponses(ctx, req)
if err != nil {
return nil, err
}
result := buildArkResponsesResult(resp)
if result.Status == "failed" {
if result.ErrorMessage != "" {
return result, fmt.Errorf("ark responses failed: %s", result.ErrorMessage)
}
return result, errors.New("ark responses failed")
}
if strings.TrimSpace(result.Text) == "" {
return result, FormatEmptyResponseError("ark_responses")
}
return result, nil
}
// GenerateArkResponsesJSON 先调用 Responses再解析成 JSON 结构体。
func GenerateArkResponsesJSON[T any](ctx context.Context, client *ArkResponsesClient, messages []ArkResponsesMessage, options ArkResponsesOptions) (*T, *ArkResponsesResult, error) {
if client == nil {
return nil, nil, errors.New("ark responses client is not ready")
}
result, err := client.GenerateText(ctx, messages, options)
if err != nil {
return nil, result, err
}
parsed, err := ParseJSONObject[T](result.Text)
if err != nil {
return nil, result, err
}
return parsed, result, nil
}
func (c *ArkResponsesClient) buildRequest(messages []ArkResponsesMessage, options ArkResponsesOptions) (*responses.ResponsesRequest, error) {
if c == nil || c.client == nil {
return nil, errors.New("ark responses client is not ready")
}
if len(messages) == 0 {
return nil, errors.New("ark responses messages is empty")
}
modelName := strings.TrimSpace(options.Model)
if modelName == "" {
modelName = c.model
}
if modelName == "" {
return nil, errors.New("ark responses model is empty")
}
inputItems := make([]*responses.InputItem, 0, len(messages))
for idx := range messages {
item, err := buildInputItem(messages[idx])
if err != nil {
return nil, fmt.Errorf("build ark responses message[%d] failed: %w", idx, err)
}
inputItems = append(inputItems, item)
}
request := &responses.ResponsesRequest{
Model: modelName,
Input: &responses.ResponsesInput{
Union: &responses.ResponsesInput_ListValue{
ListValue: &responses.InputItemList{ListValue: inputItems},
},
},
}
if options.Temperature > 0 {
request.Temperature = float64Ptr(options.Temperature)
}
if options.MaxOutputTokens > 0 {
request.MaxOutputTokens = int64Ptr(int64(options.MaxOutputTokens))
}
switch options.Thinking {
case ThinkingModeEnabled:
thinkingType := responses.ThinkingType_enabled
request.Thinking = &responses.ResponsesThinking{Type: &thinkingType}
case ThinkingModeDisabled:
thinkingType := responses.ThinkingType_disabled
request.Thinking = &responses.ResponsesThinking{Type: &thinkingType}
}
if textType, ok := parseTextType(options.TextFormat); ok {
request.Text = &responses.ResponsesText{
Format: &responses.TextFormat{
Type: textType,
},
}
}
return request, nil
}
func buildInputItem(message ArkResponsesMessage) (*responses.InputItem, error) {
role, ok := parseMessageRole(message.Role)
if !ok {
return nil, fmt.Errorf("unsupported message role: %s", strings.TrimSpace(message.Role))
}
content := make([]*responses.ContentItem, 0, 2)
if text := strings.TrimSpace(message.Text); text != "" {
content = append(content, &responses.ContentItem{
Union: &responses.ContentItem_Text{
Text: &responses.ContentItemText{
Type: responses.ContentItemType_input_text,
Text: text,
},
},
})
}
if imageURL := strings.TrimSpace(message.ImageURL); imageURL != "" {
image := &responses.ContentItemImage{
Type: responses.ContentItemType_input_image,
ImageUrl: stringPtr(imageURL),
}
if detail, ok := parseImageDetail(message.ImageDetail); ok {
image.Detail = &detail
}
content = append(content, &responses.ContentItem{
Union: &responses.ContentItem_Image{
Image: image,
},
})
}
if len(content) == 0 {
return nil, errors.New("message content is empty")
}
return &responses.InputItem{
Union: &responses.InputItem_InputMessage{
InputMessage: &responses.ItemInputMessage{
Role: role,
Content: content,
},
},
}, nil
}
func buildArkResponsesResult(resp *responses.ResponseObject) *ArkResponsesResult {
if resp == nil {
return &ArkResponsesResult{}
}
result := &ArkResponsesResult{
Text: extractArkResponsesText(resp),
Status: strings.TrimSpace(resp.GetStatus().String()),
}
if details := resp.GetIncompleteDetails(); details != nil {
result.IncompleteReason = strings.TrimSpace(details.GetReason())
}
if responseErr := resp.GetError(); responseErr != nil {
result.ErrorCode = strings.TrimSpace(responseErr.GetCode())
result.ErrorMessage = strings.TrimSpace(responseErr.GetMessage())
}
if usage := resp.GetUsage(); usage != nil {
result.Usage = &ArkResponsesUsage{
InputTokens: usage.GetInputTokens(),
OutputTokens: usage.GetOutputTokens(),
TotalTokens: usage.GetTotalTokens(),
}
}
return result
}
func extractArkResponsesText(resp *responses.ResponseObject) string {
if resp == nil {
return ""
}
textParts := make([]string, 0, 2)
for _, outputItem := range resp.GetOutput() {
outputMessage := outputItem.GetOutputMessage()
if outputMessage == nil {
continue
}
for _, contentItem := range outputMessage.GetContent() {
text := strings.TrimSpace(contentItem.GetText().GetText())
if text == "" {
continue
}
textParts = append(textParts, text)
}
}
return strings.TrimSpace(strings.Join(textParts, "\n"))
}
func parseMessageRole(raw string) (responses.MessageRole_Enum, bool) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "user":
return responses.MessageRole_user, true
case "system":
return responses.MessageRole_system, true
case "developer":
return responses.MessageRole_developer, true
case "assistant":
return responses.MessageRole_assistant, true
default:
return responses.MessageRole_unspecified, false
}
}
func parseImageDetail(raw string) (responses.ContentItemImageDetail_Enum, bool) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "high":
return responses.ContentItemImageDetail_high, true
case "low":
return responses.ContentItemImageDetail_low, true
case "auto":
return responses.ContentItemImageDetail_auto, true
default:
return responses.ContentItemImageDetail_auto, false
}
}
func parseTextType(raw string) (responses.TextType_Enum, bool) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "":
return responses.TextType_unspecified, false
case "text":
return responses.TextType_text, true
case "json_object":
return responses.TextType_json_object, true
default:
return responses.TextType_unspecified, false
}
}
func stringPtr(value string) *string {
return &value
}
func float64Ptr(value float64) *float64 {
return &value
}
func int64Ptr(value int64) *int64 {
return &value
}

View File

@@ -0,0 +1,174 @@
package llm
import (
"context"
"errors"
"fmt"
"strings"
"github.com/cloudwego/eino/schema"
)
// ThinkingMode 描述这次模型调用对 thinking 的期望。
type ThinkingMode string
const (
ThinkingModeDefault ThinkingMode = "default"
ThinkingModeEnabled ThinkingMode = "enabled"
ThinkingModeDisabled ThinkingMode = "disabled"
)
// GenerateOptions 统一收敛文本调用时最常见的公共参数。
type GenerateOptions struct {
Temperature float64
MaxTokens int
Thinking ThinkingMode
Metadata map[string]any
}
// TextResult 保存一次文本生成的最终结果和 usage。
// 1. Text 存放模型返回的纯文本。
// 2. Usage 方便上层做统一统计。
// 3. 这里不负责 JSON 解析,也不负责业务字段映射。
type TextResult struct {
Text string
Usage *schema.TokenUsage
FinishReason string
}
// StreamReader 抽象可以逐块读取消息的流式返回器。
type StreamReader interface {
Recv() (*schema.Message, error)
Close() error
}
// TextGenerateFunc 定义统一文本生成函数签名。
type TextGenerateFunc func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error)
// StreamGenerateFunc 定义统一流式生成函数签名。
type StreamGenerateFunc func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error)
// Client 是统一模型客户端门面。
// 1. 只做最小输入校验和空响应防御。
// 2. 不负责 prompt 拼装,也不负责业务 fallback。
// 3. 具体 provider 的细节由上层适配器收敛进来。
type Client struct {
generateText TextGenerateFunc
streamText StreamGenerateFunc
}
// NewClient 创建统一模型客户端。
func NewClient(generateText TextGenerateFunc, streamText StreamGenerateFunc) *Client {
return &Client{
generateText: generateText,
streamText: streamText,
}
}
// GenerateText 执行一次统一文本生成。
func (c *Client) GenerateText(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error) {
if c == nil || c.generateText == nil {
return nil, errors.New("llm client is not ready")
}
if len(messages) == 0 {
return nil, errors.New("llm messages is empty")
}
result, err := c.generateText(ctx, messages, options)
if err != nil {
return nil, err
}
if result == nil {
return nil, errors.New("llm result is nil")
}
if strings.TrimSpace(result.Text) == "" {
return nil, errors.New("llm returned empty text")
}
return result, nil
}
// GenerateJSON 先走统一文本生成,再走统一 JSON 解析。
func GenerateJSON[T any](ctx context.Context, client *Client, messages []*schema.Message, options GenerateOptions) (*T, *TextResult, error) {
result, err := client.GenerateText(ctx, messages, options)
if err != nil {
return nil, nil, err
}
parsed, err := ParseJSONObject[T](result.Text)
if err != nil {
return nil, result, err
}
return parsed, result, nil
}
// Stream 打开统一流式调用入口。
func (c *Client) Stream(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error) {
if c == nil || c.streamText == nil {
return nil, errors.New("llm stream client is not ready")
}
if len(messages) == 0 {
return nil, errors.New("llm messages is empty")
}
return c.streamText(ctx, messages, options)
}
// BuildSystemUserMessages 构造最常见的 system + history + user 消息列表。
func BuildSystemUserMessages(systemPrompt string, history []*schema.Message, userPrompt string) []*schema.Message {
messages := make([]*schema.Message, 0, len(history)+2)
if strings.TrimSpace(systemPrompt) != "" {
messages = append(messages, schema.SystemMessage(systemPrompt))
}
if len(history) > 0 {
messages = append(messages, history...)
}
if strings.TrimSpace(userPrompt) != "" {
messages = append(messages, schema.UserMessage(userPrompt))
}
return messages
}
// CloneUsage 深拷贝 token usage避免后续累加时共享同一个指针。
func CloneUsage(usage *schema.TokenUsage) *schema.TokenUsage {
if usage == nil {
return nil
}
copied := *usage
return &copied
}
// MergeUsage 合并两段 usage取各字段更大的值作为累计结果。
func MergeUsage(base *schema.TokenUsage, incoming *schema.TokenUsage) *schema.TokenUsage {
if incoming == nil {
return CloneUsage(base)
}
if base == nil {
return CloneUsage(incoming)
}
merged := *base
if incoming.PromptTokens > merged.PromptTokens {
merged.PromptTokens = incoming.PromptTokens
}
if incoming.CompletionTokens > merged.CompletionTokens {
merged.CompletionTokens = incoming.CompletionTokens
}
if incoming.TotalTokens > merged.TotalTokens {
merged.TotalTokens = incoming.TotalTokens
}
if incoming.PromptTokenDetails.CachedTokens > merged.PromptTokenDetails.CachedTokens {
merged.PromptTokenDetails.CachedTokens = incoming.PromptTokenDetails.CachedTokens
}
if incoming.CompletionTokensDetails.ReasoningTokens > merged.CompletionTokensDetails.ReasoningTokens {
merged.CompletionTokensDetails.ReasoningTokens = incoming.CompletionTokensDetails.ReasoningTokens
}
return &merged
}
// FormatEmptyResponseError 统一模型空结果的错误文案。
func FormatEmptyResponseError(scene string) error {
scene = strings.TrimSpace(scene)
if scene == "" {
scene = "unknown"
}
return fmt.Errorf("模型在 %s 场景返回空结果", scene)
}

View File

@@ -0,0 +1,102 @@
package llm
import (
"encoding/json"
"errors"
"fmt"
"strings"
)
// ParseJSONObject 解析模型返回内容中的 JSON 对象。
// 1. 先剥离常见的 markdown 代码块包装。
// 2. 再从混合文本里提取最外层 JSON 对象。
// 3. 这里只负责结构解析,不负责字段合法性校验。
func ParseJSONObject[T any](raw string) (*T, error) {
clean := strings.TrimSpace(raw)
if clean == "" {
return nil, errors.New("模型返回为空,无法解析 JSON")
}
objectText := ExtractJSONObject(clean)
if objectText == "" {
return nil, fmt.Errorf("模型返回中未找到 JSON 对象: %s", truncateForError(clean))
}
var out T
if err := json.Unmarshal([]byte(objectText), &out); err != nil {
return nil, fmt.Errorf("JSON 解析失败: %w", err)
}
return &out, nil
}
// ExtractJSONObject 从混合文本中提取第一个完整的 JSON 对象。
func ExtractJSONObject(text string) string {
clean := trimMarkdownCodeFence(strings.TrimSpace(text))
if clean == "" {
return ""
}
start := strings.Index(clean, "{")
if start < 0 {
return ""
}
depth := 0
inString := false
escaped := false
for idx := start; idx < len(clean); idx++ {
ch := clean[idx]
if escaped {
escaped = false
continue
}
if ch == '\\' && inString {
escaped = true
continue
}
if ch == '"' {
inString = !inString
continue
}
if inString {
continue
}
switch ch {
case '{':
depth++
case '}':
depth--
if depth == 0 {
return clean[start : idx+1]
}
}
}
return ""
}
func trimMarkdownCodeFence(text string) string {
trimmed := strings.TrimSpace(text)
if !strings.HasPrefix(trimmed, "```") {
return trimmed
}
lines := strings.Split(trimmed, "\n")
if len(lines) == 0 {
return trimmed
}
body := lines[1:]
if len(body) > 0 && strings.TrimSpace(body[len(body)-1]) == "```" {
body = body[:len(body)-1]
}
return strings.TrimSpace(strings.Join(body, "\n"))
}
func truncateForError(text string) string {
if len(text) <= 160 {
return text
}
return text[:160] + "..."
}

View File

@@ -0,0 +1,109 @@
package llm
import (
"strings"
"github.com/LoveLosita/smartflow/backend/inits"
)
// Service 只负责统一暴露已经构造好的模型客户端,不负责 prompt 和业务编排。
type Service struct {
liteClient *Client
proClient *Client
maxClient *Client
courseImageResponsesClient *ArkResponsesClient
}
// Options 描述 llm-service 初始化时需要接管的启动期依赖。
// 1. AIHub 仍然是当前进程内 Ark ChatModel 的来源,但服务层只保存统一 Client。
// 2. CourseImageResponsesClient 允许外部预先注入,便于测试或特殊启动路径复用。
// 3. 某个字段为空时不报错,直接保留 nil交给上层继续走兼容降级。
type Options struct {
AIHub *inits.AIHub
APIKey string
BaseURL string
CourseVisionModel string
CourseImageResponsesClient *ArkResponsesClient
}
// AgentModelClients 一次性暴露 newAgent 图常用的模型分配结果。
type AgentModelClients struct {
Chat *Client
Plan *Client
Execute *Client
Deliver *Client
Summary *Client
}
// New 构造 llm-service。
// 1. 不返回 error是为了让上层继续按 nil 客户端做逐步降级。
// 2. 只要 AIHub 已初始化,就把其中的 ChatModel 收敛成统一 Client。
// 3. 课程图片解析客户端在这里统一构建,避免业务层直接依赖 Responses SDK。
func New(opts Options) *Service {
svc := &Service{}
if opts.AIHub != nil {
svc.liteClient = WrapArkClient(opts.AIHub.Lite)
svc.proClient = WrapArkClient(opts.AIHub.Pro)
svc.maxClient = WrapArkClient(opts.AIHub.Max)
}
if opts.CourseImageResponsesClient != nil {
svc.courseImageResponsesClient = opts.CourseImageResponsesClient
} else {
apiKey := strings.TrimSpace(opts.APIKey)
baseURL := strings.TrimSpace(opts.BaseURL)
model := strings.TrimSpace(opts.CourseVisionModel)
if apiKey != "" && model != "" {
svc.courseImageResponsesClient = NewArkResponsesClient(apiKey, baseURL, model)
}
}
return svc
}
// LiteClient 返回低成本短输出模型客户端。
func (s *Service) LiteClient() *Client {
if s == nil {
return nil
}
return s.liteClient
}
// ProClient 返回默认复杂对话模型客户端。
func (s *Service) ProClient() *Client {
if s == nil {
return nil
}
return s.proClient
}
// MaxClient 返回深度推理模型客户端。
func (s *Service) MaxClient() *Client {
if s == nil {
return nil
}
return s.maxClient
}
// CourseImageResponsesClient 返回课程图片解析所用的 Responses 客户端。
func (s *Service) CourseImageResponsesClient() *ArkResponsesClient {
if s == nil {
return nil
}
return s.courseImageResponsesClient
}
// NewAgentModelClients 一次性返回 newAgent 图里常用的模型分配。
func (s *Service) NewAgentModelClients() AgentModelClients {
if s == nil {
return AgentModelClients{}
}
return AgentModelClients{
Chat: s.ProClient(),
Plan: s.MaxClient(),
Execute: s.MaxClient(),
Deliver: s.ProClient(),
Summary: s.LiteClient(),
}
}

118
backend/services/rag/api.go Normal file
View File

@@ -0,0 +1,118 @@
package rag
import (
"context"
"time"
)
// Runtime 是 RAG service 对业务侧暴露的唯一稳定方法面。
//
// 职责边界:
// 1. 负责承接 memory/web 两类语料的统一入库与检索入口;
// 2. 负责屏蔽底层 Pipeline / Store / Embedder / Reranker 的装配细节;
// 3. 不负责 provider 搜索、HTML 抓取、prompt 注入等业务语义。
type Runtime interface {
IngestMemory(ctx context.Context, req MemoryIngestRequest) (*IngestResult, error)
RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (*RetrieveResult, error)
DeleteMemory(ctx context.Context, documentIDs []string) error
IngestWeb(ctx context.Context, req WebIngestRequest) (*IngestResult, error)
RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (*RetrieveResult, error)
}
// IngestResult 描述一次统一入库执行摘要。
type IngestResult struct {
DocumentCount int
ChunkCount int
DocumentIDs []string
}
// RetrieveHit 是对业务侧暴露的统一命中项。
type RetrieveHit struct {
ChunkID string
DocumentID string
Text string
Score float64
Metadata map[string]any
}
// RetrieveResult 描述一次检索执行摘要。
type RetrieveResult struct {
Items []RetrieveHit
RawCount int
FallbackUsed bool
FallbackReason string
}
// MemoryIngestItem 是 memory 语料入库项。
type MemoryIngestItem struct {
MemoryID int64
UserID int
ConversationID string
AssistantID string
RunID string
MemoryType string
Title string
Content string
Confidence float64
Importance float64
SensitivityLevel int
IsExplicit bool
Status string
TTLAt *time.Time
CreatedAt *time.Time
}
// MemoryIngestRequest 描述一次记忆向量入库请求。
type MemoryIngestRequest struct {
TraceID string
Action string
Items []MemoryIngestItem
}
// MemoryRetrieveRequest 描述一次记忆检索请求。
type MemoryRetrieveRequest struct {
TraceID string
Query string
TopK int
Threshold float64
Action string
UserID int
ConversationID string
AssistantID string
RunID string
MemoryTypes []string
}
// WebIngestItem 是网页语料入库项。
type WebIngestItem struct {
URL string
Title string
Content string
Snippet string
Domain string
QueryID string
SessionID string
PublishedAt *time.Time
FetchedAt *time.Time
SourceRank int
}
// WebIngestRequest 描述一次网页语料入库请求。
type WebIngestRequest struct {
TraceID string
Action string
Items []WebIngestItem
}
// WebRetrieveRequest 描述一次网页检索请求。
type WebRetrieveRequest struct {
TraceID string
Query string
TopK int
Threshold float64
Action string
QueryID string
SessionID string
Domain string
}

View File

@@ -0,0 +1,85 @@
package chunk
import (
"context"
"fmt"
"strings"
"github.com/LoveLosita/smartflow/backend/services/rag/core"
)
// TextChunker 是默认文本切块器。
type TextChunker struct{}
func NewTextChunker() *TextChunker {
return &TextChunker{}
}
// Chunk 对文本执行固定窗口切块。
//
// 步骤化说明:
// 1. 先做空白归一,避免无效块进入向量库;
// 2. 再按 chunk_size/overlap 滑窗切割;
// 3. 每块继承原文 metadata并补充 chunk 序号。
func (c *TextChunker) Chunk(_ context.Context, doc core.SourceDocument, opt core.ChunkOption) ([]core.Chunk, error) {
if strings.TrimSpace(doc.ID) == "" {
return nil, fmt.Errorf("empty document id")
}
text := strings.TrimSpace(doc.Text)
if text == "" {
return nil, nil
}
if opt.ChunkSize <= 0 {
opt.ChunkSize = 400
}
if opt.ChunkOverlap < 0 {
opt.ChunkOverlap = 0
}
if opt.ChunkOverlap >= opt.ChunkSize {
opt.ChunkOverlap = opt.ChunkSize / 5
}
runes := []rune(text)
step := opt.ChunkSize - opt.ChunkOverlap
if step <= 0 {
step = opt.ChunkSize
}
result := make([]core.Chunk, 0, len(runes)/step+1)
order := 0
for start := 0; start < len(runes); start += step {
end := start + opt.ChunkSize
if end > len(runes) {
end = len(runes)
}
chunkText := strings.TrimSpace(string(runes[start:end]))
if chunkText == "" {
continue
}
metadata := cloneMap(doc.Metadata)
metadata["chunk_order"] = order
result = append(result, core.Chunk{
ID: fmt.Sprintf("%s#%d", doc.ID, order),
DocumentID: doc.ID,
Text: chunkText,
Order: order,
Metadata: metadata,
})
order++
if end == len(runes) {
break
}
}
return result, nil
}
func cloneMap(src map[string]any) map[string]any {
if len(src) == 0 {
return map[string]any{}
}
dst := make(map[string]any, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}

View File

@@ -0,0 +1,113 @@
package config
import "github.com/spf13/viper"
// Config 是 RAG Core 运行配置。
type Config struct {
Enabled bool
Store string
TopK int
Threshold float64
EmbedProvider string
EmbedModel string
EmbedBaseURL string
EmbedTimeoutMS int
EmbedDimension int
RerankerEnabled bool
RerankerProvider string
RerankerTimeoutMS int
ChunkSize int
ChunkOverlap int
RetrieveTimeoutMS int
MilvusAddress string
MilvusToken string
MilvusDBName string
MilvusCollectionName string
MilvusMetricType string
MilvusRequestTimeoutMS int
}
// LoadFromViper 读取 rag 配置并补默认值。
func LoadFromViper() Config {
cfg := Config{
Enabled: viper.GetBool("rag.enabled"),
Store: viper.GetString("rag.store"),
TopK: viper.GetInt("rag.topK"),
Threshold: viper.GetFloat64("rag.threshold"),
EmbedProvider: viper.GetString("rag.embed.provider"),
EmbedModel: viper.GetString("rag.embed.model"),
EmbedBaseURL: viper.GetString("rag.embed.baseURL"),
EmbedTimeoutMS: viper.GetInt("rag.embed.timeoutMs"),
EmbedDimension: viper.GetInt("rag.embed.dimension"),
RerankerEnabled: viper.GetBool("rag.reranker.enabled"),
RerankerProvider: viper.GetString("rag.reranker.provider"),
RerankerTimeoutMS: viper.GetInt("rag.reranker.timeoutMs"),
ChunkSize: viper.GetInt("rag.ingest.chunkSize"),
ChunkOverlap: viper.GetInt("rag.ingest.chunkOverlap"),
RetrieveTimeoutMS: viper.GetInt("rag.retrieve.timeoutMs"),
MilvusAddress: viper.GetString("rag.milvus.address"),
MilvusToken: viper.GetString("rag.milvus.token"),
MilvusDBName: viper.GetString("rag.milvus.dbName"),
MilvusCollectionName: viper.GetString("rag.milvus.collectionName"),
MilvusMetricType: viper.GetString("rag.milvus.metricType"),
MilvusRequestTimeoutMS: viper.GetInt("rag.milvus.requestTimeoutMs"),
}
if cfg.Store == "" {
cfg.Store = "inmemory"
}
if cfg.TopK <= 0 {
cfg.TopK = 8
}
if cfg.Threshold < 0 {
cfg.Threshold = 0
}
if cfg.EmbedProvider == "" {
cfg.EmbedProvider = "mock"
}
if cfg.EmbedBaseURL == "" {
cfg.EmbedBaseURL = viper.GetString("agent.baseURL")
}
if cfg.EmbedTimeoutMS <= 0 {
cfg.EmbedTimeoutMS = 1200
}
if cfg.EmbedDimension <= 0 {
cfg.EmbedDimension = 1024
}
if cfg.RerankerProvider == "" {
cfg.RerankerProvider = "noop"
}
if cfg.RerankerTimeoutMS <= 0 {
cfg.RerankerTimeoutMS = 1200
}
if cfg.ChunkSize <= 0 {
cfg.ChunkSize = 400
}
if cfg.ChunkOverlap < 0 {
cfg.ChunkOverlap = 80
}
if cfg.RetrieveTimeoutMS <= 0 {
cfg.RetrieveTimeoutMS = 1500
}
if cfg.MilvusAddress == "" {
cfg.MilvusAddress = "http://localhost:19530"
}
if cfg.MilvusToken == "" {
cfg.MilvusToken = "root:Milvus"
}
if cfg.MilvusCollectionName == "" {
cfg.MilvusCollectionName = "smartflow_rag_chunks"
}
if cfg.MilvusMetricType == "" {
cfg.MilvusMetricType = "COSINE"
}
if cfg.MilvusRequestTimeoutMS <= 0 {
cfg.MilvusRequestTimeoutMS = 1500
}
return cfg
}

View File

@@ -0,0 +1,17 @@
package core
import "errors"
var (
// ErrInvalidQuery 表示检索请求缺少有效 query。
ErrInvalidQuery = errors.New("invalid query")
// ErrInvalidTopK 表示 topK 非法。
ErrInvalidTopK = errors.New("invalid top_k")
// ErrNilDependency 表示 pipeline 关键依赖未注入。
ErrNilDependency = errors.New("nil dependency")
)
const (
// FallbackReasonRerankFailed 表示 rerank 失败后降级。
FallbackReasonRerankFailed = "RERANK_FAILED"
)

View File

@@ -0,0 +1,38 @@
package core
import "context"
// Chunker 负责文本切块。
type Chunker interface {
Chunk(ctx context.Context, doc SourceDocument, opt ChunkOption) ([]Chunk, error)
}
// Embedder 负责向量化。
type Embedder interface {
Embed(ctx context.Context, texts []string, action string) ([][]float32, error)
}
// Retriever 负责召回候选。
type Retriever interface {
Retrieve(ctx context.Context, req RetrieveRequest) ([]ScoredChunk, error)
}
// Reranker 负责重排候选。
type Reranker interface {
Rerank(ctx context.Context, query string, candidates []ScoredChunk, topK int) ([]ScoredChunk, error)
}
// VectorStore 负责向量库读写。
type VectorStore interface {
Upsert(ctx context.Context, rows []VectorRow) error
Search(ctx context.Context, req VectorSearchRequest) ([]ScoredVectorRow, error)
Delete(ctx context.Context, ids []string) error
Get(ctx context.Context, ids []string) ([]VectorRow, error)
}
// CorpusAdapter 负责把业务语料映射成统一文档/过滤条件。
type CorpusAdapter interface {
Name() string
BuildIngestDocuments(ctx context.Context, input any) ([]SourceDocument, error)
BuildRetrieveFilter(ctx context.Context, req any) (map[string]any, error)
}

View File

@@ -0,0 +1,190 @@
package core
import (
"context"
"errors"
"fmt"
"log"
"sort"
"strings"
)
// ObserveLevel 表示观测事件等级。
type ObserveLevel string
const (
ObserveLevelInfo ObserveLevel = "info"
ObserveLevelWarn ObserveLevel = "warn"
ObserveLevelError ObserveLevel = "error"
)
// ObserveEvent 描述一次统一观测事件。
//
// 职责边界:
// 1. 只承载 RAG service 的结构化运行信息;
// 2. 不绑定具体日志系统、指标系统或 tracing 实现;
// 3. 字段内容应尽量稳定,便于后续统一接入全局观测平台。
type ObserveEvent struct {
Level ObserveLevel
Component string
Operation string
Fields map[string]any
}
// Observer 是 RAG service 的最小观测接口。
//
// 职责边界:
// 1. 负责消费结构化事件;
// 2. 不负责决定业务逻辑是否继续执行;
// 3. 任一实现都不应反向影响主链路稳定性。
type Observer interface {
Observe(ctx context.Context, event ObserveEvent)
}
// ObserverFunc 允许用函数快速适配 Observer。
type ObserverFunc func(ctx context.Context, event ObserveEvent)
func (f ObserverFunc) Observe(ctx context.Context, event ObserveEvent) {
if f == nil {
return
}
f(ctx, event)
}
// NewNopObserver 返回空实现,适合在未接入统一观测平台时兜底。
func NewNopObserver() Observer {
return ObserverFunc(func(context.Context, ObserveEvent) {})
}
// NewLoggerObserver 返回标准日志适配器。
//
// 说明:
// 1. 当前项目尚未建立统一日志平台时,先把结构化字段稳定打印出来;
// 2. 后续若项目引入统一 logger/metrics/tracing只需替换该 Observer 注入实现;
// 3. 该适配器默认保持单行输出,减少和现有日志风格的割裂感。
func NewLoggerObserver(logger *log.Logger) Observer {
if logger == nil {
logger = log.Default()
}
return &loggerObserver{logger: logger}
}
type loggerObserver struct {
logger *log.Logger
}
func (o *loggerObserver) Observe(ctx context.Context, event ObserveEvent) {
if o == nil || o.logger == nil {
return
}
level := strings.TrimSpace(string(event.Level))
if level == "" {
level = string(ObserveLevelInfo)
}
component := strings.TrimSpace(event.Component)
if component == "" {
component = "unknown"
}
operation := strings.TrimSpace(event.Operation)
if operation == "" {
operation = "unknown"
}
fields := ObserveFieldsFromContext(ctx)
for key, value := range event.Fields {
key = strings.TrimSpace(key)
if key == "" || !shouldKeepObserveField(value) {
continue
}
fields[key] = value
}
parts := []string{
"rag",
fmt.Sprintf("level=%s", level),
fmt.Sprintf("component=%s", component),
fmt.Sprintf("operation=%s", operation),
}
keys := make([]string, 0, len(fields))
for key := range fields {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
parts = append(parts, fmt.Sprintf("%s=%v", key, fields[key]))
}
o.logger.Print(strings.Join(parts, " "))
}
type observeFieldsContextKey struct{}
// WithObserveFields 把通用观测字段挂入上下文,便于下游组件复用。
//
// 步骤化说明:
// 1. 先读取已有上下文字段,保证 Runtime / Pipeline / Store 能逐层补充信息;
// 2. 后写字段覆盖同名旧值,确保下游拿到的是最新语义;
// 3. 仅保存“有意义”的字段,避免日志长期堆积大量空值。
func WithObserveFields(ctx context.Context, fields map[string]any) context.Context {
if len(fields) == 0 {
return ctx
}
if ctx == nil {
ctx = context.Background()
}
merged := ObserveFieldsFromContext(ctx)
for key, value := range fields {
key = strings.TrimSpace(key)
if key == "" || !shouldKeepObserveField(value) {
continue
}
merged[key] = value
}
if len(merged) == 0 {
return ctx
}
return context.WithValue(ctx, observeFieldsContextKey{}, merged)
}
// ObserveFieldsFromContext 提取上下文中已经累积的观测字段。
func ObserveFieldsFromContext(ctx context.Context) map[string]any {
if ctx == nil {
return map[string]any{}
}
raw, ok := ctx.Value(observeFieldsContextKey{}).(map[string]any)
if !ok || len(raw) == 0 {
return map[string]any{}
}
result := make(map[string]any, len(raw))
for key, value := range raw {
result[key] = value
}
return result
}
// ClassifyErrorCode 统一把常见错误压缩为稳定错误码,便于后续接入全局观测平台。
func ClassifyErrorCode(err error) string {
switch {
case err == nil:
return ""
case errors.Is(err, context.DeadlineExceeded):
return "DEADLINE_EXCEEDED"
case errors.Is(err, context.Canceled):
return "CANCELED"
default:
return "RAG_ERROR"
}
}
func shouldKeepObserveField(value any) bool {
if value == nil {
return false
}
if text, ok := value.(string); ok {
return strings.TrimSpace(text) != ""
}
return true
}

View File

@@ -0,0 +1,366 @@
package core
import (
"context"
"errors"
"fmt"
"log"
"runtime/debug"
"strings"
"time"
)
const (
defaultTopK = 8
defaultThreshold = 0
defaultChunkSize = 400
defaultChunkOvLap = 80
)
// Pipeline 是 RAG Core 编排器。
//
// 职责边界:
// 1. 负责统一 chunk/embed/retrieve/rerank 流程;
// 2. 负责失败降级语义;
// 3. 不承载任何具体业务语义(由 CorpusAdapter 提供)。
type Pipeline struct {
chunker Chunker
embedder Embedder
store VectorStore
reranker Reranker
logger *log.Logger
observer Observer
}
func NewPipeline(chunker Chunker, embedder Embedder, store VectorStore, reranker Reranker) *Pipeline {
return &Pipeline{
chunker: chunker,
embedder: embedder,
store: store,
reranker: reranker,
logger: log.Default(),
observer: NewNopObserver(),
}
}
// SetLogger 设置 Pipeline 使用的日志器。
func (p *Pipeline) SetLogger(logger *log.Logger) {
if p == nil || logger == nil {
return
}
p.logger = logger
}
// SetObserver 设置 Pipeline 使用的统一观测器。
func (p *Pipeline) SetObserver(observer Observer) {
if p == nil || observer == nil {
return
}
p.observer = observer
}
// Ingest 执行统一入库流程。
//
// 步骤化说明:
// 1. 先由 CorpusAdapter 生成统一文档,确保不同语料入口一致;
// 2. 再统一切块与向量化,避免业务侧重复实现;
// 3. 最后一次性 Upsert失败直接返回交由上层决定是否重试。
func (p *Pipeline) Ingest(
ctx context.Context,
corpus CorpusAdapter,
input any,
opt IngestOption,
) (result *IngestResult, err error) {
defer p.recoverExecutionPanic(ctx, "ingest", &err)
if p == nil || p.chunker == nil || p.embedder == nil || p.store == nil {
return nil, ErrNilDependency
}
if corpus == nil {
return nil, errors.New("nil corpus adapter")
}
docs, err := corpus.BuildIngestDocuments(ctx, input)
if err != nil {
return nil, err
}
return p.IngestDocuments(ctx, corpus.Name(), docs, opt)
}
// IngestDocuments 执行“已标准化文档”的统一入库流程。
//
// 职责边界:
// 1. 负责处理已经完成 CorpusAdapter 映射的标准文档;
// 2. 负责统一切块、向量化与 Upsert
// 3. 不负责再做业务输入解析,避免 Runtime 为拿到 document_id 重复 build 文档。
func (p *Pipeline) IngestDocuments(
ctx context.Context,
corpusName string,
docs []SourceDocument,
opt IngestOption,
) (result *IngestResult, err error) {
defer p.recoverExecutionPanic(ctx, "ingest_documents", &err)
if p == nil || p.chunker == nil || p.embedder == nil || p.store == nil {
return nil, ErrNilDependency
}
if len(docs) == 0 {
return &IngestResult{DocumentCount: 0, ChunkCount: 0}, nil
}
chunkOpt := normalizeChunkOption(opt.Chunk)
chunks := make([]Chunk, 0, len(docs)*2)
for _, doc := range docs {
// 1. 对每个文档独立切块,失败直接中断,避免写入半成品。
docChunks, chunkErr := p.chunker.Chunk(ctx, doc, chunkOpt)
if chunkErr != nil {
return nil, chunkErr
}
chunks = append(chunks, docChunks...)
}
if len(chunks) == 0 {
return &IngestResult{DocumentCount: len(docs), ChunkCount: 0}, nil
}
texts := make([]string, 0, len(chunks))
for _, chunk := range chunks {
texts = append(texts, chunk.Text)
}
action := strings.TrimSpace(opt.Action)
if action == "" {
action = "add"
}
vectors, err := p.embedder.Embed(ctx, texts, action)
if err != nil {
return nil, err
}
if len(vectors) != len(chunks) {
return nil, fmt.Errorf("embedding result length mismatch: chunks=%d vectors=%d", len(chunks), len(vectors))
}
rows := make([]VectorRow, 0, len(chunks))
now := time.Now()
for i, chunk := range chunks {
metadata := cloneMap(chunk.Metadata)
metadata["corpus"] = corpusName
metadata["document_id"] = chunk.DocumentID
metadata["chunk_order"] = chunk.Order
rows = append(rows, VectorRow{
ID: chunk.ID,
Vector: vectors[i],
Text: chunk.Text,
Metadata: metadata,
CreatedAt: now,
UpdatedAt: now,
})
}
if err = p.store.Upsert(ctx, rows); err != nil {
return nil, err
}
return &IngestResult{
DocumentCount: len(docs),
ChunkCount: len(chunks),
}, nil
}
// Retrieve 执行统一检索流程。
//
// 步骤化说明:
// 1. 先做 query 向量化与向量检索;
// 2. 再执行阈值过滤,减少低质量候选;
// 3. 最后可选 rerank若失败则降级回原排序并打日志。
func (p *Pipeline) Retrieve(
ctx context.Context,
corpus CorpusAdapter,
req RetrieveRequest,
) (result *RetrieveResult, err error) {
defer p.recoverExecutionPanic(ctx, "retrieve", &err)
if p == nil || p.embedder == nil || p.store == nil {
return nil, ErrNilDependency
}
query := strings.TrimSpace(req.Query)
if query == "" {
return nil, ErrInvalidQuery
}
topK := req.TopK
if topK <= 0 {
topK = defaultTopK
}
threshold := req.Threshold
if threshold < 0 {
threshold = defaultThreshold
}
filter := cloneMap(req.Filter)
if corpus != nil {
// 1. 先拼接 corpus 过滤条件,避免跨语料串召回。
corpusFilter, err := corpus.BuildRetrieveFilter(ctx, req.CorpusInput)
if err != nil {
return nil, err
}
filter = mergeMap(filter, corpusFilter)
filter["corpus"] = corpus.Name()
}
action := strings.TrimSpace(req.Action)
if action == "" {
action = "search"
}
vectors, err := p.embedder.Embed(ctx, []string{query}, action)
if err != nil {
return nil, err
}
if len(vectors) != 1 {
return nil, fmt.Errorf("embedding query length mismatch: %d", len(vectors))
}
scoredRows, err := p.store.Search(ctx, VectorSearchRequest{
QueryVector: vectors[0],
TopK: topK,
Filter: filter,
})
if err != nil {
return nil, err
}
rawCount := len(scoredRows)
candidates := make([]ScoredChunk, 0, len(scoredRows))
for _, row := range scoredRows {
if row.Score < threshold {
continue
}
candidates = append(candidates, ScoredChunk{
ChunkID: row.Row.ID,
DocumentID: asString(row.Row.Metadata["document_id"]),
Text: row.Row.Text,
Score: row.Score,
Metadata: cloneMap(row.Row.Metadata),
})
}
result = &RetrieveResult{
Items: candidates,
RawCount: rawCount,
FallbackUsed: false,
}
if len(candidates) == 0 || p.reranker == nil {
return result, nil
}
reranked, rerankErr := p.reranker.Rerank(ctx, query, candidates, topK)
if rerankErr != nil {
// 2. rerank 异常不终止主流程,统一降级为原排序。
result.FallbackUsed = true
result.FallbackReason = FallbackReasonRerankFailed
if p.observer != nil {
p.observer.Observe(ctx, ObserveEvent{
Level: ObserveLevelWarn,
Component: "pipeline",
Operation: "rerank_fallback",
Fields: map[string]any{
"status": "fallback",
"fallback_reason": FallbackReasonRerankFailed,
"candidate_count": len(candidates),
"top_k": topK,
"error": rerankErr,
"error_code": ClassifyErrorCode(rerankErr),
},
})
} else if p.logger != nil {
p.logger.Printf("rag rerank fallback: reason=%s err=%v", FallbackReasonRerankFailed, rerankErr)
}
return result, nil
}
result.Items = reranked
return result, nil
}
// Delete 删除指定 ID 的向量。
func (p *Pipeline) Delete(ctx context.Context, ids []string) error {
if p == nil || p.store == nil {
return nil
}
return p.store.Delete(ctx, ids)
}
func (p *Pipeline) recoverExecutionPanic(ctx context.Context, operation string, errPtr *error) {
recovered := recover()
if recovered == nil || errPtr == nil {
return
}
panicErr := fmt.Errorf("rag pipeline panic recovered: operation=%s panic=%v", operation, recovered)
*errPtr = panicErr
// 1. Pipeline 是 chunk/embed/store/rerank 的统一编排边界,第三方依赖异常不应直接杀掉上层请求。
// 2. 这里统一 recover 后继续走 error 语义,让 runtime/service 决定降级、回退或记日志。
// 3. stack 只写观测层,不塞进返回值,避免把超长堆栈直接暴露给上层业务错误文案。
if p != nil && p.observer != nil {
p.observer.Observe(ctx, ObserveEvent{
Level: ObserveLevelError,
Component: "pipeline",
Operation: operation + "_panic_recovered",
Fields: map[string]any{
"status": "failed",
"panic": fmt.Sprintf("%v", recovered),
"panic_type": fmt.Sprintf("%T", recovered),
"error": panicErr,
"error_code": ClassifyErrorCode(panicErr),
"stack": string(debug.Stack()),
},
})
return
}
if p != nil && p.logger != nil {
p.logger.Printf("rag pipeline panic recovered: operation=%s panic=%v stack=%s", operation, recovered, string(debug.Stack()))
}
}
func normalizeChunkOption(opt ChunkOption) ChunkOption {
if opt.ChunkSize <= 0 {
opt.ChunkSize = defaultChunkSize
}
if opt.ChunkOverlap < 0 {
opt.ChunkOverlap = 0
}
if opt.ChunkOverlap >= opt.ChunkSize {
opt.ChunkOverlap = defaultChunkOvLap
if opt.ChunkOverlap >= opt.ChunkSize {
opt.ChunkOverlap = opt.ChunkSize / 5
}
}
return opt
}
func cloneMap(src map[string]any) map[string]any {
if len(src) == 0 {
return map[string]any{}
}
dst := make(map[string]any, len(src))
for key, value := range src {
dst[key] = value
}
return dst
}
func mergeMap(base map[string]any, ext map[string]any) map[string]any {
if base == nil {
base = map[string]any{}
}
for key, value := range ext {
base[key] = value
}
return base
}
func asString(v any) string {
if v == nil {
return ""
}
return fmt.Sprintf("%v", v)
}

View File

@@ -0,0 +1,94 @@
package core
import "time"
// SourceDocument 是统一语料文档模型。
//
// 职责边界:
// 1. 只描述“可被切块与索引”的原始文档;
// 2. 不承载业务流程状态。
type SourceDocument struct {
ID string
Text string
Title string
Metadata map[string]any
CreatedAt time.Time
}
// Chunk 是标准切块结果。
type Chunk struct {
ID string
DocumentID string
Text string
Order int
Metadata map[string]any
}
// ChunkOption 控制切块参数。
type ChunkOption struct {
ChunkSize int
ChunkOverlap int
}
// IngestOption 控制入库参数。
type IngestOption struct {
Chunk ChunkOption
// Action 用于 embedding 分型add/update/search
Action string
}
// IngestResult 描述一次入库执行摘要。
type IngestResult struct {
DocumentCount int
ChunkCount int
}
// RetrieveRequest 是统一检索请求。
type RetrieveRequest struct {
Query string
TopK int
Threshold float64
Action string
Filter map[string]any
CorpusInput any
}
// ScoredChunk 是统一召回结果。
type ScoredChunk struct {
ChunkID string
DocumentID string
Text string
Score float64
Metadata map[string]any
}
// RetrieveResult 是检索链路执行摘要。
type RetrieveResult struct {
Items []ScoredChunk
RawCount int
FallbackUsed bool
FallbackReason string
}
// VectorRow 是向量存储标准行。
type VectorRow struct {
ID string
Vector []float32
Text string
Metadata map[string]any
CreatedAt time.Time
UpdatedAt time.Time
}
// VectorSearchRequest 是向量检索请求。
type VectorSearchRequest struct {
QueryVector []float32
TopK int
Filter map[string]any
}
// ScoredVectorRow 是向量检索结果。
type ScoredVectorRow struct {
Row VectorRow
Score float64
}

View File

@@ -0,0 +1,13 @@
package corpus
import (
"crypto/sha256"
"encoding/hex"
"strings"
)
func hashLikeText(text string) string {
normalized := strings.TrimSpace(strings.ToLower(text))
sum := sha256.Sum256([]byte(normalized))
return hex.EncodeToString(sum[:8])
}

View File

@@ -0,0 +1,158 @@
package corpus
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/services/rag/core"
)
const memoryCorpusName = "memory"
// MemoryIngestItem 是记忆语料入库项。
type MemoryIngestItem struct {
MemoryID int64
UserID int
ConversationID string
AssistantID string
RunID string
MemoryType string
Title string
Content string
Confidence float64
Importance float64
SensitivityLevel int
IsExplicit bool
Status string
TTLAt *time.Time
CreatedAt *time.Time
}
// MemoryRetrieveInput 是记忆检索过滤输入。
type MemoryRetrieveInput struct {
UserID int
ConversationID string
AssistantID string
RunID string
MemoryType string
}
// MemoryCorpus 是记忆语料适配器。
type MemoryCorpus struct{}
func NewMemoryCorpus() *MemoryCorpus {
return &MemoryCorpus{}
}
func (c *MemoryCorpus) Name() string {
return memoryCorpusName
}
func (c *MemoryCorpus) BuildIngestDocuments(_ context.Context, input any) ([]core.SourceDocument, error) {
items, err := toMemoryItems(input)
if err != nil {
return nil, err
}
result := make([]core.SourceDocument, 0, len(items))
for _, item := range items {
if item.UserID <= 0 {
return nil, errors.New("memory ingest item user_id is invalid")
}
text := strings.TrimSpace(item.Content)
if text == "" {
continue
}
docID := fmt.Sprintf("memory:%d", item.MemoryID)
if item.MemoryID <= 0 {
docID = fmt.Sprintf("memory:uid:%d:%s", item.UserID, hashLikeText(text))
}
metadata := map[string]any{
"user_id": item.UserID,
"conversation_id": strings.TrimSpace(item.ConversationID),
"assistant_id": strings.TrimSpace(item.AssistantID),
"run_id": strings.TrimSpace(item.RunID),
"memory_type": strings.TrimSpace(strings.ToLower(item.MemoryType)),
"title": strings.TrimSpace(item.Title),
"confidence": item.Confidence,
"importance": item.Importance,
"sensitivity_level": item.SensitivityLevel,
"is_explicit": item.IsExplicit,
"status": strings.TrimSpace(item.Status),
}
if item.TTLAt != nil {
metadata["ttl_at"] = item.TTLAt.Format(time.RFC3339)
}
createdAt := time.Now()
if item.CreatedAt != nil {
createdAt = *item.CreatedAt
}
result = append(result, core.SourceDocument{
ID: docID,
Text: text,
Title: strings.TrimSpace(item.Title),
Metadata: metadata,
CreatedAt: createdAt,
})
}
return result, nil
}
func (c *MemoryCorpus) BuildRetrieveFilter(_ context.Context, req any) (map[string]any, error) {
input, ok := req.(MemoryRetrieveInput)
if !ok {
if ptr, isPtr := req.(*MemoryRetrieveInput); isPtr && ptr != nil {
input = *ptr
} else if req == nil {
return nil, errors.New("memory retrieve input is nil")
} else {
return nil, errors.New("invalid memory retrieve input")
}
}
if input.UserID <= 0 {
return nil, errors.New("memory retrieve user_id is invalid")
}
filter := map[string]any{
"user_id": input.UserID,
}
if v := strings.TrimSpace(input.ConversationID); v != "" {
filter["conversation_id"] = v
}
if v := strings.TrimSpace(input.AssistantID); v != "" {
filter["assistant_id"] = v
}
if v := strings.TrimSpace(input.RunID); v != "" {
filter["run_id"] = v
}
if v := strings.TrimSpace(strings.ToLower(input.MemoryType)); v != "" {
filter["memory_type"] = v
}
return filter, nil
}
func toMemoryItems(input any) ([]MemoryIngestItem, error) {
switch value := input.(type) {
case MemoryIngestItem:
return []MemoryIngestItem{value}, nil
case *MemoryIngestItem:
if value == nil {
return nil, errors.New("memory ingest item is nil")
}
return []MemoryIngestItem{*value}, nil
case []MemoryIngestItem:
return value, nil
case []*MemoryIngestItem:
items := make([]MemoryIngestItem, 0, len(value))
for _, ptr := range value {
if ptr == nil {
continue
}
items = append(items, *ptr)
}
return items, nil
default:
return nil, errors.New("invalid memory ingest input")
}
}

View File

@@ -0,0 +1,163 @@
package corpus
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/LoveLosita/smartflow/backend/services/rag/core"
)
const webCorpusName = "web"
// WebIngestItem 是网页语料入库项。
type WebIngestItem struct {
URL string
Title string
Content string
Snippet string
Domain string
QueryID string
SessionID string
PublishedAt *time.Time
FetchedAt *time.Time
SourceRank int
}
// WebRetrieveInput 是网页检索过滤输入。
type WebRetrieveInput struct {
QueryID string
SessionID string
Domain string
}
// WebCorpus 是网页语料适配器。
type WebCorpus struct{}
func NewWebCorpus() *WebCorpus {
return &WebCorpus{}
}
func (c *WebCorpus) Name() string {
return webCorpusName
}
func (c *WebCorpus) BuildIngestDocuments(_ context.Context, input any) ([]core.SourceDocument, error) {
items, err := toWebItems(input)
if err != nil {
return nil, err
}
result := make([]core.SourceDocument, 0, len(items))
for _, item := range items {
url := strings.TrimSpace(item.URL)
if url == "" {
return nil, errors.New("web ingest item url is empty")
}
mainText := buildWebText(item)
if strings.TrimSpace(mainText) == "" {
continue
}
docID := fmt.Sprintf("web:%s", hashLikeText(url+"|"+mainText))
metadata := map[string]any{
"url": url,
"domain": strings.TrimSpace(item.Domain),
"query_id": strings.TrimSpace(item.QueryID),
"session_id": strings.TrimSpace(item.SessionID),
"source_rank": item.SourceRank,
}
if item.PublishedAt != nil {
metadata["published_at"] = item.PublishedAt.Format(time.RFC3339)
}
if item.FetchedAt != nil {
metadata["fetched_at"] = item.FetchedAt.Format(time.RFC3339)
}
createdAt := time.Now()
if item.FetchedAt != nil {
createdAt = *item.FetchedAt
}
result = append(result, core.SourceDocument{
ID: docID,
Text: mainText,
Title: strings.TrimSpace(item.Title),
Metadata: metadata,
CreatedAt: createdAt,
})
}
return result, nil
}
func (c *WebCorpus) BuildRetrieveFilter(_ context.Context, req any) (map[string]any, error) {
input, ok := req.(WebRetrieveInput)
if !ok {
if ptr, isPtr := req.(*WebRetrieveInput); isPtr && ptr != nil {
input = *ptr
} else if req == nil {
return nil, errors.New("web retrieve input is nil")
} else {
return nil, errors.New("invalid web retrieve input")
}
}
// 1. query_id/session_id 至少要有一个,避免跨问题串数据。
queryID := strings.TrimSpace(input.QueryID)
sessionID := strings.TrimSpace(input.SessionID)
if queryID == "" && sessionID == "" {
return nil, errors.New("web retrieve filter requires query_id or session_id")
}
filter := map[string]any{}
if queryID != "" {
filter["query_id"] = queryID
}
if sessionID != "" {
filter["session_id"] = sessionID
}
if domain := strings.TrimSpace(input.Domain); domain != "" {
filter["domain"] = domain
}
return filter, nil
}
func toWebItems(input any) ([]WebIngestItem, error) {
switch value := input.(type) {
case WebIngestItem:
return []WebIngestItem{value}, nil
case *WebIngestItem:
if value == nil {
return nil, errors.New("web ingest item is nil")
}
return []WebIngestItem{*value}, nil
case []WebIngestItem:
return value, nil
case []*WebIngestItem:
items := make([]WebIngestItem, 0, len(value))
for _, ptr := range value {
if ptr == nil {
continue
}
items = append(items, *ptr)
}
return items, nil
default:
return nil, errors.New("invalid web ingest input")
}
}
func buildWebText(item WebIngestItem) string {
parts := make([]string, 0, 3)
if title := strings.TrimSpace(item.Title); title != "" {
parts = append(parts, title)
}
if snippet := strings.TrimSpace(item.Snippet); snippet != "" {
parts = append(parts, snippet)
}
if content := strings.TrimSpace(item.Content); content != "" {
parts = append(parts, content)
}
return strings.Join(parts, "\n\n")
}

View File

@@ -0,0 +1,208 @@
package embed
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
openaiembedding "github.com/cloudwego/eino-ext/libs/acl/openai"
einoembedding "github.com/cloudwego/eino/components/embedding"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
arkmodel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
)
// EinoConfig 描述 Eino embedding 运行参数。
type EinoConfig struct {
APIKey string
BaseURL string
Model string
TimeoutMS int
Dimension int
}
// EinoEmbedder 是基于 Eino 的 embedding 适配器。
//
// 说明:
// 1. 对 infra/rag 暴露统一 []float32 结果,屏蔽底层 SDK 的实现差异。
// 2. 文本 embedding 继续走当前稳定的 OpenAI 兼容链路,避免无关模型受影响。
// 3. 多模态 embedding 模型单独走 Ark 原生 `/embeddings/multimodal`,解决 vision 模型与标准 `/embeddings` 不兼容的问题。
type EinoEmbedder struct {
textClient einoembedding.Embedder
multimodalClient *arkruntime.Client
model string
timeout time.Duration
dimension int
}
func NewEinoEmbedder(ctx context.Context, cfg EinoConfig) (*EinoEmbedder, error) {
if strings.TrimSpace(cfg.APIKey) == "" {
return nil, errors.New("eino embedder api key is empty")
}
if strings.TrimSpace(cfg.Model) == "" {
return nil, errors.New("eino embedder model is empty")
}
timeout := 1200 * time.Millisecond
if cfg.TimeoutMS > 0 {
timeout = time.Duration(cfg.TimeoutMS) * time.Millisecond
}
baseURL := normalizeEmbeddingBaseURL(cfg.BaseURL)
model := strings.TrimSpace(cfg.Model)
httpClient := &http.Client{Timeout: timeout}
// 1. `doubao-embedding-vision-*` 这类模型不支持标准 `/embeddings`。
// 2. 这里直接切到 Ark 原生多模态 embedding API避免再依赖错误 endpoint 拼接。
// 3. 之所以仍保留文本链路,是为了不影响普通 text embedding 模型的既有行为。
if isMultimodalEmbeddingModel(model) {
arkOptions := []arkruntime.ConfigOption{
arkruntime.WithHTTPClient(httpClient),
}
if baseURL != "" {
arkOptions = append(arkOptions, arkruntime.WithBaseUrl(baseURL))
}
return &EinoEmbedder{
multimodalClient: arkruntime.NewClientWithApiKey(
strings.TrimSpace(cfg.APIKey),
arkOptions...,
),
model: model,
timeout: timeout,
dimension: cfg.Dimension,
}, nil
}
clientCfg := &openaiembedding.EmbeddingConfig{
APIKey: strings.TrimSpace(cfg.APIKey),
BaseURL: baseURL,
Model: model,
HTTPClient: httpClient,
}
if cfg.Dimension > 0 {
clientCfg.Dimensions = &cfg.Dimension
}
client, err := openaiembedding.NewEmbeddingClient(ctx, clientCfg)
if err != nil {
return nil, err
}
return &EinoEmbedder{
textClient: client,
model: model,
timeout: timeout,
dimension: cfg.Dimension,
}, nil
}
func (e *EinoEmbedder) Embed(ctx context.Context, texts []string, _ string) (result [][]float32, err error) {
if e == nil {
return nil, errors.New("eino embedder is not initialized")
}
if len(texts) == 0 {
return nil, nil
}
callCtx := ctx
cancel := func() {}
if e.timeout > 0 {
callCtx, cancel = context.WithTimeout(ctx, e.timeout)
}
defer cancel()
// 1. 第三方 SDK 一旦 panic不应该穿透到 RAG 主链路。
// 2. 这里统一在模型调用边界 recover并转成 error 交给上层做降级。
// 3. 这样 memory 主写链路和 agent 主回复链路都不会因为向量同步失败被直接打崩。
defer func() {
if recovered := recover(); recovered != nil {
err = fmt.Errorf("eino embedder panic recovered: %v", recovered)
result = nil
}
}()
if e.multimodalClient != nil {
return e.embedTextsWithMultimodalAPI(callCtx, texts)
}
if e.textClient == nil {
return nil, errors.New("eino embedder client is not initialized")
}
vectors, err := e.textClient.EmbedStrings(callCtx, texts, einoembedding.WithModel(e.model))
if err != nil {
return nil, err
}
result = make([][]float32, 0, len(vectors))
for _, vector := range vectors {
converted := make([]float32, len(vector))
for i, value := range vector {
converted[i] = float32(value)
}
result = append(result, converted)
}
return result, nil
}
func (e *EinoEmbedder) embedTextsWithMultimodalAPI(ctx context.Context, texts []string) ([][]float32, error) {
if e.multimodalClient == nil {
return nil, errors.New("eino multimodal embedder client is not initialized")
}
vectors := make([][]float32, 0, len(texts))
for _, text := range texts {
text := text
req := arkmodel.MultiModalEmbeddingRequest{
Model: e.model,
Input: []arkmodel.MultimodalEmbeddingInput{
{
Type: arkmodel.MultiModalEmbeddingInputTypeText,
Text: &text,
},
},
}
if e.dimension > 0 {
req.Dimensions = &e.dimension
}
// 1. Ark 的多模态 embedding 请求体是“单条内容由多个 part 组成”。
// 2. 当前 RAG 这里只传文本,因此每段文本单独发一次,避免把多段文本错误拼成同一个 multimodal sample。
// 3. 一旦后续真的要做批量多模态 embedding再单独扩展 batch 接口,而不是在这里偷改语义。
resp, err := e.multimodalClient.CreateMultiModalEmbeddings(ctx, req)
if err != nil {
return nil, err
}
converted := make([]float32, len(resp.Data.Embedding))
copy(converted, resp.Data.Embedding)
vectors = append(vectors, converted)
}
return vectors, nil
}
func isMultimodalEmbeddingModel(model string) bool {
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "doubao-embedding-vision-")
}
func normalizeEmbeddingBaseURL(raw string) string {
baseURL := strings.TrimRight(strings.TrimSpace(raw), "/")
if baseURL == "" {
return ""
}
lowerBaseURL := strings.ToLower(baseURL)
// 1. 配置里应填写 Ark 服务根路径,而不是具体 embedding endpoint。
// 2. 这里兼容两类常见误配:`/embeddings` 和 `/embeddings/multimodal`。
// 3. 统一回退到 `/api/v3` 根路径后,再由对应 SDK 自己追加正确后缀,避免最终 URL 重复拼接。
if strings.HasSuffix(lowerBaseURL, "/embeddings/multimodal") {
return strings.TrimSuffix(baseURL, baseURL[len(baseURL)-len("/embeddings/multimodal"):])
}
if strings.HasSuffix(lowerBaseURL, "/embeddings") {
return strings.TrimSuffix(baseURL, baseURL[len(baseURL)-len("/embeddings"):])
}
return baseURL
}

View File

@@ -0,0 +1,46 @@
package embed
import (
"context"
"crypto/sha256"
"encoding/binary"
"strings"
)
const defaultDim = 16
// MockEmbedder 是本地可运行的占位向量化实现。
//
// 说明:
// 1. 该实现用于开发阶段打通链路,不代表真实语义向量质量;
// 2. 后续可替换为 Eino embedding 实现,接口保持不变。
type MockEmbedder struct {
dim int
}
func NewMockEmbedder(dim int) *MockEmbedder {
if dim <= 0 {
dim = defaultDim
}
return &MockEmbedder{dim: dim}
}
func (e *MockEmbedder) Embed(_ context.Context, texts []string, _ string) ([][]float32, error) {
result := make([][]float32, 0, len(texts))
for _, text := range texts {
result = append(result, e.embedOne(text))
}
return result, nil
}
func (e *MockEmbedder) embedOne(text string) []float32 {
normalized := strings.TrimSpace(strings.ToLower(text))
sum := sha256.Sum256([]byte(normalized))
vec := make([]float32, e.dim)
for i := 0; i < e.dim; i++ {
offset := (i * 4) % len(sum)
v := binary.BigEndian.Uint32(sum[offset : offset+4])
vec[i] = float32(v%1000) / 1000
}
return vec
}

View File

@@ -0,0 +1,142 @@
package rag
import (
"context"
"fmt"
"log"
"os"
"strings"
ragchunk "github.com/LoveLosita/smartflow/backend/services/rag/chunk"
ragconfig "github.com/LoveLosita/smartflow/backend/services/rag/config"
"github.com/LoveLosita/smartflow/backend/services/rag/core"
ragembed "github.com/LoveLosita/smartflow/backend/services/rag/embed"
ragrerank "github.com/LoveLosita/smartflow/backend/services/rag/rerank"
ragstore "github.com/LoveLosita/smartflow/backend/services/rag/store"
)
// FactoryDeps 描述 Runtime 工厂所需的可选依赖。
//
// 说明:
// 1. Logger 仅作为“当前项目尚无统一日志系统”时的默认落点;
// 2. Observer 是正式的统一观测插槽,后续可替换为项目级 logger/metrics/tracing 适配器;
// 3. 业务侧仍然只拿 Runtime不直接碰底层装配细节。
type FactoryDeps struct {
Logger *log.Logger
Observer Observer
}
// NewRuntimeFromConfig 按配置统一组装 RAG Runtime。
//
// 设计说明:
// 1. 所有底层实现选择都收口到这里,业务侧不再自行 new store/embedder/reranker
// 2. 即使后续引入更多 provider也应优先扩展本工厂而不是把选择逻辑扩散到业务模块
// 3. 观测能力也在此统一注入,避免 runtime/store/pipeline 各自偷偷打印日志。
func NewRuntimeFromConfig(ctx context.Context, cfg ragconfig.Config, deps FactoryDeps) (Runtime, error) {
logger, observer := normalizeFactoryDeps(deps)
embedder, err := buildEmbedder(ctx, cfg)
if err != nil {
return nil, err
}
store, err := buildStore(cfg, logger, observer)
if err != nil {
return nil, err
}
reranker, err := buildReranker(cfg, observer)
if err != nil {
return nil, err
}
pipeline := core.NewPipeline(ragchunk.NewTextChunker(), embedder, store, reranker)
pipeline.SetLogger(logger)
pipeline.SetObserver(observer)
return newRuntime(cfg, pipeline, observer), nil
}
func normalizeFactoryDeps(deps FactoryDeps) (*log.Logger, Observer) {
logger := deps.Logger
if logger == nil {
logger = log.Default()
}
observer := deps.Observer
if observer == nil {
observer = NewLoggerObserver(logger)
}
return logger, observer
}
func buildEmbedder(ctx context.Context, cfg ragconfig.Config) (core.Embedder, error) {
switch strings.ToLower(strings.TrimSpace(cfg.EmbedProvider)) {
case "", "mock":
return ragembed.NewMockEmbedder(cfg.EmbedDimension), nil
case "eino":
// 1. RAG embedding 与普通 LLM 链路保持同一套密钥来源,统一直接读取 ARK_API_KEY
// 2. 这样可以避免再维护一层 “env 名称配置 -> 再读环境变量” 的间接映射,减少配置分叉;
// 3. 若后续真的需要多套 embedding 凭据,再显式设计独立字段,而不是继续隐式透传 env 名称。
apiKey := strings.TrimSpace(os.Getenv("ARK_API_KEY"))
if apiKey == "" {
return nil, fmt.Errorf("rag embed api key is empty: env=%s", "ARK_API_KEY")
}
return ragembed.NewEinoEmbedder(ctx, ragembed.EinoConfig{
APIKey: apiKey,
BaseURL: cfg.EmbedBaseURL,
Model: cfg.EmbedModel,
TimeoutMS: cfg.EmbedTimeoutMS,
Dimension: cfg.EmbedDimension,
})
default:
return nil, fmt.Errorf("unsupported rag embed provider: %s", cfg.EmbedProvider)
}
}
func buildStore(cfg ragconfig.Config, logger *log.Logger, observer Observer) (core.VectorStore, error) {
switch strings.ToLower(strings.TrimSpace(cfg.Store)) {
case "", "inmemory":
return ragstore.NewInMemoryVectorStore(), nil
case "milvus":
return ragstore.NewMilvusStore(ragstore.MilvusConfig{
Address: cfg.MilvusAddress,
Token: cfg.MilvusToken,
DBName: cfg.MilvusDBName,
CollectionName: cfg.MilvusCollectionName,
RequestTimeoutMS: cfg.MilvusRequestTimeoutMS,
Dimension: cfg.EmbedDimension,
MetricType: cfg.MilvusMetricType,
Logger: logger,
Observer: observer,
})
default:
return nil, fmt.Errorf("unsupported rag store: %s", cfg.Store)
}
}
func buildReranker(cfg ragconfig.Config, observer Observer) (core.Reranker, error) {
if !cfg.RerankerEnabled {
return ragrerank.NewNoopReranker(), nil
}
switch strings.ToLower(strings.TrimSpace(cfg.RerankerProvider)) {
case "", "noop":
return ragrerank.NewNoopReranker(), nil
case "eino":
if observer != nil {
observer.Observe(context.Background(), ObserveEvent{
Level: ObserveLevelWarn,
Component: "factory",
Operation: "reranker_fallback",
Fields: map[string]any{
"provider": "eino",
"status": "fallback",
"fallback_target": "noop",
"reason": "reranker_not_implemented",
},
})
}
return ragrerank.NewNoopReranker(), nil
default:
return nil, fmt.Errorf("unsupported rag reranker provider: %s", cfg.RerankerProvider)
}
}

View File

@@ -0,0 +1,32 @@
package rag
import (
"log"
"github.com/LoveLosita/smartflow/backend/services/rag/core"
)
// ObserveLevel 对外暴露统一观测等级别名,避免启动层直接依赖 core 细节。
type ObserveLevel = core.ObserveLevel
const (
ObserveLevelInfo = core.ObserveLevelInfo
ObserveLevelWarn = core.ObserveLevelWarn
ObserveLevelError = core.ObserveLevelError
)
// ObserveEvent 对外暴露统一观测事件别名。
type ObserveEvent = core.ObserveEvent
// Observer 对外暴露统一观测接口别名。
type Observer = core.Observer
// NewNopObserver 返回空实现。
func NewNopObserver() Observer {
return core.NewNopObserver()
}
// NewLoggerObserver 返回标准日志适配器。
func NewLoggerObserver(logger *log.Logger) Observer {
return core.NewLoggerObserver(logger)
}

View File

@@ -0,0 +1,23 @@
package rag
import (
"github.com/LoveLosita/smartflow/backend/services/rag/chunk"
"github.com/LoveLosita/smartflow/backend/services/rag/core"
"github.com/LoveLosita/smartflow/backend/services/rag/embed"
"github.com/LoveLosita/smartflow/backend/services/rag/rerank"
"github.com/LoveLosita/smartflow/backend/services/rag/store"
)
// NewDefaultPipeline 构造默认可运行的 RAG Pipeline。
//
// 当前策略:
// 1. 默认使用本地 MockEmbedder + InMemoryStore保证零外部依赖可运行
// 2. 后续切 Milvus / Eino 时仅替换依赖,不改业务调用方式。
func NewDefaultPipeline() *core.Pipeline {
return core.NewPipeline(
chunk.NewTextChunker(),
embed.NewMockEmbedder(16),
store.NewInMemoryVectorStore(),
rerank.NewNoopReranker(),
)
}

View File

@@ -0,0 +1,19 @@
package rerank
import (
"context"
"errors"
"github.com/LoveLosita/smartflow/backend/services/rag/core"
)
// EinoReranker 是 Eino 重排器占位实现。
type EinoReranker struct{}
func NewEinoReranker() *EinoReranker {
return &EinoReranker{}
}
func (r *EinoReranker) Rerank(_ context.Context, _ string, _ []core.ScoredChunk, _ int) ([]core.ScoredChunk, error) {
return nil, errors.New("eino reranker is not implemented yet")
}

View File

@@ -0,0 +1,30 @@
package rerank
import (
"context"
"sort"
"github.com/LoveLosita/smartflow/backend/services/rag/core"
)
// NoopReranker 是默认重排器(仅按原 score 排序)。
type NoopReranker struct{}
func NewNoopReranker() *NoopReranker {
return &NoopReranker{}
}
func (r *NoopReranker) Rerank(_ context.Context, _ string, candidates []core.ScoredChunk, topK int) ([]core.ScoredChunk, error) {
if len(candidates) == 0 {
return nil, nil
}
sorted := make([]core.ScoredChunk, len(candidates))
copy(sorted, candidates)
sort.SliceStable(sorted, func(i, j int) bool {
return sorted[i].Score > sorted[j].Score
})
if topK <= 0 || topK >= len(sorted) {
return sorted, nil
}
return sorted[:topK], nil
}

View File

@@ -0,0 +1,98 @@
package retrieve
import (
"context"
"fmt"
"strings"
"github.com/LoveLosita/smartflow/backend/services/rag/core"
)
// VectorRetriever 是通用检索器embed + vector search
type VectorRetriever struct {
embedder core.Embedder
store core.VectorStore
}
func NewVectorRetriever(embedder core.Embedder, store core.VectorStore) *VectorRetriever {
return &VectorRetriever{
embedder: embedder,
store: store,
}
}
func (r *VectorRetriever) Retrieve(ctx context.Context, req core.RetrieveRequest) (result []core.ScoredChunk, err error) {
defer func() {
recovered := recover()
if recovered == nil {
return
}
err = fmt.Errorf("vector retriever panic recovered: %v", recovered)
}()
if r == nil || r.embedder == nil || r.store == nil {
return nil, core.ErrNilDependency
}
query := strings.TrimSpace(req.Query)
if query == "" {
return nil, core.ErrInvalidQuery
}
topK := req.TopK
if topK <= 0 {
topK = 8
}
action := strings.TrimSpace(req.Action)
if action == "" {
action = "search"
}
vectors, err := r.embedder.Embed(ctx, []string{query}, action)
if err != nil {
return nil, err
}
if len(vectors) != 1 {
return nil, fmt.Errorf("embedding query length mismatch: %d", len(vectors))
}
rows, err := r.store.Search(ctx, core.VectorSearchRequest{
QueryVector: vectors[0],
TopK: topK,
Filter: req.Filter,
})
if err != nil {
return nil, err
}
result = make([]core.ScoredChunk, 0, len(rows))
for _, row := range rows {
if row.Score < req.Threshold {
continue
}
result = append(result, core.ScoredChunk{
ChunkID: row.Row.ID,
DocumentID: asString(row.Row.Metadata["document_id"]),
Text: row.Row.Text,
Score: row.Score,
Metadata: cloneMap(row.Row.Metadata),
})
}
return result, nil
}
func cloneMap(src map[string]any) map[string]any {
if len(src) == 0 {
return map[string]any{}
}
dst := make(map[string]any, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}
func asString(v any) string {
if v == nil {
return ""
}
return fmt.Sprintf("%v", v)
}

View File

@@ -0,0 +1,434 @@
package rag
import (
"context"
"fmt"
"runtime/debug"
"strings"
"time"
ragconfig "github.com/LoveLosita/smartflow/backend/services/rag/config"
"github.com/LoveLosita/smartflow/backend/services/rag/core"
"github.com/LoveLosita/smartflow/backend/services/rag/corpus"
)
type runtime struct {
cfg ragconfig.Config
pipeline *core.Pipeline
memoryCorpus *corpus.MemoryCorpus
webCorpus *corpus.WebCorpus
observer Observer
}
func newRuntime(cfg ragconfig.Config, pipeline *core.Pipeline, observer Observer) Runtime {
if observer == nil {
observer = NewNopObserver()
}
return &runtime{
cfg: cfg,
pipeline: pipeline,
memoryCorpus: corpus.NewMemoryCorpus(),
webCorpus: corpus.NewWebCorpus(),
observer: observer,
}
}
// IngestMemory 统一承接记忆语料入库。
func (r *runtime) IngestMemory(ctx context.Context, req MemoryIngestRequest) (result *IngestResult, err error) {
defer r.recoverPublicPanic(ctx, req.TraceID, "memory", normalizeAction(req.Action, "add"), "ingest", &err)
items := make([]corpus.MemoryIngestItem, 0, len(req.Items))
for _, item := range req.Items {
items = append(items, corpus.MemoryIngestItem{
MemoryID: item.MemoryID,
UserID: item.UserID,
ConversationID: item.ConversationID,
AssistantID: item.AssistantID,
RunID: item.RunID,
MemoryType: item.MemoryType,
Title: item.Title,
Content: item.Content,
Confidence: item.Confidence,
Importance: item.Importance,
SensitivityLevel: item.SensitivityLevel,
IsExplicit: item.IsExplicit,
Status: item.Status,
TTLAt: item.TTLAt,
CreatedAt: item.CreatedAt,
})
}
return r.ingestWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, items, req.Action)
}
// RetrieveMemory 统一承接记忆语料检索。
func (r *runtime) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (result *RetrieveResult, err error) {
defer r.recoverPublicPanic(ctx, req.TraceID, "memory", normalizeAction(req.Action, "search"), "retrieve", &err)
corpusInput := corpus.MemoryRetrieveInput{
UserID: req.UserID,
ConversationID: req.ConversationID,
AssistantID: req.AssistantID,
RunID: req.RunID,
}
if len(req.MemoryTypes) == 1 {
corpusInput.MemoryType = req.MemoryTypes[0]
}
result, err = r.retrieveWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, core.RetrieveRequest{
Query: req.Query,
TopK: normalizeTopK(req.TopK, r.cfg.TopK),
Threshold: normalizeThreshold(req.Threshold, r.cfg.Threshold),
Action: normalizeAction(req.Action, "search"),
CorpusInput: corpusInput,
})
if err != nil {
return nil, err
}
if len(req.MemoryTypes) <= 1 {
return result, nil
}
// 1. 当前底层过滤仍以等值条件为主,先保持 Runtime 做多类型二次筛选;
// 2. 这样可以避免把 “memory_type in (...)” 的实现细节扩散到所有 Store
// 3. 等后续底层过滤能力统一后,再考虑把该逻辑继续下沉。
allowed := make(map[string]struct{}, len(req.MemoryTypes))
for _, item := range req.MemoryTypes {
value := strings.TrimSpace(strings.ToLower(item))
if value == "" {
continue
}
allowed[value] = struct{}{}
}
filtered := make([]RetrieveHit, 0, len(result.Items))
for _, item := range result.Items {
memoryType := strings.TrimSpace(strings.ToLower(asString(item.Metadata["memory_type"])))
if len(allowed) > 0 {
if _, ok := allowed[memoryType]; !ok {
continue
}
}
filtered = append(filtered, item)
}
result.Items = filtered
if req.TopK > 0 && len(result.Items) > req.TopK {
result.Items = result.Items[:req.TopK]
}
return result, nil
}
// DeleteMemory 删除记忆语料中的指定向量。
func (r *runtime) DeleteMemory(ctx context.Context, documentIDs []string) (err error) {
defer r.recoverPublicPanic(ctx, "", "memory", "delete", "delete", &err)
if r == nil || r.pipeline == nil || len(documentIDs) == 0 {
return nil
}
return r.pipeline.Delete(ctx, documentIDs)
}
// IngestWeb 统一承接网页语料入库。
func (r *runtime) IngestWeb(ctx context.Context, req WebIngestRequest) (result *IngestResult, err error) {
defer r.recoverPublicPanic(ctx, req.TraceID, "web", normalizeAction(req.Action, "add"), "ingest", &err)
items := make([]corpus.WebIngestItem, 0, len(req.Items))
for _, item := range req.Items {
items = append(items, corpus.WebIngestItem{
URL: item.URL,
Title: item.Title,
Content: item.Content,
Snippet: item.Snippet,
Domain: item.Domain,
QueryID: item.QueryID,
SessionID: item.SessionID,
PublishedAt: item.PublishedAt,
FetchedAt: item.FetchedAt,
SourceRank: item.SourceRank,
})
}
return r.ingestWithCorpus(ctx, req.TraceID, "web", r.webCorpus, items, req.Action)
}
// RetrieveWeb 统一承接网页语料检索。
func (r *runtime) RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (result *RetrieveResult, err error) {
defer r.recoverPublicPanic(ctx, req.TraceID, "web", normalizeAction(req.Action, "search"), "retrieve", &err)
return r.retrieveWithCorpus(ctx, req.TraceID, "web", r.webCorpus, core.RetrieveRequest{
Query: req.Query,
TopK: normalizeTopK(req.TopK, r.cfg.TopK),
Threshold: normalizeThreshold(req.Threshold, r.cfg.Threshold),
Action: normalizeAction(req.Action, "search"),
CorpusInput: corpus.WebRetrieveInput{
QueryID: req.QueryID,
SessionID: req.SessionID,
Domain: req.Domain,
},
})
}
func (r *runtime) ingestWithCorpus(
ctx context.Context,
traceID string,
corpusName string,
adapter core.CorpusAdapter,
input any,
action string,
) (*IngestResult, error) {
start := time.Now()
if r == nil || r.pipeline == nil || adapter == nil {
return nil, core.ErrNilDependency
}
action = normalizeAction(action, "add")
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
docs, err := adapter.BuildIngestDocuments(observeCtx, input)
if err != nil {
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelError,
Component: "runtime",
Operation: "ingest",
Fields: map[string]any{
"status": "failed",
"latency_ms": time.Since(start).Milliseconds(),
"phase": "build_documents",
"error": err,
"error_code": core.ClassifyErrorCode(err),
"input_count": estimateInputCount(input),
},
})
return nil, err
}
docIDs := make([]string, 0, len(docs))
for _, doc := range docs {
docIDs = append(docIDs, doc.ID)
}
result, err := r.pipeline.IngestDocuments(observeCtx, adapter.Name(), docs, core.IngestOption{
Chunk: core.ChunkOption{
ChunkSize: r.cfg.ChunkSize,
ChunkOverlap: r.cfg.ChunkOverlap,
},
Action: action,
})
if err != nil {
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelError,
Component: "runtime",
Operation: "ingest",
Fields: map[string]any{
"status": "failed",
"latency_ms": time.Since(start).Milliseconds(),
"document_count": len(docs),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelInfo,
Component: "runtime",
Operation: "ingest",
Fields: map[string]any{
"status": "success",
"latency_ms": time.Since(start).Milliseconds(),
"document_count": result.DocumentCount,
"chunk_count": result.ChunkCount,
},
})
return &IngestResult{
DocumentCount: result.DocumentCount,
ChunkCount: result.ChunkCount,
DocumentIDs: docIDs,
}, nil
}
func (r *runtime) retrieveWithCorpus(
ctx context.Context,
traceID string,
corpusName string,
adapter core.CorpusAdapter,
req core.RetrieveRequest,
) (*RetrieveResult, error) {
start := time.Now()
if r == nil || r.pipeline == nil || adapter == nil {
return nil, core.ErrNilDependency
}
action := normalizeAction(req.Action, "search")
req.Action = action
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
timeoutCtx := observeCtx
cancel := func() {}
if r.cfg.RetrieveTimeoutMS > 0 {
timeoutCtx, cancel = context.WithTimeout(observeCtx, time.Duration(r.cfg.RetrieveTimeoutMS)*time.Millisecond)
}
defer cancel()
result, err := r.pipeline.Retrieve(timeoutCtx, adapter, req)
if err != nil {
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelError,
Component: "runtime",
Operation: "retrieve",
Fields: map[string]any{
"status": "failed",
"latency_ms": time.Since(start).Milliseconds(),
"query_len": len(strings.TrimSpace(req.Query)),
"top_k": req.TopK,
"threshold": req.Threshold,
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
items := make([]RetrieveHit, 0, len(result.Items))
for _, item := range result.Items {
items = append(items, RetrieveHit{
ChunkID: item.ChunkID,
DocumentID: item.DocumentID,
Text: item.Text,
Score: item.Score,
Metadata: cloneMap(item.Metadata),
})
}
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelInfo,
Component: "runtime",
Operation: "retrieve",
Fields: map[string]any{
"status": "success",
"latency_ms": time.Since(start).Milliseconds(),
"query_len": len(strings.TrimSpace(req.Query)),
"top_k": req.TopK,
"threshold": req.Threshold,
"raw_count": result.RawCount,
"hit_count": len(result.Items),
"fallback_used": result.FallbackUsed,
"fallback_reason": result.FallbackReason,
},
})
return &RetrieveResult{
Items: items,
RawCount: result.RawCount,
FallbackUsed: result.FallbackUsed,
FallbackReason: result.FallbackReason,
}, nil
}
func (r *runtime) observe(ctx context.Context, event ObserveEvent) {
if r == nil || r.observer == nil {
return
}
r.observer.Observe(ctx, event)
}
func (r *runtime) recoverPublicPanic(
ctx context.Context,
traceID string,
corpusName string,
action string,
operation string,
errPtr *error,
) {
recovered := recover()
if recovered == nil || errPtr == nil {
return
}
// 1. runtime 是 RAG service 对业务侧暴露的最终方法面,任何下层 panic 都不应再穿透到业务协程。
// 2. 这里统一把 panic 转成 error并补一条结构化观测方便继续排查是哪一层依赖失控。
// 3. 保留 stack 是为了在“进程不崩”的前提下仍能定位根因,避免只剩一句 recovered 无法复盘。
panicErr := fmt.Errorf("rag runtime panic recovered: corpus=%s operation=%s panic=%v", corpusName, operation, recovered)
*errPtr = panicErr
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
r.observe(observeCtx, ObserveEvent{
Level: ObserveLevelError,
Component: "runtime",
Operation: operation + "_panic_recovered",
Fields: map[string]any{
"status": "failed",
"panic": fmt.Sprintf("%v", recovered),
"panic_type": fmt.Sprintf("%T", recovered),
"error": panicErr,
"error_code": core.ClassifyErrorCode(panicErr),
"stack": string(debug.Stack()),
},
})
}
func newObserveContext(ctx context.Context, traceID string, corpusName string, action string) context.Context {
fields := map[string]any{
"corpus": corpusName,
"action": action,
}
if traceID = strings.TrimSpace(traceID); traceID != "" {
fields["trace_id"] = traceID
}
return core.WithObserveFields(ctx, fields)
}
func estimateInputCount(input any) int {
switch value := input.(type) {
case []corpus.MemoryIngestItem:
return len(value)
case []corpus.WebIngestItem:
return len(value)
default:
return 0
}
}
func normalizeAction(action string, fallback string) string {
action = strings.TrimSpace(action)
if action == "" {
return fallback
}
return action
}
func normalizeTopK(topK int, fallback int) int {
if topK > 0 {
return topK
}
if fallback > 0 {
return fallback
}
return 8
}
func normalizeThreshold(threshold float64, fallback float64) float64 {
if threshold >= 0 {
return threshold
}
if fallback >= 0 {
return fallback
}
return 0
}
func cloneMap(src map[string]any) map[string]any {
if len(src) == 0 {
return map[string]any{}
}
dst := make(map[string]any, len(src))
for key, value := range src {
dst[key] = value
}
return dst
}
func asString(v any) string {
if v == nil {
return ""
}
return strings.TrimSpace(fmt.Sprintf("%v", v))
}

View File

@@ -0,0 +1,111 @@
package rag
import (
"context"
ragconfig "github.com/LoveLosita/smartflow/backend/services/rag/config"
)
// Options 描述 rag-service 需要持有的底层运行时。
type Options struct {
Runtime Runtime
}
// Service 是 rag-service 对外暴露的统一入口。
//
// 职责边界:
// 1. 负责持有运行时,并把 memory / web 两条能力线统一收口到服务层。
// 2. 负责在服务入口内完成基于配置的运行时装配。
// 3. 不直接承载 chunk / embed / store 的实现细节,这些细节下沉到服务树内部子包。
type Service struct {
runtime Runtime
}
// New 使用调用方传入的运行时构造服务。
func New(opts Options) *Service {
return &Service{runtime: opts.Runtime}
}
// NewFromConfig 基于服务树内的配置与工厂能力构造自给自足的 RAG 服务。
func NewFromConfig(ctx context.Context, cfg ragconfig.Config, deps FactoryDeps) (*Service, error) {
if !cfg.Enabled {
return New(Options{}), nil
}
runtime, err := NewRuntimeFromConfig(ctx, cfg, deps)
if err != nil {
return nil, err
}
return NewWithRuntime(runtime), nil
}
// Runtime 返回当前服务持有的运行时。
func (s *Service) Runtime() Runtime {
if s == nil {
return nil
}
return s.runtime
}
// IngestMemory 写入记忆语料。
func (s *Service) IngestMemory(ctx context.Context, req MemoryIngestRequest) (*IngestResult, error) {
if s == nil || s.runtime == nil {
return nil, nil
}
return s.runtime.IngestMemory(ctx, req)
}
// RetrieveMemory 检索记忆语料。
func (s *Service) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (*RetrieveResult, error) {
if s == nil || s.runtime == nil {
return nil, nil
}
return s.runtime.RetrieveMemory(ctx, req)
}
// DeleteMemory 删除指定记忆文档。
func (s *Service) DeleteMemory(ctx context.Context, documentIDs []string) error {
if s == nil || s.runtime == nil {
return nil
}
if ctx == nil {
ctx = context.Background()
}
return s.runtime.DeleteMemory(ctx, documentIDs)
}
// IngestWeb 写入网页语料。
func (s *Service) IngestWeb(ctx context.Context, req WebIngestRequest) (*IngestResult, error) {
if s == nil || s.runtime == nil {
return nil, nil
}
return s.runtime.IngestWeb(ctx, req)
}
// RetrieveWeb 检索网页语料。
func (s *Service) RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (*RetrieveResult, error) {
if s == nil || s.runtime == nil {
return nil, nil
}
return s.runtime.RetrieveWeb(ctx, req)
}
// EnsureRuntime 返回一个可继续向下传递的运行时引用。
func (s *Service) EnsureRuntime() Runtime {
if s == nil {
return nil
}
return s.runtime
}
// SetRuntime 允许在装配阶段延迟注入运行时。
func (s *Service) SetRuntime(runtime Runtime) {
if s == nil {
return
}
s.runtime = runtime
}
// NewWithRuntime 用显式运行时构造服务。
func NewWithRuntime(runtime Runtime) *Service {
return New(Options{Runtime: runtime})
}

View File

@@ -0,0 +1,182 @@
package store
import (
"context"
"errors"
"fmt"
"math"
"sort"
"strings"
"sync"
"time"
"github.com/LoveLosita/smartflow/backend/services/rag/core"
)
// InMemoryVectorStore 是本地开发用向量存储实现。
//
// 注意:
// 1. 仅用于开发调试,不建议生产使用;
// 2. 真实环境可替换为 MilvusStore接口保持一致。
type InMemoryVectorStore struct {
mu sync.RWMutex
rows map[string]core.VectorRow
}
func NewInMemoryVectorStore() *InMemoryVectorStore {
return &InMemoryVectorStore{
rows: make(map[string]core.VectorRow),
}
}
func (s *InMemoryVectorStore) Upsert(_ context.Context, rows []core.VectorRow) error {
if s == nil {
return errors.New("inmemory vector store is nil")
}
if len(rows) == 0 {
return nil
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
if s.rows == nil {
s.rows = make(map[string]core.VectorRow)
}
for _, row := range rows {
current, exists := s.rows[row.ID]
if exists {
row.CreatedAt = current.CreatedAt
row.UpdatedAt = now
} else {
if row.CreatedAt.IsZero() {
row.CreatedAt = now
}
row.UpdatedAt = now
}
s.rows[row.ID] = row
}
return nil
}
func (s *InMemoryVectorStore) Search(_ context.Context, req core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
if s == nil {
return nil, errors.New("inmemory vector store is nil")
}
topK := req.TopK
if topK <= 0 {
topK = 8
}
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]core.ScoredVectorRow, 0, len(s.rows))
for _, row := range s.rows {
if !matchMetadataFilter(row.Metadata, req.Filter) {
continue
}
score := cosineSimilarity(req.QueryVector, row.Vector)
result = append(result, core.ScoredVectorRow{
Row: row,
Score: score,
})
}
sort.SliceStable(result, func(i, j int) bool {
return result[i].Score > result[j].Score
})
if len(result) <= topK {
return result, nil
}
return result[:topK], nil
}
func (s *InMemoryVectorStore) Delete(_ context.Context, ids []string) error {
if s == nil {
return errors.New("inmemory vector store is nil")
}
if len(ids) == 0 {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
for _, id := range ids {
delete(s.rows, id)
}
return nil
}
func (s *InMemoryVectorStore) Get(_ context.Context, ids []string) ([]core.VectorRow, error) {
if s == nil {
return nil, errors.New("inmemory vector store is nil")
}
if len(ids) == 0 {
return nil, nil
}
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]core.VectorRow, 0, len(ids))
for _, id := range ids {
row, exists := s.rows[id]
if !exists {
continue
}
result = append(result, row)
}
return result, nil
}
func cosineSimilarity(a, b []float32) float64 {
if len(a) == 0 || len(b) == 0 {
return 0
}
n := len(a)
if len(b) < n {
n = len(b)
}
if n == 0 {
return 0
}
var dot, normA, normB float64
for i := 0; i < n; i++ {
av := float64(a[i])
bv := float64(b[i])
dot += av * bv
normA += av * av
normB += bv * bv
}
if normA == 0 || normB == 0 {
return 0
}
return dot / (math.Sqrt(normA) * math.Sqrt(normB))
}
func matchMetadataFilter(metadata map[string]any, filter map[string]any) bool {
if len(filter) == 0 {
return true
}
for key, wanted := range filter {
got, exists := metadata[key]
if !exists {
return false
}
if !equalAny(got, wanted) {
return false
}
}
return true
}
func equalAny(left any, right any) bool {
return toString(left) == toString(right)
}
func toString(v any) string {
if v == nil {
return ""
}
return fmtAny(v)
}
func fmtAny(v any) string {
return strings.TrimSpace(fmt.Sprintf("%v", v))
}

View File

@@ -0,0 +1,927 @@
package store
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/LoveLosita/smartflow/backend/services/rag/core"
)
// MilvusConfig 描述 Milvus REST 存储配置。
type MilvusConfig struct {
// Address 应指向 Milvus REST 入口。
// 当前项目联调验证使用 195309091 仅用于 health/metrics不承载本文实现所走的 REST API。
Address string
Token string
DBName string
CollectionName string
RequestTimeoutMS int
Dimension int
MetricType string
Logger *log.Logger
Observer core.Observer
}
// MilvusStore 是基于 Milvus REST API 的向量存储实现。
//
// 设计说明:
// 1. 本实现优先保证“项目内可接入、可管理、可灰度”,不强依赖额外 SDK
// 2. 通过固定字段 + metadata JSON 的方式兼顾过滤能力与元数据完整性;
// 3. collection 在首次写入时自动创建,避免启动期额外初始化脚本。
type MilvusStore struct {
cfg MilvusConfig
client *http.Client
observer core.Observer
mu sync.Mutex
ensured bool
}
const (
milvusPrimaryField = "id"
milvusVectorField = "vector"
milvusTextField = "text"
milvusMetadataField = "metadata"
milvusCorpusField = "corpus"
milvusDocumentField = "document_id"
milvusUserIDField = "user_id"
milvusAssistantField = "assistant_id"
milvusConvField = "conversation_id"
milvusRunField = "run_id"
milvusMemoryType = "memory_type"
milvusQueryIDField = "query_id"
milvusSessionField = "session_id"
milvusDomainField = "domain"
milvusChunkOrder = "chunk_order"
milvusUpdatedAtField = "updated_at"
)
var milvusFilterFieldMap = map[string]string{
"corpus": milvusCorpusField,
"document_id": milvusDocumentField,
"user_id": milvusUserIDField,
"assistant_id": milvusAssistantField,
"conversation_id": milvusConvField,
"run_id": milvusRunField,
"memory_type": milvusMemoryType,
"query_id": milvusQueryIDField,
"session_id": milvusSessionField,
"domain": milvusDomainField,
"chunk_order": milvusChunkOrder,
}
func NewMilvusStore(cfg MilvusConfig) (*MilvusStore, error) {
cfg.Address = strings.TrimRight(strings.TrimSpace(cfg.Address), "/")
if cfg.Address == "" {
return nil, errors.New("milvus address is empty")
}
if cfg.CollectionName == "" {
cfg.CollectionName = "smartflow_rag_chunks"
}
if cfg.MetricType == "" {
cfg.MetricType = "COSINE"
}
if cfg.RequestTimeoutMS <= 0 {
cfg.RequestTimeoutMS = 1500
}
if cfg.Logger == nil {
cfg.Logger = log.Default()
}
if cfg.Observer == nil {
cfg.Observer = core.NewLoggerObserver(cfg.Logger)
}
return &MilvusStore{
cfg: cfg,
client: &http.Client{Timeout: time.Duration(cfg.RequestTimeoutMS) * time.Millisecond},
observer: cfg.Observer,
}, nil
}
func (s *MilvusStore) Upsert(ctx context.Context, rows []core.VectorRow) error {
if err := s.ensureReady(); err != nil {
return err
}
start := time.Now()
if len(rows) == 0 {
return nil
}
if err := s.ensureCollection(ctx, len(rows[0].Vector)); err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "upsert",
Fields: map[string]any{
"status": "failed",
"row_count": len(rows),
"vector_dim": len(rows[0].Vector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return err
}
data := make([]map[string]any, 0, len(rows))
for _, row := range rows {
item := mapRowToMilvusEntity(row)
data = append(data, item)
}
_, err := s.postJSON(ctx, "/v2/vectordb/entities/upsert", map[string]any{
"collectionName": s.cfg.CollectionName,
"data": data,
"dbName": blankToNil(s.cfg.DBName),
})
if err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "upsert",
Fields: map[string]any{
"status": "failed",
"row_count": len(rows),
"vector_dim": len(rows[0].Vector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return err
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "upsert",
Fields: map[string]any{
"status": "success",
"row_count": len(rows),
"vector_dim": len(rows[0].Vector),
"latency_ms": time.Since(start).Milliseconds(),
},
})
return err
}
func (s *MilvusStore) Search(ctx context.Context, req core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
if err := s.ensureReady(); err != nil {
return nil, err
}
start := time.Now()
if len(req.QueryVector) == 0 {
return nil, nil
}
if err := s.ensureCollection(ctx, len(req.QueryVector)); err != nil {
if isMilvusCollectionMissing(err) {
return nil, nil
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "failed",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
filterExpr, err := buildMilvusFilter(req.Filter)
if err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "failed",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
body := map[string]any{
"collectionName": s.cfg.CollectionName,
"data": [][]float32{req.QueryVector},
"annsField": milvusVectorField,
"limit": normalizeMilvusTopK(req.TopK),
"outputFields": milvusOutputFields(false),
}
if filterExpr != "" {
body["filter"] = filterExpr
}
if s.cfg.DBName != "" {
body["dbName"] = s.cfg.DBName
}
respBody, err := s.postJSON(ctx, "/v2/vectordb/entities/search", body)
if err != nil {
if isMilvusCollectionMissing(err) {
return nil, nil
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "failed",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
var resp milvusSearchResponse
if err = json.Unmarshal(respBody, &resp); err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "failed",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
if resp.Code != 0 && resp.Code != 200 {
err = fmt.Errorf("milvus search failed: code=%d message=%s", resp.Code, resp.Message)
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "failed",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
result := make([]core.ScoredVectorRow, 0, len(resp.Data))
for _, item := range resp.Data {
row, score := item.toVectorRow()
result = append(result, core.ScoredVectorRow{
Row: row,
Score: score,
})
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "search",
Fields: map[string]any{
"status": "success",
"top_k": req.TopK,
"filter_count": len(req.Filter),
"vector_dim": len(req.QueryVector),
"result_count": len(result),
"latency_ms": time.Since(start).Milliseconds(),
},
})
return result, nil
}
func (s *MilvusStore) Delete(ctx context.Context, ids []string) error {
if err := s.ensureReady(); err != nil {
return err
}
start := time.Now()
if len(ids) == 0 {
return nil
}
filter := fmt.Sprintf(`%s in [%s]`, milvusPrimaryField, joinQuotedStrings(ids))
_, err := s.postJSON(ctx, "/v2/vectordb/entities/delete", map[string]any{
"collectionName": s.cfg.CollectionName,
"filter": filter,
"dbName": blankToNil(s.cfg.DBName),
})
if isMilvusCollectionMissing(err) {
return nil
}
if err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "delete",
Fields: map[string]any{
"status": "failed",
"id_count": len(ids),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return err
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "delete",
Fields: map[string]any{
"status": "success",
"id_count": len(ids),
"latency_ms": time.Since(start).Milliseconds(),
},
})
return err
}
func (s *MilvusStore) Get(ctx context.Context, ids []string) ([]core.VectorRow, error) {
if err := s.ensureReady(); err != nil {
return nil, err
}
start := time.Now()
if len(ids) == 0 {
return nil, nil
}
respBody, err := s.postJSON(ctx, "/v2/vectordb/entities/get", map[string]any{
"collectionName": s.cfg.CollectionName,
"id": ids,
"outputFields": milvusOutputFields(true),
"dbName": blankToNil(s.cfg.DBName),
})
if err != nil {
if isMilvusCollectionMissing(err) {
return nil, nil
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "get",
Fields: map[string]any{
"status": "failed",
"id_count": len(ids),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
var resp milvusGetResponse
if err = json.Unmarshal(respBody, &resp); err != nil {
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "get",
Fields: map[string]any{
"status": "failed",
"id_count": len(ids),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
if resp.Code != 0 && resp.Code != 200 {
err = fmt.Errorf("milvus get failed: code=%d message=%s", resp.Code, resp.Message)
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "get",
Fields: map[string]any{
"status": "failed",
"id_count": len(ids),
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return nil, err
}
rows := make([]core.VectorRow, 0, len(resp.Data))
for _, item := range resp.Data {
rows = append(rows, mapMilvusRow(item, true))
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "get",
Fields: map[string]any{
"status": "success",
"id_count": len(ids),
"row_count": len(rows),
"latency_ms": time.Since(start).Milliseconds(),
},
})
return rows, nil
}
func (s *MilvusStore) ensureCollection(ctx context.Context, dimension int) error {
if err := s.ensureReady(); err != nil {
return err
}
start := time.Now()
if dimension <= 0 {
dimension = s.cfg.Dimension
}
if dimension <= 0 {
return errors.New("milvus vector dimension is invalid")
}
s.mu.Lock()
defer s.mu.Unlock()
if s.ensured {
return nil
}
payload := map[string]any{
"collectionName": s.cfg.CollectionName,
"schema": map[string]any{
"autoId": false,
"enabledDynamicField": false,
"fields": []map[string]any{
buildVarcharField(milvusPrimaryField, true, 256),
buildVectorField(milvusVectorField, dimension),
buildVarcharField(milvusTextField, false, 65535),
{"fieldName": milvusMetadataField, "dataType": "JSON"},
buildVarcharField(milvusCorpusField, false, 64),
buildVarcharField(milvusDocumentField, false, 256),
{"fieldName": milvusUserIDField, "dataType": "Int64"},
buildVarcharField(milvusAssistantField, false, 128),
buildVarcharField(milvusConvField, false, 128),
buildVarcharField(milvusRunField, false, 128),
buildVarcharField(milvusMemoryType, false, 64),
buildVarcharField(milvusQueryIDField, false, 128),
buildVarcharField(milvusSessionField, false, 128),
buildVarcharField(milvusDomainField, false, 128),
{"fieldName": milvusChunkOrder, "dataType": "Int64"},
{"fieldName": milvusUpdatedAtField, "dataType": "Int64"},
},
},
"indexParams": []map[string]any{
{
"fieldName": milvusVectorField,
"indexName": milvusVectorField,
"metricType": s.cfg.MetricType,
"indexType": "AUTOINDEX",
},
},
}
if s.cfg.DBName != "" {
payload["dbName"] = s.cfg.DBName
}
_, err := s.postJSON(ctx, "/v2/vectordb/collections/create", payload)
if err != nil {
if isMilvusAlreadyExists(err) {
s.ensured = true
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "ensure_collection",
Fields: map[string]any{
"status": "already_exists",
"vector_dim": dimension,
"metric_type": s.cfg.MetricType,
"latency_ms": time.Since(start).Milliseconds(),
},
})
return nil
}
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelError,
Component: "store",
Operation: "ensure_collection",
Fields: map[string]any{
"status": "failed",
"vector_dim": dimension,
"metric_type": s.cfg.MetricType,
"latency_ms": time.Since(start).Milliseconds(),
"error": err,
"error_code": core.ClassifyErrorCode(err),
},
})
return err
}
s.ensured = true
s.observe(ctx, core.ObserveEvent{
Level: core.ObserveLevelInfo,
Component: "store",
Operation: "ensure_collection",
Fields: map[string]any{
"status": "created",
"vector_dim": dimension,
"metric_type": s.cfg.MetricType,
"latency_ms": time.Since(start).Milliseconds(),
},
})
return nil
}
func (s *MilvusStore) postJSON(ctx context.Context, path string, payload map[string]any) ([]byte, error) {
if err := s.ensureReady(); err != nil {
return nil, err
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.cfg.Address+path, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if token := strings.TrimSpace(s.cfg.Token); token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
resp, err := s.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, readErr
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("milvus http failed: status=%d body=%s", resp.StatusCode, string(respBody))
}
var basic milvusBasicResponse
if jsonErr := json.Unmarshal(respBody, &basic); jsonErr == nil {
if basic.Code != 0 && basic.Code != 200 {
return nil, fmt.Errorf("milvus api failed: code=%d message=%s", basic.Code, basic.Message)
}
}
return respBody, nil
}
func (s *MilvusStore) ensureReady() error {
if s == nil || s.client == nil {
return errors.New("milvus store is not initialized")
}
if strings.TrimSpace(s.cfg.Address) == "" {
return errors.New("milvus address is empty")
}
if strings.TrimSpace(s.cfg.CollectionName) == "" {
return errors.New("milvus collection name is empty")
}
return nil
}
func (s *MilvusStore) observe(ctx context.Context, event core.ObserveEvent) {
if s == nil || s.observer == nil {
return
}
fields := cloneMap(event.Fields)
fields["store"] = "milvus"
fields["collection"] = s.cfg.CollectionName
if strings.TrimSpace(s.cfg.DBName) != "" {
fields["db_name"] = s.cfg.DBName
}
s.observer.Observe(ctx, core.ObserveEvent{
Level: event.Level,
Component: event.Component,
Operation: event.Operation,
Fields: fields,
})
}
func mapRowToMilvusEntity(row core.VectorRow) map[string]any {
metadata := cloneMap(row.Metadata)
entity := map[string]any{
milvusPrimaryField: row.ID,
milvusVectorField: row.Vector,
milvusTextField: row.Text,
milvusMetadataField: metadata,
milvusCorpusField: asString(metadata["corpus"]),
milvusDocumentField: asString(metadata["document_id"]),
milvusUpdatedAtField: func() int64 {
if row.UpdatedAt.IsZero() {
return time.Now().UnixMilli()
}
return row.UpdatedAt.UnixMilli()
}(),
}
assignMilvusScalar(entity, milvusUserIDField, metadata["user_id"])
assignMilvusScalar(entity, milvusAssistantField, metadata["assistant_id"])
assignMilvusScalar(entity, milvusConvField, metadata["conversation_id"])
assignMilvusScalar(entity, milvusRunField, metadata["run_id"])
assignMilvusScalar(entity, milvusMemoryType, metadata["memory_type"])
assignMilvusScalar(entity, milvusQueryIDField, metadata["query_id"])
assignMilvusScalar(entity, milvusSessionField, metadata["session_id"])
assignMilvusScalar(entity, milvusDomainField, metadata["domain"])
assignMilvusScalar(entity, milvusChunkOrder, metadata["chunk_order"])
return entity
}
func assignMilvusScalar(target map[string]any, field string, value any) {
if value == nil {
return
}
switch field {
case milvusUserIDField, milvusChunkOrder:
if parsed, ok := toInt64(value); ok {
target[field] = parsed
}
default:
if text := asString(value); text != "" {
target[field] = text
}
}
}
func buildMilvusFilter(filter map[string]any) (string, error) {
if len(filter) == 0 {
return "", nil
}
parts := make([]string, 0, len(filter))
for key, value := range filter {
field, ok := milvusFilterFieldMap[key]
if !ok {
return "", fmt.Errorf("unsupported milvus filter key: %s", key)
}
switch field {
case milvusUserIDField, milvusChunkOrder:
parsed, parseOK := toInt64(value)
if !parseOK {
return "", fmt.Errorf("milvus filter key=%s expects integer", key)
}
parts = append(parts, fmt.Sprintf("%s == %d", field, parsed))
default:
text := escapeMilvusString(asString(value))
parts = append(parts, fmt.Sprintf(`%s == "%s"`, field, text))
}
}
return strings.Join(parts, " and "), nil
}
func buildVarcharField(name string, isPrimary bool, maxLength int) map[string]any {
field := map[string]any{
"fieldName": name,
"dataType": "VarChar",
"elementTypeParams": map[string]any{"max_length": maxLength},
}
if isPrimary {
field["isPrimary"] = true
}
return field
}
func buildVectorField(name string, dimension int) map[string]any {
return map[string]any{
"fieldName": name,
"dataType": "FloatVector",
"elementTypeParams": map[string]any{"dim": dimension},
}
}
func milvusOutputFields(includeVector bool) []string {
fields := []string{
milvusTextField,
milvusMetadataField,
milvusCorpusField,
milvusDocumentField,
milvusUserIDField,
milvusAssistantField,
milvusConvField,
milvusRunField,
milvusMemoryType,
milvusQueryIDField,
milvusSessionField,
milvusDomainField,
milvusChunkOrder,
milvusUpdatedAtField,
}
if includeVector {
fields = append(fields, milvusVectorField)
}
return fields
}
func normalizeMilvusTopK(topK int) int {
if topK <= 0 {
return 8
}
return topK
}
func blankToNil(v string) any {
v = strings.TrimSpace(v)
if v == "" {
return nil
}
return v
}
func escapeMilvusString(v string) string {
v = strings.ReplaceAll(v, `\`, `\\`)
return strings.ReplaceAll(v, `"`, `\"`)
}
func joinQuotedStrings(values []string) string {
parts := make([]string, 0, len(values))
for _, value := range values {
parts = append(parts, fmt.Sprintf(`"%s"`, escapeMilvusString(value)))
}
return strings.Join(parts, ",")
}
func cloneMap(src map[string]any) map[string]any {
if len(src) == 0 {
return map[string]any{}
}
dst := make(map[string]any, len(src))
for key, value := range src {
dst[key] = value
}
return dst
}
func asString(v any) string {
if v == nil {
return ""
}
return strings.TrimSpace(fmt.Sprintf("%v", v))
}
func toInt64(v any) (int64, bool) {
switch value := v.(type) {
case int:
return int64(value), true
case int32:
return int64(value), true
case int64:
return value, true
case float64:
return int64(value), true
case json.Number:
parsed, err := value.Int64()
return parsed, err == nil
case string:
parsed, err := strconv.ParseInt(strings.TrimSpace(value), 10, 64)
return parsed, err == nil
default:
return 0, false
}
}
func isMilvusAlreadyExists(err error) bool {
if err == nil {
return false
}
text := strings.ToLower(err.Error())
return strings.Contains(text, "already exist") ||
strings.Contains(text, "already exists") ||
strings.Contains(text, "duplicate collection")
}
func isMilvusCollectionMissing(err error) bool {
if err == nil {
return false
}
text := strings.ToLower(err.Error())
return strings.Contains(text, "can't find collection") || strings.Contains(text, "collection not found")
}
type milvusBasicResponse struct {
Code int `json:"code"`
Message string `json:"message"`
}
type milvusSearchResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data []milvusSearchItem `json:"data"`
}
type milvusSearchItem map[string]any
func (m milvusSearchItem) toVectorRow() (core.VectorRow, float64) {
row := mapMilvusRow(map[string]any(m), false)
score := 0.0
if value, ok := m["distance"].(float64); ok {
score = value
}
return row, score
}
type milvusGetResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data []map[string]any `json:"data"`
}
func mapMilvusRow(raw map[string]any, includeVector bool) core.VectorRow {
metadata := cloneMap(readMetadataMap(raw[milvusMetadataField]))
assignMetadataIfPresent(metadata, "corpus", raw[milvusCorpusField])
assignMetadataIfPresent(metadata, "document_id", raw[milvusDocumentField])
assignMetadataIfPresent(metadata, "user_id", raw[milvusUserIDField])
assignMetadataIfPresent(metadata, "assistant_id", raw[milvusAssistantField])
assignMetadataIfPresent(metadata, "conversation_id", raw[milvusConvField])
assignMetadataIfPresent(metadata, "run_id", raw[milvusRunField])
assignMetadataIfPresent(metadata, "memory_type", raw[milvusMemoryType])
assignMetadataIfPresent(metadata, "query_id", raw[milvusQueryIDField])
assignMetadataIfPresent(metadata, "session_id", raw[milvusSessionField])
assignMetadataIfPresent(metadata, "domain", raw[milvusDomainField])
assignMetadataIfPresent(metadata, "chunk_order", raw[milvusChunkOrder])
row := core.VectorRow{
ID: asString(raw[milvusPrimaryField]),
Text: asString(raw[milvusTextField]),
Metadata: metadata,
}
if row.ID == "" {
row.ID = asString(raw["id"])
}
if includeVector {
row.Vector = readFloat32Vector(raw[milvusVectorField])
}
return row
}
func readMetadataMap(value any) map[string]any {
switch data := value.(type) {
case map[string]any:
return data
default:
return map[string]any{}
}
}
func readFloat32Vector(value any) []float32 {
switch vector := value.(type) {
case []float32:
return vector
case []any:
result := make([]float32, 0, len(vector))
for _, item := range vector {
switch number := item.(type) {
case float64:
result = append(result, float32(number))
case float32:
result = append(result, number)
}
}
return result
default:
return nil
}
}
func assignMetadataIfPresent(target map[string]any, key string, value any) {
if value == nil {
return
}
switch typed := value.(type) {
case string:
if strings.TrimSpace(typed) == "" {
return
}
target[key] = strings.TrimSpace(typed)
default:
target[key] = typed
}
}

View File

@@ -0,0 +1,9 @@
package store
import "github.com/LoveLosita/smartflow/backend/services/rag/core"
// EnsureCompile 用于静态校验实现是否满足接口。
func EnsureCompile() {
var _ core.VectorStore = (*InMemoryVectorStore)(nil)
var _ core.VectorStore = (*MilvusStore)(nil)
}