Version: 0.4.4.dev.260307

feat: 🚀 增强会话管理与缓存机制

* 会话 ID 空值兜底,若 `conversation_id` 为空时自动生成 UUID
* 在响应头写入 `X-Conversation-ID`,供前端使用,保持同一会话状态

perf:  会话状态缓存优化

* 当缓存未命中但 DB 已确认/创建会话后,调用 `SetConversationStatus` 回写 Redis
* 缓存写回失败时记录日志,不中断聊天主流程,确保业务流畅性

fix: 🐛 修复历史消息顺序问题与编译错误

* 修复历史消息顺序问题,保证返回的 N 条历史消息按时间正序喂给模型

  * 通过反转 `created_at desc` 查询结果的切片,确保模型输入顺序正确
* 修复 `fmt.Errorf` 参数不匹配问题,修正编译错误
* 整理 `agent-cache.go` 为标准 UTF-8 编码,避免 Go 编译报错 `invalid UTF-8 encoding`

feat: 🛠️ 独立构建 MCP 服务器

* 使用 `Codex` 构建独立于后端的 MCP 服务器,简化与 Codex 的协作
* 通过该服务器方便 Codex 直接测试和查看 Redis 与 MySQL 中的数据
This commit is contained in:
LoveLosita
2026-03-07 15:25:40 +08:00
parent 204e78d1fe
commit 26c350f378
27 changed files with 2274 additions and 17 deletions

View File

@@ -0,0 +1,60 @@
package audit
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
)
type Logger struct {
mu sync.Mutex
file *os.File
}
type Record struct {
Timestamp time.Time `json:"timestamp"`
Tool string `json:"tool"`
Caller string `json:"caller"`
Success bool `json:"success"`
DurationMs int64 `json:"duration_ms"`
Meta map[string]any `json:"meta,omitempty"`
Error string `json:"error,omitempty"`
}
func New(path string) (*Logger, error) {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("create audit dir: %w", err)
}
f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil {
return nil, fmt.Errorf("open audit log file: %w", err)
}
return &Logger{file: f}, nil
}
func (l *Logger) Close() error {
if l == nil || l.file == nil {
return nil
}
return l.file.Close()
}
func (l *Logger) Log(record Record) {
if l == nil || l.file == nil {
return
}
if record.Timestamp.IsZero() {
record.Timestamp = time.Now()
}
body, err := json.Marshal(record)
if err != nil {
return
}
l.mu.Lock()
defer l.mu.Unlock()
_, _ = l.file.Write(append(body, '\n'))
}

View File

@@ -0,0 +1,163 @@
package config
import (
"fmt"
"os"
"strconv"
"strings"
"time"
)
type Config struct {
ServerName string
ServerVersion string
ProtocolVersion string
DefaultCaller string
ToolTimeout time.Duration
RateLimitRPS float64
RateLimitBurst float64
MaxResultRows int
AuditLogPath string
EnforceWhitelist bool
RedisScanMaxKeys int
RedisScanMaxCount int
RedisValueMaxItems int
RedisMaxStringBytes int
MySQL MySQLConfig
Redis RedisConfig
}
type MySQLConfig struct {
Host string
Port int
User string
Password string
Database string
Params string
AllowedDatabases []string
AllowedTables []string
}
type RedisConfig struct {
Addr string
Password string
DB int
}
func LoadFromEnv() (Config, error) {
cfg := Config{
ServerName: getEnv("MCP_SERVER_NAME", "smartflow-mcp-server"),
ServerVersion: getEnv("MCP_SERVER_VERSION", "0.1.0"),
ProtocolVersion: getEnv("MCP_PROTOCOL_VERSION", "2024-11-05"),
DefaultCaller: getEnv("MCP_DEFAULT_CALLER", "unknown"),
ToolTimeout: getEnvDurationMS("MCP_TOOL_TIMEOUT_MS", 5000),
RateLimitRPS: getEnvFloat("MCP_RATE_LIMIT_RPS", 5),
RateLimitBurst: getEnvFloat("MCP_RATE_LIMIT_BURST", 10),
MaxResultRows: getEnvInt("MCP_MAX_RESULT_ROWS", 500),
AuditLogPath: getEnv("MCP_AUDIT_LOG_PATH", "logs/audit.log"),
EnforceWhitelist: getEnvBool("MCP_ENFORCE_WHITELIST", false),
RedisScanMaxKeys: getEnvInt("MCP_REDIS_SCAN_MAX_KEYS", 200),
RedisScanMaxCount: getEnvInt("MCP_REDIS_SCAN_MAX_COUNT", 200),
RedisValueMaxItems: getEnvInt("MCP_REDIS_VALUE_MAX_ITEMS", 100),
RedisMaxStringBytes: getEnvInt("MCP_REDIS_MAX_STRING_BYTES", 4096),
MySQL: MySQLConfig{
Host: getEnv("MYSQL_HOST", "127.0.0.1"),
Port: getEnvInt("MYSQL_PORT", 3306),
User: getEnv("MYSQL_USER", ""),
Password: getEnv("MYSQL_PASSWORD", ""),
Database: getEnv("MYSQL_DATABASE", ""),
Params: getEnv("MYSQL_PARAMS", "charset=utf8mb4&parseTime=true&loc=Local"),
AllowedDatabases: splitCommaList(getEnv("MYSQL_ALLOWED_DATABASES", "")),
AllowedTables: splitCommaList(getEnv("MYSQL_ALLOWED_TABLES", "")),
},
Redis: RedisConfig{
Addr: getEnv("REDIS_ADDR", "127.0.0.1:6379"),
Password: getEnv("REDIS_PASSWORD", ""),
DB: getEnvInt("REDIS_DB", 0),
},
}
if cfg.MySQL.User == "" || cfg.MySQL.Database == "" {
return Config{}, fmt.Errorf("MYSQL_USER and MYSQL_DATABASE are required")
}
if cfg.Redis.Addr == "" {
return Config{}, fmt.Errorf("REDIS_ADDR is required")
}
if cfg.MaxResultRows <= 0 {
return Config{}, fmt.Errorf("MCP_MAX_RESULT_ROWS must be > 0")
}
if cfg.RedisScanMaxKeys <= 0 {
return Config{}, fmt.Errorf("MCP_REDIS_SCAN_MAX_KEYS must be > 0")
}
if cfg.RedisScanMaxCount <= 0 {
return Config{}, fmt.Errorf("MCP_REDIS_SCAN_MAX_COUNT must be > 0")
}
if cfg.ToolTimeout <= 0 {
return Config{}, fmt.Errorf("MCP_TOOL_TIMEOUT_MS must be > 0")
}
if cfg.RateLimitRPS <= 0 || cfg.RateLimitBurst <= 0 {
return Config{}, fmt.Errorf("MCP_RATE_LIMIT_RPS and MCP_RATE_LIMIT_BURST must be > 0")
}
return cfg, nil
}
func getEnv(key string, defaultValue string) string {
if v, ok := os.LookupEnv(key); ok {
return strings.TrimSpace(v)
}
return defaultValue
}
func getEnvInt(key string, defaultValue int) int {
v := getEnv(key, "")
if v == "" {
return defaultValue
}
n, err := strconv.Atoi(v)
if err != nil {
return defaultValue
}
return n
}
func getEnvFloat(key string, defaultValue float64) float64 {
v := getEnv(key, "")
if v == "" {
return defaultValue
}
n, err := strconv.ParseFloat(v, 64)
if err != nil {
return defaultValue
}
return n
}
func getEnvBool(key string, defaultValue bool) bool {
v := strings.ToLower(getEnv(key, ""))
if v == "" {
return defaultValue
}
return v == "1" || v == "true" || v == "yes" || v == "on"
}
func getEnvDurationMS(key string, defaultValueMs int) time.Duration {
ms := getEnvInt(key, defaultValueMs)
return time.Duration(ms) * time.Millisecond
}
func splitCommaList(raw string) []string {
if strings.TrimSpace(raw) == "" {
return nil
}
parts := strings.Split(raw, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
trimmed := strings.TrimSpace(p)
if trimmed != "" {
out = append(out, strings.ToLower(trimmed))
}
}
return out
}

