diff --git a/backend/api/agent.go b/backend/api/agent.go index 7322804..b3de0f0 100644 --- a/backend/api/agent.go +++ b/backend/api/agent.go @@ -3,11 +3,13 @@ package api import ( "io" "net/http" + "strings" "github.com/LoveLosita/smartflow/backend/model" "github.com/LoveLosita/smartflow/backend/respond" "github.com/LoveLosita/smartflow/backend/service" "github.com/gin-gonic/gin" + "github.com/google/uuid" ) type AgentHandler struct { @@ -33,9 +35,18 @@ func (api *AgentHandler) ChatAgent(c *gin.Context) { c.JSON(http.StatusBadRequest, respond.WrongParamType) return } + + // 兼容:如果前端没传会话 ID,后端兜底创建一个 + conversationID := strings.TrimSpace(req.ConversationID) + if conversationID == "" { + conversationID = uuid.NewString() + } + // 把最终生效的会话 ID 回传给前端,方便后续继续同一会话 + c.Writer.Header().Set("X-Conversation-ID", conversationID) + userID := c.GetInt("user_id") // 从上下文中获取用户 ID // 3. 调用 Service 层的聊天方法,获取输出通道和错误通道 - outChan, errChan := api.svc.AgentChat(c.Request.Context(), req.Message, req.Thinking, userID, req.ConversationID) + outChan, errChan := api.svc.AgentChat(c.Request.Context(), req.Message, req.Thinking, userID, conversationID) // 4. 循环转发消息/错误 c.Stream(func(w io.Writer) bool { select { diff --git a/backend/dao/agent-cache.go b/backend/dao/agent-cache.go index 7c4a630..c30e723 100644 --- a/backend/dao/agent-cache.go +++ b/backend/dao/agent-cache.go @@ -95,7 +95,7 @@ func (m *AgentCache) BackfillHistory(ctx context.Context, sessionID string, mess for i, msg := range messages { data, err := json.Marshal(msg) if err != nil { - return fmt.Errorf("marshal failed at index %d: %w", err) + return fmt.Errorf("marshal failed at index %d: %w", i, err) } values[i] = data } diff --git a/backend/dao/agent.go b/backend/dao/agent.go index b8d1c54..b6d2903 100644 --- a/backend/dao/agent.go +++ b/backend/dao/agent.go @@ -44,10 +44,18 @@ func (a *AgentDAO) CreateNewChat(userID int, chatID string) (int64, error) { func (a *AgentDAO) GetUserChatHistories(ctx context.Context, userID, limit int, chatID string) ([]model.ChatHistory, error) { var histories []model.ChatHistory - err := a.db.WithContext(ctx).Where("user_id = ? AND chat_id = ?", userID, chatID).Order("created_at desc").Limit(limit).Find(&histories).Error + err := a.db.WithContext(ctx). + Where("user_id = ? AND chat_id = ?", userID, chatID). + Order("created_at desc"). + Limit(limit). + Find(&histories).Error if err != nil { return nil, err } + // 保留“最近 N 条”的前提下,反转为时间正序,便于模型消费 + for i, j := 0, len(histories)-1; i < j; i, j = i+1, j-1 { + histories[i], histories[j] = histories[j], histories[i] + } return histories, nil } diff --git a/backend/service/agent.go b/backend/service/agent.go index b54d527..d8ca440 100644 --- a/backend/service/agent.go +++ b/backend/service/agent.go @@ -3,12 +3,14 @@ package service import ( "context" "log" + "strings" "github.com/LoveLosita/smartflow/backend/agent" "github.com/LoveLosita/smartflow/backend/conv" "github.com/LoveLosita/smartflow/backend/dao" "github.com/LoveLosita/smartflow/backend/inits" "github.com/cloudwego/eino/schema" + "github.com/google/uuid" ) type AgentService struct { @@ -25,10 +27,21 @@ func NewAgentService(aiHub *inits.AIHub, repo *dao.AgentDAO, agentRedis *dao.Age } } +func normalizeConversationID(chatID string) string { + trimmed := strings.TrimSpace(chatID) + if trimmed == "" { + return uuid.NewString() + } + return trimmed +} + func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThinking bool, userID int, chatID string) (<-chan string, <-chan error) { //1. 创建一个输出通道 outChan := make(chan string, 5) errChan := make(chan error, 1) + //补充:会话 ID 兜底,避免上层漏传 + chatID = normalizeConversationID(chatID) + //2. 先确保这个会话存在(如果不存在就创建一个新的) //先看看缓存里面有没有这个会话 result, err := s.agentCache.GetConversationStatus(ctx, chatID) @@ -49,20 +62,23 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin } if !innerResult { //如果会话不存在,先创建一个新的会话 - _, err := s.repo.CreateNewChat(userID, chatID) - if err != nil { + if _, err = s.repo.CreateNewChat(userID, chatID); err != nil { errChan <- err close(outChan) close(errChan) return outChan, errChan } } + //补充:把“会话存在”状态回写缓存,后续请求可直接命中 + if err = s.agentCache.SetConversationStatus(ctx, chatID); err != nil { + //缓存回写失败不影响主流程 + log.Printf("failed to set conversation status cache for %s: %v", chatID, err) + } } //能走到这里,要么缓存里有这个会话,要么数据库里有这个会话了 //4. 提取出历史消息,构建上下文 //先尝试从缓存里拿历史消息 - var chatHistory []*schema.Message - chatHistory, err = s.agentCache.GetHistory(ctx, chatID) + chatHistory, err := s.agentCache.GetHistory(ctx, chatID) if err != nil { errChan <- err close(outChan) @@ -82,8 +98,7 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin //再转换成 Eino 的消息格式 chatHistory = conv.ToEinoMessages(histories) //把历史消息放到缓存里,方便下次直接拿 - err = s.agentCache.BackfillHistory(ctx, chatID, chatHistory) - if err != nil { + if err = s.agentCache.BackfillHistory(ctx, chatID, chatHistory); err != nil { errChan <- err close(outChan) close(errChan) @@ -93,16 +108,18 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin //3. 将用户消息异步落缓存和库 go func() { //这里先不管落库成功与否了,毕竟不想因为落库失败而影响用户的聊天体验 - _ = s.agentCache.PushMessage(ctx, chatID, &schema.Message{ - Role: "user", + bg := context.Background() + _ = s.agentCache.PushMessage(bg, chatID, &schema.Message{ + Role: schema.User, Content: userMessage, }) - _ = s.repo.SaveChatHistory(ctx, userID, chatID, "user", userMessage) + _ = s.repo.SaveChatHistory(bg, userID, chatID, "user", userMessage) }() //5. 启动一个 goroutine 来处理聊天逻辑 go func() { defer close(outChan) // 确保在函数结束时关闭通道 + defer close(errChan) //3. 调用 StreamChat 函数进行流式聊天 fullText, err := agent.StreamChat(ctx, s.AIHub.Worker, userMessage, ifThinking, chatHistory, outChan) if err != nil { @@ -111,13 +128,13 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin } //4. 将 AI 的回复异步落缓存和库 go func() { - _ = s.agentCache.PushMessage(ctx, chatID, &schema.Message{ - Role: "assistant", + bg := context.Background() + _ = s.agentCache.PushMessage(bg, chatID, &schema.Message{ + Role: schema.Assistant, Content: fullText, }) - err = s.repo.SaveChatHistory(context.Background(), userID, chatID, "assistant", fullText) - if err != nil { - log.Printf("Failed to save chat history to database: %v", err) + if saveErr := s.repo.SaveChatHistory(bg, userID, chatID, "assistant", fullText); saveErr != nil { + log.Printf("failed to save chat history to database: %v", saveErr) return } }() diff --git a/infra/smartflow-mcp-server/.editorconfig b/infra/smartflow-mcp-server/.editorconfig new file mode 100644 index 0000000..521afc9 --- /dev/null +++ b/infra/smartflow-mcp-server/.editorconfig @@ -0,0 +1,10 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true + +[*.md] +trim_trailing_whitespace = false diff --git a/infra/smartflow-mcp-server/.env.example b/infra/smartflow-mcp-server/.env.example new file mode 100644 index 0000000..7cbc8d9 --- /dev/null +++ b/infra/smartflow-mcp-server/.env.example @@ -0,0 +1,46 @@ +# ========================= +# MCP server metadata +# ========================= +MCP_SERVER_NAME=smartflow-mcp-server +MCP_SERVER_VERSION=0.1.0 +MCP_PROTOCOL_VERSION=2024-11-05 +MCP_DEFAULT_CALLER=codex + +# ========================= +# Governance & safety +# ========================= +MCP_TOOL_TIMEOUT_MS=5000 +MCP_RATE_LIMIT_RPS=5 +MCP_RATE_LIMIT_BURST=10 +MCP_MAX_RESULT_ROWS=500 +MCP_ENFORCE_WHITELIST=false +MCP_AUDIT_LOG_PATH=logs/audit.log + +# Redis scan/value caps +MCP_REDIS_SCAN_MAX_KEYS=200 +MCP_REDIS_SCAN_MAX_COUNT=200 +MCP_REDIS_VALUE_MAX_ITEMS=100 +MCP_REDIS_MAX_STRING_BYTES=4096 + +# ========================= +# MySQL +# ========================= +MYSQL_HOST=127.0.0.1 +MYSQL_PORT=3306 +MYSQL_USER=readonly_user +MYSQL_PASSWORD=replace_me +MYSQL_DATABASE=smartflow +MYSQL_PARAMS=charset=utf8mb4&parseTime=true&loc=Local + +# Comma-separated whitelist (optional) +# Example: MYSQL_ALLOWED_DATABASES=smartflow,analytics +MYSQL_ALLOWED_DATABASES=smartflow +# Example: MYSQL_ALLOWED_TABLES=smartflow.users,smartflow.tasks +MYSQL_ALLOWED_TABLES=smartflow.users,smartflow.tasks + +# ========================= +# Redis +# ========================= +REDIS_ADDR=127.0.0.1:6379 +REDIS_PASSWORD= +REDIS_DB=0 diff --git a/infra/smartflow-mcp-server/.gitignore b/infra/smartflow-mcp-server/.gitignore new file mode 100644 index 0000000..b9a8e1e --- /dev/null +++ b/infra/smartflow-mcp-server/.gitignore @@ -0,0 +1,2 @@ +.env +logs/*.log diff --git a/infra/smartflow-mcp-server/README.md b/infra/smartflow-mcp-server/README.md new file mode 100644 index 0000000..3fb6a64 --- /dev/null +++ b/infra/smartflow-mcp-server/README.md @@ -0,0 +1,213 @@ +# smartflow-mcp-server (MVP) + +用于让 Codex 通过 MCP(stdio)只读访问 MySQL 与 Redis,面向接口联调与测试。 + +## 1. 功能范围(第一阶段) + +只实现 3 个只读工具: + +1. `mysql_query_readonly` +2. `redis_get` +3. `redis_scan` + +未实现任何写操作工具。 + +## 2. 目录结构 + +```text +infra/smartflow-mcp-server +├─ cmd/server/main.go +├─ internal +│ ├─ audit +│ ├─ config +│ ├─ envutil +│ ├─ mcp +│ ├─ ratelimit +│ ├─ security +│ ├─ store +│ └─ tools +├─ .env.example +├─ go.mod +└─ README.md +``` + +## 3. 快速启动 + +```bash +go mod tidy +go test ./... +go run ./cmd/server +``` + +服务采用 stdio MCP 协议,不会启动 HTTP 端口。 + +## 4. 配置说明(全部来自环境变量) + +复制并编辑: + +```bash +cp .env.example .env +``` + +关键变量: + +- `MYSQL_HOST` / `MYSQL_PORT` / `MYSQL_USER` / `MYSQL_PASSWORD` / `MYSQL_DATABASE` +- `REDIS_ADDR` / `REDIS_PASSWORD` / `REDIS_DB` +- `MYSQL_ALLOWED_DATABASES`:逗号分隔 +- `MYSQL_ALLOWED_TABLES`:逗号分隔,支持 `db.table` 或 `table` +- `MCP_ENFORCE_WHITELIST`:`true` 时无明确表引用会拒绝执行 +- `MCP_TOOL_TIMEOUT_MS`:单次工具调用超时 +- `MCP_RATE_LIMIT_RPS` + `MCP_RATE_LIMIT_BURST`:基础令牌桶限流 +- `MCP_MAX_RESULT_ROWS`:MySQL 最大返回行数 +- `MCP_REDIS_SCAN_MAX_KEYS`:`redis_scan` 最大返回 key 数 +- `MCP_AUDIT_LOG_PATH`:审计日志路径 + +## 5. 工具说明 + +### 5.1 `mysql_query_readonly` + +输入: + +```json +{ + "sql": "SELECT id, name FROM users WHERE id = ?", + "params": [1] +} +``` + +安全限制: + +- 仅允许 `SELECT` / `SHOW` / `DESCRIBE` / `EXPLAIN` +- 禁止分号 `;`(多语句) +- 禁止注释 `--` / `#` / `/* */` +- 禁止 DDL/DML 关键字(`INSERT`/`UPDATE`/`DELETE`/`ALTER`/`DROP`/`TRUNCATE` 等) +- 支持库/表白名单校验 + +输出(结构化): + +- `columns` +- `rows` +- `rowCount` +- `truncated` +- `durationMs` + +### 5.2 `redis_get` + +输入: + +```json +{ + "key": "user:1001" +} +``` + +输出: + +- `exists` +- `key` +- `type` +- `value` +- `truncated` +- `durationMs` + +### 5.3 `redis_scan` + +输入: + +```json +{ + "pattern": "user:*", + "count": 50 +} +``` + +输出: + +- `pattern` +- `keys` +- `returned` +- `nextCursor` +- `truncated` +- `durationMs` + +## 6. 审计日志 + +每次工具调用会记录(JSON 行格式): + +- 时间 +- 工具名 +- 调用方(caller) +- 是否成功 +- 耗时 +- 脱敏后的输入摘要 +- 错误信息(截断) + +敏感字段处理: + +- SQL 字符串字面量与数字会脱敏 +- Redis key 仅保留前后少量字符 + +## 7. Codex MCP 配置示例(stdio) + +可按客户端配置格式接入,示例: + +```json +{ + "mcpServers": { + "smartflow-db-readonly": { + "command": "go", + "args": ["run", "./cmd/server"], + "cwd": "E:/SmartFlow-Agent/infra/smartflow-mcp-server", + "env": { + "MYSQL_HOST": "127.0.0.1", + "MYSQL_PORT": "3306", + "MYSQL_USER": "readonly_user", + "MYSQL_PASSWORD": "replace_me", + "MYSQL_DATABASE": "smartflow", + "MYSQL_ALLOWED_DATABASES": "smartflow", + "MYSQL_ALLOWED_TABLES": "smartflow.users,smartflow.tasks", + "REDIS_ADDR": "127.0.0.1:6379", + "REDIS_DB": "0", + "MCP_TOOL_TIMEOUT_MS": "5000", + "MCP_RATE_LIMIT_RPS": "5", + "MCP_RATE_LIMIT_BURST": "10", + "MCP_MAX_RESULT_ROWS": "500", + "MCP_REDIS_SCAN_MAX_KEYS": "200", + "MCP_AUDIT_LOG_PATH": "logs/audit.log" + } + } + } +} +``` + +## 8. 安全限制生效示例 + +- SQL 多语句:`SELECT 1; SELECT 2` -> 被拒绝(semicolon is not allowed) +- SQL 注释绕过:`SELECT * FROM users --x` -> 被拒绝(sql comments are not allowed) +- 写操作:`DELETE FROM users` -> 被拒绝(dangerous sql keyword detected) +- 白名单外表:`SELECT * FROM admin.secret` -> 被拒绝(table not in whitelist) +- Redis 大范围扫描:`redis_scan` 返回数量受 `MCP_REDIS_SCAN_MAX_KEYS` 限制 + +## 9. 风险说明(MVP 已知边界) + +1. SQL 校验采用关键字与模式匹配,不是完整 SQL AST 解析,建议二阶段引入 AST 级校验。 +2. `SHOW DATABASES` 等无显式表引用语句在非严格模式下可执行;生产建议开启 `MCP_ENFORCE_WHITELIST=true`。 +3. Redis 复杂类型返回做了截断保护,但仍建议在生产环境设置更小上限。 + +## 10. 常见问题(FAQ) + +### Q1: 启动时报 `MYSQL_USER and MYSQL_DATABASE are required` + +检查环境变量是否正确加载,建议先确认 `.env` 存在于 `infra/smartflow-mcp-server`。 + +### Q2: 为什么调用工具报限流 + +默认启用了令牌桶限流,调大 `MCP_RATE_LIMIT_RPS` 与 `MCP_RATE_LIMIT_BURST` 即可。 + +### Q3: 为什么 `redis_scan` 返回不全 + +是预期行为,结果数被 `MCP_REDIS_SCAN_MAX_KEYS` 限制,避免全量扫描拖垮 Redis。 + +### Q4: 审计日志在哪里 + +默认在 `logs/audit.log`,可用 `MCP_AUDIT_LOG_PATH` 自定义。 diff --git a/infra/smartflow-mcp-server/cmd/server/main.go b/infra/smartflow-mcp-server/cmd/server/main.go new file mode 100644 index 0000000..dc9fd0f --- /dev/null +++ b/infra/smartflow-mcp-server/cmd/server/main.go @@ -0,0 +1,90 @@ +package main + +import ( + "context" + "log" + "os" + "os/signal" + "syscall" + + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/audit" + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config" + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/envutil" + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/mcp" + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/ratelimit" + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security" + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store" + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/tools" +) + +func main() { + if err := envutil.LoadDotEnv(".env"); err != nil { + log.Fatalf("load .env failed: %v", err) + } + + cfg, err := config.LoadFromEnv() + if err != nil { + log.Fatalf("load config failed: %v", err) + } + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + auditLogger, err := audit.New(cfg.AuditLogPath) + if err != nil { + log.Fatalf("init audit logger failed: %v", err) + } + defer func() { + _ = auditLogger.Close() + }() + + mysqlClient, err := store.NewMySQLClient(ctx, cfg.MySQL) + if err != nil { + log.Fatalf("init mysql failed: %v", err) + } + defer func() { + _ = mysqlClient.Close() + }() + + redisClient, err := store.NewRedisClient(ctx, cfg.Redis) + if err != nil { + log.Fatalf("init redis failed: %v", err) + } + defer func() { + _ = redisClient.Close() + }() + + sqlValidator := security.NewSQLValidator( + cfg.MySQL.Database, + cfg.EnforceWhitelist, + cfg.MySQL.AllowedDatabases, + cfg.MySQL.AllowedTables, + ) + + registry, err := tools.NewRegistry( + tools.NewMySQLReadOnlyTool(mysqlClient, sqlValidator, cfg.MaxResultRows), + tools.NewRedisGetTool(redisClient, cfg.RedisValueMaxItems, cfg.RedisMaxStringBytes), + tools.NewRedisScanTool(redisClient, cfg.RedisScanMaxKeys, cfg.RedisScanMaxCount), + ) + if err != nil { + log.Fatalf("init tool registry failed: %v", err) + } + + limiter := ratelimit.New(cfg.RateLimitRPS, cfg.RateLimitBurst) + server := mcp.NewServer( + os.Stdin, + os.Stdout, + registry, + auditLogger, + limiter, + cfg.ServerName, + cfg.ServerVersion, + cfg.ProtocolVersion, + cfg.DefaultCaller, + cfg.ToolTimeout, + ) + + if err := server.Serve(ctx); err != nil { + log.Fatalf("mcp server exited with error: %v", err) + } +} diff --git a/infra/smartflow-mcp-server/go.mod b/infra/smartflow-mcp-server/go.mod new file mode 100644 index 0000000..12c018c --- /dev/null +++ b/infra/smartflow-mcp-server/go.mod @@ -0,0 +1,14 @@ +module github.com/LoveLosita/smartflow/infra/smartflow-mcp-server + +go 1.23.4 + +require ( + github.com/go-redis/redis/v8 v8.11.5 + github.com/go-sql-driver/mysql v1.8.1 +) + +require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect +) diff --git a/infra/smartflow-mcp-server/go.sum b/infra/smartflow-mcp-server/go.sum new file mode 100644 index 0000000..88effed --- /dev/null +++ b/infra/smartflow-mcp-server/go.sum @@ -0,0 +1,28 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= +github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= +github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/infra/smartflow-mcp-server/internal/audit/logger.go b/infra/smartflow-mcp-server/internal/audit/logger.go new file mode 100644 index 0000000..087a058 --- /dev/null +++ b/infra/smartflow-mcp-server/internal/audit/logger.go @@ -0,0 +1,60 @@ +package audit + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" +) + +type Logger struct { + mu sync.Mutex + file *os.File +} + +type Record struct { + Timestamp time.Time `json:"timestamp"` + Tool string `json:"tool"` + Caller string `json:"caller"` + Success bool `json:"success"` + DurationMs int64 `json:"duration_ms"` + Meta map[string]any `json:"meta,omitempty"` + Error string `json:"error,omitempty"` +} + +func New(path string) (*Logger, error) { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return nil, fmt.Errorf("create audit dir: %w", err) + } + f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + if err != nil { + return nil, fmt.Errorf("open audit log file: %w", err) + } + return &Logger{file: f}, nil +} + +func (l *Logger) Close() error { + if l == nil || l.file == nil { + return nil + } + return l.file.Close() +} + +func (l *Logger) Log(record Record) { + if l == nil || l.file == nil { + return + } + if record.Timestamp.IsZero() { + record.Timestamp = time.Now() + } + body, err := json.Marshal(record) + if err != nil { + return + } + l.mu.Lock() + defer l.mu.Unlock() + _, _ = l.file.Write(append(body, '\n')) +} diff --git a/infra/smartflow-mcp-server/internal/config/config.go b/infra/smartflow-mcp-server/internal/config/config.go new file mode 100644 index 0000000..3d4c7c8 --- /dev/null +++ b/infra/smartflow-mcp-server/internal/config/config.go @@ -0,0 +1,163 @@ +package config + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" +) + +type Config struct { + ServerName string + ServerVersion string + ProtocolVersion string + DefaultCaller string + ToolTimeout time.Duration + RateLimitRPS float64 + RateLimitBurst float64 + MaxResultRows int + AuditLogPath string + EnforceWhitelist bool + RedisScanMaxKeys int + RedisScanMaxCount int + RedisValueMaxItems int + RedisMaxStringBytes int + + MySQL MySQLConfig + Redis RedisConfig +} + +type MySQLConfig struct { + Host string + Port int + User string + Password string + Database string + Params string + AllowedDatabases []string + AllowedTables []string +} + +type RedisConfig struct { + Addr string + Password string + DB int +} + +func LoadFromEnv() (Config, error) { + cfg := Config{ + ServerName: getEnv("MCP_SERVER_NAME", "smartflow-mcp-server"), + ServerVersion: getEnv("MCP_SERVER_VERSION", "0.1.0"), + ProtocolVersion: getEnv("MCP_PROTOCOL_VERSION", "2024-11-05"), + DefaultCaller: getEnv("MCP_DEFAULT_CALLER", "unknown"), + ToolTimeout: getEnvDurationMS("MCP_TOOL_TIMEOUT_MS", 5000), + RateLimitRPS: getEnvFloat("MCP_RATE_LIMIT_RPS", 5), + RateLimitBurst: getEnvFloat("MCP_RATE_LIMIT_BURST", 10), + MaxResultRows: getEnvInt("MCP_MAX_RESULT_ROWS", 500), + AuditLogPath: getEnv("MCP_AUDIT_LOG_PATH", "logs/audit.log"), + EnforceWhitelist: getEnvBool("MCP_ENFORCE_WHITELIST", false), + RedisScanMaxKeys: getEnvInt("MCP_REDIS_SCAN_MAX_KEYS", 200), + RedisScanMaxCount: getEnvInt("MCP_REDIS_SCAN_MAX_COUNT", 200), + RedisValueMaxItems: getEnvInt("MCP_REDIS_VALUE_MAX_ITEMS", 100), + RedisMaxStringBytes: getEnvInt("MCP_REDIS_MAX_STRING_BYTES", 4096), + MySQL: MySQLConfig{ + Host: getEnv("MYSQL_HOST", "127.0.0.1"), + Port: getEnvInt("MYSQL_PORT", 3306), + User: getEnv("MYSQL_USER", ""), + Password: getEnv("MYSQL_PASSWORD", ""), + Database: getEnv("MYSQL_DATABASE", ""), + Params: getEnv("MYSQL_PARAMS", "charset=utf8mb4&parseTime=true&loc=Local"), + AllowedDatabases: splitCommaList(getEnv("MYSQL_ALLOWED_DATABASES", "")), + AllowedTables: splitCommaList(getEnv("MYSQL_ALLOWED_TABLES", "")), + }, + Redis: RedisConfig{ + Addr: getEnv("REDIS_ADDR", "127.0.0.1:6379"), + Password: getEnv("REDIS_PASSWORD", ""), + DB: getEnvInt("REDIS_DB", 0), + }, + } + + if cfg.MySQL.User == "" || cfg.MySQL.Database == "" { + return Config{}, fmt.Errorf("MYSQL_USER and MYSQL_DATABASE are required") + } + if cfg.Redis.Addr == "" { + return Config{}, fmt.Errorf("REDIS_ADDR is required") + } + if cfg.MaxResultRows <= 0 { + return Config{}, fmt.Errorf("MCP_MAX_RESULT_ROWS must be > 0") + } + if cfg.RedisScanMaxKeys <= 0 { + return Config{}, fmt.Errorf("MCP_REDIS_SCAN_MAX_KEYS must be > 0") + } + if cfg.RedisScanMaxCount <= 0 { + return Config{}, fmt.Errorf("MCP_REDIS_SCAN_MAX_COUNT must be > 0") + } + if cfg.ToolTimeout <= 0 { + return Config{}, fmt.Errorf("MCP_TOOL_TIMEOUT_MS must be > 0") + } + if cfg.RateLimitRPS <= 0 || cfg.RateLimitBurst <= 0 { + return Config{}, fmt.Errorf("MCP_RATE_LIMIT_RPS and MCP_RATE_LIMIT_BURST must be > 0") + } + + return cfg, nil +} + +func getEnv(key string, defaultValue string) string { + if v, ok := os.LookupEnv(key); ok { + return strings.TrimSpace(v) + } + return defaultValue +} + +func getEnvInt(key string, defaultValue int) int { + v := getEnv(key, "") + if v == "" { + return defaultValue + } + n, err := strconv.Atoi(v) + if err != nil { + return defaultValue + } + return n +} + +func getEnvFloat(key string, defaultValue float64) float64 { + v := getEnv(key, "") + if v == "" { + return defaultValue + } + n, err := strconv.ParseFloat(v, 64) + if err != nil { + return defaultValue + } + return n +} + +func getEnvBool(key string, defaultValue bool) bool { + v := strings.ToLower(getEnv(key, "")) + if v == "" { + return defaultValue + } + return v == "1" || v == "true" || v == "yes" || v == "on" +} + +func getEnvDurationMS(key string, defaultValueMs int) time.Duration { + ms := getEnvInt(key, defaultValueMs) + return time.Duration(ms) * time.Millisecond +} + +func splitCommaList(raw string) []string { + if strings.TrimSpace(raw) == "" { + return nil + } + parts := strings.Split(raw, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + trimmed := strings.TrimSpace(p) + if trimmed != "" { + out = append(out, strings.ToLower(trimmed)) + } + } + return out +} diff --git a/infra/smartflow-mcp-server/internal/envutil/loader.go b/infra/smartflow-mcp-server/internal/envutil/loader.go new file mode 100644 index 0000000..0a5544c --- /dev/null +++ b/infra/smartflow-mcp-server/internal/envutil/loader.go @@ -0,0 +1,44 @@ +package envutil + +import ( + "bufio" + "fmt" + "os" + "strings" +) + +func LoadDotEnv(path string) error { + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("open .env: %w", err) + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + i := strings.Index(line, "=") + if i <= 0 { + continue + } + key := strings.TrimSpace(line[:i]) + value := strings.TrimSpace(line[i+1:]) + value = strings.Trim(value, "\"") + if key == "" { + continue + } + if _, exists := os.LookupEnv(key); !exists { + _ = os.Setenv(key, value) + } + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("scan .env: %w", err) + } + return nil +} diff --git a/infra/smartflow-mcp-server/internal/mcp/server.go b/infra/smartflow-mcp-server/internal/mcp/server.go new file mode 100644 index 0000000..57dfdbf --- /dev/null +++ b/infra/smartflow-mcp-server/internal/mcp/server.go @@ -0,0 +1,394 @@ +package mcp + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "strconv" + "strings" + "sync" + "time" + + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/audit" + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/ratelimit" + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security" + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/tools" +) + +const ( + jsonRPCVersion = "2.0" + + errCodeParseError = -32700 + errCodeInvalidRequest = -32600 + errCodeMethodNotFound = -32601 + errCodeInvalidParams = -32602 + errCodeInternalError = -32603 +) + +type Server struct { + reader *bufio.Reader + writer io.Writer + writeMu sync.Mutex + registry *tools.Registry + auditLogger *audit.Logger + limiter *ratelimit.Limiter + serverName string + serverVersion string + protocolVersion string + defaultCaller string + toolTimeout time.Duration +} + +type request struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +type response struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Result any `json:"result,omitempty"` + Error *respError `json:"error,omitempty"` +} + +type respError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type initializeParams struct { + ProtocolVersion string `json:"protocolVersion"` +} + +type toolCallParams struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + Meta map[string]any `json:"_meta,omitempty"` +} + +func NewServer( + in io.Reader, + out io.Writer, + registry *tools.Registry, + auditLogger *audit.Logger, + limiter *ratelimit.Limiter, + serverName string, + serverVersion string, + protocolVersion string, + defaultCaller string, + toolTimeout time.Duration, +) *Server { + return &Server{ + reader: bufio.NewReader(in), + writer: out, + registry: registry, + auditLogger: auditLogger, + limiter: limiter, + serverName: serverName, + serverVersion: serverVersion, + protocolVersion: protocolVersion, + defaultCaller: defaultCaller, + toolTimeout: toolTimeout, + } +} + +func (s *Server) Serve(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return nil + default: + } + + body, err := readMessage(s.reader) + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + log.Printf("read message failed: %v", err) + continue + } + + var req request + if err := json.Unmarshal(body, &req); err != nil { + _ = s.writeError(nil, errCodeParseError, "invalid json") + continue + } + if req.Method == "" || req.JSONRPC != jsonRPCVersion { + _ = s.writeError(req.ID, errCodeInvalidRequest, "invalid json-rpc request") + continue + } + + if err := s.handleRequest(ctx, req); err != nil { + log.Printf("handle request failed: %v", err) + } + } +} + +func (s *Server) handleRequest(ctx context.Context, req request) error { + switch req.Method { + case "initialize": + return s.handleInitialize(req) + case "notifications/initialized": + return nil + case "tools/list": + if len(req.ID) == 0 { + return nil + } + return s.writeResult(req.ID, map[string]any{"tools": s.registry.List()}) + case "tools/call": + if len(req.ID) == 0 { + return nil + } + return s.handleToolCall(ctx, req) + case "ping": + if len(req.ID) == 0 { + return nil + } + return s.writeResult(req.ID, map[string]any{}) + default: + if len(req.ID) == 0 { + return nil + } + return s.writeError(req.ID, errCodeMethodNotFound, "method not found") + } +} + +func (s *Server) handleInitialize(req request) error { + if len(req.ID) == 0 { + return nil + } + + var params initializeParams + if len(req.Params) > 0 { + _ = json.Unmarshal(req.Params, ¶ms) + } + _ = params + + return s.writeResult(req.ID, map[string]any{ + "protocolVersion": s.protocolVersion, + "capabilities": map[string]any{ + "tools": map[string]any{"listChanged": false}, + }, + "serverInfo": map[string]any{ + "name": s.serverName, + "version": s.serverVersion, + }, + }) +} + +func (s *Server) handleToolCall(ctx context.Context, req request) error { + var params toolCallParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return s.writeError(req.ID, errCodeInvalidParams, "invalid tool call params") + } + if params.Name == "" { + return s.writeError(req.ID, errCodeInvalidParams, "tool name is required") + } + if params.Arguments == nil { + params.Arguments = map[string]any{} + } + + caller := extractCaller(params, s.defaultCaller) + rateKey := fmt.Sprintf("%s:%s", caller, params.Name) + if !s.limiter.Allow(rateKey) { + result := buildToolErrorResult("rate limit exceeded") + s.auditLogger.Log(audit.Record{ + Tool: params.Name, + Caller: caller, + Success: false, + DurationMs: 0, + Meta: sanitizeAuditInput(params.Name, params.Arguments), + Error: "rate limit exceeded", + }) + return s.writeResult(req.ID, result) + } + + tool, ok := s.registry.Find(params.Name) + if !ok { + result := buildToolErrorResult("tool not found") + s.auditLogger.Log(audit.Record{ + Tool: params.Name, + Caller: caller, + Success: false, + DurationMs: 0, + Meta: sanitizeAuditInput(params.Name, params.Arguments), + Error: "tool not found", + }) + return s.writeResult(req.ID, result) + } + + start := time.Now() + toolCtx, cancel := context.WithTimeout(ctx, s.toolTimeout) + defer cancel() + + output, err := tool.Execute(toolCtx, params.Arguments) + duration := time.Since(start).Milliseconds() + if err != nil { + errMsg := sanitizeError(err) + s.auditLogger.Log(audit.Record{ + Tool: params.Name, + Caller: caller, + Success: false, + DurationMs: duration, + Meta: sanitizeAuditInput(params.Name, params.Arguments), + Error: errMsg, + }) + return s.writeResult(req.ID, buildToolErrorResult(errMsg)) + } + + s.auditLogger.Log(audit.Record{ + Tool: params.Name, + Caller: caller, + Success: true, + DurationMs: duration, + Meta: sanitizeAuditInput(params.Name, params.Arguments), + }) + return s.writeResult(req.ID, buildToolSuccessResult(output)) +} + +func buildToolSuccessResult(output map[string]any) map[string]any { + asJSON, _ := json.Marshal(output) + return map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": string(asJSON)}, + }, + "structuredContent": output, + } +} + +func buildToolErrorResult(message string) map[string]any { + return map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": message}, + }, + "structuredContent": map[string]any{"error": message}, + "isError": true, + } +} + +func extractCaller(params toolCallParams, fallback string) string { + if params.Meta != nil { + if v, ok := params.Meta["caller"]; ok { + if s, ok := v.(string); ok && strings.TrimSpace(s) != "" { + return strings.TrimSpace(s) + } + } + } + if v, ok := params.Arguments["caller"]; ok { + if s, ok := v.(string); ok && strings.TrimSpace(s) != "" { + return strings.TrimSpace(s) + } + } + return fallback +} + +func sanitizeAuditInput(toolName string, args map[string]any) map[string]any { + meta := map[string]any{} + switch toolName { + case "mysql_query_readonly": + sql, _ := args["sql"].(string) + meta["sql"] = security.RedactSQL(sql) + if params, ok := args["params"].([]any); ok { + meta["paramCount"] = len(params) + } + case "redis_get": + if key, ok := args["key"].(string); ok { + meta["key"] = security.RedactKey(key) + } + case "redis_scan": + if pattern, ok := args["pattern"].(string); ok { + if len(pattern) > 40 { + meta["pattern"] = pattern[:40] + } else { + meta["pattern"] = pattern + } + } + if count, ok := args["count"]; ok { + meta["count"] = count + } + default: + meta["args"] = "masked" + } + return meta +} + +func sanitizeError(err error) string { + msg := err.Error() + msg = strings.ReplaceAll(msg, "\n", " ") + msg = strings.ReplaceAll(msg, "\r", " ") + if len(msg) > 300 { + return msg[:300] + } + return msg +} + +func (s *Server) writeResult(id json.RawMessage, result any) error { + resp := response{JSONRPC: jsonRPCVersion, ID: id, Result: result} + return s.writeMessage(resp) +} + +func (s *Server) writeError(id json.RawMessage, code int, message string) error { + resp := response{ + JSONRPC: jsonRPCVersion, + ID: id, + Error: &respError{Code: code, Message: message}, + } + return s.writeMessage(resp) +} + +func (s *Server) writeMessage(payload any) error { + body, err := json.Marshal(payload) + if err != nil { + return err + } + header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(body)) + s.writeMu.Lock() + defer s.writeMu.Unlock() + if _, err := s.writer.Write([]byte(header)); err != nil { + return err + } + _, err = s.writer.Write(body) + return err +} + +func readMessage(reader *bufio.Reader) ([]byte, error) { + contentLength := -1 + for { + line, err := reader.ReadString('\n') + if err != nil { + return nil, err + } + line = strings.TrimRight(line, "\r\n") + if line == "" { + break + } + parts := strings.SplitN(line, ":", 2) + if len(parts) != 2 { + continue + } + headerName := strings.TrimSpace(parts[0]) + headerValue := strings.TrimSpace(parts[1]) + if strings.EqualFold(headerName, "Content-Length") { + n, err := strconv.Atoi(headerValue) + if err != nil || n < 0 { + return nil, fmt.Errorf("invalid Content-Length") + } + contentLength = n + } + } + if contentLength < 0 { + return nil, fmt.Errorf("missing Content-Length") + } + body := make([]byte, contentLength) + if _, err := io.ReadFull(reader, body); err != nil { + return nil, err + } + return body, nil +} diff --git a/infra/smartflow-mcp-server/internal/mcp/server_test.go b/infra/smartflow-mcp-server/internal/mcp/server_test.go new file mode 100644 index 0000000..e6776e0 --- /dev/null +++ b/infra/smartflow-mcp-server/internal/mcp/server_test.go @@ -0,0 +1,56 @@ +package mcp + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/ratelimit" + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/tools" +) + +type dummyTool struct{} + +func (d dummyTool) Name() string { return "dummy" } +func (d dummyTool) Description() string { return "dummy" } +func (d dummyTool) InputSchema() map[string]any { + return map[string]any{"type": "object"} +} +func (d dummyTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) { + return map[string]any{"ok": true}, nil +} + +func TestHandleInitialize(t *testing.T) { + registry, err := tools.NewRegistry(dummyTool{}) + if err != nil { + t.Fatal(err) + } + out := bytes.NewBuffer(nil) + s := NewServer(bytes.NewBuffer(nil), out, registry, nil, ratelimit.New(10, 10), "name", "ver", "2024-11-05", "unknown", time.Second) + + req := request{JSONRPC: jsonRPCVersion, ID: json.RawMessage("1"), Method: "initialize", Params: json.RawMessage(`{"protocolVersion":"2024-11-05"}`)} + if err := s.handleRequest(context.Background(), req); err != nil { + t.Fatalf("handle initialize error: %v", err) + } + if !bytes.Contains(out.Bytes(), []byte("protocolVersion")) { + t.Fatalf("response missing protocolVersion: %s", out.String()) + } +} + +func TestReadMessage(t *testing.T) { + body := []byte(`{"jsonrpc":"2.0","id":1,"method":"ping"}`) + frame := fmt.Sprintf("Content-Length: %d\r\n\r\n%s", len(body), string(body)) + input := bufio.NewReader(bytes.NewBufferString(frame)) + + msg, err := readMessage(input) + if err != nil { + t.Fatalf("read message failed: %v", err) + } + if !bytes.Equal(msg, body) { + t.Fatalf("unexpected body: %s", string(msg)) + } +} diff --git a/infra/smartflow-mcp-server/internal/ratelimit/limiter.go b/infra/smartflow-mcp-server/internal/ratelimit/limiter.go new file mode 100644 index 0000000..a02a2e1 --- /dev/null +++ b/infra/smartflow-mcp-server/internal/ratelimit/limiter.go @@ -0,0 +1,51 @@ +package ratelimit + +import ( + "sync" + "time" +) + +type Limiter struct { + mu sync.Mutex + rate float64 + burst float64 + buckets map[string]*bucket +} + +type bucket struct { + tokens float64 + last time.Time +} + +func New(rate, burst float64) *Limiter { + return &Limiter{ + rate: rate, + burst: burst, + buckets: make(map[string]*bucket), + } +} + +func (l *Limiter) Allow(key string) bool { + now := time.Now() + l.mu.Lock() + defer l.mu.Unlock() + + b, ok := l.buckets[key] + if !ok { + l.buckets[key] = &bucket{tokens: l.burst - 1, last: now} + return true + } + + elapsed := now.Sub(b.last).Seconds() + b.tokens += elapsed * l.rate + if b.tokens > l.burst { + b.tokens = l.burst + } + b.last = now + + if b.tokens < 1 { + return false + } + b.tokens -= 1 + return true +} diff --git a/infra/smartflow-mcp-server/internal/ratelimit/limiter_test.go b/infra/smartflow-mcp-server/internal/ratelimit/limiter_test.go new file mode 100644 index 0000000..85fb135 --- /dev/null +++ b/infra/smartflow-mcp-server/internal/ratelimit/limiter_test.go @@ -0,0 +1,26 @@ +package ratelimit + +import ( + "testing" + "time" +) + +func TestLimiter(t *testing.T) { + l := New(2, 2) + key := "user:tool" + + if !l.Allow(key) { + t.Fatal("first request should pass") + } + if !l.Allow(key) { + t.Fatal("second request should pass") + } + if l.Allow(key) { + t.Fatal("third request should be rate limited") + } + + time.Sleep(600 * time.Millisecond) + if !l.Allow(key) { + t.Fatal("request should pass after token refill") + } +} diff --git a/infra/smartflow-mcp-server/internal/security/redact.go b/infra/smartflow-mcp-server/internal/security/redact.go new file mode 100644 index 0000000..771840e --- /dev/null +++ b/infra/smartflow-mcp-server/internal/security/redact.go @@ -0,0 +1,27 @@ +package security + +import ( + "regexp" + "strings" +) + +var ( + singleQuotedString = regexp.MustCompile(`'([^'\\]|\\.)*'`) + doubleQuotedString = regexp.MustCompile(`"([^"\\]|\\.)*"`) + numericLiteral = regexp.MustCompile(`\b\d+\b`) +) + +func RedactSQL(sql string) string { + masked := singleQuotedString.ReplaceAllString(sql, "'***'") + masked = doubleQuotedString.ReplaceAllString(masked, `"***"`) + masked = numericLiteral.ReplaceAllString(masked, "?") + return strings.TrimSpace(masked) +} + +func RedactKey(key string) string { + key = strings.TrimSpace(key) + if len(key) <= 4 { + return "****" + } + return key[:2] + "***" + key[len(key)-2:] +} diff --git a/infra/smartflow-mcp-server/internal/security/sql_validator.go b/infra/smartflow-mcp-server/internal/security/sql_validator.go new file mode 100644 index 0000000..92a798f --- /dev/null +++ b/infra/smartflow-mcp-server/internal/security/sql_validator.go @@ -0,0 +1,193 @@ +package security + +import ( + "fmt" + "regexp" + "strings" +) + +var ( + commentPattern = regexp.MustCompile(`(?s)/\*.*?\*/|--|#`) + forbiddenWords = regexp.MustCompile(`(?i)\b(INSERT|UPDATE|DELETE|ALTER|DROP|TRUNCATE|CREATE|REPLACE|RENAME|GRANT|REVOKE|MERGE|CALL|EXEC|LOCK|UNLOCK|LOAD|OUTFILE|INFILE|HANDLER|SET|USE)\b`) + fromJoinRef = regexp.MustCompile(`(?i)\b(?:FROM|JOIN)\s+([` + "`" + `"\w\.]+)`) + describeRef = regexp.MustCompile(`(?i)\b(?:DESCRIBE|DESC)\s+([` + "`" + `"\w\.]+)`) + showFromRef = regexp.MustCompile(`(?i)\bSHOW\b[\w\s]*\b(?:FROM|IN)\s+([` + "`" + `"\w\.]+)`) +) + +type SQLValidator struct { + defaultDatabase string + enforceWhitelist bool + allowedDatabases map[string]struct{} + allowedTables map[string]struct{} + allowedTableOnly map[string]struct{} +} + +type sqlObjectRef struct { + database string + table string +} + +func NewSQLValidator(defaultDatabase string, enforceWhitelist bool, allowedDatabases []string, allowedTables []string) *SQLValidator { + v := &SQLValidator{ + defaultDatabase: strings.ToLower(strings.TrimSpace(defaultDatabase)), + enforceWhitelist: enforceWhitelist, + allowedDatabases: make(map[string]struct{}), + allowedTables: make(map[string]struct{}), + allowedTableOnly: make(map[string]struct{}), + } + for _, db := range allowedDatabases { + db = strings.ToLower(strings.TrimSpace(db)) + if db != "" { + v.allowedDatabases[db] = struct{}{} + } + } + for _, tbl := range allowedTables { + tbl = strings.ToLower(strings.TrimSpace(strings.Trim(tbl, "`\""))) + if tbl == "" { + continue + } + v.allowedTables[tbl] = struct{}{} + if dot := strings.LastIndex(tbl, "."); dot >= 0 && dot < len(tbl)-1 { + v.allowedTableOnly[tbl[dot+1:]] = struct{}{} + } else { + v.allowedTableOnly[tbl] = struct{}{} + } + } + return v +} + +func (v *SQLValidator) ValidateReadOnlySQL(sql string) error { + trimmed := strings.TrimSpace(sql) + if trimmed == "" { + return fmt.Errorf("sql is required") + } + if strings.Contains(trimmed, ";") { + return fmt.Errorf("semicolon is not allowed") + } + if commentPattern.MatchString(trimmed) { + return fmt.Errorf("sql comments are not allowed") + } + + first := strings.ToUpper(firstToken(trimmed)) + switch first { + case "SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN": + default: + return fmt.Errorf("only SELECT/SHOW/DESCRIBE/EXPLAIN are allowed") + } + + withoutLiterals := removeStringLiterals(trimmed) + if forbiddenWords.MatchString(strings.ToUpper(withoutLiterals)) { + return fmt.Errorf("dangerous sql keyword detected") + } + + refs := extractObjectRefs(withoutLiterals) + if err := v.validateWhitelist(refs); err != nil { + return err + } + return nil +} + +func (v *SQLValidator) validateWhitelist(refs []sqlObjectRef) error { + hasDBAllowlist := len(v.allowedDatabases) > 0 + hasTableAllowlist := len(v.allowedTables) > 0 + + if !v.enforceWhitelist && !hasDBAllowlist && !hasTableAllowlist { + return nil + } + if len(refs) == 0 { + if v.enforceWhitelist { + return fmt.Errorf("sql does not contain explicit table reference under whitelist mode") + } + return nil + } + + for _, ref := range refs { + db := ref.database + tbl := ref.table + if db == "" { + db = v.defaultDatabase + } + + if hasDBAllowlist { + if db == "" { + return fmt.Errorf("database is not explicit and no default database set") + } + if _, ok := v.allowedDatabases[db]; !ok { + return fmt.Errorf("database %s is not in whitelist", db) + } + } + + if hasTableAllowlist { + full := tbl + if db != "" { + full = db + "." + tbl + } + if _, ok := v.allowedTables[full]; ok { + continue + } + if _, ok := v.allowedTableOnly[tbl]; ok { + continue + } + return fmt.Errorf("table %s is not in whitelist", full) + } + } + + return nil +} + +func extractObjectRefs(sql string) []sqlObjectRef { + refs := make([]sqlObjectRef, 0) + seen := make(map[string]struct{}) + + for _, re := range []*regexp.Regexp{fromJoinRef, describeRef, showFromRef} { + matches := re.FindAllStringSubmatch(sql, -1) + for _, m := range matches { + if len(m) < 2 { + continue + } + ref := normalizeRef(m[1]) + if ref.table == "" { + continue + } + key := ref.database + "." + ref.table + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + refs = append(refs, ref) + } + } + return refs +} + +func normalizeRef(raw string) sqlObjectRef { + clean := strings.ToLower(strings.TrimSpace(raw)) + clean = strings.Trim(clean, "`\"'(),") + clean = strings.TrimSpace(clean) + if clean == "" { + return sqlObjectRef{} + } + parts := strings.Split(clean, ".") + if len(parts) == 1 { + return sqlObjectRef{table: parts[0]} + } + return sqlObjectRef{database: parts[0], table: parts[len(parts)-1]} +} + +func firstToken(sql string) string { + for i, r := range sql { + if r == ' ' || r == '\n' || r == '\t' || r == '\r' { + if i == 0 { + continue + } + return sql[:i] + } + } + return sql +} + +func removeStringLiterals(sql string) string { + masked := singleQuotedString.ReplaceAllString(sql, "''") + masked = doubleQuotedString.ReplaceAllString(masked, `""`) + return masked +} diff --git a/infra/smartflow-mcp-server/internal/security/sql_validator_test.go b/infra/smartflow-mcp-server/internal/security/sql_validator_test.go new file mode 100644 index 0000000..0d0e203 --- /dev/null +++ b/infra/smartflow-mcp-server/internal/security/sql_validator_test.go @@ -0,0 +1,44 @@ +package security + +import "testing" + +func TestValidateReadOnlySQL(t *testing.T) { + validator := NewSQLValidator("smartflow", true, []string{"smartflow"}, []string{"smartflow.users", "smartflow.tasks"}) + + tests := []struct { + name string + sql string + wantErr bool + }{ + {name: "allow select", sql: "SELECT id, name FROM users WHERE id = 1", wantErr: false}, + {name: "allow explain", sql: "EXPLAIN SELECT * FROM tasks", wantErr: false}, + {name: "reject insert", sql: "INSERT INTO users(name) VALUES('x')", wantErr: true}, + {name: "reject multi statement", sql: "SELECT * FROM users; SELECT * FROM tasks", wantErr: true}, + {name: "reject comment", sql: "SELECT * FROM users -- bypass", wantErr: true}, + {name: "reject not whitelisted table", sql: "SELECT * FROM orders", wantErr: true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := validator.ValidateReadOnlySQL(tc.sql) + if tc.wantErr && err == nil { + t.Fatalf("expected error, got nil") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, got %v", err) + } + }) + } +} + +func TestRedact(t *testing.T) { + masked := RedactSQL("SELECT * FROM users WHERE token='abc123' AND id=42") + if masked == "" || masked == "SELECT * FROM users WHERE token='abc123' AND id=42" { + t.Fatalf("redaction not applied: %s", masked) + } + + key := RedactKey("very-sensitive-key") + if key == "very-sensitive-key" { + t.Fatalf("key not redacted") + } +} diff --git a/infra/smartflow-mcp-server/internal/store/mysql.go b/infra/smartflow-mcp-server/internal/store/mysql.go new file mode 100644 index 0000000..71338ce --- /dev/null +++ b/infra/smartflow-mcp-server/internal/store/mysql.go @@ -0,0 +1,198 @@ +package store + +import ( + "context" + "database/sql" + "fmt" + "net/url" + "reflect" + "strconv" + "strings" + "time" + + cfgpkg "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config" + "github.com/go-sql-driver/mysql" +) + +type MySQLClient struct { + db *sql.DB +} + +type QueryColumn struct { + Name string `json:"name"` + DatabaseType string `json:"databaseType"` + Nullable *bool `json:"nullable,omitempty"` + ScanType string `json:"scanType,omitempty"` +} + +type QueryResult struct { + Columns []QueryColumn `json:"columns"` + Rows []map[string]any `json:"rows"` + RowCount int `json:"rowCount"` + Truncated bool `json:"truncated"` + DurationMs int64 `json:"durationMs"` +} + +func NewMySQLClient(ctx context.Context, cfg cfgpkg.MySQLConfig) (*MySQLClient, error) { + dsn := mysql.Config{ + User: cfg.User, + Passwd: cfg.Password, + Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), + Net: "tcp", + DBName: cfg.Database, + AllowNativePasswords: true, + ParseTime: true, + } + + applyMySQLParams(&dsn, cfg.Params) + + db, err := sql.Open("mysql", dsn.FormatDSN()) + if err != nil { + return nil, fmt.Errorf("open mysql: %w", err) + } + db.SetConnMaxLifetime(5 * time.Minute) + db.SetMaxOpenConns(5) + db.SetMaxIdleConns(5) + + pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + if err := db.PingContext(pingCtx); err != nil { + _ = db.Close() + return nil, fmt.Errorf("ping mysql: %w", err) + } + + return &MySQLClient{db: db}, nil +} + +func applyMySQLParams(cfg *mysql.Config, raw string) { + if strings.TrimSpace(raw) == "" { + return + } + values, err := url.ParseQuery(raw) + if err != nil { + return + } + if cfg.Params == nil { + cfg.Params = make(map[string]string) + } + for key, valueList := range values { + if len(valueList) == 0 { + continue + } + value := valueList[0] + switch strings.ToLower(key) { + case "parsetime": + if parsed, err := strconv.ParseBool(value); err == nil { + cfg.ParseTime = parsed + } + case "loc": + if loc, err := time.LoadLocation(value); err == nil { + cfg.Loc = loc + } + case "collation": + cfg.Collation = value + default: + cfg.Params[key] = value + } + } +} + +func (c *MySQLClient) Close() error { + if c == nil || c.db == nil { + return nil + } + return c.db.Close() +} + +func (c *MySQLClient) QueryReadOnly(ctx context.Context, query string, args []any, maxRows int) (QueryResult, error) { + start := time.Now() + rows, err := c.db.QueryContext(ctx, query, args...) + if err != nil { + return QueryResult{}, err + } + defer rows.Close() + + columnTypes, err := rows.ColumnTypes() + if err != nil { + return QueryResult{}, err + } + + columns := make([]QueryColumn, 0, len(columnTypes)) + columnNames := make([]string, 0, len(columnTypes)) + for _, ct := range columnTypes { + var nullablePtr *bool + if nullable, ok := ct.Nullable(); ok { + n := nullable + nullablePtr = &n + } + scanType := "" + if st := ct.ScanType(); st != nil { + scanType = st.String() + } + columns = append(columns, QueryColumn{ + Name: ct.Name(), + DatabaseType: ct.DatabaseTypeName(), + Nullable: nullablePtr, + ScanType: scanType, + }) + columnNames = append(columnNames, ct.Name()) + } + + resultRows := make([]map[string]any, 0) + truncated := false + for rows.Next() { + if len(resultRows) >= maxRows { + truncated = true + break + } + scanned, err := scanRow(rows, len(columnNames)) + if err != nil { + return QueryResult{}, err + } + rowMap := make(map[string]any, len(columnNames)) + for i, name := range columnNames { + rowMap[name] = normalizeValue(scanned[i]) + } + resultRows = append(resultRows, rowMap) + } + if err := rows.Err(); err != nil { + return QueryResult{}, err + } + + return QueryResult{ + Columns: columns, + Rows: resultRows, + RowCount: len(resultRows), + Truncated: truncated, + DurationMs: time.Since(start).Milliseconds(), + }, nil +} + +func scanRow(rows *sql.Rows, size int) ([]any, error) { + dest := make([]any, size) + holders := make([]any, size) + for i := range dest { + holders[i] = &dest[i] + } + if err := rows.Scan(holders...); err != nil { + return nil, err + } + return dest, nil +} + +func normalizeValue(v any) any { + switch val := v.(type) { + case nil: + return nil + case []byte: + return string(val) + case time.Time: + return val.Format(time.RFC3339Nano) + default: + rv := reflect.ValueOf(v) + if rv.Kind() == reflect.Ptr && rv.IsNil() { + return nil + } + return v + } +} diff --git a/infra/smartflow-mcp-server/internal/store/redis.go b/infra/smartflow-mcp-server/internal/store/redis.go new file mode 100644 index 0000000..4472577 --- /dev/null +++ b/infra/smartflow-mcp-server/internal/store/redis.go @@ -0,0 +1,189 @@ +package store + +import ( + "context" + "fmt" + "strings" + "time" + + cfgpkg "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config" + "github.com/go-redis/redis/v8" +) + +type RedisClient struct { + client *redis.Client +} + +type RedisGetResult struct { + Exists bool `json:"exists"` + Key string `json:"key"` + Type string `json:"type"` + Value any `json:"value,omitempty"` + Truncated bool `json:"truncated"` + DurationMs int64 `json:"durationMs"` +} + +type RedisScanResult struct { + Pattern string `json:"pattern"` + Keys []string `json:"keys"` + Returned int `json:"returned"` + NextCursor uint64 `json:"nextCursor"` + Truncated bool `json:"truncated"` + DurationMs int64 `json:"durationMs"` +} + +func NewRedisClient(ctx context.Context, cfg cfgpkg.RedisConfig) (*RedisClient, error) { + client := redis.NewClient(&redis.Options{ + Addr: cfg.Addr, + Password: cfg.Password, + DB: cfg.DB, + }) + + pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + if err := client.Ping(pingCtx).Err(); err != nil { + _ = client.Close() + return nil, fmt.Errorf("ping redis: %w", err) + } + return &RedisClient{client: client}, nil +} + +func (c *RedisClient) Close() error { + if c == nil || c.client == nil { + return nil + } + return c.client.Close() +} + +func (c *RedisClient) GetWithType(ctx context.Context, key string, maxItems int, maxStringBytes int) (RedisGetResult, error) { + start := time.Now() + t, err := c.client.Type(ctx, key).Result() + if err != nil { + return RedisGetResult{}, err + } + if t == "none" { + return RedisGetResult{ + Exists: false, + Key: key, + Type: "none", + Truncated: false, + DurationMs: time.Since(start).Milliseconds(), + }, nil + } + + result := RedisGetResult{Exists: true, Key: key, Type: t} + switch t { + case "string": + v, err := c.client.Get(ctx, key).Result() + if err != nil { + return RedisGetResult{}, err + } + if len(v) > maxStringBytes { + result.Value = v[:maxStringBytes] + result.Truncated = true + } else { + result.Value = v + } + case "list": + vals, err := c.client.LRange(ctx, key, 0, int64(maxItems-1)).Result() + if err != nil { + return RedisGetResult{}, err + } + result.Value = vals + if length, _ := c.client.LLen(ctx, key).Result(); length > int64(maxItems) { + result.Truncated = true + } + case "set": + vals, err := c.client.SMembers(ctx, key).Result() + if err != nil { + return RedisGetResult{}, err + } + if len(vals) > maxItems { + result.Value = vals[:maxItems] + result.Truncated = true + } else { + result.Value = vals + } + case "zset": + vals, err := c.client.ZRangeWithScores(ctx, key, 0, int64(maxItems-1)).Result() + if err != nil { + return RedisGetResult{}, err + } + resultRows := make([]map[string]any, 0, len(vals)) + for _, item := range vals { + resultRows = append(resultRows, map[string]any{"member": item.Member, "score": item.Score}) + } + result.Value = resultRows + if length, _ := c.client.ZCard(ctx, key).Result(); length > int64(maxItems) { + result.Truncated = true + } + case "hash": + vals, err := c.client.HGetAll(ctx, key).Result() + if err != nil { + return RedisGetResult{}, err + } + if len(vals) <= maxItems { + result.Value = vals + } else { + trimmed := make(map[string]string, maxItems) + count := 0 + for k, v := range vals { + trimmed[k] = v + count++ + if count >= maxItems { + break + } + } + result.Value = trimmed + result.Truncated = true + } + default: + raw, err := c.client.Dump(ctx, key).Result() + if err != nil { + return RedisGetResult{}, err + } + result.Value = strings.ToUpper(fmt.Sprintf("UNSUPPORTED_TYPE_%s_DUMP_SIZE_%d", t, len(raw))) + } + result.DurationMs = time.Since(start).Milliseconds() + return result, nil +} + +func (c *RedisClient) ScanKeys(ctx context.Context, pattern string, count int64, maxKeys int) (RedisScanResult, error) { + start := time.Now() + if pattern == "" { + pattern = "*" + } + if count <= 0 { + count = 20 + } + + keys := make([]string, 0, maxKeys) + var cursor uint64 + truncated := false + for { + batch, nextCursor, err := c.client.Scan(ctx, cursor, pattern, count).Result() + if err != nil { + return RedisScanResult{}, err + } + for _, key := range batch { + if len(keys) >= maxKeys { + truncated = true + break + } + keys = append(keys, key) + } + cursor = nextCursor + if truncated || cursor == 0 { + break + } + } + + return RedisScanResult{ + Pattern: pattern, + Keys: keys, + Returned: len(keys), + NextCursor: cursor, + Truncated: truncated, + DurationMs: time.Since(start).Milliseconds(), + }, nil +} diff --git a/infra/smartflow-mcp-server/internal/tools/integration_test.go b/infra/smartflow-mcp-server/internal/tools/integration_test.go new file mode 100644 index 0000000..15d6a09 --- /dev/null +++ b/infra/smartflow-mcp-server/internal/tools/integration_test.go @@ -0,0 +1,95 @@ +package tools + +import ( + "context" + "os" + "strconv" + "testing" + "time" + + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config" + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security" + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store" +) + +func TestIntegrationMySQLReadOnlyTool(t *testing.T) { + if os.Getenv("MCP_IT_RUN") != "1" { + t.Skip("set MCP_IT_RUN=1 to run integration tests") + } + + port := 3306 + if p := os.Getenv("MYSQL_PORT"); p != "" { + if n, err := strconv.Atoi(p); err == nil { + port = n + } + } + + mysqlCfg := config.MySQLConfig{ + Host: os.Getenv("MYSQL_HOST"), + Port: port, + User: os.Getenv("MYSQL_USER"), + Password: os.Getenv("MYSQL_PASSWORD"), + Database: os.Getenv("MYSQL_DATABASE"), + Params: "charset=utf8mb4&parseTime=true&loc=Local", + } + if mysqlCfg.Host == "" || mysqlCfg.User == "" || mysqlCfg.Database == "" { + t.Skip("missing MYSQL_* env for integration test") + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + client, err := store.NewMySQLClient(ctx, mysqlCfg) + if err != nil { + t.Fatalf("mysql not available: %v", err) + } + defer func() { _ = client.Close() }() + + validator := security.NewSQLValidator(mysqlCfg.Database, false, nil, nil) + tool := NewMySQLReadOnlyTool(client, validator, 10) + res, err := tool.Execute(ctx, map[string]any{"sql": "SELECT 1 AS ok"}) + if err != nil { + t.Fatalf("tool execute failed: %v", err) + } + if res["rowCount"].(int) < 1 { + t.Fatalf("expected rowCount >= 1") + } +} + +func TestIntegrationRedisTools(t *testing.T) { + if os.Getenv("MCP_IT_RUN") != "1" { + t.Skip("set MCP_IT_RUN=1 to run integration tests") + } + + db := 0 + if p := os.Getenv("REDIS_DB"); p != "" { + if n, err := strconv.Atoi(p); err == nil { + db = n + } + } + redisCfg := config.RedisConfig{ + Addr: os.Getenv("REDIS_ADDR"), + Password: os.Getenv("REDIS_PASSWORD"), + DB: db, + } + if redisCfg.Addr == "" { + t.Skip("REDIS_ADDR is empty") + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + client, err := store.NewRedisClient(ctx, redisCfg) + if err != nil { + t.Fatalf("redis not available: %v", err) + } + defer func() { _ = client.Close() }() + + getTool := NewRedisGetTool(client, 10, 128) + if _, err := getTool.Execute(ctx, map[string]any{"key": "__integration_missing_key__"}); err != nil { + t.Fatalf("redis_get failed: %v", err) + } + + scanTool := NewRedisScanTool(client, 10, 10) + if _, err := scanTool.Execute(ctx, map[string]any{"pattern": "*", "count": float64(5)}); err != nil { + t.Fatalf("redis_scan failed: %v", err) + } +} diff --git a/infra/smartflow-mcp-server/internal/tools/mysql_readonly.go b/infra/smartflow-mcp-server/internal/tools/mysql_readonly.go new file mode 100644 index 0000000..e56d864 --- /dev/null +++ b/infra/smartflow-mcp-server/internal/tools/mysql_readonly.go @@ -0,0 +1,95 @@ +package tools + +import ( + "context" + "fmt" + + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security" + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store" +) + +type MySQLReadOnlyTool struct { + client *store.MySQLClient + validator *security.SQLValidator + maxRows int +} + +func NewMySQLReadOnlyTool(client *store.MySQLClient, validator *security.SQLValidator, maxRows int) *MySQLReadOnlyTool { + return &MySQLReadOnlyTool{client: client, validator: validator, maxRows: maxRows} +} + +func (t *MySQLReadOnlyTool) Name() string { + return "mysql_query_readonly" +} + +func (t *MySQLReadOnlyTool) Description() string { + return "Execute read-only SQL on MySQL. Only SELECT/SHOW/DESCRIBE/EXPLAIN are allowed." +} + +func (t *MySQLReadOnlyTool) InputSchema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "sql": map[string]any{ + "type": "string", + "description": "Read-only SQL statement", + }, + "params": map[string]any{ + "type": "array", + "description": "Optional bind parameters", + "items": map[string]any{ + "type": []string{"string", "number", "boolean", "null"}, + }, + }, + }, + "required": []string{"sql"}, + "additionalProperties": false, + } +} + +func (t *MySQLReadOnlyTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) { + rawSQL, ok := args["sql"].(string) + if !ok || rawSQL == "" { + return nil, fmt.Errorf("sql must be a non-empty string") + } + if err := t.validator.ValidateReadOnlySQL(rawSQL); err != nil { + return nil, err + } + + params, err := normalizeParams(args["params"]) + if err != nil { + return nil, err + } + + res, err := t.client.QueryReadOnly(ctx, rawSQL, params, t.maxRows) + if err != nil { + return nil, err + } + return map[string]any{ + "columns": res.Columns, + "rows": res.Rows, + "rowCount": res.RowCount, + "truncated": res.Truncated, + "durationMs": res.DurationMs, + }, nil +} + +func normalizeParams(raw any) ([]any, error) { + if raw == nil { + return nil, nil + } + arr, ok := raw.([]any) + if !ok { + return nil, fmt.Errorf("params must be an array") + } + out := make([]any, 0, len(arr)) + for _, item := range arr { + switch v := item.(type) { + case string, float64, bool, nil: + out = append(out, v) + default: + return nil, fmt.Errorf("params contains unsupported type") + } + } + return out, nil +} diff --git a/infra/smartflow-mcp-server/internal/tools/redis_tools.go b/infra/smartflow-mcp-server/internal/tools/redis_tools.go new file mode 100644 index 0000000..1d734d4 --- /dev/null +++ b/infra/smartflow-mcp-server/internal/tools/redis_tools.go @@ -0,0 +1,130 @@ +package tools + +import ( + "context" + "fmt" + + "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store" +) + +type RedisGetTool struct { + client *store.RedisClient + valueMaxItems int + maxStringBytes int +} + +func NewRedisGetTool(client *store.RedisClient, valueMaxItems int, maxStringBytes int) *RedisGetTool { + return &RedisGetTool{client: client, valueMaxItems: valueMaxItems, maxStringBytes: maxStringBytes} +} + +func (t *RedisGetTool) Name() string { + return "redis_get" +} + +func (t *RedisGetTool) Description() string { + return "Get a Redis key by name and return its type and value." +} + +func (t *RedisGetTool) InputSchema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "key": map[string]any{ + "type": "string", + "description": "Redis key", + }, + }, + "required": []string{"key"}, + "additionalProperties": false, + } +} + +func (t *RedisGetTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) { + key, ok := args["key"].(string) + if !ok || key == "" { + return nil, fmt.Errorf("key must be a non-empty string") + } + res, err := t.client.GetWithType(ctx, key, t.valueMaxItems, t.maxStringBytes) + if err != nil { + return nil, err + } + return map[string]any{ + "exists": res.Exists, + "key": res.Key, + "type": res.Type, + "value": res.Value, + "truncated": res.Truncated, + "durationMs": res.DurationMs, + }, nil +} + +type RedisScanTool struct { + client *store.RedisClient + maxKeys int + maxScanCount int +} + +func NewRedisScanTool(client *store.RedisClient, maxKeys int, maxScanCount int) *RedisScanTool { + return &RedisScanTool{client: client, maxKeys: maxKeys, maxScanCount: maxScanCount} +} + +func (t *RedisScanTool) Name() string { + return "redis_scan" +} + +func (t *RedisScanTool) Description() string { + return "Scan Redis keys by pattern with capped result size." +} + +func (t *RedisScanTool) InputSchema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "pattern": map[string]any{ + "type": "string", + "description": "Pattern, for example user:*", + }, + "count": map[string]any{ + "type": "number", + "description": "Optional scan count hint", + }, + }, + "required": []string{"pattern"}, + "additionalProperties": false, + } +} + +func (t *RedisScanTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) { + pattern, ok := args["pattern"].(string) + if !ok || pattern == "" { + return nil, fmt.Errorf("pattern must be a non-empty string") + } + + count := int64(20) + if rawCount, ok := args["count"]; ok { + number, ok := rawCount.(float64) + if !ok { + return nil, fmt.Errorf("count must be a number") + } + if number <= 0 { + return nil, fmt.Errorf("count must be > 0") + } + count = int64(number) + } + if count > int64(t.maxScanCount) { + count = int64(t.maxScanCount) + } + + res, err := t.client.ScanKeys(ctx, pattern, count, t.maxKeys) + if err != nil { + return nil, err + } + return map[string]any{ + "pattern": res.Pattern, + "keys": res.Keys, + "returned": res.Returned, + "nextCursor": res.NextCursor, + "truncated": res.Truncated, + "durationMs": res.DurationMs, + }, nil +} diff --git a/infra/smartflow-mcp-server/internal/tools/registry.go b/infra/smartflow-mcp-server/internal/tools/registry.go new file mode 100644 index 0000000..0e60325 --- /dev/null +++ b/infra/smartflow-mcp-server/internal/tools/registry.go @@ -0,0 +1,53 @@ +package tools + +import ( + "context" + "fmt" + "sort" +) + +type Tool interface { + Name() string + Description() string + InputSchema() map[string]any + Execute(ctx context.Context, args map[string]any) (map[string]any, error) +} + +type Registry struct { + tools map[string]Tool +} + +func NewRegistry(toolList ...Tool) (*Registry, error) { + r := &Registry{tools: make(map[string]Tool, len(toolList))} + for _, t := range toolList { + name := t.Name() + if _, exists := r.tools[name]; exists { + return nil, fmt.Errorf("duplicated tool name: %s", name) + } + r.tools[name] = t + } + return r, nil +} + +func (r *Registry) Find(name string) (Tool, bool) { + t, ok := r.tools[name] + return t, ok +} + +func (r *Registry) List() []map[string]any { + out := make([]map[string]any, 0, len(r.tools)) + names := make([]string, 0, len(r.tools)) + for name := range r.tools { + names = append(names, name) + } + sort.Strings(names) + for _, name := range names { + t := r.tools[name] + out = append(out, map[string]any{ + "name": t.Name(), + "description": t.Description(), + "inputSchema": t.InputSchema(), + }) + } + return out +}