Files
LoveLosita 26c350f378 Version: 0.4.4.dev.260307
feat: 🚀 增强会话管理与缓存机制

* 会话 ID 空值兜底,若 `conversation_id` 为空时自动生成 UUID
* 在响应头写入 `X-Conversation-ID`,供前端使用,保持同一会话状态

perf:  会话状态缓存优化

* 当缓存未命中但 DB 已确认/创建会话后,调用 `SetConversationStatus` 回写 Redis
* 缓存写回失败时记录日志,不中断聊天主流程,确保业务流畅性

fix: 🐛 修复历史消息顺序问题与编译错误

* 修复历史消息顺序问题,保证返回的 N 条历史消息按时间正序喂给模型

  * 通过反转 `created_at desc` 查询结果的切片,确保模型输入顺序正确
* 修复 `fmt.Errorf` 参数不匹配问题,修正编译错误
* 整理 `agent-cache.go` 为标准 UTF-8 编码,避免 Go 编译报错 `invalid UTF-8 encoding`

feat: 🛠️ 独立构建 MCP 服务器

* 使用 `Codex` 构建独立于后端的 MCP 服务器,简化与 Codex 的协作
* 通过该服务器方便 Codex 直接测试和查看 Redis 与 MySQL 中的数据
2026-03-07 15:25:40 +08:00

199 lines
4.4 KiB
Go

package store
import (
"context"
"database/sql"
"fmt"
"net/url"
"reflect"
"strconv"
"strings"
"time"
cfgpkg "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
"github.com/go-sql-driver/mysql"
)
type MySQLClient struct {
db *sql.DB
}
type QueryColumn struct {
Name string `json:"name"`
DatabaseType string `json:"databaseType"`
Nullable *bool `json:"nullable,omitempty"`
ScanType string `json:"scanType,omitempty"`
}
type QueryResult struct {
Columns []QueryColumn `json:"columns"`
Rows []map[string]any `json:"rows"`
RowCount int `json:"rowCount"`
Truncated bool `json:"truncated"`
DurationMs int64 `json:"durationMs"`
}
func NewMySQLClient(ctx context.Context, cfg cfgpkg.MySQLConfig) (*MySQLClient, error) {
dsn := mysql.Config{
User: cfg.User,
Passwd: cfg.Password,
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Net: "tcp",
DBName: cfg.Database,
AllowNativePasswords: true,
ParseTime: true,
}
applyMySQLParams(&dsn, cfg.Params)
db, err := sql.Open("mysql", dsn.FormatDSN())
if err != nil {
return nil, fmt.Errorf("open mysql: %w", err)
}
db.SetConnMaxLifetime(5 * time.Minute)
db.SetMaxOpenConns(5)
db.SetMaxIdleConns(5)
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
if err := db.PingContext(pingCtx); err != nil {
_ = db.Close()
return nil, fmt.Errorf("ping mysql: %w", err)
}
return &MySQLClient{db: db}, nil
}
func applyMySQLParams(cfg *mysql.Config, raw string) {
if strings.TrimSpace(raw) == "" {
return
}
values, err := url.ParseQuery(raw)
if err != nil {
return
}
if cfg.Params == nil {
cfg.Params = make(map[string]string)
}
for key, valueList := range values {
if len(valueList) == 0 {
continue
}
value := valueList[0]
switch strings.ToLower(key) {
case "parsetime":
if parsed, err := strconv.ParseBool(value); err == nil {
cfg.ParseTime = parsed
}
case "loc":
if loc, err := time.LoadLocation(value); err == nil {
cfg.Loc = loc
}
case "collation":
cfg.Collation = value
default:
cfg.Params[key] = value
}
}
}
func (c *MySQLClient) Close() error {
if c == nil || c.db == nil {
return nil
}
return c.db.Close()
}
func (c *MySQLClient) QueryReadOnly(ctx context.Context, query string, args []any, maxRows int) (QueryResult, error) {
start := time.Now()
rows, err := c.db.QueryContext(ctx, query, args...)
if err != nil {
return QueryResult{}, err
}
defer rows.Close()
columnTypes, err := rows.ColumnTypes()
if err != nil {
return QueryResult{}, err
}
columns := make([]QueryColumn, 0, len(columnTypes))
columnNames := make([]string, 0, len(columnTypes))
for _, ct := range columnTypes {
var nullablePtr *bool
if nullable, ok := ct.Nullable(); ok {
n := nullable
nullablePtr = &n
}
scanType := ""
if st := ct.ScanType(); st != nil {
scanType = st.String()
}
columns = append(columns, QueryColumn{
Name: ct.Name(),
DatabaseType: ct.DatabaseTypeName(),
Nullable: nullablePtr,
ScanType: scanType,
})
columnNames = append(columnNames, ct.Name())
}
resultRows := make([]map[string]any, 0)
truncated := false
for rows.Next() {
if len(resultRows) >= maxRows {
truncated = true
break
}
scanned, err := scanRow(rows, len(columnNames))
if err != nil {
return QueryResult{}, err
}
rowMap := make(map[string]any, len(columnNames))
for i, name := range columnNames {
rowMap[name] = normalizeValue(scanned[i])
}
resultRows = append(resultRows, rowMap)
}
if err := rows.Err(); err != nil {
return QueryResult{}, err
}
return QueryResult{
Columns: columns,
Rows: resultRows,
RowCount: len(resultRows),
Truncated: truncated,
DurationMs: time.Since(start).Milliseconds(),
}, nil
}
func scanRow(rows *sql.Rows, size int) ([]any, error) {
dest := make([]any, size)
holders := make([]any, size)
for i := range dest {
holders[i] = &dest[i]
}
if err := rows.Scan(holders...); err != nil {
return nil, err
}
return dest, nil
}
func normalizeValue(v any) any {
switch val := v.(type) {
case nil:
return nil
case []byte:
return string(val)
case time.Time:
return val.Format(time.RFC3339Nano)
default:
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr && rv.IsNil() {
return nil
}
return v
}
}