View File

@@ -0,0 +1,44 @@
package envutil
import (
"bufio"
"fmt"
"os"
"strings"
)
func LoadDotEnv(path string) error {
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("open .env: %w", err)
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
i := strings.Index(line, "=")
if i <= 0 {
continue
}
key := strings.TrimSpace(line[:i])
value := strings.TrimSpace(line[i+1:])
value = strings.Trim(value, "\"")
if key == "" {
continue
}
if _, exists := os.LookupEnv(key); !exists {
_ = os.Setenv(key, value)
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("scan .env: %w", err)
}
return nil
}

View File

@@ -0,0 +1,394 @@
package mcp
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"strconv"
"strings"
"sync"
"time"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/audit"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/ratelimit"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/tools"
)
const (
jsonRPCVersion = "2.0"
errCodeParseError = -32700
errCodeInvalidRequest = -32600
errCodeMethodNotFound = -32601
errCodeInvalidParams = -32602
errCodeInternalError = -32603
)
type Server struct {
reader *bufio.Reader
writer io.Writer
writeMu sync.Mutex
registry *tools.Registry
auditLogger *audit.Logger
limiter *ratelimit.Limiter
serverName string
serverVersion string
protocolVersion string
defaultCaller string
toolTimeout time.Duration
}
type request struct {
JSONRPC string `json:"jsonrpc"`
ID json.RawMessage `json:"id,omitempty"`
Method string `json:"method"`
Params json.RawMessage `json:"params,omitempty"`
}
type response struct {
JSONRPC string `json:"jsonrpc"`
ID json.RawMessage `json:"id,omitempty"`
Result any `json:"result,omitempty"`
Error *respError `json:"error,omitempty"`
}
type respError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type initializeParams struct {
ProtocolVersion string `json:"protocolVersion"`
}
type toolCallParams struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
Meta map[string]any `json:"_meta,omitempty"`
}
func NewServer(
in io.Reader,
out io.Writer,
registry *tools.Registry,
auditLogger *audit.Logger,
limiter *ratelimit.Limiter,
serverName string,
serverVersion string,
protocolVersion string,
defaultCaller string,
toolTimeout time.Duration,
) *Server {
return &Server{
reader: bufio.NewReader(in),
writer: out,
registry: registry,
auditLogger: auditLogger,
limiter: limiter,
serverName: serverName,
serverVersion: serverVersion,
protocolVersion: protocolVersion,
defaultCaller: defaultCaller,
toolTimeout: toolTimeout,
}
}
func (s *Server) Serve(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
default:
}
body, err := readMessage(s.reader)
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
log.Printf("read message failed: %v", err)
continue
}
var req request
if err := json.Unmarshal(body, &req); err != nil {
_ = s.writeError(nil, errCodeParseError, "invalid json")
continue
}
if req.Method == "" || req.JSONRPC != jsonRPCVersion {
_ = s.writeError(req.ID, errCodeInvalidRequest, "invalid json-rpc request")
continue
}
if err := s.handleRequest(ctx, req); err != nil {
log.Printf("handle request failed: %v", err)
}
}
}
func (s *Server) handleRequest(ctx context.Context, req request) error {
switch req.Method {
case "initialize":
return s.handleInitialize(req)
case "notifications/initialized":
return nil
case "tools/list":
if len(req.ID) == 0 {
return nil
}
return s.writeResult(req.ID, map[string]any{"tools": s.registry.List()})
case "tools/call":
if len(req.ID) == 0 {
return nil
}
return s.handleToolCall(ctx, req)
case "ping":
if len(req.ID) == 0 {
return nil
}
return s.writeResult(req.ID, map[string]any{})
default:
if len(req.ID) == 0 {
return nil
}
return s.writeError(req.ID, errCodeMethodNotFound, "method not found")
}
}
func (s *Server) handleInitialize(req request) error {
if len(req.ID) == 0 {
return nil
}
var params initializeParams
if len(req.Params) > 0 {
_ = json.Unmarshal(req.Params, &params)
}
_ = params
return s.writeResult(req.ID, map[string]any{
"protocolVersion": s.protocolVersion,
"capabilities": map[string]any{
"tools": map[string]any{"listChanged": false},
},
"serverInfo": map[string]any{
"name": s.serverName,
"version": s.serverVersion,
},
})
}
func (s *Server) handleToolCall(ctx context.Context, req request) error {
var params toolCallParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return s.writeError(req.ID, errCodeInvalidParams, "invalid tool call params")
}
if params.Name == "" {
return s.writeError(req.ID, errCodeInvalidParams, "tool name is required")
}
if params.Arguments == nil {
params.Arguments = map[string]any{}
}
caller := extractCaller(params, s.defaultCaller)
rateKey := fmt.Sprintf("%s:%s", caller, params.Name)
if !s.limiter.Allow(rateKey) {
result := buildToolErrorResult("rate limit exceeded")
s.auditLogger.Log(audit.Record{
Tool: params.Name,
Caller: caller,
Success: false,
DurationMs: 0,
Meta: sanitizeAuditInput(params.Name, params.Arguments),
Error: "rate limit exceeded",
})
return s.writeResult(req.ID, result)
}
tool, ok := s.registry.Find(params.Name)
if !ok {
result := buildToolErrorResult("tool not found")
s.auditLogger.Log(audit.Record{
Tool: params.Name,
Caller: caller,
Success: false,
DurationMs: 0,
Meta: sanitizeAuditInput(params.Name, params.Arguments),
Error: "tool not found",
})
return s.writeResult(req.ID, result)
}
start := time.Now()
toolCtx, cancel := context.WithTimeout(ctx, s.toolTimeout)
defer cancel()
output, err := tool.Execute(toolCtx, params.Arguments)
duration := time.Since(start).Milliseconds()
if err != nil {
errMsg := sanitizeError(err)
s.auditLogger.Log(audit.Record{
Tool: params.Name,
Caller: caller,
Success: false,
DurationMs: duration,
Meta: sanitizeAuditInput(params.Name, params.Arguments),
Error: errMsg,
})
return s.writeResult(req.ID, buildToolErrorResult(errMsg))
}
s.auditLogger.Log(audit.Record{
Tool: params.Name,
Caller: caller,
Success: true,
DurationMs: duration,
Meta: sanitizeAuditInput(params.Name, params.Arguments),
})
return s.writeResult(req.ID, buildToolSuccessResult(output))
}
func buildToolSuccessResult(output map[string]any) map[string]any {
asJSON, _ := json.Marshal(output)
return map[string]any{
"content": []map[string]any{
{"type": "text", "text": string(asJSON)},
},
"structuredContent": output,
}
}
func buildToolErrorResult(message string) map[string]any {
return map[string]any{
"content": []map[string]any{
{"type": "text", "text": message},
},
"structuredContent": map[string]any{"error": message},
"isError": true,
}
}
func extractCaller(params toolCallParams, fallback string) string {
if params.Meta != nil {
if v, ok := params.Meta["caller"]; ok {
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
return strings.TrimSpace(s)
}
}
}
if v, ok := params.Arguments["caller"]; ok {
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
return strings.TrimSpace(s)
}
}
return fallback
}
func sanitizeAuditInput(toolName string, args map[string]any) map[string]any {
meta := map[string]any{}
switch toolName {
case "mysql_query_readonly":
sql, _ := args["sql"].(string)
meta["sql"] = security.RedactSQL(sql)
if params, ok := args["params"].([]any); ok {
meta["paramCount"] = len(params)
}
case "redis_get":
if key, ok := args["key"].(string); ok {
meta["key"] = security.RedactKey(key)
}
case "redis_scan":
if pattern, ok := args["pattern"].(string); ok {
if len(pattern) > 40 {
meta["pattern"] = pattern[:40]
} else {
meta["pattern"] = pattern
}
}
if count, ok := args["count"]; ok {
meta["count"] = count
}
default:
meta["args"] = "masked"
}
return meta
}
func sanitizeError(err error) string {
msg := err.Error()
msg = strings.ReplaceAll(msg, "\n", " ")
msg = strings.ReplaceAll(msg, "\r", " ")
if len(msg) > 300 {
return msg[:300]
}
return msg
}
func (s *Server) writeResult(id json.RawMessage, result any) error {
resp := response{JSONRPC: jsonRPCVersion, ID: id, Result: result}
return s.writeMessage(resp)
}
func (s *Server) writeError(id json.RawMessage, code int, message string) error {
resp := response{
JSONRPC: jsonRPCVersion,
ID: id,
Error: &respError{Code: code, Message: message},
}
return s.writeMessage(resp)
}
func (s *Server) writeMessage(payload any) error {
body, err := json.Marshal(payload)
if err != nil {
return err
}
header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(body))
s.writeMu.Lock()
defer s.writeMu.Unlock()
if _, err := s.writer.Write([]byte(header)); err != nil {
return err
}
_, err = s.writer.Write(body)
return err
}
func readMessage(reader *bufio.Reader) ([]byte, error) {
contentLength := -1
for {
line, err := reader.ReadString('\n')
if err != nil {
return nil, err
}
line = strings.TrimRight(line, "\r\n")
if line == "" {
break
}
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
headerName := strings.TrimSpace(parts[0])
headerValue := strings.TrimSpace(parts[1])
if strings.EqualFold(headerName, "Content-Length") {
n, err := strconv.Atoi(headerValue)
if err != nil || n < 0 {
return nil, fmt.Errorf("invalid Content-Length")
}
contentLength = n
}
}
if contentLength < 0 {
return nil, fmt.Errorf("missing Content-Length")
}
body := make([]byte, contentLength)
if _, err := io.ReadFull(reader, body); err != nil {
return nil, err
}
return body, nil
}

