Version: 0.9.65.dev.260503
后端: 1. 阶段 1.5/1.6 收口 llm-service / rag-service,统一模型出口与检索基础设施入口,清退 backend/infra/llm 与 backend/infra/rag 旧实现; 2. 同步更新相关调用链与微服务迁移计划文档
This commit is contained in:
73
backend/services/llm/ark.go
Normal file
73
backend/services/llm/ark.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
einoModel "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
// ArkCallOptions 是直接调用 ark.ChatModel 时使用的通用入参。
|
||||
type ArkCallOptions struct {
|
||||
Temperature float64
|
||||
MaxTokens int
|
||||
Thinking ThinkingMode
|
||||
}
|
||||
|
||||
// CallArkText 调用 ark 模型并返回纯文本。
|
||||
func CallArkText(ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, options ArkCallOptions) (string, error) {
|
||||
if chatModel == nil {
|
||||
return "", errors.New("ark model is nil")
|
||||
}
|
||||
|
||||
messages := []*schema.Message{
|
||||
schema.SystemMessage(systemPrompt),
|
||||
schema.UserMessage(userPrompt),
|
||||
}
|
||||
resp, err := chatModel.Generate(ctx, messages, buildArkOptions(options)...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp == nil {
|
||||
return "", errors.New("模型返回为空")
|
||||
}
|
||||
|
||||
text := strings.TrimSpace(resp.Content)
|
||||
if text == "" {
|
||||
return "", errors.New("模型返回内容为空")
|
||||
}
|
||||
return text, nil
|
||||
}
|
||||
|
||||
// CallArkJSON 调用 ark 模型并直接解析 JSON。
|
||||
func CallArkJSON[T any](ctx context.Context, chatModel *ark.ChatModel, systemPrompt, userPrompt string, options ArkCallOptions) (*T, string, error) {
|
||||
raw, err := CallArkText(ctx, chatModel, systemPrompt, userPrompt, options)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
parsed, err := ParseJSONObject[T](raw)
|
||||
if err != nil {
|
||||
return nil, raw, err
|
||||
}
|
||||
return parsed, raw, nil
|
||||
}
|
||||
|
||||
func buildArkOptions(options ArkCallOptions) []einoModel.Option {
|
||||
thinkingType := arkModel.ThinkingTypeDisabled
|
||||
if options.Thinking == ThinkingModeEnabled {
|
||||
thinkingType = arkModel.ThinkingTypeEnabled
|
||||
}
|
||||
|
||||
opts := []einoModel.Option{
|
||||
ark.WithThinking(&arkModel.Thinking{Type: thinkingType}),
|
||||
einoModel.WithTemperature(float32(options.Temperature)),
|
||||
}
|
||||
if options.MaxTokens > 0 {
|
||||
opts = append(opts, einoModel.WithMaxTokens(options.MaxTokens))
|
||||
}
|
||||
return opts
|
||||
}
|
||||
111
backend/services/llm/ark_adapter.go
Normal file
111
backend/services/llm/ark_adapter.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
einoModel "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
// WrapArkClient 将 ark.ChatModel 适配为统一 Client。
|
||||
// 1. generateText 走 Generate,供 GenerateJSON/GenerateText 使用。
|
||||
// 2. streamText 走 Stream,供需要流式输出的场景使用。
|
||||
// 3. 两条路径共用同一套参数转换逻辑。
|
||||
func WrapArkClient(arkChatModel *ark.ChatModel) *Client {
|
||||
if arkChatModel == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
generateFunc := func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error) {
|
||||
arkOpts := buildArkStreamOptions(options)
|
||||
msg, err := arkChatModel.Generate(ctx, messages, arkOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msg == nil {
|
||||
return nil, errors.New("ark model returned nil message")
|
||||
}
|
||||
|
||||
var usage *schema.TokenUsage
|
||||
finishReason := ""
|
||||
if msg.ResponseMeta != nil {
|
||||
usage = CloneUsage(msg.ResponseMeta.Usage)
|
||||
finishReason = msg.ResponseMeta.FinishReason
|
||||
}
|
||||
|
||||
return &TextResult{
|
||||
Text: msg.Content,
|
||||
Usage: usage,
|
||||
FinishReason: finishReason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
streamFunc := func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error) {
|
||||
arkOpts := buildArkStreamOptions(options)
|
||||
reader, err := arkChatModel.Stream(ctx, messages, arkOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &arkStreamReaderAdapter{reader: reader}, nil
|
||||
}
|
||||
|
||||
return NewClient(generateFunc, streamFunc)
|
||||
}
|
||||
|
||||
// buildArkStreamOptions 将统一的 GenerateOptions 转换为 ark 的流式调用参数。
|
||||
func buildArkStreamOptions(options GenerateOptions) []einoModel.Option {
|
||||
thinkingEnabled := options.Thinking == ThinkingModeEnabled
|
||||
|
||||
thinkingType := arkModel.ThinkingTypeDisabled
|
||||
if thinkingEnabled {
|
||||
thinkingType = arkModel.ThinkingTypeEnabled
|
||||
}
|
||||
opts := []einoModel.Option{
|
||||
ark.WithThinking(&arkModel.Thinking{Type: thinkingType}),
|
||||
}
|
||||
|
||||
if thinkingEnabled {
|
||||
opts = append(opts, einoModel.WithTemperature(1.0))
|
||||
} else if options.Temperature > 0 {
|
||||
opts = append(opts, einoModel.WithTemperature(float32(options.Temperature)))
|
||||
}
|
||||
|
||||
maxTokens := options.MaxTokens
|
||||
if thinkingEnabled {
|
||||
const minThinkingBudget = 16000
|
||||
if maxTokens < minThinkingBudget {
|
||||
maxTokens = minThinkingBudget
|
||||
}
|
||||
}
|
||||
if maxTokens > 0 {
|
||||
opts = append(opts, einoModel.WithMaxTokens(maxTokens))
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
// arkStreamReaderAdapter 把 ark 的流式 reader 转成统一的 StreamReader 接口。
|
||||
type arkStreamReaderAdapter struct {
|
||||
reader *schema.StreamReader[*schema.Message]
|
||||
}
|
||||
|
||||
// Recv 转发到底层 reader。
|
||||
func (r *arkStreamReaderAdapter) Recv() (*schema.Message, error) {
|
||||
if r == nil || r.reader == nil {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return r.reader.Recv()
|
||||
}
|
||||
|
||||
// Close 适配 ark reader 的 Close 行为。
|
||||
func (r *arkStreamReaderAdapter) Close() error {
|
||||
if r == nil || r.reader == nil {
|
||||
return nil
|
||||
}
|
||||
r.reader.Close()
|
||||
return nil
|
||||
}
|
||||
330
backend/services/llm/ark_responses_client.go
Normal file
330
backend/services/llm/ark_responses_client.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses"
|
||||
)
|
||||
|
||||
// ArkResponsesMessage 描述一次 Responses 输入消息。
|
||||
type ArkResponsesMessage struct {
|
||||
Role string
|
||||
Text string
|
||||
ImageURL string
|
||||
ImageDetail string
|
||||
}
|
||||
|
||||
// ArkResponsesOptions 描述 Responses 调用参数。
|
||||
type ArkResponsesOptions struct {
|
||||
Model string
|
||||
Temperature float64
|
||||
MaxOutputTokens int
|
||||
Thinking ThinkingMode
|
||||
TextFormat string
|
||||
}
|
||||
|
||||
// ArkResponsesUsage 统一转写 token usage。
|
||||
type ArkResponsesUsage struct {
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
TotalTokens int64
|
||||
}
|
||||
|
||||
// ArkResponsesResult 是 Responses 调用的统一输出结构。
|
||||
type ArkResponsesResult struct {
|
||||
Text string
|
||||
Status string
|
||||
IncompleteReason string
|
||||
ErrorCode string
|
||||
ErrorMessage string
|
||||
Usage *ArkResponsesUsage
|
||||
}
|
||||
|
||||
// ArkResponsesClient 是 Ark SDK Responses 的统一模型出口。
|
||||
type ArkResponsesClient struct {
|
||||
model string
|
||||
client *arkruntime.Client
|
||||
}
|
||||
|
||||
// NewArkResponsesClient 创建 Ark SDK Responses 客户端。
|
||||
// 1. model 为空时直接返回 nil,表示这条能力没有启用。
|
||||
// 2. baseURL 为空时使用 SDK 默认地址。
|
||||
// 3. 这里只负责本地构造,不做连通性探测。
|
||||
func NewArkResponsesClient(apiKey string, baseURL string, model string) *ArkResponsesClient {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
options := make([]arkruntime.ConfigOption, 0, 1)
|
||||
if strings.TrimSpace(baseURL) != "" {
|
||||
options = append(options, arkruntime.WithBaseUrl(strings.TrimSpace(baseURL)))
|
||||
}
|
||||
|
||||
return &ArkResponsesClient{
|
||||
model: model,
|
||||
client: arkruntime.NewClientWithApiKey(strings.TrimSpace(apiKey), options...),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateText 执行一次非流式 Responses 调用并提取文本。
|
||||
func (c *ArkResponsesClient) GenerateText(ctx context.Context, messages []ArkResponsesMessage, options ArkResponsesOptions) (*ArkResponsesResult, error) {
|
||||
req, err := c.buildRequest(messages, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.client.CreateResponses(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := buildArkResponsesResult(resp)
|
||||
if result.Status == "failed" {
|
||||
if result.ErrorMessage != "" {
|
||||
return result, fmt.Errorf("ark responses failed: %s", result.ErrorMessage)
|
||||
}
|
||||
return result, errors.New("ark responses failed")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(result.Text) == "" {
|
||||
return result, FormatEmptyResponseError("ark_responses")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateArkResponsesJSON 先调用 Responses,再解析成 JSON 结构体。
|
||||
func GenerateArkResponsesJSON[T any](ctx context.Context, client *ArkResponsesClient, messages []ArkResponsesMessage, options ArkResponsesOptions) (*T, *ArkResponsesResult, error) {
|
||||
if client == nil {
|
||||
return nil, nil, errors.New("ark responses client is not ready")
|
||||
}
|
||||
|
||||
result, err := client.GenerateText(ctx, messages, options)
|
||||
if err != nil {
|
||||
return nil, result, err
|
||||
}
|
||||
|
||||
parsed, err := ParseJSONObject[T](result.Text)
|
||||
if err != nil {
|
||||
return nil, result, err
|
||||
}
|
||||
return parsed, result, nil
|
||||
}
|
||||
|
||||
func (c *ArkResponsesClient) buildRequest(messages []ArkResponsesMessage, options ArkResponsesOptions) (*responses.ResponsesRequest, error) {
|
||||
if c == nil || c.client == nil {
|
||||
return nil, errors.New("ark responses client is not ready")
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return nil, errors.New("ark responses messages is empty")
|
||||
}
|
||||
|
||||
modelName := strings.TrimSpace(options.Model)
|
||||
if modelName == "" {
|
||||
modelName = c.model
|
||||
}
|
||||
if modelName == "" {
|
||||
return nil, errors.New("ark responses model is empty")
|
||||
}
|
||||
|
||||
inputItems := make([]*responses.InputItem, 0, len(messages))
|
||||
for idx := range messages {
|
||||
item, err := buildInputItem(messages[idx])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build ark responses message[%d] failed: %w", idx, err)
|
||||
}
|
||||
inputItems = append(inputItems, item)
|
||||
}
|
||||
|
||||
request := &responses.ResponsesRequest{
|
||||
Model: modelName,
|
||||
Input: &responses.ResponsesInput{
|
||||
Union: &responses.ResponsesInput_ListValue{
|
||||
ListValue: &responses.InputItemList{ListValue: inputItems},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if options.Temperature > 0 {
|
||||
request.Temperature = float64Ptr(options.Temperature)
|
||||
}
|
||||
if options.MaxOutputTokens > 0 {
|
||||
request.MaxOutputTokens = int64Ptr(int64(options.MaxOutputTokens))
|
||||
}
|
||||
|
||||
switch options.Thinking {
|
||||
case ThinkingModeEnabled:
|
||||
thinkingType := responses.ThinkingType_enabled
|
||||
request.Thinking = &responses.ResponsesThinking{Type: &thinkingType}
|
||||
case ThinkingModeDisabled:
|
||||
thinkingType := responses.ThinkingType_disabled
|
||||
request.Thinking = &responses.ResponsesThinking{Type: &thinkingType}
|
||||
}
|
||||
|
||||
if textType, ok := parseTextType(options.TextFormat); ok {
|
||||
request.Text = &responses.ResponsesText{
|
||||
Format: &responses.TextFormat{
|
||||
Type: textType,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func buildInputItem(message ArkResponsesMessage) (*responses.InputItem, error) {
|
||||
role, ok := parseMessageRole(message.Role)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported message role: %s", strings.TrimSpace(message.Role))
|
||||
}
|
||||
|
||||
content := make([]*responses.ContentItem, 0, 2)
|
||||
if text := strings.TrimSpace(message.Text); text != "" {
|
||||
content = append(content, &responses.ContentItem{
|
||||
Union: &responses.ContentItem_Text{
|
||||
Text: &responses.ContentItemText{
|
||||
Type: responses.ContentItemType_input_text,
|
||||
Text: text,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if imageURL := strings.TrimSpace(message.ImageURL); imageURL != "" {
|
||||
image := &responses.ContentItemImage{
|
||||
Type: responses.ContentItemType_input_image,
|
||||
ImageUrl: stringPtr(imageURL),
|
||||
}
|
||||
if detail, ok := parseImageDetail(message.ImageDetail); ok {
|
||||
image.Detail = &detail
|
||||
}
|
||||
|
||||
content = append(content, &responses.ContentItem{
|
||||
Union: &responses.ContentItem_Image{
|
||||
Image: image,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if len(content) == 0 {
|
||||
return nil, errors.New("message content is empty")
|
||||
}
|
||||
|
||||
return &responses.InputItem{
|
||||
Union: &responses.InputItem_InputMessage{
|
||||
InputMessage: &responses.ItemInputMessage{
|
||||
Role: role,
|
||||
Content: content,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildArkResponsesResult(resp *responses.ResponseObject) *ArkResponsesResult {
|
||||
if resp == nil {
|
||||
return &ArkResponsesResult{}
|
||||
}
|
||||
|
||||
result := &ArkResponsesResult{
|
||||
Text: extractArkResponsesText(resp),
|
||||
Status: strings.TrimSpace(resp.GetStatus().String()),
|
||||
}
|
||||
|
||||
if details := resp.GetIncompleteDetails(); details != nil {
|
||||
result.IncompleteReason = strings.TrimSpace(details.GetReason())
|
||||
}
|
||||
|
||||
if responseErr := resp.GetError(); responseErr != nil {
|
||||
result.ErrorCode = strings.TrimSpace(responseErr.GetCode())
|
||||
result.ErrorMessage = strings.TrimSpace(responseErr.GetMessage())
|
||||
}
|
||||
|
||||
if usage := resp.GetUsage(); usage != nil {
|
||||
result.Usage = &ArkResponsesUsage{
|
||||
InputTokens: usage.GetInputTokens(),
|
||||
OutputTokens: usage.GetOutputTokens(),
|
||||
TotalTokens: usage.GetTotalTokens(),
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func extractArkResponsesText(resp *responses.ResponseObject) string {
|
||||
if resp == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
textParts := make([]string, 0, 2)
|
||||
for _, outputItem := range resp.GetOutput() {
|
||||
outputMessage := outputItem.GetOutputMessage()
|
||||
if outputMessage == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, contentItem := range outputMessage.GetContent() {
|
||||
text := strings.TrimSpace(contentItem.GetText().GetText())
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
textParts = append(textParts, text)
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(textParts, "\n"))
|
||||
}
|
||||
|
||||
func parseMessageRole(raw string) (responses.MessageRole_Enum, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "user":
|
||||
return responses.MessageRole_user, true
|
||||
case "system":
|
||||
return responses.MessageRole_system, true
|
||||
case "developer":
|
||||
return responses.MessageRole_developer, true
|
||||
case "assistant":
|
||||
return responses.MessageRole_assistant, true
|
||||
default:
|
||||
return responses.MessageRole_unspecified, false
|
||||
}
|
||||
}
|
||||
|
||||
func parseImageDetail(raw string) (responses.ContentItemImageDetail_Enum, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "high":
|
||||
return responses.ContentItemImageDetail_high, true
|
||||
case "low":
|
||||
return responses.ContentItemImageDetail_low, true
|
||||
case "auto":
|
||||
return responses.ContentItemImageDetail_auto, true
|
||||
default:
|
||||
return responses.ContentItemImageDetail_auto, false
|
||||
}
|
||||
}
|
||||
|
||||
func parseTextType(raw string) (responses.TextType_Enum, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "":
|
||||
return responses.TextType_unspecified, false
|
||||
case "text":
|
||||
return responses.TextType_text, true
|
||||
case "json_object":
|
||||
return responses.TextType_json_object, true
|
||||
default:
|
||||
return responses.TextType_unspecified, false
|
||||
}
|
||||
}
|
||||
|
||||
func stringPtr(value string) *string {
|
||||
return &value
|
||||
}
|
||||
|
||||
func float64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
func int64Ptr(value int64) *int64 {
|
||||
return &value
|
||||
}
|
||||
174
backend/services/llm/client.go
Normal file
174
backend/services/llm/client.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// ThinkingMode 描述这次模型调用对 thinking 的期望。
|
||||
type ThinkingMode string
|
||||
|
||||
const (
|
||||
ThinkingModeDefault ThinkingMode = "default"
|
||||
ThinkingModeEnabled ThinkingMode = "enabled"
|
||||
ThinkingModeDisabled ThinkingMode = "disabled"
|
||||
)
|
||||
|
||||
// GenerateOptions 统一收敛文本调用时最常见的公共参数。
|
||||
type GenerateOptions struct {
|
||||
Temperature float64
|
||||
MaxTokens int
|
||||
Thinking ThinkingMode
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// TextResult 保存一次文本生成的最终结果和 usage。
|
||||
// 1. Text 存放模型返回的纯文本。
|
||||
// 2. Usage 方便上层做统一统计。
|
||||
// 3. 这里不负责 JSON 解析,也不负责业务字段映射。
|
||||
type TextResult struct {
|
||||
Text string
|
||||
Usage *schema.TokenUsage
|
||||
FinishReason string
|
||||
}
|
||||
|
||||
// StreamReader 抽象可以逐块读取消息的流式返回器。
|
||||
type StreamReader interface {
|
||||
Recv() (*schema.Message, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// TextGenerateFunc 定义统一文本生成函数签名。
|
||||
type TextGenerateFunc func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error)
|
||||
|
||||
// StreamGenerateFunc 定义统一流式生成函数签名。
|
||||
type StreamGenerateFunc func(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error)
|
||||
|
||||
// Client 是统一模型客户端门面。
|
||||
// 1. 只做最小输入校验和空响应防御。
|
||||
// 2. 不负责 prompt 拼装,也不负责业务 fallback。
|
||||
// 3. 具体 provider 的细节由上层适配器收敛进来。
|
||||
type Client struct {
|
||||
generateText TextGenerateFunc
|
||||
streamText StreamGenerateFunc
|
||||
}
|
||||
|
||||
// NewClient 创建统一模型客户端。
|
||||
func NewClient(generateText TextGenerateFunc, streamText StreamGenerateFunc) *Client {
|
||||
return &Client{
|
||||
generateText: generateText,
|
||||
streamText: streamText,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateText 执行一次统一文本生成。
|
||||
func (c *Client) GenerateText(ctx context.Context, messages []*schema.Message, options GenerateOptions) (*TextResult, error) {
|
||||
if c == nil || c.generateText == nil {
|
||||
return nil, errors.New("llm client is not ready")
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return nil, errors.New("llm messages is empty")
|
||||
}
|
||||
|
||||
result, err := c.generateText(ctx, messages, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result == nil {
|
||||
return nil, errors.New("llm result is nil")
|
||||
}
|
||||
if strings.TrimSpace(result.Text) == "" {
|
||||
return nil, errors.New("llm returned empty text")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateJSON 先走统一文本生成,再走统一 JSON 解析。
|
||||
func GenerateJSON[T any](ctx context.Context, client *Client, messages []*schema.Message, options GenerateOptions) (*T, *TextResult, error) {
|
||||
result, err := client.GenerateText(ctx, messages, options)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
parsed, err := ParseJSONObject[T](result.Text)
|
||||
if err != nil {
|
||||
return nil, result, err
|
||||
}
|
||||
return parsed, result, nil
|
||||
}
|
||||
|
||||
// Stream 打开统一流式调用入口。
|
||||
func (c *Client) Stream(ctx context.Context, messages []*schema.Message, options GenerateOptions) (StreamReader, error) {
|
||||
if c == nil || c.streamText == nil {
|
||||
return nil, errors.New("llm stream client is not ready")
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return nil, errors.New("llm messages is empty")
|
||||
}
|
||||
return c.streamText(ctx, messages, options)
|
||||
}
|
||||
|
||||
// BuildSystemUserMessages 构造最常见的 system + history + user 消息列表。
|
||||
func BuildSystemUserMessages(systemPrompt string, history []*schema.Message, userPrompt string) []*schema.Message {
|
||||
messages := make([]*schema.Message, 0, len(history)+2)
|
||||
if strings.TrimSpace(systemPrompt) != "" {
|
||||
messages = append(messages, schema.SystemMessage(systemPrompt))
|
||||
}
|
||||
if len(history) > 0 {
|
||||
messages = append(messages, history...)
|
||||
}
|
||||
if strings.TrimSpace(userPrompt) != "" {
|
||||
messages = append(messages, schema.UserMessage(userPrompt))
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
// CloneUsage 深拷贝 token usage,避免后续累加时共享同一个指针。
|
||||
func CloneUsage(usage *schema.TokenUsage) *schema.TokenUsage {
|
||||
if usage == nil {
|
||||
return nil
|
||||
}
|
||||
copied := *usage
|
||||
return &copied
|
||||
}
|
||||
|
||||
// MergeUsage 合并两段 usage,取各字段更大的值作为累计结果。
|
||||
func MergeUsage(base *schema.TokenUsage, incoming *schema.TokenUsage) *schema.TokenUsage {
|
||||
if incoming == nil {
|
||||
return CloneUsage(base)
|
||||
}
|
||||
if base == nil {
|
||||
return CloneUsage(incoming)
|
||||
}
|
||||
|
||||
merged := *base
|
||||
if incoming.PromptTokens > merged.PromptTokens {
|
||||
merged.PromptTokens = incoming.PromptTokens
|
||||
}
|
||||
if incoming.CompletionTokens > merged.CompletionTokens {
|
||||
merged.CompletionTokens = incoming.CompletionTokens
|
||||
}
|
||||
if incoming.TotalTokens > merged.TotalTokens {
|
||||
merged.TotalTokens = incoming.TotalTokens
|
||||
}
|
||||
if incoming.PromptTokenDetails.CachedTokens > merged.PromptTokenDetails.CachedTokens {
|
||||
merged.PromptTokenDetails.CachedTokens = incoming.PromptTokenDetails.CachedTokens
|
||||
}
|
||||
if incoming.CompletionTokensDetails.ReasoningTokens > merged.CompletionTokensDetails.ReasoningTokens {
|
||||
merged.CompletionTokensDetails.ReasoningTokens = incoming.CompletionTokensDetails.ReasoningTokens
|
||||
}
|
||||
return &merged
|
||||
}
|
||||
|
||||
// FormatEmptyResponseError 统一模型空结果的错误文案。
|
||||
func FormatEmptyResponseError(scene string) error {
|
||||
scene = strings.TrimSpace(scene)
|
||||
if scene == "" {
|
||||
scene = "unknown"
|
||||
}
|
||||
return fmt.Errorf("模型在 %s 场景返回空结果", scene)
|
||||
}
|
||||
102
backend/services/llm/json.go
Normal file
102
backend/services/llm/json.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseJSONObject 解析模型返回内容中的 JSON 对象。
|
||||
// 1. 先剥离常见的 markdown 代码块包装。
|
||||
// 2. 再从混合文本里提取最外层 JSON 对象。
|
||||
// 3. 这里只负责结构解析,不负责字段合法性校验。
|
||||
func ParseJSONObject[T any](raw string) (*T, error) {
|
||||
clean := strings.TrimSpace(raw)
|
||||
if clean == "" {
|
||||
return nil, errors.New("模型返回为空,无法解析 JSON")
|
||||
}
|
||||
|
||||
objectText := ExtractJSONObject(clean)
|
||||
if objectText == "" {
|
||||
return nil, fmt.Errorf("模型返回中未找到 JSON 对象: %s", truncateForError(clean))
|
||||
}
|
||||
|
||||
var out T
|
||||
if err := json.Unmarshal([]byte(objectText), &out); err != nil {
|
||||
return nil, fmt.Errorf("JSON 解析失败: %w", err)
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
// ExtractJSONObject 从混合文本中提取第一个完整的 JSON 对象。
|
||||
func ExtractJSONObject(text string) string {
|
||||
clean := trimMarkdownCodeFence(strings.TrimSpace(text))
|
||||
if clean == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
start := strings.Index(clean, "{")
|
||||
if start < 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
depth := 0
|
||||
inString := false
|
||||
escaped := false
|
||||
for idx := start; idx < len(clean); idx++ {
|
||||
ch := clean[idx]
|
||||
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' && inString {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
inString = !inString
|
||||
continue
|
||||
}
|
||||
if inString {
|
||||
continue
|
||||
}
|
||||
|
||||
switch ch {
|
||||
case '{':
|
||||
depth++
|
||||
case '}':
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return clean[start : idx+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func trimMarkdownCodeFence(text string) string {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if !strings.HasPrefix(trimmed, "```") {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
lines := strings.Split(trimmed, "\n")
|
||||
if len(lines) == 0 {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
body := lines[1:]
|
||||
if len(body) > 0 && strings.TrimSpace(body[len(body)-1]) == "```" {
|
||||
body = body[:len(body)-1]
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(body, "\n"))
|
||||
}
|
||||
|
||||
func truncateForError(text string) string {
|
||||
if len(text) <= 160 {
|
||||
return text
|
||||
}
|
||||
return text[:160] + "..."
|
||||
}
|
||||
109
backend/services/llm/service.go
Normal file
109
backend/services/llm/service.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/inits"
|
||||
)
|
||||
|
||||
// Service 只负责统一暴露已经构造好的模型客户端,不负责 prompt 和业务编排。
|
||||
type Service struct {
|
||||
liteClient *Client
|
||||
proClient *Client
|
||||
maxClient *Client
|
||||
courseImageResponsesClient *ArkResponsesClient
|
||||
}
|
||||
|
||||
// Options 描述 llm-service 初始化时需要接管的启动期依赖。
|
||||
// 1. AIHub 仍然是当前进程内 Ark ChatModel 的来源,但服务层只保存统一 Client。
|
||||
// 2. CourseImageResponsesClient 允许外部预先注入,便于测试或特殊启动路径复用。
|
||||
// 3. 某个字段为空时不报错,直接保留 nil,交给上层继续走兼容降级。
|
||||
type Options struct {
|
||||
AIHub *inits.AIHub
|
||||
APIKey string
|
||||
BaseURL string
|
||||
CourseVisionModel string
|
||||
CourseImageResponsesClient *ArkResponsesClient
|
||||
}
|
||||
|
||||
// AgentModelClients 一次性暴露 newAgent 图常用的模型分配结果。
|
||||
type AgentModelClients struct {
|
||||
Chat *Client
|
||||
Plan *Client
|
||||
Execute *Client
|
||||
Deliver *Client
|
||||
Summary *Client
|
||||
}
|
||||
|
||||
// New 构造 llm-service。
|
||||
// 1. 不返回 error,是为了让上层继续按 nil 客户端做逐步降级。
|
||||
// 2. 只要 AIHub 已初始化,就把其中的 ChatModel 收敛成统一 Client。
|
||||
// 3. 课程图片解析客户端在这里统一构建,避免业务层直接依赖 Responses SDK。
|
||||
func New(opts Options) *Service {
|
||||
svc := &Service{}
|
||||
|
||||
if opts.AIHub != nil {
|
||||
svc.liteClient = WrapArkClient(opts.AIHub.Lite)
|
||||
svc.proClient = WrapArkClient(opts.AIHub.Pro)
|
||||
svc.maxClient = WrapArkClient(opts.AIHub.Max)
|
||||
}
|
||||
|
||||
if opts.CourseImageResponsesClient != nil {
|
||||
svc.courseImageResponsesClient = opts.CourseImageResponsesClient
|
||||
} else {
|
||||
apiKey := strings.TrimSpace(opts.APIKey)
|
||||
baseURL := strings.TrimSpace(opts.BaseURL)
|
||||
model := strings.TrimSpace(opts.CourseVisionModel)
|
||||
if apiKey != "" && model != "" {
|
||||
svc.courseImageResponsesClient = NewArkResponsesClient(apiKey, baseURL, model)
|
||||
}
|
||||
}
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
// LiteClient 返回低成本短输出模型客户端。
|
||||
func (s *Service) LiteClient() *Client {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.liteClient
|
||||
}
|
||||
|
||||
// ProClient 返回默认复杂对话模型客户端。
|
||||
func (s *Service) ProClient() *Client {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.proClient
|
||||
}
|
||||
|
||||
// MaxClient 返回深度推理模型客户端。
|
||||
func (s *Service) MaxClient() *Client {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.maxClient
|
||||
}
|
||||
|
||||
// CourseImageResponsesClient 返回课程图片解析所用的 Responses 客户端。
|
||||
func (s *Service) CourseImageResponsesClient() *ArkResponsesClient {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.courseImageResponsesClient
|
||||
}
|
||||
|
||||
// NewAgentModelClients 一次性返回 newAgent 图里常用的模型分配。
|
||||
func (s *Service) NewAgentModelClients() AgentModelClients {
|
||||
if s == nil {
|
||||
return AgentModelClients{}
|
||||
}
|
||||
return AgentModelClients{
|
||||
Chat: s.ProClient(),
|
||||
Plan: s.MaxClient(),
|
||||
Execute: s.MaxClient(),
|
||||
Deliver: s.ProClient(),
|
||||
Summary: s.LiteClient(),
|
||||
}
|
||||
}
|
||||
118
backend/services/rag/api.go
Normal file
118
backend/services/rag/api.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Runtime 是 RAG service 对业务侧暴露的唯一稳定方法面。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责承接 memory/web 两类语料的统一入库与检索入口;
|
||||
// 2. 负责屏蔽底层 Pipeline / Store / Embedder / Reranker 的装配细节;
|
||||
// 3. 不负责 provider 搜索、HTML 抓取、prompt 注入等业务语义。
|
||||
type Runtime interface {
|
||||
IngestMemory(ctx context.Context, req MemoryIngestRequest) (*IngestResult, error)
|
||||
RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (*RetrieveResult, error)
|
||||
DeleteMemory(ctx context.Context, documentIDs []string) error
|
||||
|
||||
IngestWeb(ctx context.Context, req WebIngestRequest) (*IngestResult, error)
|
||||
RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (*RetrieveResult, error)
|
||||
}
|
||||
|
||||
// IngestResult 描述一次统一入库执行摘要。
|
||||
type IngestResult struct {
|
||||
DocumentCount int
|
||||
ChunkCount int
|
||||
DocumentIDs []string
|
||||
}
|
||||
|
||||
// RetrieveHit 是对业务侧暴露的统一命中项。
|
||||
type RetrieveHit struct {
|
||||
ChunkID string
|
||||
DocumentID string
|
||||
Text string
|
||||
Score float64
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// RetrieveResult 描述一次检索执行摘要。
|
||||
type RetrieveResult struct {
|
||||
Items []RetrieveHit
|
||||
RawCount int
|
||||
FallbackUsed bool
|
||||
FallbackReason string
|
||||
}
|
||||
|
||||
// MemoryIngestItem 是 memory 语料入库项。
|
||||
type MemoryIngestItem struct {
|
||||
MemoryID int64
|
||||
UserID int
|
||||
ConversationID string
|
||||
AssistantID string
|
||||
RunID string
|
||||
MemoryType string
|
||||
Title string
|
||||
Content string
|
||||
Confidence float64
|
||||
Importance float64
|
||||
SensitivityLevel int
|
||||
IsExplicit bool
|
||||
Status string
|
||||
TTLAt *time.Time
|
||||
CreatedAt *time.Time
|
||||
}
|
||||
|
||||
// MemoryIngestRequest 描述一次记忆向量入库请求。
|
||||
type MemoryIngestRequest struct {
|
||||
TraceID string
|
||||
Action string
|
||||
Items []MemoryIngestItem
|
||||
}
|
||||
|
||||
// MemoryRetrieveRequest 描述一次记忆检索请求。
|
||||
type MemoryRetrieveRequest struct {
|
||||
TraceID string
|
||||
Query string
|
||||
TopK int
|
||||
Threshold float64
|
||||
Action string
|
||||
UserID int
|
||||
ConversationID string
|
||||
AssistantID string
|
||||
RunID string
|
||||
MemoryTypes []string
|
||||
}
|
||||
|
||||
// WebIngestItem 是网页语料入库项。
|
||||
type WebIngestItem struct {
|
||||
URL string
|
||||
Title string
|
||||
Content string
|
||||
Snippet string
|
||||
Domain string
|
||||
QueryID string
|
||||
SessionID string
|
||||
PublishedAt *time.Time
|
||||
FetchedAt *time.Time
|
||||
SourceRank int
|
||||
}
|
||||
|
||||
// WebIngestRequest 描述一次网页语料入库请求。
|
||||
type WebIngestRequest struct {
|
||||
TraceID string
|
||||
Action string
|
||||
Items []WebIngestItem
|
||||
}
|
||||
|
||||
// WebRetrieveRequest 描述一次网页检索请求。
|
||||
type WebRetrieveRequest struct {
|
||||
TraceID string
|
||||
Query string
|
||||
TopK int
|
||||
Threshold float64
|
||||
Action string
|
||||
QueryID string
|
||||
SessionID string
|
||||
Domain string
|
||||
}
|
||||
85
backend/services/rag/chunk/text_chunker.go
Normal file
85
backend/services/rag/chunk/text_chunker.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package chunk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
// TextChunker 是默认文本切块器。
|
||||
type TextChunker struct{}
|
||||
|
||||
func NewTextChunker() *TextChunker {
|
||||
return &TextChunker{}
|
||||
}
|
||||
|
||||
// Chunk 对文本执行固定窗口切块。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先做空白归一,避免无效块进入向量库;
|
||||
// 2. 再按 chunk_size/overlap 滑窗切割;
|
||||
// 3. 每块继承原文 metadata,并补充 chunk 序号。
|
||||
func (c *TextChunker) Chunk(_ context.Context, doc core.SourceDocument, opt core.ChunkOption) ([]core.Chunk, error) {
|
||||
if strings.TrimSpace(doc.ID) == "" {
|
||||
return nil, fmt.Errorf("empty document id")
|
||||
}
|
||||
text := strings.TrimSpace(doc.Text)
|
||||
if text == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if opt.ChunkSize <= 0 {
|
||||
opt.ChunkSize = 400
|
||||
}
|
||||
if opt.ChunkOverlap < 0 {
|
||||
opt.ChunkOverlap = 0
|
||||
}
|
||||
if opt.ChunkOverlap >= opt.ChunkSize {
|
||||
opt.ChunkOverlap = opt.ChunkSize / 5
|
||||
}
|
||||
|
||||
runes := []rune(text)
|
||||
step := opt.ChunkSize - opt.ChunkOverlap
|
||||
if step <= 0 {
|
||||
step = opt.ChunkSize
|
||||
}
|
||||
|
||||
result := make([]core.Chunk, 0, len(runes)/step+1)
|
||||
order := 0
|
||||
for start := 0; start < len(runes); start += step {
|
||||
end := start + opt.ChunkSize
|
||||
if end > len(runes) {
|
||||
end = len(runes)
|
||||
}
|
||||
chunkText := strings.TrimSpace(string(runes[start:end]))
|
||||
if chunkText == "" {
|
||||
continue
|
||||
}
|
||||
metadata := cloneMap(doc.Metadata)
|
||||
metadata["chunk_order"] = order
|
||||
result = append(result, core.Chunk{
|
||||
ID: fmt.Sprintf("%s#%d", doc.ID, order),
|
||||
DocumentID: doc.ID,
|
||||
Text: chunkText,
|
||||
Order: order,
|
||||
Metadata: metadata,
|
||||
})
|
||||
order++
|
||||
if end == len(runes) {
|
||||
break
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
113
backend/services/rag/config/config.go
Normal file
113
backend/services/rag/config/config.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package config
|
||||
|
||||
import "github.com/spf13/viper"
|
||||
|
||||
// Config 是 RAG Core 运行配置。
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
Store string
|
||||
TopK int
|
||||
|
||||
Threshold float64
|
||||
|
||||
EmbedProvider string
|
||||
EmbedModel string
|
||||
EmbedBaseURL string
|
||||
EmbedTimeoutMS int
|
||||
EmbedDimension int
|
||||
|
||||
RerankerEnabled bool
|
||||
RerankerProvider string
|
||||
RerankerTimeoutMS int
|
||||
|
||||
ChunkSize int
|
||||
ChunkOverlap int
|
||||
|
||||
RetrieveTimeoutMS int
|
||||
|
||||
MilvusAddress string
|
||||
MilvusToken string
|
||||
MilvusDBName string
|
||||
MilvusCollectionName string
|
||||
MilvusMetricType string
|
||||
MilvusRequestTimeoutMS int
|
||||
}
|
||||
|
||||
// LoadFromViper 读取 rag 配置并补默认值。
|
||||
func LoadFromViper() Config {
|
||||
cfg := Config{
|
||||
Enabled: viper.GetBool("rag.enabled"),
|
||||
Store: viper.GetString("rag.store"),
|
||||
TopK: viper.GetInt("rag.topK"),
|
||||
Threshold: viper.GetFloat64("rag.threshold"),
|
||||
EmbedProvider: viper.GetString("rag.embed.provider"),
|
||||
EmbedModel: viper.GetString("rag.embed.model"),
|
||||
EmbedBaseURL: viper.GetString("rag.embed.baseURL"),
|
||||
EmbedTimeoutMS: viper.GetInt("rag.embed.timeoutMs"),
|
||||
EmbedDimension: viper.GetInt("rag.embed.dimension"),
|
||||
RerankerEnabled: viper.GetBool("rag.reranker.enabled"),
|
||||
RerankerProvider: viper.GetString("rag.reranker.provider"),
|
||||
RerankerTimeoutMS: viper.GetInt("rag.reranker.timeoutMs"),
|
||||
ChunkSize: viper.GetInt("rag.ingest.chunkSize"),
|
||||
ChunkOverlap: viper.GetInt("rag.ingest.chunkOverlap"),
|
||||
RetrieveTimeoutMS: viper.GetInt("rag.retrieve.timeoutMs"),
|
||||
MilvusAddress: viper.GetString("rag.milvus.address"),
|
||||
MilvusToken: viper.GetString("rag.milvus.token"),
|
||||
MilvusDBName: viper.GetString("rag.milvus.dbName"),
|
||||
MilvusCollectionName: viper.GetString("rag.milvus.collectionName"),
|
||||
MilvusMetricType: viper.GetString("rag.milvus.metricType"),
|
||||
MilvusRequestTimeoutMS: viper.GetInt("rag.milvus.requestTimeoutMs"),
|
||||
}
|
||||
if cfg.Store == "" {
|
||||
cfg.Store = "inmemory"
|
||||
}
|
||||
if cfg.TopK <= 0 {
|
||||
cfg.TopK = 8
|
||||
}
|
||||
if cfg.Threshold < 0 {
|
||||
cfg.Threshold = 0
|
||||
}
|
||||
if cfg.EmbedProvider == "" {
|
||||
cfg.EmbedProvider = "mock"
|
||||
}
|
||||
if cfg.EmbedBaseURL == "" {
|
||||
cfg.EmbedBaseURL = viper.GetString("agent.baseURL")
|
||||
}
|
||||
if cfg.EmbedTimeoutMS <= 0 {
|
||||
cfg.EmbedTimeoutMS = 1200
|
||||
}
|
||||
if cfg.EmbedDimension <= 0 {
|
||||
cfg.EmbedDimension = 1024
|
||||
}
|
||||
if cfg.RerankerProvider == "" {
|
||||
cfg.RerankerProvider = "noop"
|
||||
}
|
||||
if cfg.RerankerTimeoutMS <= 0 {
|
||||
cfg.RerankerTimeoutMS = 1200
|
||||
}
|
||||
if cfg.ChunkSize <= 0 {
|
||||
cfg.ChunkSize = 400
|
||||
}
|
||||
if cfg.ChunkOverlap < 0 {
|
||||
cfg.ChunkOverlap = 80
|
||||
}
|
||||
if cfg.RetrieveTimeoutMS <= 0 {
|
||||
cfg.RetrieveTimeoutMS = 1500
|
||||
}
|
||||
if cfg.MilvusAddress == "" {
|
||||
cfg.MilvusAddress = "http://localhost:19530"
|
||||
}
|
||||
if cfg.MilvusToken == "" {
|
||||
cfg.MilvusToken = "root:Milvus"
|
||||
}
|
||||
if cfg.MilvusCollectionName == "" {
|
||||
cfg.MilvusCollectionName = "smartflow_rag_chunks"
|
||||
}
|
||||
if cfg.MilvusMetricType == "" {
|
||||
cfg.MilvusMetricType = "COSINE"
|
||||
}
|
||||
if cfg.MilvusRequestTimeoutMS <= 0 {
|
||||
cfg.MilvusRequestTimeoutMS = 1500
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
17
backend/services/rag/core/errors.go
Normal file
17
backend/services/rag/core/errors.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package core
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrInvalidQuery 表示检索请求缺少有效 query。
|
||||
ErrInvalidQuery = errors.New("invalid query")
|
||||
// ErrInvalidTopK 表示 topK 非法。
|
||||
ErrInvalidTopK = errors.New("invalid top_k")
|
||||
// ErrNilDependency 表示 pipeline 关键依赖未注入。
|
||||
ErrNilDependency = errors.New("nil dependency")
|
||||
)
|
||||
|
||||
const (
|
||||
// FallbackReasonRerankFailed 表示 rerank 失败后降级。
|
||||
FallbackReasonRerankFailed = "RERANK_FAILED"
|
||||
)
|
||||
38
backend/services/rag/core/interfaces.go
Normal file
38
backend/services/rag/core/interfaces.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package core
|
||||
|
||||
import "context"
|
||||
|
||||
// Chunker 负责文本切块。
|
||||
type Chunker interface {
|
||||
Chunk(ctx context.Context, doc SourceDocument, opt ChunkOption) ([]Chunk, error)
|
||||
}
|
||||
|
||||
// Embedder 负责向量化。
|
||||
type Embedder interface {
|
||||
Embed(ctx context.Context, texts []string, action string) ([][]float32, error)
|
||||
}
|
||||
|
||||
// Retriever 负责召回候选。
|
||||
type Retriever interface {
|
||||
Retrieve(ctx context.Context, req RetrieveRequest) ([]ScoredChunk, error)
|
||||
}
|
||||
|
||||
// Reranker 负责重排候选。
|
||||
type Reranker interface {
|
||||
Rerank(ctx context.Context, query string, candidates []ScoredChunk, topK int) ([]ScoredChunk, error)
|
||||
}
|
||||
|
||||
// VectorStore 负责向量库读写。
|
||||
type VectorStore interface {
|
||||
Upsert(ctx context.Context, rows []VectorRow) error
|
||||
Search(ctx context.Context, req VectorSearchRequest) ([]ScoredVectorRow, error)
|
||||
Delete(ctx context.Context, ids []string) error
|
||||
Get(ctx context.Context, ids []string) ([]VectorRow, error)
|
||||
}
|
||||
|
||||
// CorpusAdapter 负责把业务语料映射成统一文档/过滤条件。
|
||||
type CorpusAdapter interface {
|
||||
Name() string
|
||||
BuildIngestDocuments(ctx context.Context, input any) ([]SourceDocument, error)
|
||||
BuildRetrieveFilter(ctx context.Context, req any) (map[string]any, error)
|
||||
}
|
||||
190
backend/services/rag/core/observer.go
Normal file
190
backend/services/rag/core/observer.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ObserveLevel 表示观测事件等级。
|
||||
type ObserveLevel string
|
||||
|
||||
const (
|
||||
ObserveLevelInfo ObserveLevel = "info"
|
||||
ObserveLevelWarn ObserveLevel = "warn"
|
||||
ObserveLevelError ObserveLevel = "error"
|
||||
)
|
||||
|
||||
// ObserveEvent 描述一次统一观测事件。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只承载 RAG service 的结构化运行信息;
|
||||
// 2. 不绑定具体日志系统、指标系统或 tracing 实现;
|
||||
// 3. 字段内容应尽量稳定,便于后续统一接入全局观测平台。
|
||||
type ObserveEvent struct {
|
||||
Level ObserveLevel
|
||||
Component string
|
||||
Operation string
|
||||
Fields map[string]any
|
||||
}
|
||||
|
||||
// Observer 是 RAG service 的最小观测接口。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责消费结构化事件;
|
||||
// 2. 不负责决定业务逻辑是否继续执行;
|
||||
// 3. 任一实现都不应反向影响主链路稳定性。
|
||||
type Observer interface {
|
||||
Observe(ctx context.Context, event ObserveEvent)
|
||||
}
|
||||
|
||||
// ObserverFunc 允许用函数快速适配 Observer。
|
||||
type ObserverFunc func(ctx context.Context, event ObserveEvent)
|
||||
|
||||
func (f ObserverFunc) Observe(ctx context.Context, event ObserveEvent) {
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
f(ctx, event)
|
||||
}
|
||||
|
||||
// NewNopObserver 返回空实现,适合在未接入统一观测平台时兜底。
|
||||
func NewNopObserver() Observer {
|
||||
return ObserverFunc(func(context.Context, ObserveEvent) {})
|
||||
}
|
||||
|
||||
// NewLoggerObserver 返回标准日志适配器。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 当前项目尚未建立统一日志平台时,先把结构化字段稳定打印出来;
|
||||
// 2. 后续若项目引入统一 logger/metrics/tracing,只需替换该 Observer 注入实现;
|
||||
// 3. 该适配器默认保持单行输出,减少和现有日志风格的割裂感。
|
||||
func NewLoggerObserver(logger *log.Logger) Observer {
|
||||
if logger == nil {
|
||||
logger = log.Default()
|
||||
}
|
||||
return &loggerObserver{logger: logger}
|
||||
}
|
||||
|
||||
type loggerObserver struct {
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func (o *loggerObserver) Observe(ctx context.Context, event ObserveEvent) {
|
||||
if o == nil || o.logger == nil {
|
||||
return
|
||||
}
|
||||
|
||||
level := strings.TrimSpace(string(event.Level))
|
||||
if level == "" {
|
||||
level = string(ObserveLevelInfo)
|
||||
}
|
||||
component := strings.TrimSpace(event.Component)
|
||||
if component == "" {
|
||||
component = "unknown"
|
||||
}
|
||||
operation := strings.TrimSpace(event.Operation)
|
||||
if operation == "" {
|
||||
operation = "unknown"
|
||||
}
|
||||
|
||||
fields := ObserveFieldsFromContext(ctx)
|
||||
for key, value := range event.Fields {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" || !shouldKeepObserveField(value) {
|
||||
continue
|
||||
}
|
||||
fields[key] = value
|
||||
}
|
||||
|
||||
parts := []string{
|
||||
"rag",
|
||||
fmt.Sprintf("level=%s", level),
|
||||
fmt.Sprintf("component=%s", component),
|
||||
fmt.Sprintf("operation=%s", operation),
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(fields))
|
||||
for key := range fields {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, key := range keys {
|
||||
parts = append(parts, fmt.Sprintf("%s=%v", key, fields[key]))
|
||||
}
|
||||
|
||||
o.logger.Print(strings.Join(parts, " "))
|
||||
}
|
||||
|
||||
type observeFieldsContextKey struct{}
|
||||
|
||||
// WithObserveFields 把通用观测字段挂入上下文,便于下游组件复用。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先读取已有上下文字段,保证 Runtime / Pipeline / Store 能逐层补充信息;
|
||||
// 2. 后写字段覆盖同名旧值,确保下游拿到的是最新语义;
|
||||
// 3. 仅保存“有意义”的字段,避免日志长期堆积大量空值。
|
||||
func WithObserveFields(ctx context.Context, fields map[string]any) context.Context {
|
||||
if len(fields) == 0 {
|
||||
return ctx
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
merged := ObserveFieldsFromContext(ctx)
|
||||
for key, value := range fields {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" || !shouldKeepObserveField(value) {
|
||||
continue
|
||||
}
|
||||
merged[key] = value
|
||||
}
|
||||
if len(merged) == 0 {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, observeFieldsContextKey{}, merged)
|
||||
}
|
||||
|
||||
// ObserveFieldsFromContext 提取上下文中已经累积的观测字段。
|
||||
func ObserveFieldsFromContext(ctx context.Context) map[string]any {
|
||||
if ctx == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
raw, ok := ctx.Value(observeFieldsContextKey{}).(map[string]any)
|
||||
if !ok || len(raw) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
result := make(map[string]any, len(raw))
|
||||
for key, value := range raw {
|
||||
result[key] = value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ClassifyErrorCode 统一把常见错误压缩为稳定错误码,便于后续接入全局观测平台。
|
||||
func ClassifyErrorCode(err error) string {
|
||||
switch {
|
||||
case err == nil:
|
||||
return ""
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
return "DEADLINE_EXCEEDED"
|
||||
case errors.Is(err, context.Canceled):
|
||||
return "CANCELED"
|
||||
default:
|
||||
return "RAG_ERROR"
|
||||
}
|
||||
}
|
||||
|
||||
func shouldKeepObserveField(value any) bool {
|
||||
if value == nil {
|
||||
return false
|
||||
}
|
||||
if text, ok := value.(string); ok {
|
||||
return strings.TrimSpace(text) != ""
|
||||
}
|
||||
return true
|
||||
}
|
||||
366
backend/services/rag/core/pipeline.go
Normal file
366
backend/services/rag/core/pipeline.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTopK = 8
|
||||
defaultThreshold = 0
|
||||
defaultChunkSize = 400
|
||||
defaultChunkOvLap = 80
|
||||
)
|
||||
|
||||
// Pipeline 是 RAG Core 编排器。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责统一 chunk/embed/retrieve/rerank 流程;
|
||||
// 2. 负责失败降级语义;
|
||||
// 3. 不承载任何具体业务语义(由 CorpusAdapter 提供)。
|
||||
type Pipeline struct {
|
||||
chunker Chunker
|
||||
embedder Embedder
|
||||
store VectorStore
|
||||
reranker Reranker
|
||||
logger *log.Logger
|
||||
observer Observer
|
||||
}
|
||||
|
||||
func NewPipeline(chunker Chunker, embedder Embedder, store VectorStore, reranker Reranker) *Pipeline {
|
||||
return &Pipeline{
|
||||
chunker: chunker,
|
||||
embedder: embedder,
|
||||
store: store,
|
||||
reranker: reranker,
|
||||
logger: log.Default(),
|
||||
observer: NewNopObserver(),
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogger 设置 Pipeline 使用的日志器。
|
||||
func (p *Pipeline) SetLogger(logger *log.Logger) {
|
||||
if p == nil || logger == nil {
|
||||
return
|
||||
}
|
||||
p.logger = logger
|
||||
}
|
||||
|
||||
// SetObserver 设置 Pipeline 使用的统一观测器。
|
||||
func (p *Pipeline) SetObserver(observer Observer) {
|
||||
if p == nil || observer == nil {
|
||||
return
|
||||
}
|
||||
p.observer = observer
|
||||
}
|
||||
|
||||
// Ingest 执行统一入库流程。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先由 CorpusAdapter 生成统一文档,确保不同语料入口一致;
|
||||
// 2. 再统一切块与向量化,避免业务侧重复实现;
|
||||
// 3. 最后一次性 Upsert,失败直接返回,交由上层决定是否重试。
|
||||
func (p *Pipeline) Ingest(
|
||||
ctx context.Context,
|
||||
corpus CorpusAdapter,
|
||||
input any,
|
||||
opt IngestOption,
|
||||
) (result *IngestResult, err error) {
|
||||
defer p.recoverExecutionPanic(ctx, "ingest", &err)
|
||||
|
||||
if p == nil || p.chunker == nil || p.embedder == nil || p.store == nil {
|
||||
return nil, ErrNilDependency
|
||||
}
|
||||
if corpus == nil {
|
||||
return nil, errors.New("nil corpus adapter")
|
||||
}
|
||||
|
||||
docs, err := corpus.BuildIngestDocuments(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p.IngestDocuments(ctx, corpus.Name(), docs, opt)
|
||||
}
|
||||
|
||||
// IngestDocuments 执行“已标准化文档”的统一入库流程。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责处理已经完成 CorpusAdapter 映射的标准文档;
|
||||
// 2. 负责统一切块、向量化与 Upsert;
|
||||
// 3. 不负责再做业务输入解析,避免 Runtime 为拿到 document_id 重复 build 文档。
|
||||
func (p *Pipeline) IngestDocuments(
|
||||
ctx context.Context,
|
||||
corpusName string,
|
||||
docs []SourceDocument,
|
||||
opt IngestOption,
|
||||
) (result *IngestResult, err error) {
|
||||
defer p.recoverExecutionPanic(ctx, "ingest_documents", &err)
|
||||
|
||||
if p == nil || p.chunker == nil || p.embedder == nil || p.store == nil {
|
||||
return nil, ErrNilDependency
|
||||
}
|
||||
if len(docs) == 0 {
|
||||
return &IngestResult{DocumentCount: 0, ChunkCount: 0}, nil
|
||||
}
|
||||
|
||||
chunkOpt := normalizeChunkOption(opt.Chunk)
|
||||
chunks := make([]Chunk, 0, len(docs)*2)
|
||||
for _, doc := range docs {
|
||||
// 1. 对每个文档独立切块,失败直接中断,避免写入半成品。
|
||||
docChunks, chunkErr := p.chunker.Chunk(ctx, doc, chunkOpt)
|
||||
if chunkErr != nil {
|
||||
return nil, chunkErr
|
||||
}
|
||||
chunks = append(chunks, docChunks...)
|
||||
}
|
||||
if len(chunks) == 0 {
|
||||
return &IngestResult{DocumentCount: len(docs), ChunkCount: 0}, nil
|
||||
}
|
||||
|
||||
texts := make([]string, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
texts = append(texts, chunk.Text)
|
||||
}
|
||||
|
||||
action := strings.TrimSpace(opt.Action)
|
||||
if action == "" {
|
||||
action = "add"
|
||||
}
|
||||
vectors, err := p.embedder.Embed(ctx, texts, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vectors) != len(chunks) {
|
||||
return nil, fmt.Errorf("embedding result length mismatch: chunks=%d vectors=%d", len(chunks), len(vectors))
|
||||
}
|
||||
|
||||
rows := make([]VectorRow, 0, len(chunks))
|
||||
now := time.Now()
|
||||
for i, chunk := range chunks {
|
||||
metadata := cloneMap(chunk.Metadata)
|
||||
metadata["corpus"] = corpusName
|
||||
metadata["document_id"] = chunk.DocumentID
|
||||
metadata["chunk_order"] = chunk.Order
|
||||
rows = append(rows, VectorRow{
|
||||
ID: chunk.ID,
|
||||
Vector: vectors[i],
|
||||
Text: chunk.Text,
|
||||
Metadata: metadata,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
})
|
||||
}
|
||||
|
||||
if err = p.store.Upsert(ctx, rows); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &IngestResult{
|
||||
DocumentCount: len(docs),
|
||||
ChunkCount: len(chunks),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Retrieve 执行统一检索流程。
|
||||
//
|
||||
// 步骤化说明:
|
||||
// 1. 先做 query 向量化与向量检索;
|
||||
// 2. 再执行阈值过滤,减少低质量候选;
|
||||
// 3. 最后可选 rerank,若失败则降级回原排序并打日志。
|
||||
func (p *Pipeline) Retrieve(
|
||||
ctx context.Context,
|
||||
corpus CorpusAdapter,
|
||||
req RetrieveRequest,
|
||||
) (result *RetrieveResult, err error) {
|
||||
defer p.recoverExecutionPanic(ctx, "retrieve", &err)
|
||||
|
||||
if p == nil || p.embedder == nil || p.store == nil {
|
||||
return nil, ErrNilDependency
|
||||
}
|
||||
query := strings.TrimSpace(req.Query)
|
||||
if query == "" {
|
||||
return nil, ErrInvalidQuery
|
||||
}
|
||||
|
||||
topK := req.TopK
|
||||
if topK <= 0 {
|
||||
topK = defaultTopK
|
||||
}
|
||||
threshold := req.Threshold
|
||||
if threshold < 0 {
|
||||
threshold = defaultThreshold
|
||||
}
|
||||
|
||||
filter := cloneMap(req.Filter)
|
||||
if corpus != nil {
|
||||
// 1. 先拼接 corpus 过滤条件,避免跨语料串召回。
|
||||
corpusFilter, err := corpus.BuildRetrieveFilter(ctx, req.CorpusInput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filter = mergeMap(filter, corpusFilter)
|
||||
filter["corpus"] = corpus.Name()
|
||||
}
|
||||
|
||||
action := strings.TrimSpace(req.Action)
|
||||
if action == "" {
|
||||
action = "search"
|
||||
}
|
||||
|
||||
vectors, err := p.embedder.Embed(ctx, []string{query}, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vectors) != 1 {
|
||||
return nil, fmt.Errorf("embedding query length mismatch: %d", len(vectors))
|
||||
}
|
||||
|
||||
scoredRows, err := p.store.Search(ctx, VectorSearchRequest{
|
||||
QueryVector: vectors[0],
|
||||
TopK: topK,
|
||||
Filter: filter,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawCount := len(scoredRows)
|
||||
candidates := make([]ScoredChunk, 0, len(scoredRows))
|
||||
for _, row := range scoredRows {
|
||||
if row.Score < threshold {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, ScoredChunk{
|
||||
ChunkID: row.Row.ID,
|
||||
DocumentID: asString(row.Row.Metadata["document_id"]),
|
||||
Text: row.Row.Text,
|
||||
Score: row.Score,
|
||||
Metadata: cloneMap(row.Row.Metadata),
|
||||
})
|
||||
}
|
||||
|
||||
result = &RetrieveResult{
|
||||
Items: candidates,
|
||||
RawCount: rawCount,
|
||||
FallbackUsed: false,
|
||||
}
|
||||
if len(candidates) == 0 || p.reranker == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
reranked, rerankErr := p.reranker.Rerank(ctx, query, candidates, topK)
|
||||
if rerankErr != nil {
|
||||
// 2. rerank 异常不终止主流程,统一降级为原排序。
|
||||
result.FallbackUsed = true
|
||||
result.FallbackReason = FallbackReasonRerankFailed
|
||||
if p.observer != nil {
|
||||
p.observer.Observe(ctx, ObserveEvent{
|
||||
Level: ObserveLevelWarn,
|
||||
Component: "pipeline",
|
||||
Operation: "rerank_fallback",
|
||||
Fields: map[string]any{
|
||||
"status": "fallback",
|
||||
"fallback_reason": FallbackReasonRerankFailed,
|
||||
"candidate_count": len(candidates),
|
||||
"top_k": topK,
|
||||
"error": rerankErr,
|
||||
"error_code": ClassifyErrorCode(rerankErr),
|
||||
},
|
||||
})
|
||||
} else if p.logger != nil {
|
||||
p.logger.Printf("rag rerank fallback: reason=%s err=%v", FallbackReasonRerankFailed, rerankErr)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
result.Items = reranked
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Delete 删除指定 ID 的向量。
|
||||
func (p *Pipeline) Delete(ctx context.Context, ids []string) error {
|
||||
if p == nil || p.store == nil {
|
||||
return nil
|
||||
}
|
||||
return p.store.Delete(ctx, ids)
|
||||
}
|
||||
|
||||
func (p *Pipeline) recoverExecutionPanic(ctx context.Context, operation string, errPtr *error) {
|
||||
recovered := recover()
|
||||
if recovered == nil || errPtr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
panicErr := fmt.Errorf("rag pipeline panic recovered: operation=%s panic=%v", operation, recovered)
|
||||
*errPtr = panicErr
|
||||
|
||||
// 1. Pipeline 是 chunk/embed/store/rerank 的统一编排边界,第三方依赖异常不应直接杀掉上层请求。
|
||||
// 2. 这里统一 recover 后继续走 error 语义,让 runtime/service 决定降级、回退或记日志。
|
||||
// 3. stack 只写观测层,不塞进返回值,避免把超长堆栈直接暴露给上层业务错误文案。
|
||||
if p != nil && p.observer != nil {
|
||||
p.observer.Observe(ctx, ObserveEvent{
|
||||
Level: ObserveLevelError,
|
||||
Component: "pipeline",
|
||||
Operation: operation + "_panic_recovered",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"panic": fmt.Sprintf("%v", recovered),
|
||||
"panic_type": fmt.Sprintf("%T", recovered),
|
||||
"error": panicErr,
|
||||
"error_code": ClassifyErrorCode(panicErr),
|
||||
"stack": string(debug.Stack()),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
if p != nil && p.logger != nil {
|
||||
p.logger.Printf("rag pipeline panic recovered: operation=%s panic=%v stack=%s", operation, recovered, string(debug.Stack()))
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeChunkOption(opt ChunkOption) ChunkOption {
|
||||
if opt.ChunkSize <= 0 {
|
||||
opt.ChunkSize = defaultChunkSize
|
||||
}
|
||||
if opt.ChunkOverlap < 0 {
|
||||
opt.ChunkOverlap = 0
|
||||
}
|
||||
if opt.ChunkOverlap >= opt.ChunkSize {
|
||||
opt.ChunkOverlap = defaultChunkOvLap
|
||||
if opt.ChunkOverlap >= opt.ChunkSize {
|
||||
opt.ChunkOverlap = opt.ChunkSize / 5
|
||||
}
|
||||
}
|
||||
return opt
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for key, value := range src {
|
||||
dst[key] = value
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func mergeMap(base map[string]any, ext map[string]any) map[string]any {
|
||||
if base == nil {
|
||||
base = map[string]any{}
|
||||
}
|
||||
for key, value := range ext {
|
||||
base[key] = value
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
94
backend/services/rag/core/types.go
Normal file
94
backend/services/rag/core/types.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package core
|
||||
|
||||
import "time"
|
||||
|
||||
// SourceDocument 是统一语料文档模型。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 只描述“可被切块与索引”的原始文档;
|
||||
// 2. 不承载业务流程状态。
|
||||
type SourceDocument struct {
|
||||
ID string
|
||||
Text string
|
||||
Title string
|
||||
Metadata map[string]any
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// Chunk 是标准切块结果。
|
||||
type Chunk struct {
|
||||
ID string
|
||||
DocumentID string
|
||||
Text string
|
||||
Order int
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// ChunkOption 控制切块参数。
|
||||
type ChunkOption struct {
|
||||
ChunkSize int
|
||||
ChunkOverlap int
|
||||
}
|
||||
|
||||
// IngestOption 控制入库参数。
|
||||
type IngestOption struct {
|
||||
Chunk ChunkOption
|
||||
// Action 用于 embedding 分型(add/update/search)。
|
||||
Action string
|
||||
}
|
||||
|
||||
// IngestResult 描述一次入库执行摘要。
|
||||
type IngestResult struct {
|
||||
DocumentCount int
|
||||
ChunkCount int
|
||||
}
|
||||
|
||||
// RetrieveRequest 是统一检索请求。
|
||||
type RetrieveRequest struct {
|
||||
Query string
|
||||
TopK int
|
||||
Threshold float64
|
||||
Action string
|
||||
Filter map[string]any
|
||||
CorpusInput any
|
||||
}
|
||||
|
||||
// ScoredChunk 是统一召回结果。
|
||||
type ScoredChunk struct {
|
||||
ChunkID string
|
||||
DocumentID string
|
||||
Text string
|
||||
Score float64
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// RetrieveResult 是检索链路执行摘要。
|
||||
type RetrieveResult struct {
|
||||
Items []ScoredChunk
|
||||
RawCount int
|
||||
FallbackUsed bool
|
||||
FallbackReason string
|
||||
}
|
||||
|
||||
// VectorRow 是向量存储标准行。
|
||||
type VectorRow struct {
|
||||
ID string
|
||||
Vector []float32
|
||||
Text string
|
||||
Metadata map[string]any
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// VectorSearchRequest 是向量检索请求。
|
||||
type VectorSearchRequest struct {
|
||||
QueryVector []float32
|
||||
TopK int
|
||||
Filter map[string]any
|
||||
}
|
||||
|
||||
// ScoredVectorRow 是向量检索结果。
|
||||
type ScoredVectorRow struct {
|
||||
Row VectorRow
|
||||
Score float64
|
||||
}
|
||||
13
backend/services/rag/corpus/common.go
Normal file
13
backend/services/rag/corpus/common.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package corpus
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func hashLikeText(text string) string {
|
||||
normalized := strings.TrimSpace(strings.ToLower(text))
|
||||
sum := sha256.Sum256([]byte(normalized))
|
||||
return hex.EncodeToString(sum[:8])
|
||||
}
|
||||
158
backend/services/rag/corpus/memory_corpus.go
Normal file
158
backend/services/rag/corpus/memory_corpus.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package corpus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
const memoryCorpusName = "memory"
|
||||
|
||||
// MemoryIngestItem 是记忆语料入库项。
|
||||
type MemoryIngestItem struct {
|
||||
MemoryID int64
|
||||
UserID int
|
||||
ConversationID string
|
||||
AssistantID string
|
||||
RunID string
|
||||
MemoryType string
|
||||
Title string
|
||||
Content string
|
||||
Confidence float64
|
||||
Importance float64
|
||||
SensitivityLevel int
|
||||
IsExplicit bool
|
||||
Status string
|
||||
TTLAt *time.Time
|
||||
CreatedAt *time.Time
|
||||
}
|
||||
|
||||
// MemoryRetrieveInput 是记忆检索过滤输入。
|
||||
type MemoryRetrieveInput struct {
|
||||
UserID int
|
||||
ConversationID string
|
||||
AssistantID string
|
||||
RunID string
|
||||
MemoryType string
|
||||
}
|
||||
|
||||
// MemoryCorpus 是记忆语料适配器。
|
||||
type MemoryCorpus struct{}
|
||||
|
||||
func NewMemoryCorpus() *MemoryCorpus {
|
||||
return &MemoryCorpus{}
|
||||
}
|
||||
|
||||
func (c *MemoryCorpus) Name() string {
|
||||
return memoryCorpusName
|
||||
}
|
||||
|
||||
func (c *MemoryCorpus) BuildIngestDocuments(_ context.Context, input any) ([]core.SourceDocument, error) {
|
||||
items, err := toMemoryItems(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]core.SourceDocument, 0, len(items))
|
||||
for _, item := range items {
|
||||
if item.UserID <= 0 {
|
||||
return nil, errors.New("memory ingest item user_id is invalid")
|
||||
}
|
||||
text := strings.TrimSpace(item.Content)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
docID := fmt.Sprintf("memory:%d", item.MemoryID)
|
||||
if item.MemoryID <= 0 {
|
||||
docID = fmt.Sprintf("memory:uid:%d:%s", item.UserID, hashLikeText(text))
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"user_id": item.UserID,
|
||||
"conversation_id": strings.TrimSpace(item.ConversationID),
|
||||
"assistant_id": strings.TrimSpace(item.AssistantID),
|
||||
"run_id": strings.TrimSpace(item.RunID),
|
||||
"memory_type": strings.TrimSpace(strings.ToLower(item.MemoryType)),
|
||||
"title": strings.TrimSpace(item.Title),
|
||||
"confidence": item.Confidence,
|
||||
"importance": item.Importance,
|
||||
"sensitivity_level": item.SensitivityLevel,
|
||||
"is_explicit": item.IsExplicit,
|
||||
"status": strings.TrimSpace(item.Status),
|
||||
}
|
||||
if item.TTLAt != nil {
|
||||
metadata["ttl_at"] = item.TTLAt.Format(time.RFC3339)
|
||||
}
|
||||
createdAt := time.Now()
|
||||
if item.CreatedAt != nil {
|
||||
createdAt = *item.CreatedAt
|
||||
}
|
||||
result = append(result, core.SourceDocument{
|
||||
ID: docID,
|
||||
Text: text,
|
||||
Title: strings.TrimSpace(item.Title),
|
||||
Metadata: metadata,
|
||||
CreatedAt: createdAt,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *MemoryCorpus) BuildRetrieveFilter(_ context.Context, req any) (map[string]any, error) {
|
||||
input, ok := req.(MemoryRetrieveInput)
|
||||
if !ok {
|
||||
if ptr, isPtr := req.(*MemoryRetrieveInput); isPtr && ptr != nil {
|
||||
input = *ptr
|
||||
} else if req == nil {
|
||||
return nil, errors.New("memory retrieve input is nil")
|
||||
} else {
|
||||
return nil, errors.New("invalid memory retrieve input")
|
||||
}
|
||||
}
|
||||
if input.UserID <= 0 {
|
||||
return nil, errors.New("memory retrieve user_id is invalid")
|
||||
}
|
||||
filter := map[string]any{
|
||||
"user_id": input.UserID,
|
||||
}
|
||||
if v := strings.TrimSpace(input.ConversationID); v != "" {
|
||||
filter["conversation_id"] = v
|
||||
}
|
||||
if v := strings.TrimSpace(input.AssistantID); v != "" {
|
||||
filter["assistant_id"] = v
|
||||
}
|
||||
if v := strings.TrimSpace(input.RunID); v != "" {
|
||||
filter["run_id"] = v
|
||||
}
|
||||
if v := strings.TrimSpace(strings.ToLower(input.MemoryType)); v != "" {
|
||||
filter["memory_type"] = v
|
||||
}
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func toMemoryItems(input any) ([]MemoryIngestItem, error) {
|
||||
switch value := input.(type) {
|
||||
case MemoryIngestItem:
|
||||
return []MemoryIngestItem{value}, nil
|
||||
case *MemoryIngestItem:
|
||||
if value == nil {
|
||||
return nil, errors.New("memory ingest item is nil")
|
||||
}
|
||||
return []MemoryIngestItem{*value}, nil
|
||||
case []MemoryIngestItem:
|
||||
return value, nil
|
||||
case []*MemoryIngestItem:
|
||||
items := make([]MemoryIngestItem, 0, len(value))
|
||||
for _, ptr := range value {
|
||||
if ptr == nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, *ptr)
|
||||
}
|
||||
return items, nil
|
||||
default:
|
||||
return nil, errors.New("invalid memory ingest input")
|
||||
}
|
||||
}
|
||||
163
backend/services/rag/corpus/web_corpus.go
Normal file
163
backend/services/rag/corpus/web_corpus.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package corpus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
const webCorpusName = "web"
|
||||
|
||||
// WebIngestItem 是网页语料入库项。
|
||||
type WebIngestItem struct {
|
||||
URL string
|
||||
Title string
|
||||
Content string
|
||||
Snippet string
|
||||
Domain string
|
||||
QueryID string
|
||||
SessionID string
|
||||
PublishedAt *time.Time
|
||||
FetchedAt *time.Time
|
||||
SourceRank int
|
||||
}
|
||||
|
||||
// WebRetrieveInput 是网页检索过滤输入。
|
||||
type WebRetrieveInput struct {
|
||||
QueryID string
|
||||
SessionID string
|
||||
Domain string
|
||||
}
|
||||
|
||||
// WebCorpus 是网页语料适配器。
|
||||
type WebCorpus struct{}
|
||||
|
||||
func NewWebCorpus() *WebCorpus {
|
||||
return &WebCorpus{}
|
||||
}
|
||||
|
||||
func (c *WebCorpus) Name() string {
|
||||
return webCorpusName
|
||||
}
|
||||
|
||||
func (c *WebCorpus) BuildIngestDocuments(_ context.Context, input any) ([]core.SourceDocument, error) {
|
||||
items, err := toWebItems(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]core.SourceDocument, 0, len(items))
|
||||
for _, item := range items {
|
||||
url := strings.TrimSpace(item.URL)
|
||||
if url == "" {
|
||||
return nil, errors.New("web ingest item url is empty")
|
||||
}
|
||||
|
||||
mainText := buildWebText(item)
|
||||
if strings.TrimSpace(mainText) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
docID := fmt.Sprintf("web:%s", hashLikeText(url+"|"+mainText))
|
||||
metadata := map[string]any{
|
||||
"url": url,
|
||||
"domain": strings.TrimSpace(item.Domain),
|
||||
"query_id": strings.TrimSpace(item.QueryID),
|
||||
"session_id": strings.TrimSpace(item.SessionID),
|
||||
"source_rank": item.SourceRank,
|
||||
}
|
||||
if item.PublishedAt != nil {
|
||||
metadata["published_at"] = item.PublishedAt.Format(time.RFC3339)
|
||||
}
|
||||
if item.FetchedAt != nil {
|
||||
metadata["fetched_at"] = item.FetchedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
createdAt := time.Now()
|
||||
if item.FetchedAt != nil {
|
||||
createdAt = *item.FetchedAt
|
||||
}
|
||||
result = append(result, core.SourceDocument{
|
||||
ID: docID,
|
||||
Text: mainText,
|
||||
Title: strings.TrimSpace(item.Title),
|
||||
Metadata: metadata,
|
||||
CreatedAt: createdAt,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *WebCorpus) BuildRetrieveFilter(_ context.Context, req any) (map[string]any, error) {
|
||||
input, ok := req.(WebRetrieveInput)
|
||||
if !ok {
|
||||
if ptr, isPtr := req.(*WebRetrieveInput); isPtr && ptr != nil {
|
||||
input = *ptr
|
||||
} else if req == nil {
|
||||
return nil, errors.New("web retrieve input is nil")
|
||||
} else {
|
||||
return nil, errors.New("invalid web retrieve input")
|
||||
}
|
||||
}
|
||||
|
||||
// 1. query_id/session_id 至少要有一个,避免跨问题串数据。
|
||||
queryID := strings.TrimSpace(input.QueryID)
|
||||
sessionID := strings.TrimSpace(input.SessionID)
|
||||
if queryID == "" && sessionID == "" {
|
||||
return nil, errors.New("web retrieve filter requires query_id or session_id")
|
||||
}
|
||||
|
||||
filter := map[string]any{}
|
||||
if queryID != "" {
|
||||
filter["query_id"] = queryID
|
||||
}
|
||||
if sessionID != "" {
|
||||
filter["session_id"] = sessionID
|
||||
}
|
||||
if domain := strings.TrimSpace(input.Domain); domain != "" {
|
||||
filter["domain"] = domain
|
||||
}
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func toWebItems(input any) ([]WebIngestItem, error) {
|
||||
switch value := input.(type) {
|
||||
case WebIngestItem:
|
||||
return []WebIngestItem{value}, nil
|
||||
case *WebIngestItem:
|
||||
if value == nil {
|
||||
return nil, errors.New("web ingest item is nil")
|
||||
}
|
||||
return []WebIngestItem{*value}, nil
|
||||
case []WebIngestItem:
|
||||
return value, nil
|
||||
case []*WebIngestItem:
|
||||
items := make([]WebIngestItem, 0, len(value))
|
||||
for _, ptr := range value {
|
||||
if ptr == nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, *ptr)
|
||||
}
|
||||
return items, nil
|
||||
default:
|
||||
return nil, errors.New("invalid web ingest input")
|
||||
}
|
||||
}
|
||||
|
||||
func buildWebText(item WebIngestItem) string {
|
||||
parts := make([]string, 0, 3)
|
||||
if title := strings.TrimSpace(item.Title); title != "" {
|
||||
parts = append(parts, title)
|
||||
}
|
||||
if snippet := strings.TrimSpace(item.Snippet); snippet != "" {
|
||||
parts = append(parts, snippet)
|
||||
}
|
||||
if content := strings.TrimSpace(item.Content); content != "" {
|
||||
parts = append(parts, content)
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
208
backend/services/rag/embed/eino_embedder.go
Normal file
208
backend/services/rag/embed/eino_embedder.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
openaiembedding "github.com/cloudwego/eino-ext/libs/acl/openai"
|
||||
einoembedding "github.com/cloudwego/eino/components/embedding"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||||
arkmodel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
// EinoConfig 描述 Eino embedding 运行参数。
|
||||
type EinoConfig struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
Model string
|
||||
TimeoutMS int
|
||||
Dimension int
|
||||
}
|
||||
|
||||
// EinoEmbedder 是基于 Eino 的 embedding 适配器。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 对 infra/rag 暴露统一 []float32 结果,屏蔽底层 SDK 的实现差异。
|
||||
// 2. 文本 embedding 继续走当前稳定的 OpenAI 兼容链路,避免无关模型受影响。
|
||||
// 3. 多模态 embedding 模型单独走 Ark 原生 `/embeddings/multimodal`,解决 vision 模型与标准 `/embeddings` 不兼容的问题。
|
||||
type EinoEmbedder struct {
|
||||
textClient einoembedding.Embedder
|
||||
multimodalClient *arkruntime.Client
|
||||
model string
|
||||
timeout time.Duration
|
||||
dimension int
|
||||
}
|
||||
|
||||
func NewEinoEmbedder(ctx context.Context, cfg EinoConfig) (*EinoEmbedder, error) {
|
||||
if strings.TrimSpace(cfg.APIKey) == "" {
|
||||
return nil, errors.New("eino embedder api key is empty")
|
||||
}
|
||||
if strings.TrimSpace(cfg.Model) == "" {
|
||||
return nil, errors.New("eino embedder model is empty")
|
||||
}
|
||||
|
||||
timeout := 1200 * time.Millisecond
|
||||
if cfg.TimeoutMS > 0 {
|
||||
timeout = time.Duration(cfg.TimeoutMS) * time.Millisecond
|
||||
}
|
||||
|
||||
baseURL := normalizeEmbeddingBaseURL(cfg.BaseURL)
|
||||
model := strings.TrimSpace(cfg.Model)
|
||||
httpClient := &http.Client{Timeout: timeout}
|
||||
|
||||
// 1. `doubao-embedding-vision-*` 这类模型不支持标准 `/embeddings`。
|
||||
// 2. 这里直接切到 Ark 原生多模态 embedding API,避免再依赖错误 endpoint 拼接。
|
||||
// 3. 之所以仍保留文本链路,是为了不影响普通 text embedding 模型的既有行为。
|
||||
if isMultimodalEmbeddingModel(model) {
|
||||
arkOptions := []arkruntime.ConfigOption{
|
||||
arkruntime.WithHTTPClient(httpClient),
|
||||
}
|
||||
if baseURL != "" {
|
||||
arkOptions = append(arkOptions, arkruntime.WithBaseUrl(baseURL))
|
||||
}
|
||||
|
||||
return &EinoEmbedder{
|
||||
multimodalClient: arkruntime.NewClientWithApiKey(
|
||||
strings.TrimSpace(cfg.APIKey),
|
||||
arkOptions...,
|
||||
),
|
||||
model: model,
|
||||
timeout: timeout,
|
||||
dimension: cfg.Dimension,
|
||||
}, nil
|
||||
}
|
||||
|
||||
clientCfg := &openaiembedding.EmbeddingConfig{
|
||||
APIKey: strings.TrimSpace(cfg.APIKey),
|
||||
BaseURL: baseURL,
|
||||
Model: model,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
if cfg.Dimension > 0 {
|
||||
clientCfg.Dimensions = &cfg.Dimension
|
||||
}
|
||||
|
||||
client, err := openaiembedding.NewEmbeddingClient(ctx, clientCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &EinoEmbedder{
|
||||
textClient: client,
|
||||
model: model,
|
||||
timeout: timeout,
|
||||
dimension: cfg.Dimension,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *EinoEmbedder) Embed(ctx context.Context, texts []string, _ string) (result [][]float32, err error) {
|
||||
if e == nil {
|
||||
return nil, errors.New("eino embedder is not initialized")
|
||||
}
|
||||
if len(texts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
callCtx := ctx
|
||||
cancel := func() {}
|
||||
if e.timeout > 0 {
|
||||
callCtx, cancel = context.WithTimeout(ctx, e.timeout)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
// 1. 第三方 SDK 一旦 panic,不应该穿透到 RAG 主链路。
|
||||
// 2. 这里统一在模型调用边界 recover,并转成 error 交给上层做降级。
|
||||
// 3. 这样 memory 主写链路和 agent 主回复链路都不会因为向量同步失败被直接打崩。
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
err = fmt.Errorf("eino embedder panic recovered: %v", recovered)
|
||||
result = nil
|
||||
}
|
||||
}()
|
||||
|
||||
if e.multimodalClient != nil {
|
||||
return e.embedTextsWithMultimodalAPI(callCtx, texts)
|
||||
}
|
||||
if e.textClient == nil {
|
||||
return nil, errors.New("eino embedder client is not initialized")
|
||||
}
|
||||
|
||||
vectors, err := e.textClient.EmbedStrings(callCtx, texts, einoembedding.WithModel(e.model))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = make([][]float32, 0, len(vectors))
|
||||
for _, vector := range vectors {
|
||||
converted := make([]float32, len(vector))
|
||||
for i, value := range vector {
|
||||
converted[i] = float32(value)
|
||||
}
|
||||
result = append(result, converted)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (e *EinoEmbedder) embedTextsWithMultimodalAPI(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
if e.multimodalClient == nil {
|
||||
return nil, errors.New("eino multimodal embedder client is not initialized")
|
||||
}
|
||||
|
||||
vectors := make([][]float32, 0, len(texts))
|
||||
for _, text := range texts {
|
||||
text := text
|
||||
req := arkmodel.MultiModalEmbeddingRequest{
|
||||
Model: e.model,
|
||||
Input: []arkmodel.MultimodalEmbeddingInput{
|
||||
{
|
||||
Type: arkmodel.MultiModalEmbeddingInputTypeText,
|
||||
Text: &text,
|
||||
},
|
||||
},
|
||||
}
|
||||
if e.dimension > 0 {
|
||||
req.Dimensions = &e.dimension
|
||||
}
|
||||
|
||||
// 1. Ark 的多模态 embedding 请求体是“单条内容由多个 part 组成”。
|
||||
// 2. 当前 RAG 这里只传文本,因此每段文本单独发一次,避免把多段文本错误拼成同一个 multimodal sample。
|
||||
// 3. 一旦后续真的要做批量多模态 embedding,再单独扩展 batch 接口,而不是在这里偷改语义。
|
||||
resp, err := e.multimodalClient.CreateMultiModalEmbeddings(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
converted := make([]float32, len(resp.Data.Embedding))
|
||||
copy(converted, resp.Data.Embedding)
|
||||
vectors = append(vectors, converted)
|
||||
}
|
||||
return vectors, nil
|
||||
}
|
||||
|
||||
func isMultimodalEmbeddingModel(model string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "doubao-embedding-vision-")
|
||||
}
|
||||
|
||||
func normalizeEmbeddingBaseURL(raw string) string {
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(raw), "/")
|
||||
if baseURL == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
lowerBaseURL := strings.ToLower(baseURL)
|
||||
|
||||
// 1. 配置里应填写 Ark 服务根路径,而不是具体 embedding endpoint。
|
||||
// 2. 这里兼容两类常见误配:`/embeddings` 和 `/embeddings/multimodal`。
|
||||
// 3. 统一回退到 `/api/v3` 根路径后,再由对应 SDK 自己追加正确后缀,避免最终 URL 重复拼接。
|
||||
if strings.HasSuffix(lowerBaseURL, "/embeddings/multimodal") {
|
||||
return strings.TrimSuffix(baseURL, baseURL[len(baseURL)-len("/embeddings/multimodal"):])
|
||||
}
|
||||
if strings.HasSuffix(lowerBaseURL, "/embeddings") {
|
||||
return strings.TrimSuffix(baseURL, baseURL[len(baseURL)-len("/embeddings"):])
|
||||
}
|
||||
return baseURL
|
||||
}
|
||||
46
backend/services/rag/embed/mock_embedder.go
Normal file
46
backend/services/rag/embed/mock_embedder.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const defaultDim = 16
|
||||
|
||||
// MockEmbedder 是本地可运行的占位向量化实现。
|
||||
//
|
||||
// 说明:
|
||||
// 1. 该实现用于开发阶段打通链路,不代表真实语义向量质量;
|
||||
// 2. 后续可替换为 Eino embedding 实现,接口保持不变。
|
||||
type MockEmbedder struct {
|
||||
dim int
|
||||
}
|
||||
|
||||
func NewMockEmbedder(dim int) *MockEmbedder {
|
||||
if dim <= 0 {
|
||||
dim = defaultDim
|
||||
}
|
||||
return &MockEmbedder{dim: dim}
|
||||
}
|
||||
|
||||
func (e *MockEmbedder) Embed(_ context.Context, texts []string, _ string) ([][]float32, error) {
|
||||
result := make([][]float32, 0, len(texts))
|
||||
for _, text := range texts {
|
||||
result = append(result, e.embedOne(text))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (e *MockEmbedder) embedOne(text string) []float32 {
|
||||
normalized := strings.TrimSpace(strings.ToLower(text))
|
||||
sum := sha256.Sum256([]byte(normalized))
|
||||
vec := make([]float32, e.dim)
|
||||
for i := 0; i < e.dim; i++ {
|
||||
offset := (i * 4) % len(sum)
|
||||
v := binary.BigEndian.Uint32(sum[offset : offset+4])
|
||||
vec[i] = float32(v%1000) / 1000
|
||||
}
|
||||
return vec
|
||||
}
|
||||
142
backend/services/rag/factory.go
Normal file
142
backend/services/rag/factory.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
ragchunk "github.com/LoveLosita/smartflow/backend/services/rag/chunk"
|
||||
ragconfig "github.com/LoveLosita/smartflow/backend/services/rag/config"
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
ragembed "github.com/LoveLosita/smartflow/backend/services/rag/embed"
|
||||
ragrerank "github.com/LoveLosita/smartflow/backend/services/rag/rerank"
|
||||
ragstore "github.com/LoveLosita/smartflow/backend/services/rag/store"
|
||||
)
|
||||
|
||||
// FactoryDeps 描述 Runtime 工厂所需的可选依赖。
|
||||
//
|
||||
// 说明:
|
||||
// 1. Logger 仅作为“当前项目尚无统一日志系统”时的默认落点;
|
||||
// 2. Observer 是正式的统一观测插槽,后续可替换为项目级 logger/metrics/tracing 适配器;
|
||||
// 3. 业务侧仍然只拿 Runtime,不直接碰底层装配细节。
|
||||
type FactoryDeps struct {
|
||||
Logger *log.Logger
|
||||
Observer Observer
|
||||
}
|
||||
|
||||
// NewRuntimeFromConfig 按配置统一组装 RAG Runtime。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 所有底层实现选择都收口到这里,业务侧不再自行 new store/embedder/reranker;
|
||||
// 2. 即使后续引入更多 provider,也应优先扩展本工厂,而不是把选择逻辑扩散到业务模块;
|
||||
// 3. 观测能力也在此统一注入,避免 runtime/store/pipeline 各自偷偷打印日志。
|
||||
func NewRuntimeFromConfig(ctx context.Context, cfg ragconfig.Config, deps FactoryDeps) (Runtime, error) {
|
||||
logger, observer := normalizeFactoryDeps(deps)
|
||||
|
||||
embedder, err := buildEmbedder(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
store, err := buildStore(cfg, logger, observer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reranker, err := buildReranker(cfg, observer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pipeline := core.NewPipeline(ragchunk.NewTextChunker(), embedder, store, reranker)
|
||||
pipeline.SetLogger(logger)
|
||||
pipeline.SetObserver(observer)
|
||||
return newRuntime(cfg, pipeline, observer), nil
|
||||
}
|
||||
|
||||
func normalizeFactoryDeps(deps FactoryDeps) (*log.Logger, Observer) {
|
||||
logger := deps.Logger
|
||||
if logger == nil {
|
||||
logger = log.Default()
|
||||
}
|
||||
observer := deps.Observer
|
||||
if observer == nil {
|
||||
observer = NewLoggerObserver(logger)
|
||||
}
|
||||
return logger, observer
|
||||
}
|
||||
|
||||
func buildEmbedder(ctx context.Context, cfg ragconfig.Config) (core.Embedder, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.EmbedProvider)) {
|
||||
case "", "mock":
|
||||
return ragembed.NewMockEmbedder(cfg.EmbedDimension), nil
|
||||
case "eino":
|
||||
// 1. RAG embedding 与普通 LLM 链路保持同一套密钥来源,统一直接读取 ARK_API_KEY;
|
||||
// 2. 这样可以避免再维护一层 “env 名称配置 -> 再读环境变量” 的间接映射,减少配置分叉;
|
||||
// 3. 若后续真的需要多套 embedding 凭据,再显式设计独立字段,而不是继续隐式透传 env 名称。
|
||||
apiKey := strings.TrimSpace(os.Getenv("ARK_API_KEY"))
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("rag embed api key is empty: env=%s", "ARK_API_KEY")
|
||||
}
|
||||
return ragembed.NewEinoEmbedder(ctx, ragembed.EinoConfig{
|
||||
APIKey: apiKey,
|
||||
BaseURL: cfg.EmbedBaseURL,
|
||||
Model: cfg.EmbedModel,
|
||||
TimeoutMS: cfg.EmbedTimeoutMS,
|
||||
Dimension: cfg.EmbedDimension,
|
||||
})
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported rag embed provider: %s", cfg.EmbedProvider)
|
||||
}
|
||||
}
|
||||
|
||||
func buildStore(cfg ragconfig.Config, logger *log.Logger, observer Observer) (core.VectorStore, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.Store)) {
|
||||
case "", "inmemory":
|
||||
return ragstore.NewInMemoryVectorStore(), nil
|
||||
case "milvus":
|
||||
return ragstore.NewMilvusStore(ragstore.MilvusConfig{
|
||||
Address: cfg.MilvusAddress,
|
||||
Token: cfg.MilvusToken,
|
||||
DBName: cfg.MilvusDBName,
|
||||
CollectionName: cfg.MilvusCollectionName,
|
||||
RequestTimeoutMS: cfg.MilvusRequestTimeoutMS,
|
||||
Dimension: cfg.EmbedDimension,
|
||||
MetricType: cfg.MilvusMetricType,
|
||||
Logger: logger,
|
||||
Observer: observer,
|
||||
})
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported rag store: %s", cfg.Store)
|
||||
}
|
||||
}
|
||||
|
||||
func buildReranker(cfg ragconfig.Config, observer Observer) (core.Reranker, error) {
|
||||
if !cfg.RerankerEnabled {
|
||||
return ragrerank.NewNoopReranker(), nil
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.RerankerProvider)) {
|
||||
case "", "noop":
|
||||
return ragrerank.NewNoopReranker(), nil
|
||||
case "eino":
|
||||
if observer != nil {
|
||||
observer.Observe(context.Background(), ObserveEvent{
|
||||
Level: ObserveLevelWarn,
|
||||
Component: "factory",
|
||||
Operation: "reranker_fallback",
|
||||
Fields: map[string]any{
|
||||
"provider": "eino",
|
||||
"status": "fallback",
|
||||
"fallback_target": "noop",
|
||||
"reason": "reranker_not_implemented",
|
||||
},
|
||||
})
|
||||
}
|
||||
return ragrerank.NewNoopReranker(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported rag reranker provider: %s", cfg.RerankerProvider)
|
||||
}
|
||||
}
|
||||
32
backend/services/rag/observe.go
Normal file
32
backend/services/rag/observe.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
// ObserveLevel 对外暴露统一观测等级别名,避免启动层直接依赖 core 细节。
|
||||
type ObserveLevel = core.ObserveLevel
|
||||
|
||||
const (
|
||||
ObserveLevelInfo = core.ObserveLevelInfo
|
||||
ObserveLevelWarn = core.ObserveLevelWarn
|
||||
ObserveLevelError = core.ObserveLevelError
|
||||
)
|
||||
|
||||
// ObserveEvent 对外暴露统一观测事件别名。
|
||||
type ObserveEvent = core.ObserveEvent
|
||||
|
||||
// Observer 对外暴露统一观测接口别名。
|
||||
type Observer = core.Observer
|
||||
|
||||
// NewNopObserver 返回空实现。
|
||||
func NewNopObserver() Observer {
|
||||
return core.NewNopObserver()
|
||||
}
|
||||
|
||||
// NewLoggerObserver 返回标准日志适配器。
|
||||
func NewLoggerObserver(logger *log.Logger) Observer {
|
||||
return core.NewLoggerObserver(logger)
|
||||
}
|
||||
23
backend/services/rag/rag.go
Normal file
23
backend/services/rag/rag.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/chunk"
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/embed"
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/rerank"
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/store"
|
||||
)
|
||||
|
||||
// NewDefaultPipeline 构造默认可运行的 RAG Pipeline。
|
||||
//
|
||||
// 当前策略:
|
||||
// 1. 默认使用本地 MockEmbedder + InMemoryStore,保证零外部依赖可运行;
|
||||
// 2. 后续切 Milvus / Eino 时仅替换依赖,不改业务调用方式。
|
||||
func NewDefaultPipeline() *core.Pipeline {
|
||||
return core.NewPipeline(
|
||||
chunk.NewTextChunker(),
|
||||
embed.NewMockEmbedder(16),
|
||||
store.NewInMemoryVectorStore(),
|
||||
rerank.NewNoopReranker(),
|
||||
)
|
||||
}
|
||||
19
backend/services/rag/rerank/eino_reranker.go
Normal file
19
backend/services/rag/rerank/eino_reranker.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
// EinoReranker 是 Eino 重排器占位实现。
|
||||
type EinoReranker struct{}
|
||||
|
||||
func NewEinoReranker() *EinoReranker {
|
||||
return &EinoReranker{}
|
||||
}
|
||||
|
||||
func (r *EinoReranker) Rerank(_ context.Context, _ string, _ []core.ScoredChunk, _ int) ([]core.ScoredChunk, error) {
|
||||
return nil, errors.New("eino reranker is not implemented yet")
|
||||
}
|
||||
30
backend/services/rag/rerank/noop_reranker.go
Normal file
30
backend/services/rag/rerank/noop_reranker.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
// NoopReranker 是默认重排器(仅按原 score 排序)。
|
||||
type NoopReranker struct{}
|
||||
|
||||
func NewNoopReranker() *NoopReranker {
|
||||
return &NoopReranker{}
|
||||
}
|
||||
|
||||
func (r *NoopReranker) Rerank(_ context.Context, _ string, candidates []core.ScoredChunk, topK int) ([]core.ScoredChunk, error) {
|
||||
if len(candidates) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
sorted := make([]core.ScoredChunk, len(candidates))
|
||||
copy(sorted, candidates)
|
||||
sort.SliceStable(sorted, func(i, j int) bool {
|
||||
return sorted[i].Score > sorted[j].Score
|
||||
})
|
||||
if topK <= 0 || topK >= len(sorted) {
|
||||
return sorted, nil
|
||||
}
|
||||
return sorted[:topK], nil
|
||||
}
|
||||
98
backend/services/rag/retrieve/vector_retriever.go
Normal file
98
backend/services/rag/retrieve/vector_retriever.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package retrieve
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
// VectorRetriever 是通用检索器(embed + vector search)。
|
||||
type VectorRetriever struct {
|
||||
embedder core.Embedder
|
||||
store core.VectorStore
|
||||
}
|
||||
|
||||
func NewVectorRetriever(embedder core.Embedder, store core.VectorStore) *VectorRetriever {
|
||||
return &VectorRetriever{
|
||||
embedder: embedder,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *VectorRetriever) Retrieve(ctx context.Context, req core.RetrieveRequest) (result []core.ScoredChunk, err error) {
|
||||
defer func() {
|
||||
recovered := recover()
|
||||
if recovered == nil {
|
||||
return
|
||||
}
|
||||
err = fmt.Errorf("vector retriever panic recovered: %v", recovered)
|
||||
}()
|
||||
|
||||
if r == nil || r.embedder == nil || r.store == nil {
|
||||
return nil, core.ErrNilDependency
|
||||
}
|
||||
query := strings.TrimSpace(req.Query)
|
||||
if query == "" {
|
||||
return nil, core.ErrInvalidQuery
|
||||
}
|
||||
topK := req.TopK
|
||||
if topK <= 0 {
|
||||
topK = 8
|
||||
}
|
||||
action := strings.TrimSpace(req.Action)
|
||||
if action == "" {
|
||||
action = "search"
|
||||
}
|
||||
|
||||
vectors, err := r.embedder.Embed(ctx, []string{query}, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vectors) != 1 {
|
||||
return nil, fmt.Errorf("embedding query length mismatch: %d", len(vectors))
|
||||
}
|
||||
|
||||
rows, err := r.store.Search(ctx, core.VectorSearchRequest{
|
||||
QueryVector: vectors[0],
|
||||
TopK: topK,
|
||||
Filter: req.Filter,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = make([]core.ScoredChunk, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
if row.Score < req.Threshold {
|
||||
continue
|
||||
}
|
||||
result = append(result, core.ScoredChunk{
|
||||
ChunkID: row.Row.ID,
|
||||
DocumentID: asString(row.Row.Metadata["document_id"]),
|
||||
Text: row.Row.Text,
|
||||
Score: row.Score,
|
||||
Metadata: cloneMap(row.Row.Metadata),
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
434
backend/services/rag/runtime.go
Normal file
434
backend/services/rag/runtime.go
Normal file
@@ -0,0 +1,434 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
ragconfig "github.com/LoveLosita/smartflow/backend/services/rag/config"
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/corpus"
|
||||
)
|
||||
|
||||
type runtime struct {
|
||||
cfg ragconfig.Config
|
||||
pipeline *core.Pipeline
|
||||
memoryCorpus *corpus.MemoryCorpus
|
||||
webCorpus *corpus.WebCorpus
|
||||
observer Observer
|
||||
}
|
||||
|
||||
func newRuntime(cfg ragconfig.Config, pipeline *core.Pipeline, observer Observer) Runtime {
|
||||
if observer == nil {
|
||||
observer = NewNopObserver()
|
||||
}
|
||||
return &runtime{
|
||||
cfg: cfg,
|
||||
pipeline: pipeline,
|
||||
memoryCorpus: corpus.NewMemoryCorpus(),
|
||||
webCorpus: corpus.NewWebCorpus(),
|
||||
observer: observer,
|
||||
}
|
||||
}
|
||||
|
||||
// IngestMemory 统一承接记忆语料入库。
|
||||
func (r *runtime) IngestMemory(ctx context.Context, req MemoryIngestRequest) (result *IngestResult, err error) {
|
||||
defer r.recoverPublicPanic(ctx, req.TraceID, "memory", normalizeAction(req.Action, "add"), "ingest", &err)
|
||||
|
||||
items := make([]corpus.MemoryIngestItem, 0, len(req.Items))
|
||||
for _, item := range req.Items {
|
||||
items = append(items, corpus.MemoryIngestItem{
|
||||
MemoryID: item.MemoryID,
|
||||
UserID: item.UserID,
|
||||
ConversationID: item.ConversationID,
|
||||
AssistantID: item.AssistantID,
|
||||
RunID: item.RunID,
|
||||
MemoryType: item.MemoryType,
|
||||
Title: item.Title,
|
||||
Content: item.Content,
|
||||
Confidence: item.Confidence,
|
||||
Importance: item.Importance,
|
||||
SensitivityLevel: item.SensitivityLevel,
|
||||
IsExplicit: item.IsExplicit,
|
||||
Status: item.Status,
|
||||
TTLAt: item.TTLAt,
|
||||
CreatedAt: item.CreatedAt,
|
||||
})
|
||||
}
|
||||
return r.ingestWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, items, req.Action)
|
||||
}
|
||||
|
||||
// RetrieveMemory 统一承接记忆语料检索。
|
||||
func (r *runtime) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (result *RetrieveResult, err error) {
|
||||
defer r.recoverPublicPanic(ctx, req.TraceID, "memory", normalizeAction(req.Action, "search"), "retrieve", &err)
|
||||
|
||||
corpusInput := corpus.MemoryRetrieveInput{
|
||||
UserID: req.UserID,
|
||||
ConversationID: req.ConversationID,
|
||||
AssistantID: req.AssistantID,
|
||||
RunID: req.RunID,
|
||||
}
|
||||
if len(req.MemoryTypes) == 1 {
|
||||
corpusInput.MemoryType = req.MemoryTypes[0]
|
||||
}
|
||||
|
||||
result, err = r.retrieveWithCorpus(ctx, req.TraceID, "memory", r.memoryCorpus, core.RetrieveRequest{
|
||||
Query: req.Query,
|
||||
TopK: normalizeTopK(req.TopK, r.cfg.TopK),
|
||||
Threshold: normalizeThreshold(req.Threshold, r.cfg.Threshold),
|
||||
Action: normalizeAction(req.Action, "search"),
|
||||
CorpusInput: corpusInput,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(req.MemoryTypes) <= 1 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 1. 当前底层过滤仍以等值条件为主,先保持 Runtime 做多类型二次筛选;
|
||||
// 2. 这样可以避免把 “memory_type in (...)” 的实现细节扩散到所有 Store;
|
||||
// 3. 等后续底层过滤能力统一后,再考虑把该逻辑继续下沉。
|
||||
allowed := make(map[string]struct{}, len(req.MemoryTypes))
|
||||
for _, item := range req.MemoryTypes {
|
||||
value := strings.TrimSpace(strings.ToLower(item))
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
allowed[value] = struct{}{}
|
||||
}
|
||||
|
||||
filtered := make([]RetrieveHit, 0, len(result.Items))
|
||||
for _, item := range result.Items {
|
||||
memoryType := strings.TrimSpace(strings.ToLower(asString(item.Metadata["memory_type"])))
|
||||
if len(allowed) > 0 {
|
||||
if _, ok := allowed[memoryType]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
result.Items = filtered
|
||||
if req.TopK > 0 && len(result.Items) > req.TopK {
|
||||
result.Items = result.Items[:req.TopK]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteMemory 删除记忆语料中的指定向量。
|
||||
func (r *runtime) DeleteMemory(ctx context.Context, documentIDs []string) (err error) {
|
||||
defer r.recoverPublicPanic(ctx, "", "memory", "delete", "delete", &err)
|
||||
|
||||
if r == nil || r.pipeline == nil || len(documentIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return r.pipeline.Delete(ctx, documentIDs)
|
||||
}
|
||||
|
||||
// IngestWeb 统一承接网页语料入库。
|
||||
func (r *runtime) IngestWeb(ctx context.Context, req WebIngestRequest) (result *IngestResult, err error) {
|
||||
defer r.recoverPublicPanic(ctx, req.TraceID, "web", normalizeAction(req.Action, "add"), "ingest", &err)
|
||||
|
||||
items := make([]corpus.WebIngestItem, 0, len(req.Items))
|
||||
for _, item := range req.Items {
|
||||
items = append(items, corpus.WebIngestItem{
|
||||
URL: item.URL,
|
||||
Title: item.Title,
|
||||
Content: item.Content,
|
||||
Snippet: item.Snippet,
|
||||
Domain: item.Domain,
|
||||
QueryID: item.QueryID,
|
||||
SessionID: item.SessionID,
|
||||
PublishedAt: item.PublishedAt,
|
||||
FetchedAt: item.FetchedAt,
|
||||
SourceRank: item.SourceRank,
|
||||
})
|
||||
}
|
||||
return r.ingestWithCorpus(ctx, req.TraceID, "web", r.webCorpus, items, req.Action)
|
||||
}
|
||||
|
||||
// RetrieveWeb 统一承接网页语料检索。
|
||||
func (r *runtime) RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (result *RetrieveResult, err error) {
|
||||
defer r.recoverPublicPanic(ctx, req.TraceID, "web", normalizeAction(req.Action, "search"), "retrieve", &err)
|
||||
|
||||
return r.retrieveWithCorpus(ctx, req.TraceID, "web", r.webCorpus, core.RetrieveRequest{
|
||||
Query: req.Query,
|
||||
TopK: normalizeTopK(req.TopK, r.cfg.TopK),
|
||||
Threshold: normalizeThreshold(req.Threshold, r.cfg.Threshold),
|
||||
Action: normalizeAction(req.Action, "search"),
|
||||
CorpusInput: corpus.WebRetrieveInput{
|
||||
QueryID: req.QueryID,
|
||||
SessionID: req.SessionID,
|
||||
Domain: req.Domain,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (r *runtime) ingestWithCorpus(
|
||||
ctx context.Context,
|
||||
traceID string,
|
||||
corpusName string,
|
||||
adapter core.CorpusAdapter,
|
||||
input any,
|
||||
action string,
|
||||
) (*IngestResult, error) {
|
||||
start := time.Now()
|
||||
if r == nil || r.pipeline == nil || adapter == nil {
|
||||
return nil, core.ErrNilDependency
|
||||
}
|
||||
|
||||
action = normalizeAction(action, "add")
|
||||
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
|
||||
|
||||
docs, err := adapter.BuildIngestDocuments(observeCtx, input)
|
||||
if err != nil {
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelError,
|
||||
Component: "runtime",
|
||||
Operation: "ingest",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"phase": "build_documents",
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
"input_count": estimateInputCount(input),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
docIDs := make([]string, 0, len(docs))
|
||||
for _, doc := range docs {
|
||||
docIDs = append(docIDs, doc.ID)
|
||||
}
|
||||
|
||||
result, err := r.pipeline.IngestDocuments(observeCtx, adapter.Name(), docs, core.IngestOption{
|
||||
Chunk: core.ChunkOption{
|
||||
ChunkSize: r.cfg.ChunkSize,
|
||||
ChunkOverlap: r.cfg.ChunkOverlap,
|
||||
},
|
||||
Action: action,
|
||||
})
|
||||
if err != nil {
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelError,
|
||||
Component: "runtime",
|
||||
Operation: "ingest",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"document_count": len(docs),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelInfo,
|
||||
Component: "runtime",
|
||||
Operation: "ingest",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"document_count": result.DocumentCount,
|
||||
"chunk_count": result.ChunkCount,
|
||||
},
|
||||
})
|
||||
return &IngestResult{
|
||||
DocumentCount: result.DocumentCount,
|
||||
ChunkCount: result.ChunkCount,
|
||||
DocumentIDs: docIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *runtime) retrieveWithCorpus(
|
||||
ctx context.Context,
|
||||
traceID string,
|
||||
corpusName string,
|
||||
adapter core.CorpusAdapter,
|
||||
req core.RetrieveRequest,
|
||||
) (*RetrieveResult, error) {
|
||||
start := time.Now()
|
||||
if r == nil || r.pipeline == nil || adapter == nil {
|
||||
return nil, core.ErrNilDependency
|
||||
}
|
||||
|
||||
action := normalizeAction(req.Action, "search")
|
||||
req.Action = action
|
||||
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
|
||||
|
||||
timeoutCtx := observeCtx
|
||||
cancel := func() {}
|
||||
if r.cfg.RetrieveTimeoutMS > 0 {
|
||||
timeoutCtx, cancel = context.WithTimeout(observeCtx, time.Duration(r.cfg.RetrieveTimeoutMS)*time.Millisecond)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
result, err := r.pipeline.Retrieve(timeoutCtx, adapter, req)
|
||||
if err != nil {
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelError,
|
||||
Component: "runtime",
|
||||
Operation: "retrieve",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"query_len": len(strings.TrimSpace(req.Query)),
|
||||
"top_k": req.TopK,
|
||||
"threshold": req.Threshold,
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
items := make([]RetrieveHit, 0, len(result.Items))
|
||||
for _, item := range result.Items {
|
||||
items = append(items, RetrieveHit{
|
||||
ChunkID: item.ChunkID,
|
||||
DocumentID: item.DocumentID,
|
||||
Text: item.Text,
|
||||
Score: item.Score,
|
||||
Metadata: cloneMap(item.Metadata),
|
||||
})
|
||||
}
|
||||
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelInfo,
|
||||
Component: "runtime",
|
||||
Operation: "retrieve",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"query_len": len(strings.TrimSpace(req.Query)),
|
||||
"top_k": req.TopK,
|
||||
"threshold": req.Threshold,
|
||||
"raw_count": result.RawCount,
|
||||
"hit_count": len(result.Items),
|
||||
"fallback_used": result.FallbackUsed,
|
||||
"fallback_reason": result.FallbackReason,
|
||||
},
|
||||
})
|
||||
return &RetrieveResult{
|
||||
Items: items,
|
||||
RawCount: result.RawCount,
|
||||
FallbackUsed: result.FallbackUsed,
|
||||
FallbackReason: result.FallbackReason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *runtime) observe(ctx context.Context, event ObserveEvent) {
|
||||
if r == nil || r.observer == nil {
|
||||
return
|
||||
}
|
||||
r.observer.Observe(ctx, event)
|
||||
}
|
||||
|
||||
func (r *runtime) recoverPublicPanic(
|
||||
ctx context.Context,
|
||||
traceID string,
|
||||
corpusName string,
|
||||
action string,
|
||||
operation string,
|
||||
errPtr *error,
|
||||
) {
|
||||
recovered := recover()
|
||||
if recovered == nil || errPtr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 1. runtime 是 RAG service 对业务侧暴露的最终方法面,任何下层 panic 都不应再穿透到业务协程。
|
||||
// 2. 这里统一把 panic 转成 error,并补一条结构化观测,方便继续排查是哪一层依赖失控。
|
||||
// 3. 保留 stack 是为了在“进程不崩”的前提下仍能定位根因,避免只剩一句 recovered 无法复盘。
|
||||
panicErr := fmt.Errorf("rag runtime panic recovered: corpus=%s operation=%s panic=%v", corpusName, operation, recovered)
|
||||
*errPtr = panicErr
|
||||
|
||||
observeCtx := newObserveContext(ctx, traceID, corpusName, action)
|
||||
r.observe(observeCtx, ObserveEvent{
|
||||
Level: ObserveLevelError,
|
||||
Component: "runtime",
|
||||
Operation: operation + "_panic_recovered",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"panic": fmt.Sprintf("%v", recovered),
|
||||
"panic_type": fmt.Sprintf("%T", recovered),
|
||||
"error": panicErr,
|
||||
"error_code": core.ClassifyErrorCode(panicErr),
|
||||
"stack": string(debug.Stack()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func newObserveContext(ctx context.Context, traceID string, corpusName string, action string) context.Context {
|
||||
fields := map[string]any{
|
||||
"corpus": corpusName,
|
||||
"action": action,
|
||||
}
|
||||
if traceID = strings.TrimSpace(traceID); traceID != "" {
|
||||
fields["trace_id"] = traceID
|
||||
}
|
||||
return core.WithObserveFields(ctx, fields)
|
||||
}
|
||||
|
||||
func estimateInputCount(input any) int {
|
||||
switch value := input.(type) {
|
||||
case []corpus.MemoryIngestItem:
|
||||
return len(value)
|
||||
case []corpus.WebIngestItem:
|
||||
return len(value)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeAction(action string, fallback string) string {
|
||||
action = strings.TrimSpace(action)
|
||||
if action == "" {
|
||||
return fallback
|
||||
}
|
||||
return action
|
||||
}
|
||||
|
||||
func normalizeTopK(topK int, fallback int) int {
|
||||
if topK > 0 {
|
||||
return topK
|
||||
}
|
||||
if fallback > 0 {
|
||||
return fallback
|
||||
}
|
||||
return 8
|
||||
}
|
||||
|
||||
func normalizeThreshold(threshold float64, fallback float64) float64 {
|
||||
if threshold >= 0 {
|
||||
return threshold
|
||||
}
|
||||
if fallback >= 0 {
|
||||
return fallback
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for key, value := range src {
|
||||
dst[key] = value
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
111
backend/services/rag/service.go
Normal file
111
backend/services/rag/service.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
ragconfig "github.com/LoveLosita/smartflow/backend/services/rag/config"
|
||||
)
|
||||
|
||||
// Options 描述 rag-service 需要持有的底层运行时。
|
||||
type Options struct {
|
||||
Runtime Runtime
|
||||
}
|
||||
|
||||
// Service 是 rag-service 对外暴露的统一入口。
|
||||
//
|
||||
// 职责边界:
|
||||
// 1. 负责持有运行时,并把 memory / web 两条能力线统一收口到服务层。
|
||||
// 2. 负责在服务入口内完成基于配置的运行时装配。
|
||||
// 3. 不直接承载 chunk / embed / store 的实现细节,这些细节下沉到服务树内部子包。
|
||||
type Service struct {
|
||||
runtime Runtime
|
||||
}
|
||||
|
||||
// New 使用调用方传入的运行时构造服务。
|
||||
func New(opts Options) *Service {
|
||||
return &Service{runtime: opts.Runtime}
|
||||
}
|
||||
|
||||
// NewFromConfig 基于服务树内的配置与工厂能力构造自给自足的 RAG 服务。
|
||||
func NewFromConfig(ctx context.Context, cfg ragconfig.Config, deps FactoryDeps) (*Service, error) {
|
||||
if !cfg.Enabled {
|
||||
return New(Options{}), nil
|
||||
}
|
||||
runtime, err := NewRuntimeFromConfig(ctx, cfg, deps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewWithRuntime(runtime), nil
|
||||
}
|
||||
|
||||
// Runtime 返回当前服务持有的运行时。
|
||||
func (s *Service) Runtime() Runtime {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.runtime
|
||||
}
|
||||
|
||||
// IngestMemory 写入记忆语料。
|
||||
func (s *Service) IngestMemory(ctx context.Context, req MemoryIngestRequest) (*IngestResult, error) {
|
||||
if s == nil || s.runtime == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.runtime.IngestMemory(ctx, req)
|
||||
}
|
||||
|
||||
// RetrieveMemory 检索记忆语料。
|
||||
func (s *Service) RetrieveMemory(ctx context.Context, req MemoryRetrieveRequest) (*RetrieveResult, error) {
|
||||
if s == nil || s.runtime == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.runtime.RetrieveMemory(ctx, req)
|
||||
}
|
||||
|
||||
// DeleteMemory 删除指定记忆文档。
|
||||
func (s *Service) DeleteMemory(ctx context.Context, documentIDs []string) error {
|
||||
if s == nil || s.runtime == nil {
|
||||
return nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return s.runtime.DeleteMemory(ctx, documentIDs)
|
||||
}
|
||||
|
||||
// IngestWeb 写入网页语料。
|
||||
func (s *Service) IngestWeb(ctx context.Context, req WebIngestRequest) (*IngestResult, error) {
|
||||
if s == nil || s.runtime == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.runtime.IngestWeb(ctx, req)
|
||||
}
|
||||
|
||||
// RetrieveWeb 检索网页语料。
|
||||
func (s *Service) RetrieveWeb(ctx context.Context, req WebRetrieveRequest) (*RetrieveResult, error) {
|
||||
if s == nil || s.runtime == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.runtime.RetrieveWeb(ctx, req)
|
||||
}
|
||||
|
||||
// EnsureRuntime 返回一个可继续向下传递的运行时引用。
|
||||
func (s *Service) EnsureRuntime() Runtime {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.runtime
|
||||
}
|
||||
|
||||
// SetRuntime 允许在装配阶段延迟注入运行时。
|
||||
func (s *Service) SetRuntime(runtime Runtime) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.runtime = runtime
|
||||
}
|
||||
|
||||
// NewWithRuntime 用显式运行时构造服务。
|
||||
func NewWithRuntime(runtime Runtime) *Service {
|
||||
return New(Options{Runtime: runtime})
|
||||
}
|
||||
182
backend/services/rag/store/inmemory_store.go
Normal file
182
backend/services/rag/store/inmemory_store.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
// InMemoryVectorStore 是本地开发用向量存储实现。
|
||||
//
|
||||
// 注意:
|
||||
// 1. 仅用于开发调试,不建议生产使用;
|
||||
// 2. 真实环境可替换为 MilvusStore,接口保持一致。
|
||||
type InMemoryVectorStore struct {
|
||||
mu sync.RWMutex
|
||||
rows map[string]core.VectorRow
|
||||
}
|
||||
|
||||
func NewInMemoryVectorStore() *InMemoryVectorStore {
|
||||
return &InMemoryVectorStore{
|
||||
rows: make(map[string]core.VectorRow),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *InMemoryVectorStore) Upsert(_ context.Context, rows []core.VectorRow) error {
|
||||
if s == nil {
|
||||
return errors.New("inmemory vector store is nil")
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return nil
|
||||
}
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.rows == nil {
|
||||
s.rows = make(map[string]core.VectorRow)
|
||||
}
|
||||
for _, row := range rows {
|
||||
current, exists := s.rows[row.ID]
|
||||
if exists {
|
||||
row.CreatedAt = current.CreatedAt
|
||||
row.UpdatedAt = now
|
||||
} else {
|
||||
if row.CreatedAt.IsZero() {
|
||||
row.CreatedAt = now
|
||||
}
|
||||
row.UpdatedAt = now
|
||||
}
|
||||
s.rows[row.ID] = row
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InMemoryVectorStore) Search(_ context.Context, req core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("inmemory vector store is nil")
|
||||
}
|
||||
topK := req.TopK
|
||||
if topK <= 0 {
|
||||
topK = 8
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
result := make([]core.ScoredVectorRow, 0, len(s.rows))
|
||||
for _, row := range s.rows {
|
||||
if !matchMetadataFilter(row.Metadata, req.Filter) {
|
||||
continue
|
||||
}
|
||||
score := cosineSimilarity(req.QueryVector, row.Vector)
|
||||
result = append(result, core.ScoredVectorRow{
|
||||
Row: row,
|
||||
Score: score,
|
||||
})
|
||||
}
|
||||
sort.SliceStable(result, func(i, j int) bool {
|
||||
return result[i].Score > result[j].Score
|
||||
})
|
||||
if len(result) <= topK {
|
||||
return result, nil
|
||||
}
|
||||
return result[:topK], nil
|
||||
}
|
||||
|
||||
func (s *InMemoryVectorStore) Delete(_ context.Context, ids []string) error {
|
||||
if s == nil {
|
||||
return errors.New("inmemory vector store is nil")
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, id := range ids {
|
||||
delete(s.rows, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InMemoryVectorStore) Get(_ context.Context, ids []string) ([]core.VectorRow, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("inmemory vector store is nil")
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
result := make([]core.VectorRow, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
row, exists := s.rows[id]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
result = append(result, row)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func cosineSimilarity(a, b []float32) float64 {
|
||||
if len(a) == 0 || len(b) == 0 {
|
||||
return 0
|
||||
}
|
||||
n := len(a)
|
||||
if len(b) < n {
|
||||
n = len(b)
|
||||
}
|
||||
if n == 0 {
|
||||
return 0
|
||||
}
|
||||
var dot, normA, normB float64
|
||||
for i := 0; i < n; i++ {
|
||||
av := float64(a[i])
|
||||
bv := float64(b[i])
|
||||
dot += av * bv
|
||||
normA += av * av
|
||||
normB += bv * bv
|
||||
}
|
||||
if normA == 0 || normB == 0 {
|
||||
return 0
|
||||
}
|
||||
return dot / (math.Sqrt(normA) * math.Sqrt(normB))
|
||||
}
|
||||
|
||||
func matchMetadataFilter(metadata map[string]any, filter map[string]any) bool {
|
||||
if len(filter) == 0 {
|
||||
return true
|
||||
}
|
||||
for key, wanted := range filter {
|
||||
got, exists := metadata[key]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
if !equalAny(got, wanted) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func equalAny(left any, right any) bool {
|
||||
return toString(left) == toString(right)
|
||||
}
|
||||
|
||||
func toString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return fmtAny(v)
|
||||
}
|
||||
|
||||
func fmtAny(v any) string {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
927
backend/services/rag/store/milvus_store.go
Normal file
927
backend/services/rag/store/milvus_store.go
Normal file
@@ -0,0 +1,927 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
)
|
||||
|
||||
// MilvusConfig 描述 Milvus REST 存储配置。
|
||||
type MilvusConfig struct {
|
||||
// Address 应指向 Milvus REST 入口。
|
||||
// 当前项目联调验证使用 19530;9091 仅用于 health/metrics,不承载本文实现所走的 REST API。
|
||||
Address string
|
||||
Token string
|
||||
DBName string
|
||||
CollectionName string
|
||||
RequestTimeoutMS int
|
||||
Dimension int
|
||||
MetricType string
|
||||
Logger *log.Logger
|
||||
Observer core.Observer
|
||||
}
|
||||
|
||||
// MilvusStore 是基于 Milvus REST API 的向量存储实现。
|
||||
//
|
||||
// 设计说明:
|
||||
// 1. 本实现优先保证“项目内可接入、可管理、可灰度”,不强依赖额外 SDK;
|
||||
// 2. 通过固定字段 + metadata JSON 的方式兼顾过滤能力与元数据完整性;
|
||||
// 3. collection 在首次写入时自动创建,避免启动期额外初始化脚本。
|
||||
type MilvusStore struct {
|
||||
cfg MilvusConfig
|
||||
client *http.Client
|
||||
observer core.Observer
|
||||
mu sync.Mutex
|
||||
ensured bool
|
||||
}
|
||||
|
||||
const (
|
||||
milvusPrimaryField = "id"
|
||||
milvusVectorField = "vector"
|
||||
milvusTextField = "text"
|
||||
milvusMetadataField = "metadata"
|
||||
milvusCorpusField = "corpus"
|
||||
milvusDocumentField = "document_id"
|
||||
milvusUserIDField = "user_id"
|
||||
milvusAssistantField = "assistant_id"
|
||||
milvusConvField = "conversation_id"
|
||||
milvusRunField = "run_id"
|
||||
milvusMemoryType = "memory_type"
|
||||
milvusQueryIDField = "query_id"
|
||||
milvusSessionField = "session_id"
|
||||
milvusDomainField = "domain"
|
||||
milvusChunkOrder = "chunk_order"
|
||||
milvusUpdatedAtField = "updated_at"
|
||||
)
|
||||
|
||||
var milvusFilterFieldMap = map[string]string{
|
||||
"corpus": milvusCorpusField,
|
||||
"document_id": milvusDocumentField,
|
||||
"user_id": milvusUserIDField,
|
||||
"assistant_id": milvusAssistantField,
|
||||
"conversation_id": milvusConvField,
|
||||
"run_id": milvusRunField,
|
||||
"memory_type": milvusMemoryType,
|
||||
"query_id": milvusQueryIDField,
|
||||
"session_id": milvusSessionField,
|
||||
"domain": milvusDomainField,
|
||||
"chunk_order": milvusChunkOrder,
|
||||
}
|
||||
|
||||
func NewMilvusStore(cfg MilvusConfig) (*MilvusStore, error) {
|
||||
cfg.Address = strings.TrimRight(strings.TrimSpace(cfg.Address), "/")
|
||||
if cfg.Address == "" {
|
||||
return nil, errors.New("milvus address is empty")
|
||||
}
|
||||
if cfg.CollectionName == "" {
|
||||
cfg.CollectionName = "smartflow_rag_chunks"
|
||||
}
|
||||
if cfg.MetricType == "" {
|
||||
cfg.MetricType = "COSINE"
|
||||
}
|
||||
if cfg.RequestTimeoutMS <= 0 {
|
||||
cfg.RequestTimeoutMS = 1500
|
||||
}
|
||||
if cfg.Logger == nil {
|
||||
cfg.Logger = log.Default()
|
||||
}
|
||||
if cfg.Observer == nil {
|
||||
cfg.Observer = core.NewLoggerObserver(cfg.Logger)
|
||||
}
|
||||
|
||||
return &MilvusStore{
|
||||
cfg: cfg,
|
||||
client: &http.Client{Timeout: time.Duration(cfg.RequestTimeoutMS) * time.Millisecond},
|
||||
observer: cfg.Observer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Upsert(ctx context.Context, rows []core.VectorRow) error {
|
||||
if err := s.ensureReady(); err != nil {
|
||||
return err
|
||||
}
|
||||
start := time.Now()
|
||||
if len(rows) == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := s.ensureCollection(ctx, len(rows[0].Vector)); err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "upsert",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"row_count": len(rows),
|
||||
"vector_dim": len(rows[0].Vector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
data := make([]map[string]any, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
item := mapRowToMilvusEntity(row)
|
||||
data = append(data, item)
|
||||
}
|
||||
|
||||
_, err := s.postJSON(ctx, "/v2/vectordb/entities/upsert", map[string]any{
|
||||
"collectionName": s.cfg.CollectionName,
|
||||
"data": data,
|
||||
"dbName": blankToNil(s.cfg.DBName),
|
||||
})
|
||||
if err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "upsert",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"row_count": len(rows),
|
||||
"vector_dim": len(rows[0].Vector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "upsert",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"row_count": len(rows),
|
||||
"vector_dim": len(rows[0].Vector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Search(ctx context.Context, req core.VectorSearchRequest) ([]core.ScoredVectorRow, error) {
|
||||
if err := s.ensureReady(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
start := time.Now()
|
||||
if len(req.QueryVector) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if err := s.ensureCollection(ctx, len(req.QueryVector)); err != nil {
|
||||
if isMilvusCollectionMissing(err) {
|
||||
return nil, nil
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filterExpr, err := buildMilvusFilter(req.Filter)
|
||||
if err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body := map[string]any{
|
||||
"collectionName": s.cfg.CollectionName,
|
||||
"data": [][]float32{req.QueryVector},
|
||||
"annsField": milvusVectorField,
|
||||
"limit": normalizeMilvusTopK(req.TopK),
|
||||
"outputFields": milvusOutputFields(false),
|
||||
}
|
||||
if filterExpr != "" {
|
||||
body["filter"] = filterExpr
|
||||
}
|
||||
if s.cfg.DBName != "" {
|
||||
body["dbName"] = s.cfg.DBName
|
||||
}
|
||||
|
||||
respBody, err := s.postJSON(ctx, "/v2/vectordb/entities/search", body)
|
||||
if err != nil {
|
||||
if isMilvusCollectionMissing(err) {
|
||||
return nil, nil
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp milvusSearchResponse
|
||||
if err = json.Unmarshal(respBody, &resp); err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
if resp.Code != 0 && resp.Code != 200 {
|
||||
err = fmt.Errorf("milvus search failed: code=%d message=%s", resp.Code, resp.Message)
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]core.ScoredVectorRow, 0, len(resp.Data))
|
||||
for _, item := range resp.Data {
|
||||
row, score := item.toVectorRow()
|
||||
result = append(result, core.ScoredVectorRow{
|
||||
Row: row,
|
||||
Score: score,
|
||||
})
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "search",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"top_k": req.TopK,
|
||||
"filter_count": len(req.Filter),
|
||||
"vector_dim": len(req.QueryVector),
|
||||
"result_count": len(result),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Delete(ctx context.Context, ids []string) error {
|
||||
if err := s.ensureReady(); err != nil {
|
||||
return err
|
||||
}
|
||||
start := time.Now()
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
filter := fmt.Sprintf(`%s in [%s]`, milvusPrimaryField, joinQuotedStrings(ids))
|
||||
_, err := s.postJSON(ctx, "/v2/vectordb/entities/delete", map[string]any{
|
||||
"collectionName": s.cfg.CollectionName,
|
||||
"filter": filter,
|
||||
"dbName": blankToNil(s.cfg.DBName),
|
||||
})
|
||||
if isMilvusCollectionMissing(err) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "delete",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"id_count": len(ids),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "delete",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"id_count": len(ids),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *MilvusStore) Get(ctx context.Context, ids []string) ([]core.VectorRow, error) {
|
||||
if err := s.ensureReady(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
start := time.Now()
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
respBody, err := s.postJSON(ctx, "/v2/vectordb/entities/get", map[string]any{
|
||||
"collectionName": s.cfg.CollectionName,
|
||||
"id": ids,
|
||||
"outputFields": milvusOutputFields(true),
|
||||
"dbName": blankToNil(s.cfg.DBName),
|
||||
})
|
||||
if err != nil {
|
||||
if isMilvusCollectionMissing(err) {
|
||||
return nil, nil
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "get",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"id_count": len(ids),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp milvusGetResponse
|
||||
if err = json.Unmarshal(respBody, &resp); err != nil {
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "get",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"id_count": len(ids),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
if resp.Code != 0 && resp.Code != 200 {
|
||||
err = fmt.Errorf("milvus get failed: code=%d message=%s", resp.Code, resp.Message)
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "get",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"id_count": len(ids),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows := make([]core.VectorRow, 0, len(resp.Data))
|
||||
for _, item := range resp.Data {
|
||||
rows = append(rows, mapMilvusRow(item, true))
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "get",
|
||||
Fields: map[string]any{
|
||||
"status": "success",
|
||||
"id_count": len(ids),
|
||||
"row_count": len(rows),
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) ensureCollection(ctx context.Context, dimension int) error {
|
||||
if err := s.ensureReady(); err != nil {
|
||||
return err
|
||||
}
|
||||
start := time.Now()
|
||||
if dimension <= 0 {
|
||||
dimension = s.cfg.Dimension
|
||||
}
|
||||
if dimension <= 0 {
|
||||
return errors.New("milvus vector dimension is invalid")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.ensured {
|
||||
return nil
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"collectionName": s.cfg.CollectionName,
|
||||
"schema": map[string]any{
|
||||
"autoId": false,
|
||||
"enabledDynamicField": false,
|
||||
"fields": []map[string]any{
|
||||
buildVarcharField(milvusPrimaryField, true, 256),
|
||||
buildVectorField(milvusVectorField, dimension),
|
||||
buildVarcharField(milvusTextField, false, 65535),
|
||||
{"fieldName": milvusMetadataField, "dataType": "JSON"},
|
||||
buildVarcharField(milvusCorpusField, false, 64),
|
||||
buildVarcharField(milvusDocumentField, false, 256),
|
||||
{"fieldName": milvusUserIDField, "dataType": "Int64"},
|
||||
buildVarcharField(milvusAssistantField, false, 128),
|
||||
buildVarcharField(milvusConvField, false, 128),
|
||||
buildVarcharField(milvusRunField, false, 128),
|
||||
buildVarcharField(milvusMemoryType, false, 64),
|
||||
buildVarcharField(milvusQueryIDField, false, 128),
|
||||
buildVarcharField(milvusSessionField, false, 128),
|
||||
buildVarcharField(milvusDomainField, false, 128),
|
||||
{"fieldName": milvusChunkOrder, "dataType": "Int64"},
|
||||
{"fieldName": milvusUpdatedAtField, "dataType": "Int64"},
|
||||
},
|
||||
},
|
||||
"indexParams": []map[string]any{
|
||||
{
|
||||
"fieldName": milvusVectorField,
|
||||
"indexName": milvusVectorField,
|
||||
"metricType": s.cfg.MetricType,
|
||||
"indexType": "AUTOINDEX",
|
||||
},
|
||||
},
|
||||
}
|
||||
if s.cfg.DBName != "" {
|
||||
payload["dbName"] = s.cfg.DBName
|
||||
}
|
||||
|
||||
_, err := s.postJSON(ctx, "/v2/vectordb/collections/create", payload)
|
||||
if err != nil {
|
||||
if isMilvusAlreadyExists(err) {
|
||||
s.ensured = true
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "ensure_collection",
|
||||
Fields: map[string]any{
|
||||
"status": "already_exists",
|
||||
"vector_dim": dimension,
|
||||
"metric_type": s.cfg.MetricType,
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelError,
|
||||
Component: "store",
|
||||
Operation: "ensure_collection",
|
||||
Fields: map[string]any{
|
||||
"status": "failed",
|
||||
"vector_dim": dimension,
|
||||
"metric_type": s.cfg.MetricType,
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
"error": err,
|
||||
"error_code": core.ClassifyErrorCode(err),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
s.ensured = true
|
||||
s.observe(ctx, core.ObserveEvent{
|
||||
Level: core.ObserveLevelInfo,
|
||||
Component: "store",
|
||||
Operation: "ensure_collection",
|
||||
Fields: map[string]any{
|
||||
"status": "created",
|
||||
"vector_dim": dimension,
|
||||
"metric_type": s.cfg.MetricType,
|
||||
"latency_ms": time.Since(start).Milliseconds(),
|
||||
},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) postJSON(ctx context.Context, path string, payload map[string]any) ([]byte, error) {
|
||||
if err := s.ensureReady(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.cfg.Address+path, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if token := strings.TrimSpace(s.cfg.Token); token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, readErr
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("milvus http failed: status=%d body=%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var basic milvusBasicResponse
|
||||
if jsonErr := json.Unmarshal(respBody, &basic); jsonErr == nil {
|
||||
if basic.Code != 0 && basic.Code != 200 {
|
||||
return nil, fmt.Errorf("milvus api failed: code=%d message=%s", basic.Code, basic.Message)
|
||||
}
|
||||
}
|
||||
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) ensureReady() error {
|
||||
if s == nil || s.client == nil {
|
||||
return errors.New("milvus store is not initialized")
|
||||
}
|
||||
if strings.TrimSpace(s.cfg.Address) == "" {
|
||||
return errors.New("milvus address is empty")
|
||||
}
|
||||
if strings.TrimSpace(s.cfg.CollectionName) == "" {
|
||||
return errors.New("milvus collection name is empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MilvusStore) observe(ctx context.Context, event core.ObserveEvent) {
|
||||
if s == nil || s.observer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
fields := cloneMap(event.Fields)
|
||||
fields["store"] = "milvus"
|
||||
fields["collection"] = s.cfg.CollectionName
|
||||
if strings.TrimSpace(s.cfg.DBName) != "" {
|
||||
fields["db_name"] = s.cfg.DBName
|
||||
}
|
||||
|
||||
s.observer.Observe(ctx, core.ObserveEvent{
|
||||
Level: event.Level,
|
||||
Component: event.Component,
|
||||
Operation: event.Operation,
|
||||
Fields: fields,
|
||||
})
|
||||
}
|
||||
|
||||
func mapRowToMilvusEntity(row core.VectorRow) map[string]any {
|
||||
metadata := cloneMap(row.Metadata)
|
||||
entity := map[string]any{
|
||||
milvusPrimaryField: row.ID,
|
||||
milvusVectorField: row.Vector,
|
||||
milvusTextField: row.Text,
|
||||
milvusMetadataField: metadata,
|
||||
milvusCorpusField: asString(metadata["corpus"]),
|
||||
milvusDocumentField: asString(metadata["document_id"]),
|
||||
milvusUpdatedAtField: func() int64 {
|
||||
if row.UpdatedAt.IsZero() {
|
||||
return time.Now().UnixMilli()
|
||||
}
|
||||
return row.UpdatedAt.UnixMilli()
|
||||
}(),
|
||||
}
|
||||
assignMilvusScalar(entity, milvusUserIDField, metadata["user_id"])
|
||||
assignMilvusScalar(entity, milvusAssistantField, metadata["assistant_id"])
|
||||
assignMilvusScalar(entity, milvusConvField, metadata["conversation_id"])
|
||||
assignMilvusScalar(entity, milvusRunField, metadata["run_id"])
|
||||
assignMilvusScalar(entity, milvusMemoryType, metadata["memory_type"])
|
||||
assignMilvusScalar(entity, milvusQueryIDField, metadata["query_id"])
|
||||
assignMilvusScalar(entity, milvusSessionField, metadata["session_id"])
|
||||
assignMilvusScalar(entity, milvusDomainField, metadata["domain"])
|
||||
assignMilvusScalar(entity, milvusChunkOrder, metadata["chunk_order"])
|
||||
return entity
|
||||
}
|
||||
|
||||
func assignMilvusScalar(target map[string]any, field string, value any) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
switch field {
|
||||
case milvusUserIDField, milvusChunkOrder:
|
||||
if parsed, ok := toInt64(value); ok {
|
||||
target[field] = parsed
|
||||
}
|
||||
default:
|
||||
if text := asString(value); text != "" {
|
||||
target[field] = text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildMilvusFilter(filter map[string]any) (string, error) {
|
||||
if len(filter) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(filter))
|
||||
for key, value := range filter {
|
||||
field, ok := milvusFilterFieldMap[key]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unsupported milvus filter key: %s", key)
|
||||
}
|
||||
switch field {
|
||||
case milvusUserIDField, milvusChunkOrder:
|
||||
parsed, parseOK := toInt64(value)
|
||||
if !parseOK {
|
||||
return "", fmt.Errorf("milvus filter key=%s expects integer", key)
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("%s == %d", field, parsed))
|
||||
default:
|
||||
text := escapeMilvusString(asString(value))
|
||||
parts = append(parts, fmt.Sprintf(`%s == "%s"`, field, text))
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " and "), nil
|
||||
}
|
||||
|
||||
func buildVarcharField(name string, isPrimary bool, maxLength int) map[string]any {
|
||||
field := map[string]any{
|
||||
"fieldName": name,
|
||||
"dataType": "VarChar",
|
||||
"elementTypeParams": map[string]any{"max_length": maxLength},
|
||||
}
|
||||
if isPrimary {
|
||||
field["isPrimary"] = true
|
||||
}
|
||||
return field
|
||||
}
|
||||
|
||||
func buildVectorField(name string, dimension int) map[string]any {
|
||||
return map[string]any{
|
||||
"fieldName": name,
|
||||
"dataType": "FloatVector",
|
||||
"elementTypeParams": map[string]any{"dim": dimension},
|
||||
}
|
||||
}
|
||||
|
||||
func milvusOutputFields(includeVector bool) []string {
|
||||
fields := []string{
|
||||
milvusTextField,
|
||||
milvusMetadataField,
|
||||
milvusCorpusField,
|
||||
milvusDocumentField,
|
||||
milvusUserIDField,
|
||||
milvusAssistantField,
|
||||
milvusConvField,
|
||||
milvusRunField,
|
||||
milvusMemoryType,
|
||||
milvusQueryIDField,
|
||||
milvusSessionField,
|
||||
milvusDomainField,
|
||||
milvusChunkOrder,
|
||||
milvusUpdatedAtField,
|
||||
}
|
||||
if includeVector {
|
||||
fields = append(fields, milvusVectorField)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func normalizeMilvusTopK(topK int) int {
|
||||
if topK <= 0 {
|
||||
return 8
|
||||
}
|
||||
return topK
|
||||
}
|
||||
|
||||
func blankToNil(v string) any {
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func escapeMilvusString(v string) string {
|
||||
v = strings.ReplaceAll(v, `\`, `\\`)
|
||||
return strings.ReplaceAll(v, `"`, `\"`)
|
||||
}
|
||||
|
||||
func joinQuotedStrings(values []string) string {
|
||||
parts := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
parts = append(parts, fmt.Sprintf(`"%s"`, escapeMilvusString(value)))
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
func cloneMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for key, value := range src {
|
||||
dst[key] = value
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
|
||||
func toInt64(v any) (int64, bool) {
|
||||
switch value := v.(type) {
|
||||
case int:
|
||||
return int64(value), true
|
||||
case int32:
|
||||
return int64(value), true
|
||||
case int64:
|
||||
return value, true
|
||||
case float64:
|
||||
return int64(value), true
|
||||
case json.Number:
|
||||
parsed, err := value.Int64()
|
||||
return parsed, err == nil
|
||||
case string:
|
||||
parsed, err := strconv.ParseInt(strings.TrimSpace(value), 10, 64)
|
||||
return parsed, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func isMilvusAlreadyExists(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
text := strings.ToLower(err.Error())
|
||||
return strings.Contains(text, "already exist") ||
|
||||
strings.Contains(text, "already exists") ||
|
||||
strings.Contains(text, "duplicate collection")
|
||||
}
|
||||
|
||||
func isMilvusCollectionMissing(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
text := strings.ToLower(err.Error())
|
||||
return strings.Contains(text, "can't find collection") || strings.Contains(text, "collection not found")
|
||||
}
|
||||
|
||||
type milvusBasicResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type milvusSearchResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data []milvusSearchItem `json:"data"`
|
||||
}
|
||||
|
||||
type milvusSearchItem map[string]any
|
||||
|
||||
func (m milvusSearchItem) toVectorRow() (core.VectorRow, float64) {
|
||||
row := mapMilvusRow(map[string]any(m), false)
|
||||
score := 0.0
|
||||
if value, ok := m["distance"].(float64); ok {
|
||||
score = value
|
||||
}
|
||||
return row, score
|
||||
}
|
||||
|
||||
type milvusGetResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data []map[string]any `json:"data"`
|
||||
}
|
||||
|
||||
func mapMilvusRow(raw map[string]any, includeVector bool) core.VectorRow {
|
||||
metadata := cloneMap(readMetadataMap(raw[milvusMetadataField]))
|
||||
assignMetadataIfPresent(metadata, "corpus", raw[milvusCorpusField])
|
||||
assignMetadataIfPresent(metadata, "document_id", raw[milvusDocumentField])
|
||||
assignMetadataIfPresent(metadata, "user_id", raw[milvusUserIDField])
|
||||
assignMetadataIfPresent(metadata, "assistant_id", raw[milvusAssistantField])
|
||||
assignMetadataIfPresent(metadata, "conversation_id", raw[milvusConvField])
|
||||
assignMetadataIfPresent(metadata, "run_id", raw[milvusRunField])
|
||||
assignMetadataIfPresent(metadata, "memory_type", raw[milvusMemoryType])
|
||||
assignMetadataIfPresent(metadata, "query_id", raw[milvusQueryIDField])
|
||||
assignMetadataIfPresent(metadata, "session_id", raw[milvusSessionField])
|
||||
assignMetadataIfPresent(metadata, "domain", raw[milvusDomainField])
|
||||
assignMetadataIfPresent(metadata, "chunk_order", raw[milvusChunkOrder])
|
||||
|
||||
row := core.VectorRow{
|
||||
ID: asString(raw[milvusPrimaryField]),
|
||||
Text: asString(raw[milvusTextField]),
|
||||
Metadata: metadata,
|
||||
}
|
||||
if row.ID == "" {
|
||||
row.ID = asString(raw["id"])
|
||||
}
|
||||
if includeVector {
|
||||
row.Vector = readFloat32Vector(raw[milvusVectorField])
|
||||
}
|
||||
return row
|
||||
}
|
||||
|
||||
func readMetadataMap(value any) map[string]any {
|
||||
switch data := value.(type) {
|
||||
case map[string]any:
|
||||
return data
|
||||
default:
|
||||
return map[string]any{}
|
||||
}
|
||||
}
|
||||
|
||||
func readFloat32Vector(value any) []float32 {
|
||||
switch vector := value.(type) {
|
||||
case []float32:
|
||||
return vector
|
||||
case []any:
|
||||
result := make([]float32, 0, len(vector))
|
||||
for _, item := range vector {
|
||||
switch number := item.(type) {
|
||||
case float64:
|
||||
result = append(result, float32(number))
|
||||
case float32:
|
||||
result = append(result, number)
|
||||
}
|
||||
}
|
||||
return result
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func assignMetadataIfPresent(target map[string]any, key string, value any) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(typed) == "" {
|
||||
return
|
||||
}
|
||||
target[key] = strings.TrimSpace(typed)
|
||||
default:
|
||||
target[key] = typed
|
||||
}
|
||||
}
|
||||
9
backend/services/rag/store/vector_store.go
Normal file
9
backend/services/rag/store/vector_store.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package store
|
||||
|
||||
import "github.com/LoveLosita/smartflow/backend/services/rag/core"
|
||||
|
||||
// EnsureCompile 用于静态校验实现是否满足接口。
|
||||
func EnsureCompile() {
|
||||
var _ core.VectorStore = (*InMemoryVectorStore)(nil)
|
||||
var _ core.VectorStore = (*MilvusStore)(nil)
|
||||
}
|
||||
Reference in New Issue
Block a user