Version: 0.4.4.dev.260307
feat: 🚀 增强会话管理与缓存机制 * 会话 ID 空值兜底,若 `conversation_id` 为空时自动生成 UUID * 在响应头写入 `X-Conversation-ID`,供前端使用,保持同一会话状态 perf: ⚡ 会话状态缓存优化 * 当缓存未命中但 DB 已确认/创建会话后,调用 `SetConversationStatus` 回写 Redis * 缓存写回失败时记录日志,不中断聊天主流程,确保业务流畅性 fix: 🐛 修复历史消息顺序问题与编译错误 * 修复历史消息顺序问题,保证返回的 N 条历史消息按时间正序喂给模型 * 通过反转 `created_at desc` 查询结果的切片,确保模型输入顺序正确 * 修复 `fmt.Errorf` 参数不匹配问题,修正编译错误 * 整理 `agent-cache.go` 为标准 UTF-8 编码,避免 Go 编译报错 `invalid UTF-8 encoding` feat: 🛠️ 独立构建 MCP 服务器 * 使用 `Codex` 构建独立于后端的 MCP 服务器,简化与 Codex 的协作 * 通过该服务器方便 Codex 直接测试和查看 Redis 与 MySQL 中的数据
This commit is contained in:
198
infra/smartflow-mcp-server/internal/store/mysql.go
Normal file
198
infra/smartflow-mcp-server/internal/store/mysql.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cfgpkg "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
type MySQLClient struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
type QueryColumn struct {
|
||||
Name string `json:"name"`
|
||||
DatabaseType string `json:"databaseType"`
|
||||
Nullable *bool `json:"nullable,omitempty"`
|
||||
ScanType string `json:"scanType,omitempty"`
|
||||
}
|
||||
|
||||
type QueryResult struct {
|
||||
Columns []QueryColumn `json:"columns"`
|
||||
Rows []map[string]any `json:"rows"`
|
||||
RowCount int `json:"rowCount"`
|
||||
Truncated bool `json:"truncated"`
|
||||
DurationMs int64 `json:"durationMs"`
|
||||
}
|
||||
|
||||
func NewMySQLClient(ctx context.Context, cfg cfgpkg.MySQLConfig) (*MySQLClient, error) {
|
||||
dsn := mysql.Config{
|
||||
User: cfg.User,
|
||||
Passwd: cfg.Password,
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
Net: "tcp",
|
||||
DBName: cfg.Database,
|
||||
AllowNativePasswords: true,
|
||||
ParseTime: true,
|
||||
}
|
||||
|
||||
applyMySQLParams(&dsn, cfg.Params)
|
||||
|
||||
db, err := sql.Open("mysql", dsn.FormatDSN())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open mysql: %w", err)
|
||||
}
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
db.SetMaxOpenConns(5)
|
||||
db.SetMaxIdleConns(5)
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := db.PingContext(pingCtx); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("ping mysql: %w", err)
|
||||
}
|
||||
|
||||
return &MySQLClient{db: db}, nil
|
||||
}
|
||||
|
||||
func applyMySQLParams(cfg *mysql.Config, raw string) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return
|
||||
}
|
||||
values, err := url.ParseQuery(raw)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if cfg.Params == nil {
|
||||
cfg.Params = make(map[string]string)
|
||||
}
|
||||
for key, valueList := range values {
|
||||
if len(valueList) == 0 {
|
||||
continue
|
||||
}
|
||||
value := valueList[0]
|
||||
switch strings.ToLower(key) {
|
||||
case "parsetime":
|
||||
if parsed, err := strconv.ParseBool(value); err == nil {
|
||||
cfg.ParseTime = parsed
|
||||
}
|
||||
case "loc":
|
||||
if loc, err := time.LoadLocation(value); err == nil {
|
||||
cfg.Loc = loc
|
||||
}
|
||||
case "collation":
|
||||
cfg.Collation = value
|
||||
default:
|
||||
cfg.Params[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MySQLClient) Close() error {
|
||||
if c == nil || c.db == nil {
|
||||
return nil
|
||||
}
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func (c *MySQLClient) QueryReadOnly(ctx context.Context, query string, args []any, maxRows int) (QueryResult, error) {
|
||||
start := time.Now()
|
||||
rows, err := c.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return QueryResult{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columnTypes, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return QueryResult{}, err
|
||||
}
|
||||
|
||||
columns := make([]QueryColumn, 0, len(columnTypes))
|
||||
columnNames := make([]string, 0, len(columnTypes))
|
||||
for _, ct := range columnTypes {
|
||||
var nullablePtr *bool
|
||||
if nullable, ok := ct.Nullable(); ok {
|
||||
n := nullable
|
||||
nullablePtr = &n
|
||||
}
|
||||
scanType := ""
|
||||
if st := ct.ScanType(); st != nil {
|
||||
scanType = st.String()
|
||||
}
|
||||
columns = append(columns, QueryColumn{
|
||||
Name: ct.Name(),
|
||||
DatabaseType: ct.DatabaseTypeName(),
|
||||
Nullable: nullablePtr,
|
||||
ScanType: scanType,
|
||||
})
|
||||
columnNames = append(columnNames, ct.Name())
|
||||
}
|
||||
|
||||
resultRows := make([]map[string]any, 0)
|
||||
truncated := false
|
||||
for rows.Next() {
|
||||
if len(resultRows) >= maxRows {
|
||||
truncated = true
|
||||
break
|
||||
}
|
||||
scanned, err := scanRow(rows, len(columnNames))
|
||||
if err != nil {
|
||||
return QueryResult{}, err
|
||||
}
|
||||
rowMap := make(map[string]any, len(columnNames))
|
||||
for i, name := range columnNames {
|
||||
rowMap[name] = normalizeValue(scanned[i])
|
||||
}
|
||||
resultRows = append(resultRows, rowMap)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return QueryResult{}, err
|
||||
}
|
||||
|
||||
return QueryResult{
|
||||
Columns: columns,
|
||||
Rows: resultRows,
|
||||
RowCount: len(resultRows),
|
||||
Truncated: truncated,
|
||||
DurationMs: time.Since(start).Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func scanRow(rows *sql.Rows, size int) ([]any, error) {
|
||||
dest := make([]any, size)
|
||||
holders := make([]any, size)
|
||||
for i := range dest {
|
||||
holders[i] = &dest[i]
|
||||
}
|
||||
if err := rows.Scan(holders...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
func normalizeValue(v any) any {
|
||||
switch val := v.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case []byte:
|
||||
return string(val)
|
||||
case time.Time:
|
||||
return val.Format(time.RFC3339Nano)
|
||||
default:
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
}
|
||||
189
infra/smartflow-mcp-server/internal/store/redis.go
Normal file
189
infra/smartflow-mcp-server/internal/store/redis.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cfgpkg "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
type RedisClient struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
type RedisGetResult struct {
|
||||
Exists bool `json:"exists"`
|
||||
Key string `json:"key"`
|
||||
Type string `json:"type"`
|
||||
Value any `json:"value,omitempty"`
|
||||
Truncated bool `json:"truncated"`
|
||||
DurationMs int64 `json:"durationMs"`
|
||||
}
|
||||
|
||||
type RedisScanResult struct {
|
||||
Pattern string `json:"pattern"`
|
||||
Keys []string `json:"keys"`
|
||||
Returned int `json:"returned"`
|
||||
NextCursor uint64 `json:"nextCursor"`
|
||||
Truncated bool `json:"truncated"`
|
||||
DurationMs int64 `json:"durationMs"`
|
||||
}
|
||||
|
||||
func NewRedisClient(ctx context.Context, cfg cfgpkg.RedisConfig) (*RedisClient, error) {
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
})
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := client.Ping(pingCtx).Err(); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, fmt.Errorf("ping redis: %w", err)
|
||||
}
|
||||
return &RedisClient{client: client}, nil
|
||||
}
|
||||
|
||||
func (c *RedisClient) Close() error {
|
||||
if c == nil || c.client == nil {
|
||||
return nil
|
||||
}
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
func (c *RedisClient) GetWithType(ctx context.Context, key string, maxItems int, maxStringBytes int) (RedisGetResult, error) {
|
||||
start := time.Now()
|
||||
t, err := c.client.Type(ctx, key).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
if t == "none" {
|
||||
return RedisGetResult{
|
||||
Exists: false,
|
||||
Key: key,
|
||||
Type: "none",
|
||||
Truncated: false,
|
||||
DurationMs: time.Since(start).Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
result := RedisGetResult{Exists: true, Key: key, Type: t}
|
||||
switch t {
|
||||
case "string":
|
||||
v, err := c.client.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
if len(v) > maxStringBytes {
|
||||
result.Value = v[:maxStringBytes]
|
||||
result.Truncated = true
|
||||
} else {
|
||||
result.Value = v
|
||||
}
|
||||
case "list":
|
||||
vals, err := c.client.LRange(ctx, key, 0, int64(maxItems-1)).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
result.Value = vals
|
||||
if length, _ := c.client.LLen(ctx, key).Result(); length > int64(maxItems) {
|
||||
result.Truncated = true
|
||||
}
|
||||
case "set":
|
||||
vals, err := c.client.SMembers(ctx, key).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
if len(vals) > maxItems {
|
||||
result.Value = vals[:maxItems]
|
||||
result.Truncated = true
|
||||
} else {
|
||||
result.Value = vals
|
||||
}
|
||||
case "zset":
|
||||
vals, err := c.client.ZRangeWithScores(ctx, key, 0, int64(maxItems-1)).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
resultRows := make([]map[string]any, 0, len(vals))
|
||||
for _, item := range vals {
|
||||
resultRows = append(resultRows, map[string]any{"member": item.Member, "score": item.Score})
|
||||
}
|
||||
result.Value = resultRows
|
||||
if length, _ := c.client.ZCard(ctx, key).Result(); length > int64(maxItems) {
|
||||
result.Truncated = true
|
||||
}
|
||||
case "hash":
|
||||
vals, err := c.client.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
if len(vals) <= maxItems {
|
||||
result.Value = vals
|
||||
} else {
|
||||
trimmed := make(map[string]string, maxItems)
|
||||
count := 0
|
||||
for k, v := range vals {
|
||||
trimmed[k] = v
|
||||
count++
|
||||
if count >= maxItems {
|
||||
break
|
||||
}
|
||||
}
|
||||
result.Value = trimmed
|
||||
result.Truncated = true
|
||||
}
|
||||
default:
|
||||
raw, err := c.client.Dump(ctx, key).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
result.Value = strings.ToUpper(fmt.Sprintf("UNSUPPORTED_TYPE_%s_DUMP_SIZE_%d", t, len(raw)))
|
||||
}
|
||||
result.DurationMs = time.Since(start).Milliseconds()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *RedisClient) ScanKeys(ctx context.Context, pattern string, count int64, maxKeys int) (RedisScanResult, error) {
|
||||
start := time.Now()
|
||||
if pattern == "" {
|
||||
pattern = "*"
|
||||
}
|
||||
if count <= 0 {
|
||||
count = 20
|
||||
}
|
||||
|
||||
keys := make([]string, 0, maxKeys)
|
||||
var cursor uint64
|
||||
truncated := false
|
||||
for {
|
||||
batch, nextCursor, err := c.client.Scan(ctx, cursor, pattern, count).Result()
|
||||
if err != nil {
|
||||
return RedisScanResult{}, err
|
||||
}
|
||||
for _, key := range batch {
|
||||
if len(keys) >= maxKeys {
|
||||
truncated = true
|
||||
break
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
cursor = nextCursor
|
||||
if truncated || cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return RedisScanResult{
|
||||
Pattern: pattern,
|
||||
Keys: keys,
|
||||
Returned: len(keys),
|
||||
NextCursor: cursor,
|
||||
Truncated: truncated,
|
||||
DurationMs: time.Since(start).Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user