View File

@@ -0,0 +1,56 @@
package mcp
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"testing"
"time"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/ratelimit"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/tools"
)
type dummyTool struct{}
func (d dummyTool) Name() string { return "dummy" }
func (d dummyTool) Description() string { return "dummy" }
func (d dummyTool) InputSchema() map[string]any {
return map[string]any{"type": "object"}
}
func (d dummyTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
return map[string]any{"ok": true}, nil
}
func TestHandleInitialize(t *testing.T) {
registry, err := tools.NewRegistry(dummyTool{})
if err != nil {
t.Fatal(err)
}
out := bytes.NewBuffer(nil)
s := NewServer(bytes.NewBuffer(nil), out, registry, nil, ratelimit.New(10, 10), "name", "ver", "2024-11-05", "unknown", time.Second)
req := request{JSONRPC: jsonRPCVersion, ID: json.RawMessage("1"), Method: "initialize", Params: json.RawMessage(`{"protocolVersion":"2024-11-05"}`)}
if err := s.handleRequest(context.Background(), req); err != nil {
t.Fatalf("handle initialize error: %v", err)
}
if !bytes.Contains(out.Bytes(), []byte("protocolVersion")) {
t.Fatalf("response missing protocolVersion: %s", out.String())
}
}
func TestReadMessage(t *testing.T) {
body := []byte(`{"jsonrpc":"2.0","id":1,"method":"ping"}`)
frame := fmt.Sprintf("Content-Length: %d\r\n\r\n%s", len(body), string(body))
input := bufio.NewReader(bytes.NewBufferString(frame))
msg, err := readMessage(input)
if err != nil {
t.Fatalf("read message failed: %v", err)
}
if !bytes.Equal(msg, body) {
t.Fatalf("unexpected body: %s", string(msg))
}
}

