✨ 为原有流式聊天链路补充“聊天结束后异步调用 LLM 生成对话标题并落库”的机制,相关测试已通过 📄 新增“获取对话元信息”接口,便于前端统一获取对话的各类信息,包括上述异步生成的标题
128 lines
3.8 KiB
Go
128 lines
3.8 KiB
Go
package dao
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"strings"
|
||
|
||
"github.com/LoveLosita/smartflow/backend/model"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type AgentDAO struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
func NewAgentDAO(db *gorm.DB) *AgentDAO {
|
||
return &AgentDAO{db: db}
|
||
}
|
||
|
||
func (a *AgentDAO) SaveChatHistory(ctx context.Context, userID int, conversationID string, role, message string) error {
|
||
userChat := model.ChatHistory{
|
||
UserID: userID,
|
||
MessageContent: &message,
|
||
Role: &role,
|
||
ChatID: conversationID,
|
||
}
|
||
if err := a.db.WithContext(ctx).Create(&userChat).Error; err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (a *AgentDAO) CreateNewChat(userID int, chatID string) (int64, error) {
|
||
chat := model.AgentChat{
|
||
ChatID: chatID,
|
||
UserID: userID,
|
||
MessageCount: 0,
|
||
LastMessageAt: nil,
|
||
}
|
||
if err := a.db.Create(&chat).Error; err != nil {
|
||
return 0, err
|
||
}
|
||
return chat.ID, nil
|
||
}
|
||
|
||
func (a *AgentDAO) GetUserChatHistories(ctx context.Context, userID, limit int, chatID string) ([]model.ChatHistory, error) {
|
||
var histories []model.ChatHistory
|
||
err := a.db.WithContext(ctx).
|
||
Where("user_id = ? AND chat_id = ?", userID, chatID).
|
||
Order("created_at desc").
|
||
Limit(limit).
|
||
Find(&histories).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
// 保留“最近 N 条”的前提下,反转为时间正序,便于模型消费
|
||
for i, j := 0, len(histories)-1; i < j; i, j = i+1, j-1 {
|
||
histories[i], histories[j] = histories[j], histories[i]
|
||
}
|
||
return histories, nil
|
||
}
|
||
|
||
func (a *AgentDAO) IfChatExists(ctx context.Context, userID int, chatID string) (bool, error) {
|
||
var chat model.AgentChat
|
||
err := a.db.WithContext(ctx).Where("user_id = ? AND chat_id = ?", userID, chatID).First(&chat).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return false, nil // 没有找到记录,表示会话不存在
|
||
}
|
||
return false, err
|
||
}
|
||
return true, nil
|
||
}
|
||
|
||
// GetConversationMeta 查询单个会话的元信息。
|
||
// 用途:
|
||
// 1) 给前端提供“当前会话标题/消息数/最近消息时间”等展示字段;
|
||
// 2) 与流式聊天接口解耦,避免在 SSE 头部里塞动态标题。
|
||
func (a *AgentDAO) GetConversationMeta(ctx context.Context, userID int, chatID string) (*model.AgentChat, error) {
|
||
var chat model.AgentChat
|
||
err := a.db.WithContext(ctx).
|
||
Select("chat_id", "title", "message_count", "last_message_at", "status").
|
||
Where("user_id = ? AND chat_id = ?", userID, chatID).
|
||
First(&chat).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &chat, nil
|
||
}
|
||
|
||
// GetConversationTitle 读取当前会话标题。
|
||
// 返回值说明:
|
||
// 1) title:标题内容(若为空表示尚未生成);
|
||
// 2) exists:会话是否存在;
|
||
// 3) err:数据库错误。
|
||
func (a *AgentDAO) GetConversationTitle(ctx context.Context, userID int, chatID string) (title string, exists bool, err error) {
|
||
var chat model.AgentChat
|
||
queryErr := a.db.WithContext(ctx).
|
||
Select("title").
|
||
Where("user_id = ? AND chat_id = ?", userID, chatID).
|
||
First(&chat).Error
|
||
if queryErr != nil {
|
||
if errors.Is(queryErr, gorm.ErrRecordNotFound) {
|
||
return "", false, nil
|
||
}
|
||
return "", false, queryErr
|
||
}
|
||
if chat.Title == nil {
|
||
return "", true, nil
|
||
}
|
||
return strings.TrimSpace(*chat.Title), true, nil
|
||
}
|
||
|
||
// UpdateConversationTitleIfEmpty 仅在标题为空时写入会话标题。
|
||
// 设计目的:
|
||
// 1) 避免每轮对话都覆盖已有标题;
|
||
// 2) 并发下保持幂等:多个 goroutine 同时尝试写标题,最终只会成功一次。
|
||
func (a *AgentDAO) UpdateConversationTitleIfEmpty(ctx context.Context, userID int, chatID, title string) error {
|
||
normalized := strings.TrimSpace(title)
|
||
if normalized == "" {
|
||
return nil
|
||
}
|
||
return a.db.WithContext(ctx).
|
||
Model(&model.AgentChat{}).
|
||
Where("user_id = ? AND chat_id = ? AND (title IS NULL OR title = '')", userID, chatID).
|
||
Update("title", normalized).Error
|
||
}
|