View File

@@ -0,0 +1,51 @@
package ratelimit
import (
"sync"
"time"
)
type Limiter struct {
mu sync.Mutex
rate float64
burst float64
buckets map[string]*bucket
}
type bucket struct {
tokens float64
last time.Time
}
func New(rate, burst float64) *Limiter {
return &Limiter{
rate: rate,
burst: burst,
buckets: make(map[string]*bucket),
}
}
func (l *Limiter) Allow(key string) bool {
now := time.Now()
l.mu.Lock()
defer l.mu.Unlock()
b, ok := l.buckets[key]
if !ok {
l.buckets[key] = &bucket{tokens: l.burst - 1, last: now}
return true
}
elapsed := now.Sub(b.last).Seconds()
b.tokens += elapsed * l.rate
if b.tokens > l.burst {
b.tokens = l.burst
}
b.last = now
if b.tokens < 1 {
return false
}
b.tokens -= 1
return true
}

View File

@@ -0,0 +1,26 @@
package ratelimit
import (
"testing"
"time"
)
func TestLimiter(t *testing.T) {
l := New(2, 2)
key := "user:tool"
if !l.Allow(key) {
t.Fatal("first request should pass")
}
if !l.Allow(key) {
t.Fatal("second request should pass")
}
if l.Allow(key) {
t.Fatal("third request should be rate limited")
}
time.Sleep(600 * time.Millisecond)
if !l.Allow(key) {
t.Fatal("request should pass after token refill")
}
}

View File

@@ -0,0 +1,27 @@
package security
import (
"regexp"
"strings"
)
var (
singleQuotedString = regexp.MustCompile(`'([^'\\]|\\.)*'`)
doubleQuotedString = regexp.MustCompile(`"([^"\\]|\\.)*"`)
numericLiteral = regexp.MustCompile(`\b\d+\b`)
)
func RedactSQL(sql string) string {
masked := singleQuotedString.ReplaceAllString(sql, "'***'")
masked = doubleQuotedString.ReplaceAllString(masked, `"***"`)
masked = numericLiteral.ReplaceAllString(masked, "?")
return strings.TrimSpace(masked)
}
func RedactKey(key string) string {
key = strings.TrimSpace(key)
if len(key) <= 4 {
return "****"
}
return key[:2] + "***" + key[len(key)-2:]
}

View File

@@ -0,0 +1,193 @@
package security
import (
"fmt"
"regexp"
"strings"
)
var (
commentPattern = regexp.MustCompile(`(?s)/\*.*?\*/|--|#`)
forbiddenWords = regexp.MustCompile(`(?i)\b(INSERT|UPDATE|DELETE|ALTER|DROP|TRUNCATE|CREATE|REPLACE|RENAME|GRANT|REVOKE|MERGE|CALL|EXEC|LOCK|UNLOCK|LOAD|OUTFILE|INFILE|HANDLER|SET|USE)\b`)
fromJoinRef = regexp.MustCompile(`(?i)\b(?:FROM|JOIN)\s+([` + "`" + `"\w\.]+)`)
describeRef = regexp.MustCompile(`(?i)\b(?:DESCRIBE|DESC)\s+([` + "`" + `"\w\.]+)`)
showFromRef = regexp.MustCompile(`(?i)\bSHOW\b[\w\s]*\b(?:FROM|IN)\s+([` + "`" + `"\w\.]+)`)
)
type SQLValidator struct {
defaultDatabase string
enforceWhitelist bool
allowedDatabases map[string]struct{}
allowedTables map[string]struct{}
allowedTableOnly map[string]struct{}
}
type sqlObjectRef struct {
database string
table string
}
func NewSQLValidator(defaultDatabase string, enforceWhitelist bool, allowedDatabases []string, allowedTables []string) *SQLValidator {
v := &SQLValidator{
defaultDatabase: strings.ToLower(strings.TrimSpace(defaultDatabase)),
enforceWhitelist: enforceWhitelist,
allowedDatabases: make(map[string]struct{}),
allowedTables: make(map[string]struct{}),
allowedTableOnly: make(map[string]struct{}),
}
for _, db := range allowedDatabases {
db = strings.ToLower(strings.TrimSpace(db))
if db != "" {
v.allowedDatabases[db] = struct{}{}
}
}
for _, tbl := range allowedTables {
tbl = strings.ToLower(strings.TrimSpace(strings.Trim(tbl, "`\"")))
if tbl == "" {
continue
}
v.allowedTables[tbl] = struct{}{}
if dot := strings.LastIndex(tbl, "."); dot >= 0 && dot < len(tbl)-1 {
v.allowedTableOnly[tbl[dot+1:]] = struct{}{}
} else {
v.allowedTableOnly[tbl] = struct{}{}
}
}
return v
}
func (v *SQLValidator) ValidateReadOnlySQL(sql string) error {
trimmed := strings.TrimSpace(sql)
if trimmed == "" {
return fmt.Errorf("sql is required")
}
if strings.Contains(trimmed, ";") {
return fmt.Errorf("semicolon is not allowed")
}
if commentPattern.MatchString(trimmed) {
return fmt.Errorf("sql comments are not allowed")
}
first := strings.ToUpper(firstToken(trimmed))
switch first {
case "SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN":
default:
return fmt.Errorf("only SELECT/SHOW/DESCRIBE/EXPLAIN are allowed")
}
withoutLiterals := removeStringLiterals(trimmed)
if forbiddenWords.MatchString(strings.ToUpper(withoutLiterals)) {
return fmt.Errorf("dangerous sql keyword detected")
}
refs := extractObjectRefs(withoutLiterals)
if err := v.validateWhitelist(refs); err != nil {
return err
}
return nil
}
func (v *SQLValidator) validateWhitelist(refs []sqlObjectRef) error {
hasDBAllowlist := len(v.allowedDatabases) > 0
hasTableAllowlist := len(v.allowedTables) > 0
if !v.enforceWhitelist && !hasDBAllowlist && !hasTableAllowlist {
return nil
}
if len(refs) == 0 {
if v.enforceWhitelist {
return fmt.Errorf("sql does not contain explicit table reference under whitelist mode")
}
return nil
}
for _, ref := range refs {
db := ref.database
tbl := ref.table
if db == "" {
db = v.defaultDatabase
}
if hasDBAllowlist {
if db == "" {
return fmt.Errorf("database is not explicit and no default database set")
}
if _, ok := v.allowedDatabases[db]; !ok {
return fmt.Errorf("database %s is not in whitelist", db)
}
}
if hasTableAllowlist {
full := tbl
if db != "" {
full = db + "." + tbl
}
if _, ok := v.allowedTables[full]; ok {
continue
}
if _, ok := v.allowedTableOnly[tbl]; ok {
continue
}
return fmt.Errorf("table %s is not in whitelist", full)
}
}
return nil
}
func extractObjectRefs(sql string) []sqlObjectRef {
refs := make([]sqlObjectRef, 0)
seen := make(map[string]struct{})
for _, re := range []*regexp.Regexp{fromJoinRef, describeRef, showFromRef} {
matches := re.FindAllStringSubmatch(sql, -1)
for _, m := range matches {
if len(m) < 2 {
continue
}
ref := normalizeRef(m[1])
if ref.table == "" {
continue
}
key := ref.database + "." + ref.table
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
refs = append(refs, ref)
}
}
return refs
}
func normalizeRef(raw string) sqlObjectRef {
clean := strings.ToLower(strings.TrimSpace(raw))
clean = strings.Trim(clean, "`\"'(),")
clean = strings.TrimSpace(clean)
if clean == "" {
return sqlObjectRef{}
}
parts := strings.Split(clean, ".")
if len(parts) == 1 {
return sqlObjectRef{table: parts[0]}
}
return sqlObjectRef{database: parts[0], table: parts[len(parts)-1]}
}
func firstToken(sql string) string {
for i, r := range sql {
if r == ' ' || r == '\n' || r == '\t' || r == '\r' {
if i == 0 {
continue
}
return sql[:i]
}
}
return sql
}
func removeStringLiterals(sql string) string {
masked := singleQuotedString.ReplaceAllString(sql, "''")
masked = doubleQuotedString.ReplaceAllString(masked, `""`)
return masked
}

View File

@@ -0,0 +1,44 @@
package security
import "testing"
func TestValidateReadOnlySQL(t *testing.T) {
validator := NewSQLValidator("smartflow", true, []string{"smartflow"}, []string{"smartflow.users", "smartflow.tasks"})
tests := []struct {
name string
sql string
wantErr bool
}{
{name: "allow select", sql: "SELECT id, name FROM users WHERE id = 1", wantErr: false},
{name: "allow explain", sql: "EXPLAIN SELECT * FROM tasks", wantErr: false},
{name: "reject insert", sql: "INSERT INTO users(name) VALUES('x')", wantErr: true},
{name: "reject multi statement", sql: "SELECT * FROM users; SELECT * FROM tasks", wantErr: true},
{name: "reject comment", sql: "SELECT * FROM users -- bypass", wantErr: true},
{name: "reject not whitelisted table", sql: "SELECT * FROM orders", wantErr: true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := validator.ValidateReadOnlySQL(tc.sql)
if tc.wantErr && err == nil {
t.Fatalf("expected error, got nil")
}
if !tc.wantErr && err != nil {
t.Fatalf("expected no error, got %v", err)
}
})
}
}
func TestRedact(t *testing.T) {
masked := RedactSQL("SELECT * FROM users WHERE token='abc123' AND id=42")
if masked == "" || masked == "SELECT * FROM users WHERE token='abc123' AND id=42" {
t.Fatalf("redaction not applied: %s", masked)
}
key := RedactKey("very-sensitive-key")
if key == "very-sensitive-key" {
t.Fatalf("key not redacted")
}
}

View File

@@ -0,0 +1,198 @@
package store
import (
"context"
"database/sql"
"fmt"
"net/url"
"reflect"
"strconv"
"strings"
"time"
cfgpkg "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
"github.com/go-sql-driver/mysql"
)
type MySQLClient struct {
db *sql.DB
}
type QueryColumn struct {
Name string `json:"name"`
DatabaseType string `json:"databaseType"`
Nullable *bool `json:"nullable,omitempty"`
ScanType string `json:"scanType,omitempty"`
}
type QueryResult struct {
Columns []QueryColumn `json:"columns"`
Rows []map[string]any `json:"rows"`
RowCount int `json:"rowCount"`
Truncated bool `json:"truncated"`
DurationMs int64 `json:"durationMs"`
}
func NewMySQLClient(ctx context.Context, cfg cfgpkg.MySQLConfig) (*MySQLClient, error) {
dsn := mysql.Config{
User: cfg.User,
Passwd: cfg.Password,
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Net: "tcp",
DBName: cfg.Database,
AllowNativePasswords: true,
ParseTime: true,
}
applyMySQLParams(&dsn, cfg.Params)
db, err := sql.Open("mysql", dsn.FormatDSN())
if err != nil {
return nil, fmt.Errorf("open mysql: %w", err)
}
db.SetConnMaxLifetime(5 * time.Minute)
db.SetMaxOpenConns(5)
db.SetMaxIdleConns(5)
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
if err := db.PingContext(pingCtx); err != nil {
_ = db.Close()
return nil, fmt.Errorf("ping mysql: %w", err)
}
return &MySQLClient{db: db}, nil
}
func applyMySQLParams(cfg *mysql.Config, raw string) {
if strings.TrimSpace(raw) == "" {
return
}
values, err := url.ParseQuery(raw)
if err != nil {
return
}
if cfg.Params == nil {
cfg.Params = make(map[string]string)
}
for key, valueList := range values {
if len(valueList) == 0 {
continue
}
value := valueList[0]
switch strings.ToLower(key) {
case "parsetime":
if parsed, err := strconv.ParseBool(value); err == nil {
cfg.ParseTime = parsed
}
case "loc":
if loc, err := time.LoadLocation(value); err == nil {
cfg.Loc = loc
}
case "collation":
cfg.Collation = value
default:
cfg.Params[key] = value
}
}
}
func (c *MySQLClient) Close() error {
if c == nil || c.db == nil {
return nil
}
return c.db.Close()
}
func (c *MySQLClient) QueryReadOnly(ctx context.Context, query string, args []any, maxRows int) (QueryResult, error) {
start := time.Now()
rows, err := c.db.QueryContext(ctx, query, args...)
if err != nil {
return QueryResult{}, err
}
defer rows.Close()
columnTypes, err := rows.ColumnTypes()
if err != nil {
return QueryResult{}, err
}
columns := make([]QueryColumn, 0, len(columnTypes))
columnNames := make([]string, 0, len(columnTypes))
for _, ct := range columnTypes {
var nullablePtr *bool
if nullable, ok := ct.Nullable(); ok {
n := nullable
nullablePtr = &n
}
scanType := ""
if st := ct.ScanType(); st != nil {
scanType = st.String()
}
columns = append(columns, QueryColumn{
Name: ct.Name(),
DatabaseType: ct.DatabaseTypeName(),
Nullable: nullablePtr,
ScanType: scanType,
})
columnNames = append(columnNames, ct.Name())
}
resultRows := make([]map[string]any, 0)
truncated := false
for rows.Next() {
if len(resultRows) >= maxRows {
truncated = true
break
}
scanned, err := scanRow(rows, len(columnNames))
if err != nil {
return QueryResult{}, err
}
rowMap := make(map[string]any, len(columnNames))
for i, name := range columnNames {
rowMap[name] = normalizeValue(scanned[i])
}
resultRows = append(resultRows, rowMap)
}
if err := rows.Err(); err != nil {
return QueryResult{}, err
}
return QueryResult{
Columns: columns,
Rows: resultRows,
RowCount: len(resultRows),
Truncated: truncated,
DurationMs: time.Since(start).Milliseconds(),
}, nil
}
func scanRow(rows *sql.Rows, size int) ([]any, error) {
dest := make([]any, size)
holders := make([]any, size)
for i := range dest {
holders[i] = &dest[i]
}
if err := rows.Scan(holders...); err != nil {
return nil, err
}
return dest, nil
}
func normalizeValue(v any) any {
switch val := v.(type) {
case nil:
return nil
case []byte:
return string(val)
case time.Time:
return val.Format(time.RFC3339Nano)
default:
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr && rv.IsNil() {
return nil
}
return v
}
}

View File

@@ -0,0 +1,189 @@
package store
import (
"context"
"fmt"
"strings"
"time"
cfgpkg "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
"github.com/go-redis/redis/v8"
)
type RedisClient struct {
client *redis.Client
}
type RedisGetResult struct {
Exists bool `json:"exists"`
Key string `json:"key"`
Type string `json:"type"`
Value any `json:"value,omitempty"`
Truncated bool `json:"truncated"`
DurationMs int64 `json:"durationMs"`
}
type RedisScanResult struct {
Pattern string `json:"pattern"`
Keys []string `json:"keys"`
Returned int `json:"returned"`
NextCursor uint64 `json:"nextCursor"`
Truncated bool `json:"truncated"`
DurationMs int64 `json:"durationMs"`
}
func NewRedisClient(ctx context.Context, cfg cfgpkg.RedisConfig) (*RedisClient, error) {
client := redis.NewClient(&redis.Options{
Addr: cfg.Addr,
Password: cfg.Password,
DB: cfg.DB,
})
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
if err := client.Ping(pingCtx).Err(); err != nil {
_ = client.Close()
return nil, fmt.Errorf("ping redis: %w", err)
}
return &RedisClient{client: client}, nil
}
func (c *RedisClient) Close() error {
if c == nil || c.client == nil {
return nil
}
return c.client.Close()
}
func (c *RedisClient) GetWithType(ctx context.Context, key string, maxItems int, maxStringBytes int) (RedisGetResult, error) {
start := time.Now()
t, err := c.client.Type(ctx, key).Result()
if err != nil {
return RedisGetResult{}, err
}
if t == "none" {
return RedisGetResult{
Exists: false,
Key: key,
Type: "none",
Truncated: false,
DurationMs: time.Since(start).Milliseconds(),
}, nil
}
result := RedisGetResult{Exists: true, Key: key, Type: t}
switch t {
case "string":
v, err := c.client.Get(ctx, key).Result()
if err != nil {
return RedisGetResult{}, err
}
if len(v) > maxStringBytes {
result.Value = v[:maxStringBytes]
result.Truncated = true
} else {
result.Value = v
}
case "list":
vals, err := c.client.LRange(ctx, key, 0, int64(maxItems-1)).Result()
if err != nil {
return RedisGetResult{}, err
}
result.Value = vals
if length, _ := c.client.LLen(ctx, key).Result(); length > int64(maxItems) {
result.Truncated = true
}
case "set":
vals, err := c.client.SMembers(ctx, key).Result()
if err != nil {
return RedisGetResult{}, err
}
if len(vals) > maxItems {
result.Value = vals[:maxItems]
result.Truncated = true
} else {
result.Value = vals
}
case "zset":
vals, err := c.client.ZRangeWithScores(ctx, key, 0, int64(maxItems-1)).Result()
if err != nil {
return RedisGetResult{}, err
}
resultRows := make([]map[string]any, 0, len(vals))
for _, item := range vals {
resultRows = append(resultRows, map[string]any{"member": item.Member, "score": item.Score})
}
result.Value = resultRows
if length, _ := c.client.ZCard(ctx, key).Result(); length > int64(maxItems) {
result.Truncated = true
}
case "hash":
vals, err := c.client.HGetAll(ctx, key).Result()
if err != nil {
return RedisGetResult{}, err
}
if len(vals) <= maxItems {
result.Value = vals
} else {
trimmed := make(map[string]string, maxItems)
count := 0
for k, v := range vals {
trimmed[k] = v
count++
if count >= maxItems {
break
}
}
result.Value = trimmed
result.Truncated = true
}
default:
raw, err := c.client.Dump(ctx, key).Result()
if err != nil {
return RedisGetResult{}, err
}
result.Value = strings.ToUpper(fmt.Sprintf("UNSUPPORTED_TYPE_%s_DUMP_SIZE_%d", t, len(raw)))
}
result.DurationMs = time.Since(start).Milliseconds()
return result, nil
}
func (c *RedisClient) ScanKeys(ctx context.Context, pattern string, count int64, maxKeys int) (RedisScanResult, error) {
start := time.Now()
if pattern == "" {
pattern = "*"
}
if count <= 0 {
count = 20
}
keys := make([]string, 0, maxKeys)
var cursor uint64
truncated := false
for {
batch, nextCursor, err := c.client.Scan(ctx, cursor, pattern, count).Result()
if err != nil {
return RedisScanResult{}, err
}
for _, key := range batch {
if len(keys) >= maxKeys {
truncated = true
break
}
keys = append(keys, key)
}
cursor = nextCursor
if truncated || cursor == 0 {
break
}
}
return RedisScanResult{
Pattern: pattern,
Keys: keys,
Returned: len(keys),
NextCursor: cursor,
Truncated: truncated,
DurationMs: time.Since(start).Milliseconds(),
}, nil
}

View File

@@ -0,0 +1,95 @@
package tools
import (
"context"
"os"
"strconv"
"testing"
"time"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store"
)
func TestIntegrationMySQLReadOnlyTool(t *testing.T) {
if os.Getenv("MCP_IT_RUN") != "1" {
t.Skip("set MCP_IT_RUN=1 to run integration tests")
}
port := 3306
if p := os.Getenv("MYSQL_PORT"); p != "" {
if n, err := strconv.Atoi(p); err == nil {
port = n
}
}
mysqlCfg := config.MySQLConfig{
Host: os.Getenv("MYSQL_HOST"),
Port: port,
User: os.Getenv("MYSQL_USER"),
Password: os.Getenv("MYSQL_PASSWORD"),
Database: os.Getenv("MYSQL_DATABASE"),
Params: "charset=utf8mb4&parseTime=true&loc=Local",
}
if mysqlCfg.Host == "" || mysqlCfg.User == "" || mysqlCfg.Database == "" {
t.Skip("missing MYSQL_* env for integration test")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
client, err := store.NewMySQLClient(ctx, mysqlCfg)
if err != nil {
t.Fatalf("mysql not available: %v", err)
}
defer func() { _ = client.Close() }()
validator := security.NewSQLValidator(mysqlCfg.Database, false, nil, nil)
tool := NewMySQLReadOnlyTool(client, validator, 10)
res, err := tool.Execute(ctx, map[string]any{"sql": "SELECT 1 AS ok"})
if err != nil {
t.Fatalf("tool execute failed: %v", err)
}
if res["rowCount"].(int) < 1 {
t.Fatalf("expected rowCount >= 1")
}
}
func TestIntegrationRedisTools(t *testing.T) {
if os.Getenv("MCP_IT_RUN") != "1" {
t.Skip("set MCP_IT_RUN=1 to run integration tests")
}
db := 0
if p := os.Getenv("REDIS_DB"); p != "" {
if n, err := strconv.Atoi(p); err == nil {
db = n
}
}
redisCfg := config.RedisConfig{
Addr: os.Getenv("REDIS_ADDR"),
Password: os.Getenv("REDIS_PASSWORD"),
DB: db,
}
if redisCfg.Addr == "" {
t.Skip("REDIS_ADDR is empty")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
client, err := store.NewRedisClient(ctx, redisCfg)
if err != nil {
t.Fatalf("redis not available: %v", err)
}
defer func() { _ = client.Close() }()
getTool := NewRedisGetTool(client, 10, 128)
if _, err := getTool.Execute(ctx, map[string]any{"key": "__integration_missing_key__"}); err != nil {
t.Fatalf("redis_get failed: %v", err)
}
scanTool := NewRedisScanTool(client, 10, 10)
if _, err := scanTool.Execute(ctx, map[string]any{"pattern": "*", "count": float64(5)}); err != nil {
t.Fatalf("redis_scan failed: %v", err)
}
}

View File

@@ -0,0 +1,95 @@
package tools
import (
"context"
"fmt"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store"
)
type MySQLReadOnlyTool struct {
client *store.MySQLClient
validator *security.SQLValidator
maxRows int
}
func NewMySQLReadOnlyTool(client *store.MySQLClient, validator *security.SQLValidator, maxRows int) *MySQLReadOnlyTool {
return &MySQLReadOnlyTool{client: client, validator: validator, maxRows: maxRows}
}
func (t *MySQLReadOnlyTool) Name() string {
return "mysql_query_readonly"
}
func (t *MySQLReadOnlyTool) Description() string {
return "Execute read-only SQL on MySQL. Only SELECT/SHOW/DESCRIBE/EXPLAIN are allowed."
}
func (t *MySQLReadOnlyTool) InputSchema() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"sql": map[string]any{
"type": "string",
"description": "Read-only SQL statement",
},
"params": map[string]any{
"type": "array",
"description": "Optional bind parameters",
"items": map[string]any{
"type": []string{"string", "number", "boolean", "null"},
},
},
},
"required": []string{"sql"},
"additionalProperties": false,
}
}
func (t *MySQLReadOnlyTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
rawSQL, ok := args["sql"].(string)
if !ok || rawSQL == "" {
return nil, fmt.Errorf("sql must be a non-empty string")
}
if err := t.validator.ValidateReadOnlySQL(rawSQL); err != nil {
return nil, err
}
params, err := normalizeParams(args["params"])
if err != nil {
return nil, err
}
res, err := t.client.QueryReadOnly(ctx, rawSQL, params, t.maxRows)
if err != nil {
return nil, err
}
return map[string]any{
"columns": res.Columns,
"rows": res.Rows,
"rowCount": res.RowCount,
"truncated": res.Truncated,
"durationMs": res.DurationMs,
}, nil
}
func normalizeParams(raw any) ([]any, error) {
if raw == nil {
return nil, nil
}
arr, ok := raw.([]any)
if !ok {
return nil, fmt.Errorf("params must be an array")
}
out := make([]any, 0, len(arr))
for _, item := range arr {
switch v := item.(type) {
case string, float64, bool, nil:
out = append(out, v)
default:
return nil, fmt.Errorf("params contains unsupported type")
}
}
return out, nil
}

View File

@@ -0,0 +1,130 @@
package tools
import (
"context"
"fmt"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store"
)
type RedisGetTool struct {
client *store.RedisClient
valueMaxItems int
maxStringBytes int
}
func NewRedisGetTool(client *store.RedisClient, valueMaxItems int, maxStringBytes int) *RedisGetTool {
return &RedisGetTool{client: client, valueMaxItems: valueMaxItems, maxStringBytes: maxStringBytes}
}
func (t *RedisGetTool) Name() string {
return "redis_get"
}
func (t *RedisGetTool) Description() string {
return "Get a Redis key by name and return its type and value."
}
func (t *RedisGetTool) InputSchema() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"key": map[string]any{
"type": "string",
"description": "Redis key",
},
},
"required": []string{"key"},
"additionalProperties": false,
}
}
func (t *RedisGetTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
key, ok := args["key"].(string)
if !ok || key == "" {
return nil, fmt.Errorf("key must be a non-empty string")
}
res, err := t.client.GetWithType(ctx, key, t.valueMaxItems, t.maxStringBytes)
if err != nil {
return nil, err
}
return map[string]any{
"exists": res.Exists,
"key": res.Key,
"type": res.Type,
"value": res.Value,
"truncated": res.Truncated,
"durationMs": res.DurationMs,
}, nil
}
type RedisScanTool struct {
client *store.RedisClient
maxKeys int
maxScanCount int
}
func NewRedisScanTool(client *store.RedisClient, maxKeys int, maxScanCount int) *RedisScanTool {
return &RedisScanTool{client: client, maxKeys: maxKeys, maxScanCount: maxScanCount}
}
func (t *RedisScanTool) Name() string {
return "redis_scan"
}
func (t *RedisScanTool) Description() string {
return "Scan Redis keys by pattern with capped result size."
}
func (t *RedisScanTool) InputSchema() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"pattern": map[string]any{
"type": "string",
"description": "Pattern, for example user:*",
},
"count": map[string]any{
"type": "number",
"description": "Optional scan count hint",
},
},
"required": []string{"pattern"},
"additionalProperties": false,
}
}
func (t *RedisScanTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
pattern, ok := args["pattern"].(string)
if !ok || pattern == "" {
return nil, fmt.Errorf("pattern must be a non-empty string")
}
count := int64(20)
if rawCount, ok := args["count"]; ok {
number, ok := rawCount.(float64)
if !ok {
return nil, fmt.Errorf("count must be a number")
}
if number <= 0 {
return nil, fmt.Errorf("count must be > 0")
}
count = int64(number)
}
if count > int64(t.maxScanCount) {
count = int64(t.maxScanCount)
}
res, err := t.client.ScanKeys(ctx, pattern, count, t.maxKeys)
if err != nil {
return nil, err
}
return map[string]any{
"pattern": res.Pattern,
"keys": res.Keys,
"returned": res.Returned,
"nextCursor": res.NextCursor,
"truncated": res.Truncated,
"durationMs": res.DurationMs,
}, nil
}

View File

@@ -0,0 +1,53 @@
package tools
import (
"context"
"fmt"
"sort"
)
type Tool interface {
Name() string
Description() string
InputSchema() map[string]any
Execute(ctx context.Context, args map[string]any) (map[string]any, error)
}
type Registry struct {
tools map[string]Tool
}
func NewRegistry(toolList ...Tool) (*Registry, error) {
r := &Registry{tools: make(map[string]Tool, len(toolList))}
for _, t := range toolList {
name := t.Name()
if _, exists := r.tools[name]; exists {
return nil, fmt.Errorf("duplicated tool name: %s", name)
}
r.tools[name] = t
}
return r, nil
}
func (r *Registry) Find(name string) (Tool, bool) {
t, ok := r.tools[name]
return t, ok
}
func (r *Registry) List() []map[string]any {
out := make([]map[string]any, 0, len(r.tools))
names := make([]string, 0, len(r.tools))
for name := range r.tools {
names = append(names, name)
}
sort.Strings(names)
for _, name := range names {
t := r.tools[name]
out = append(out, map[string]any{
"name": t.Name(),
"description": t.Description(),
"inputSchema": t.InputSchema(),
})
}
return out
}