Version: 0.4.4.dev.260307
feat: 🚀 增强会话管理与缓存机制 * 会话 ID 空值兜底,若 `conversation_id` 为空时自动生成 UUID * 在响应头写入 `X-Conversation-ID`,供前端使用,保持同一会话状态 perf: ⚡ 会话状态缓存优化 * 当缓存未命中但 DB 已确认/创建会话后,调用 `SetConversationStatus` 回写 Redis * 缓存写回失败时记录日志,不中断聊天主流程,确保业务流畅性 fix: 🐛 修复历史消息顺序问题与编译错误 * 修复历史消息顺序问题,保证返回的 N 条历史消息按时间正序喂给模型 * 通过反转 `created_at desc` 查询结果的切片,确保模型输入顺序正确 * 修复 `fmt.Errorf` 参数不匹配问题,修正编译错误 * 整理 `agent-cache.go` 为标准 UTF-8 编码,避免 Go 编译报错 `invalid UTF-8 encoding` feat: 🛠️ 独立构建 MCP 服务器 * 使用 `Codex` 构建独立于后端的 MCP 服务器,简化与 Codex 的协作 * 通过该服务器方便 Codex 直接测试和查看 Redis 与 MySQL 中的数据
This commit is contained in:
10
infra/smartflow-mcp-server/.editorconfig
Normal file
10
infra/smartflow-mcp-server/.editorconfig
Normal file
@@ -0,0 +1,10 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
charset = utf-8
|
||||
end_of_line = lf
|
||||
insert_final_newline = true
|
||||
trim_trailing_whitespace = true
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
||||
46
infra/smartflow-mcp-server/.env.example
Normal file
46
infra/smartflow-mcp-server/.env.example
Normal file
@@ -0,0 +1,46 @@
|
||||
# =========================
|
||||
# MCP server metadata
|
||||
# =========================
|
||||
MCP_SERVER_NAME=smartflow-mcp-server
|
||||
MCP_SERVER_VERSION=0.1.0
|
||||
MCP_PROTOCOL_VERSION=2024-11-05
|
||||
MCP_DEFAULT_CALLER=codex
|
||||
|
||||
# =========================
|
||||
# Governance & safety
|
||||
# =========================
|
||||
MCP_TOOL_TIMEOUT_MS=5000
|
||||
MCP_RATE_LIMIT_RPS=5
|
||||
MCP_RATE_LIMIT_BURST=10
|
||||
MCP_MAX_RESULT_ROWS=500
|
||||
MCP_ENFORCE_WHITELIST=false
|
||||
MCP_AUDIT_LOG_PATH=logs/audit.log
|
||||
|
||||
# Redis scan/value caps
|
||||
MCP_REDIS_SCAN_MAX_KEYS=200
|
||||
MCP_REDIS_SCAN_MAX_COUNT=200
|
||||
MCP_REDIS_VALUE_MAX_ITEMS=100
|
||||
MCP_REDIS_MAX_STRING_BYTES=4096
|
||||
|
||||
# =========================
|
||||
# MySQL
|
||||
# =========================
|
||||
MYSQL_HOST=127.0.0.1
|
||||
MYSQL_PORT=3306
|
||||
MYSQL_USER=readonly_user
|
||||
MYSQL_PASSWORD=replace_me
|
||||
MYSQL_DATABASE=smartflow
|
||||
MYSQL_PARAMS=charset=utf8mb4&parseTime=true&loc=Local
|
||||
|
||||
# Comma-separated whitelist (optional)
|
||||
# Example: MYSQL_ALLOWED_DATABASES=smartflow,analytics
|
||||
MYSQL_ALLOWED_DATABASES=smartflow
|
||||
# Example: MYSQL_ALLOWED_TABLES=smartflow.users,smartflow.tasks
|
||||
MYSQL_ALLOWED_TABLES=smartflow.users,smartflow.tasks
|
||||
|
||||
# =========================
|
||||
# Redis
|
||||
# =========================
|
||||
REDIS_ADDR=127.0.0.1:6379
|
||||
REDIS_PASSWORD=
|
||||
REDIS_DB=0
|
||||
2
infra/smartflow-mcp-server/.gitignore
vendored
Normal file
2
infra/smartflow-mcp-server/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
.env
|
||||
logs/*.log
|
||||
213
infra/smartflow-mcp-server/README.md
Normal file
213
infra/smartflow-mcp-server/README.md
Normal file
@@ -0,0 +1,213 @@
|
||||
# smartflow-mcp-server (MVP)
|
||||
|
||||
用于让 Codex 通过 MCP(stdio)只读访问 MySQL 与 Redis,面向接口联调与测试。
|
||||
|
||||
## 1. 功能范围(第一阶段)
|
||||
|
||||
只实现 3 个只读工具:
|
||||
|
||||
1. `mysql_query_readonly`
|
||||
2. `redis_get`
|
||||
3. `redis_scan`
|
||||
|
||||
未实现任何写操作工具。
|
||||
|
||||
## 2. 目录结构
|
||||
|
||||
```text
|
||||
infra/smartflow-mcp-server
|
||||
├─ cmd/server/main.go
|
||||
├─ internal
|
||||
│ ├─ audit
|
||||
│ ├─ config
|
||||
│ ├─ envutil
|
||||
│ ├─ mcp
|
||||
│ ├─ ratelimit
|
||||
│ ├─ security
|
||||
│ ├─ store
|
||||
│ └─ tools
|
||||
├─ .env.example
|
||||
├─ go.mod
|
||||
└─ README.md
|
||||
```
|
||||
|
||||
## 3. 快速启动
|
||||
|
||||
```bash
|
||||
go mod tidy
|
||||
go test ./...
|
||||
go run ./cmd/server
|
||||
```
|
||||
|
||||
服务采用 stdio MCP 协议,不会启动 HTTP 端口。
|
||||
|
||||
## 4. 配置说明(全部来自环境变量)
|
||||
|
||||
复制并编辑:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
关键变量:
|
||||
|
||||
- `MYSQL_HOST` / `MYSQL_PORT` / `MYSQL_USER` / `MYSQL_PASSWORD` / `MYSQL_DATABASE`
|
||||
- `REDIS_ADDR` / `REDIS_PASSWORD` / `REDIS_DB`
|
||||
- `MYSQL_ALLOWED_DATABASES`:逗号分隔
|
||||
- `MYSQL_ALLOWED_TABLES`:逗号分隔,支持 `db.table` 或 `table`
|
||||
- `MCP_ENFORCE_WHITELIST`:`true` 时无明确表引用会拒绝执行
|
||||
- `MCP_TOOL_TIMEOUT_MS`:单次工具调用超时
|
||||
- `MCP_RATE_LIMIT_RPS` + `MCP_RATE_LIMIT_BURST`:基础令牌桶限流
|
||||
- `MCP_MAX_RESULT_ROWS`:MySQL 最大返回行数
|
||||
- `MCP_REDIS_SCAN_MAX_KEYS`:`redis_scan` 最大返回 key 数
|
||||
- `MCP_AUDIT_LOG_PATH`:审计日志路径
|
||||
|
||||
## 5. 工具说明
|
||||
|
||||
### 5.1 `mysql_query_readonly`
|
||||
|
||||
输入:
|
||||
|
||||
```json
|
||||
{
|
||||
"sql": "SELECT id, name FROM users WHERE id = ?",
|
||||
"params": [1]
|
||||
}
|
||||
```
|
||||
|
||||
安全限制:
|
||||
|
||||
- 仅允许 `SELECT` / `SHOW` / `DESCRIBE` / `EXPLAIN`
|
||||
- 禁止分号 `;`(多语句)
|
||||
- 禁止注释 `--` / `#` / `/* */`
|
||||
- 禁止 DDL/DML 关键字(`INSERT`/`UPDATE`/`DELETE`/`ALTER`/`DROP`/`TRUNCATE` 等)
|
||||
- 支持库/表白名单校验
|
||||
|
||||
输出(结构化):
|
||||
|
||||
- `columns`
|
||||
- `rows`
|
||||
- `rowCount`
|
||||
- `truncated`
|
||||
- `durationMs`
|
||||
|
||||
### 5.2 `redis_get`
|
||||
|
||||
输入:
|
||||
|
||||
```json
|
||||
{
|
||||
"key": "user:1001"
|
||||
}
|
||||
```
|
||||
|
||||
输出:
|
||||
|
||||
- `exists`
|
||||
- `key`
|
||||
- `type`
|
||||
- `value`
|
||||
- `truncated`
|
||||
- `durationMs`
|
||||
|
||||
### 5.3 `redis_scan`
|
||||
|
||||
输入:
|
||||
|
||||
```json
|
||||
{
|
||||
"pattern": "user:*",
|
||||
"count": 50
|
||||
}
|
||||
```
|
||||
|
||||
输出:
|
||||
|
||||
- `pattern`
|
||||
- `keys`
|
||||
- `returned`
|
||||
- `nextCursor`
|
||||
- `truncated`
|
||||
- `durationMs`
|
||||
|
||||
## 6. 审计日志
|
||||
|
||||
每次工具调用会记录(JSON 行格式):
|
||||
|
||||
- 时间
|
||||
- 工具名
|
||||
- 调用方(caller)
|
||||
- 是否成功
|
||||
- 耗时
|
||||
- 脱敏后的输入摘要
|
||||
- 错误信息(截断)
|
||||
|
||||
敏感字段处理:
|
||||
|
||||
- SQL 字符串字面量与数字会脱敏
|
||||
- Redis key 仅保留前后少量字符
|
||||
|
||||
## 7. Codex MCP 配置示例(stdio)
|
||||
|
||||
可按客户端配置格式接入,示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"smartflow-db-readonly": {
|
||||
"command": "go",
|
||||
"args": ["run", "./cmd/server"],
|
||||
"cwd": "E:/SmartFlow-Agent/infra/smartflow-mcp-server",
|
||||
"env": {
|
||||
"MYSQL_HOST": "127.0.0.1",
|
||||
"MYSQL_PORT": "3306",
|
||||
"MYSQL_USER": "readonly_user",
|
||||
"MYSQL_PASSWORD": "replace_me",
|
||||
"MYSQL_DATABASE": "smartflow",
|
||||
"MYSQL_ALLOWED_DATABASES": "smartflow",
|
||||
"MYSQL_ALLOWED_TABLES": "smartflow.users,smartflow.tasks",
|
||||
"REDIS_ADDR": "127.0.0.1:6379",
|
||||
"REDIS_DB": "0",
|
||||
"MCP_TOOL_TIMEOUT_MS": "5000",
|
||||
"MCP_RATE_LIMIT_RPS": "5",
|
||||
"MCP_RATE_LIMIT_BURST": "10",
|
||||
"MCP_MAX_RESULT_ROWS": "500",
|
||||
"MCP_REDIS_SCAN_MAX_KEYS": "200",
|
||||
"MCP_AUDIT_LOG_PATH": "logs/audit.log"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 8. 安全限制生效示例
|
||||
|
||||
- SQL 多语句:`SELECT 1; SELECT 2` -> 被拒绝(semicolon is not allowed)
|
||||
- SQL 注释绕过:`SELECT * FROM users --x` -> 被拒绝(sql comments are not allowed)
|
||||
- 写操作:`DELETE FROM users` -> 被拒绝(dangerous sql keyword detected)
|
||||
- 白名单外表:`SELECT * FROM admin.secret` -> 被拒绝(table not in whitelist)
|
||||
- Redis 大范围扫描:`redis_scan` 返回数量受 `MCP_REDIS_SCAN_MAX_KEYS` 限制
|
||||
|
||||
## 9. 风险说明(MVP 已知边界)
|
||||
|
||||
1. SQL 校验采用关键字与模式匹配,不是完整 SQL AST 解析,建议二阶段引入 AST 级校验。
|
||||
2. `SHOW DATABASES` 等无显式表引用语句在非严格模式下可执行;生产建议开启 `MCP_ENFORCE_WHITELIST=true`。
|
||||
3. Redis 复杂类型返回做了截断保护,但仍建议在生产环境设置更小上限。
|
||||
|
||||
## 10. 常见问题(FAQ)
|
||||
|
||||
### Q1: 启动时报 `MYSQL_USER and MYSQL_DATABASE are required`
|
||||
|
||||
检查环境变量是否正确加载,建议先确认 `.env` 存在于 `infra/smartflow-mcp-server`。
|
||||
|
||||
### Q2: 为什么调用工具报限流
|
||||
|
||||
默认启用了令牌桶限流,调大 `MCP_RATE_LIMIT_RPS` 与 `MCP_RATE_LIMIT_BURST` 即可。
|
||||
|
||||
### Q3: 为什么 `redis_scan` 返回不全
|
||||
|
||||
是预期行为,结果数被 `MCP_REDIS_SCAN_MAX_KEYS` 限制,避免全量扫描拖垮 Redis。
|
||||
|
||||
### Q4: 审计日志在哪里
|
||||
|
||||
默认在 `logs/audit.log`,可用 `MCP_AUDIT_LOG_PATH` 自定义。
|
||||
90
infra/smartflow-mcp-server/cmd/server/main.go
Normal file
90
infra/smartflow-mcp-server/cmd/server/main.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/audit"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/envutil"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/mcp"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/ratelimit"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/tools"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := envutil.LoadDotEnv(".env"); err != nil {
|
||||
log.Fatalf("load .env failed: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadFromEnv()
|
||||
if err != nil {
|
||||
log.Fatalf("load config failed: %v", err)
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
auditLogger, err := audit.New(cfg.AuditLogPath)
|
||||
if err != nil {
|
||||
log.Fatalf("init audit logger failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = auditLogger.Close()
|
||||
}()
|
||||
|
||||
mysqlClient, err := store.NewMySQLClient(ctx, cfg.MySQL)
|
||||
if err != nil {
|
||||
log.Fatalf("init mysql failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = mysqlClient.Close()
|
||||
}()
|
||||
|
||||
redisClient, err := store.NewRedisClient(ctx, cfg.Redis)
|
||||
if err != nil {
|
||||
log.Fatalf("init redis failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = redisClient.Close()
|
||||
}()
|
||||
|
||||
sqlValidator := security.NewSQLValidator(
|
||||
cfg.MySQL.Database,
|
||||
cfg.EnforceWhitelist,
|
||||
cfg.MySQL.AllowedDatabases,
|
||||
cfg.MySQL.AllowedTables,
|
||||
)
|
||||
|
||||
registry, err := tools.NewRegistry(
|
||||
tools.NewMySQLReadOnlyTool(mysqlClient, sqlValidator, cfg.MaxResultRows),
|
||||
tools.NewRedisGetTool(redisClient, cfg.RedisValueMaxItems, cfg.RedisMaxStringBytes),
|
||||
tools.NewRedisScanTool(redisClient, cfg.RedisScanMaxKeys, cfg.RedisScanMaxCount),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatalf("init tool registry failed: %v", err)
|
||||
}
|
||||
|
||||
limiter := ratelimit.New(cfg.RateLimitRPS, cfg.RateLimitBurst)
|
||||
server := mcp.NewServer(
|
||||
os.Stdin,
|
||||
os.Stdout,
|
||||
registry,
|
||||
auditLogger,
|
||||
limiter,
|
||||
cfg.ServerName,
|
||||
cfg.ServerVersion,
|
||||
cfg.ProtocolVersion,
|
||||
cfg.DefaultCaller,
|
||||
cfg.ToolTimeout,
|
||||
)
|
||||
|
||||
if err := server.Serve(ctx); err != nil {
|
||||
log.Fatalf("mcp server exited with error: %v", err)
|
||||
}
|
||||
}
|
||||
14
infra/smartflow-mcp-server/go.mod
Normal file
14
infra/smartflow-mcp-server/go.mod
Normal file
@@ -0,0 +1,14 @@
|
||||
module github.com/LoveLosita/smartflow/infra/smartflow-mcp-server
|
||||
|
||||
go 1.23.4
|
||||
|
||||
require (
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/go-sql-driver/mysql v1.8.1
|
||||
)
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
)
|
||||
28
infra/smartflow-mcp-server/go.sum
Normal file
28
infra/smartflow-mcp-server/go.sum
Normal file
@@ -0,0 +1,28 @@
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
|
||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
||||
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
|
||||
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0=
|
||||
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
||||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM=
|
||||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
60
infra/smartflow-mcp-server/internal/audit/logger.go
Normal file
60
infra/smartflow-mcp-server/internal/audit/logger.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
mu sync.Mutex
|
||||
file *os.File
|
||||
}
|
||||
|
||||
type Record struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Tool string `json:"tool"`
|
||||
Caller string `json:"caller"`
|
||||
Success bool `json:"success"`
|
||||
DurationMs int64 `json:"duration_ms"`
|
||||
Meta map[string]any `json:"meta,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func New(path string) (*Logger, error) {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create audit dir: %w", err)
|
||||
}
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open audit log file: %w", err)
|
||||
}
|
||||
return &Logger{file: f}, nil
|
||||
}
|
||||
|
||||
func (l *Logger) Close() error {
|
||||
if l == nil || l.file == nil {
|
||||
return nil
|
||||
}
|
||||
return l.file.Close()
|
||||
}
|
||||
|
||||
func (l *Logger) Log(record Record) {
|
||||
if l == nil || l.file == nil {
|
||||
return
|
||||
}
|
||||
if record.Timestamp.IsZero() {
|
||||
record.Timestamp = time.Now()
|
||||
}
|
||||
body, err := json.Marshal(record)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
_, _ = l.file.Write(append(body, '\n'))
|
||||
}
|
||||
163
infra/smartflow-mcp-server/internal/config/config.go
Normal file
163
infra/smartflow-mcp-server/internal/config/config.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ServerName string
|
||||
ServerVersion string
|
||||
ProtocolVersion string
|
||||
DefaultCaller string
|
||||
ToolTimeout time.Duration
|
||||
RateLimitRPS float64
|
||||
RateLimitBurst float64
|
||||
MaxResultRows int
|
||||
AuditLogPath string
|
||||
EnforceWhitelist bool
|
||||
RedisScanMaxKeys int
|
||||
RedisScanMaxCount int
|
||||
RedisValueMaxItems int
|
||||
RedisMaxStringBytes int
|
||||
|
||||
MySQL MySQLConfig
|
||||
Redis RedisConfig
|
||||
}
|
||||
|
||||
type MySQLConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
Database string
|
||||
Params string
|
||||
AllowedDatabases []string
|
||||
AllowedTables []string
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
Addr string
|
||||
Password string
|
||||
DB int
|
||||
}
|
||||
|
||||
func LoadFromEnv() (Config, error) {
|
||||
cfg := Config{
|
||||
ServerName: getEnv("MCP_SERVER_NAME", "smartflow-mcp-server"),
|
||||
ServerVersion: getEnv("MCP_SERVER_VERSION", "0.1.0"),
|
||||
ProtocolVersion: getEnv("MCP_PROTOCOL_VERSION", "2024-11-05"),
|
||||
DefaultCaller: getEnv("MCP_DEFAULT_CALLER", "unknown"),
|
||||
ToolTimeout: getEnvDurationMS("MCP_TOOL_TIMEOUT_MS", 5000),
|
||||
RateLimitRPS: getEnvFloat("MCP_RATE_LIMIT_RPS", 5),
|
||||
RateLimitBurst: getEnvFloat("MCP_RATE_LIMIT_BURST", 10),
|
||||
MaxResultRows: getEnvInt("MCP_MAX_RESULT_ROWS", 500),
|
||||
AuditLogPath: getEnv("MCP_AUDIT_LOG_PATH", "logs/audit.log"),
|
||||
EnforceWhitelist: getEnvBool("MCP_ENFORCE_WHITELIST", false),
|
||||
RedisScanMaxKeys: getEnvInt("MCP_REDIS_SCAN_MAX_KEYS", 200),
|
||||
RedisScanMaxCount: getEnvInt("MCP_REDIS_SCAN_MAX_COUNT", 200),
|
||||
RedisValueMaxItems: getEnvInt("MCP_REDIS_VALUE_MAX_ITEMS", 100),
|
||||
RedisMaxStringBytes: getEnvInt("MCP_REDIS_MAX_STRING_BYTES", 4096),
|
||||
MySQL: MySQLConfig{
|
||||
Host: getEnv("MYSQL_HOST", "127.0.0.1"),
|
||||
Port: getEnvInt("MYSQL_PORT", 3306),
|
||||
User: getEnv("MYSQL_USER", ""),
|
||||
Password: getEnv("MYSQL_PASSWORD", ""),
|
||||
Database: getEnv("MYSQL_DATABASE", ""),
|
||||
Params: getEnv("MYSQL_PARAMS", "charset=utf8mb4&parseTime=true&loc=Local"),
|
||||
AllowedDatabases: splitCommaList(getEnv("MYSQL_ALLOWED_DATABASES", "")),
|
||||
AllowedTables: splitCommaList(getEnv("MYSQL_ALLOWED_TABLES", "")),
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Addr: getEnv("REDIS_ADDR", "127.0.0.1:6379"),
|
||||
Password: getEnv("REDIS_PASSWORD", ""),
|
||||
DB: getEnvInt("REDIS_DB", 0),
|
||||
},
|
||||
}
|
||||
|
||||
if cfg.MySQL.User == "" || cfg.MySQL.Database == "" {
|
||||
return Config{}, fmt.Errorf("MYSQL_USER and MYSQL_DATABASE are required")
|
||||
}
|
||||
if cfg.Redis.Addr == "" {
|
||||
return Config{}, fmt.Errorf("REDIS_ADDR is required")
|
||||
}
|
||||
if cfg.MaxResultRows <= 0 {
|
||||
return Config{}, fmt.Errorf("MCP_MAX_RESULT_ROWS must be > 0")
|
||||
}
|
||||
if cfg.RedisScanMaxKeys <= 0 {
|
||||
return Config{}, fmt.Errorf("MCP_REDIS_SCAN_MAX_KEYS must be > 0")
|
||||
}
|
||||
if cfg.RedisScanMaxCount <= 0 {
|
||||
return Config{}, fmt.Errorf("MCP_REDIS_SCAN_MAX_COUNT must be > 0")
|
||||
}
|
||||
if cfg.ToolTimeout <= 0 {
|
||||
return Config{}, fmt.Errorf("MCP_TOOL_TIMEOUT_MS must be > 0")
|
||||
}
|
||||
if cfg.RateLimitRPS <= 0 || cfg.RateLimitBurst <= 0 {
|
||||
return Config{}, fmt.Errorf("MCP_RATE_LIMIT_RPS and MCP_RATE_LIMIT_BURST must be > 0")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func getEnv(key string, defaultValue string) string {
|
||||
if v, ok := os.LookupEnv(key); ok {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvInt(key string, defaultValue int) int {
|
||||
v := getEnv(key, "")
|
||||
if v == "" {
|
||||
return defaultValue
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func getEnvFloat(key string, defaultValue float64) float64 {
|
||||
v := getEnv(key, "")
|
||||
if v == "" {
|
||||
return defaultValue
|
||||
}
|
||||
n, err := strconv.ParseFloat(v, 64)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func getEnvBool(key string, defaultValue bool) bool {
|
||||
v := strings.ToLower(getEnv(key, ""))
|
||||
if v == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return v == "1" || v == "true" || v == "yes" || v == "on"
|
||||
}
|
||||
|
||||
func getEnvDurationMS(key string, defaultValueMs int) time.Duration {
|
||||
ms := getEnvInt(key, defaultValueMs)
|
||||
return time.Duration(ms) * time.Millisecond
|
||||
}
|
||||
|
||||
func splitCommaList(raw string) []string {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(raw, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
trimmed := strings.TrimSpace(p)
|
||||
if trimmed != "" {
|
||||
out = append(out, strings.ToLower(trimmed))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
44
infra/smartflow-mcp-server/internal/envutil/loader.go
Normal file
44
infra/smartflow-mcp-server/internal/envutil/loader.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package envutil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func LoadDotEnv(path string) error {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("open .env: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
i := strings.Index(line, "=")
|
||||
if i <= 0 {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(line[:i])
|
||||
value := strings.TrimSpace(line[i+1:])
|
||||
value = strings.Trim(value, "\"")
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := os.LookupEnv(key); !exists {
|
||||
_ = os.Setenv(key, value)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("scan .env: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
394
infra/smartflow-mcp-server/internal/mcp/server.go
Normal file
394
infra/smartflow-mcp-server/internal/mcp/server.go
Normal file
@@ -0,0 +1,394 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/audit"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/ratelimit"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
jsonRPCVersion = "2.0"
|
||||
|
||||
errCodeParseError = -32700
|
||||
errCodeInvalidRequest = -32600
|
||||
errCodeMethodNotFound = -32601
|
||||
errCodeInvalidParams = -32602
|
||||
errCodeInternalError = -32603
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
reader *bufio.Reader
|
||||
writer io.Writer
|
||||
writeMu sync.Mutex
|
||||
registry *tools.Registry
|
||||
auditLogger *audit.Logger
|
||||
limiter *ratelimit.Limiter
|
||||
serverName string
|
||||
serverVersion string
|
||||
protocolVersion string
|
||||
defaultCaller string
|
||||
toolTimeout time.Duration
|
||||
}
|
||||
|
||||
type request struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id,omitempty"`
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type response struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id,omitempty"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error *respError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type respError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type initializeParams struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
}
|
||||
|
||||
type toolCallParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
Meta map[string]any `json:"_meta,omitempty"`
|
||||
}
|
||||
|
||||
func NewServer(
|
||||
in io.Reader,
|
||||
out io.Writer,
|
||||
registry *tools.Registry,
|
||||
auditLogger *audit.Logger,
|
||||
limiter *ratelimit.Limiter,
|
||||
serverName string,
|
||||
serverVersion string,
|
||||
protocolVersion string,
|
||||
defaultCaller string,
|
||||
toolTimeout time.Duration,
|
||||
) *Server {
|
||||
return &Server{
|
||||
reader: bufio.NewReader(in),
|
||||
writer: out,
|
||||
registry: registry,
|
||||
auditLogger: auditLogger,
|
||||
limiter: limiter,
|
||||
serverName: serverName,
|
||||
serverVersion: serverVersion,
|
||||
protocolVersion: protocolVersion,
|
||||
defaultCaller: defaultCaller,
|
||||
toolTimeout: toolTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Serve(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
body, err := readMessage(s.reader)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
log.Printf("read message failed: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
var req request
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
_ = s.writeError(nil, errCodeParseError, "invalid json")
|
||||
continue
|
||||
}
|
||||
if req.Method == "" || req.JSONRPC != jsonRPCVersion {
|
||||
_ = s.writeError(req.ID, errCodeInvalidRequest, "invalid json-rpc request")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.handleRequest(ctx, req); err != nil {
|
||||
log.Printf("handle request failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(ctx context.Context, req request) error {
|
||||
switch req.Method {
|
||||
case "initialize":
|
||||
return s.handleInitialize(req)
|
||||
case "notifications/initialized":
|
||||
return nil
|
||||
case "tools/list":
|
||||
if len(req.ID) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.writeResult(req.ID, map[string]any{"tools": s.registry.List()})
|
||||
case "tools/call":
|
||||
if len(req.ID) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.handleToolCall(ctx, req)
|
||||
case "ping":
|
||||
if len(req.ID) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.writeResult(req.ID, map[string]any{})
|
||||
default:
|
||||
if len(req.ID) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.writeError(req.ID, errCodeMethodNotFound, "method not found")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleInitialize(req request) error {
|
||||
if len(req.ID) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var params initializeParams
|
||||
if len(req.Params) > 0 {
|
||||
_ = json.Unmarshal(req.Params, ¶ms)
|
||||
}
|
||||
_ = params
|
||||
|
||||
return s.writeResult(req.ID, map[string]any{
|
||||
"protocolVersion": s.protocolVersion,
|
||||
"capabilities": map[string]any{
|
||||
"tools": map[string]any{"listChanged": false},
|
||||
},
|
||||
"serverInfo": map[string]any{
|
||||
"name": s.serverName,
|
||||
"version": s.serverVersion,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleToolCall(ctx context.Context, req request) error {
|
||||
var params toolCallParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
return s.writeError(req.ID, errCodeInvalidParams, "invalid tool call params")
|
||||
}
|
||||
if params.Name == "" {
|
||||
return s.writeError(req.ID, errCodeInvalidParams, "tool name is required")
|
||||
}
|
||||
if params.Arguments == nil {
|
||||
params.Arguments = map[string]any{}
|
||||
}
|
||||
|
||||
caller := extractCaller(params, s.defaultCaller)
|
||||
rateKey := fmt.Sprintf("%s:%s", caller, params.Name)
|
||||
if !s.limiter.Allow(rateKey) {
|
||||
result := buildToolErrorResult("rate limit exceeded")
|
||||
s.auditLogger.Log(audit.Record{
|
||||
Tool: params.Name,
|
||||
Caller: caller,
|
||||
Success: false,
|
||||
DurationMs: 0,
|
||||
Meta: sanitizeAuditInput(params.Name, params.Arguments),
|
||||
Error: "rate limit exceeded",
|
||||
})
|
||||
return s.writeResult(req.ID, result)
|
||||
}
|
||||
|
||||
tool, ok := s.registry.Find(params.Name)
|
||||
if !ok {
|
||||
result := buildToolErrorResult("tool not found")
|
||||
s.auditLogger.Log(audit.Record{
|
||||
Tool: params.Name,
|
||||
Caller: caller,
|
||||
Success: false,
|
||||
DurationMs: 0,
|
||||
Meta: sanitizeAuditInput(params.Name, params.Arguments),
|
||||
Error: "tool not found",
|
||||
})
|
||||
return s.writeResult(req.ID, result)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
toolCtx, cancel := context.WithTimeout(ctx, s.toolTimeout)
|
||||
defer cancel()
|
||||
|
||||
output, err := tool.Execute(toolCtx, params.Arguments)
|
||||
duration := time.Since(start).Milliseconds()
|
||||
if err != nil {
|
||||
errMsg := sanitizeError(err)
|
||||
s.auditLogger.Log(audit.Record{
|
||||
Tool: params.Name,
|
||||
Caller: caller,
|
||||
Success: false,
|
||||
DurationMs: duration,
|
||||
Meta: sanitizeAuditInput(params.Name, params.Arguments),
|
||||
Error: errMsg,
|
||||
})
|
||||
return s.writeResult(req.ID, buildToolErrorResult(errMsg))
|
||||
}
|
||||
|
||||
s.auditLogger.Log(audit.Record{
|
||||
Tool: params.Name,
|
||||
Caller: caller,
|
||||
Success: true,
|
||||
DurationMs: duration,
|
||||
Meta: sanitizeAuditInput(params.Name, params.Arguments),
|
||||
})
|
||||
return s.writeResult(req.ID, buildToolSuccessResult(output))
|
||||
}
|
||||
|
||||
func buildToolSuccessResult(output map[string]any) map[string]any {
|
||||
asJSON, _ := json.Marshal(output)
|
||||
return map[string]any{
|
||||
"content": []map[string]any{
|
||||
{"type": "text", "text": string(asJSON)},
|
||||
},
|
||||
"structuredContent": output,
|
||||
}
|
||||
}
|
||||
|
||||
func buildToolErrorResult(message string) map[string]any {
|
||||
return map[string]any{
|
||||
"content": []map[string]any{
|
||||
{"type": "text", "text": message},
|
||||
},
|
||||
"structuredContent": map[string]any{"error": message},
|
||||
"isError": true,
|
||||
}
|
||||
}
|
||||
|
||||
func extractCaller(params toolCallParams, fallback string) string {
|
||||
if params.Meta != nil {
|
||||
if v, ok := params.Meta["caller"]; ok {
|
||||
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
if v, ok := params.Arguments["caller"]; ok {
|
||||
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func sanitizeAuditInput(toolName string, args map[string]any) map[string]any {
|
||||
meta := map[string]any{}
|
||||
switch toolName {
|
||||
case "mysql_query_readonly":
|
||||
sql, _ := args["sql"].(string)
|
||||
meta["sql"] = security.RedactSQL(sql)
|
||||
if params, ok := args["params"].([]any); ok {
|
||||
meta["paramCount"] = len(params)
|
||||
}
|
||||
case "redis_get":
|
||||
if key, ok := args["key"].(string); ok {
|
||||
meta["key"] = security.RedactKey(key)
|
||||
}
|
||||
case "redis_scan":
|
||||
if pattern, ok := args["pattern"].(string); ok {
|
||||
if len(pattern) > 40 {
|
||||
meta["pattern"] = pattern[:40]
|
||||
} else {
|
||||
meta["pattern"] = pattern
|
||||
}
|
||||
}
|
||||
if count, ok := args["count"]; ok {
|
||||
meta["count"] = count
|
||||
}
|
||||
default:
|
||||
meta["args"] = "masked"
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
func sanitizeError(err error) string {
|
||||
msg := err.Error()
|
||||
msg = strings.ReplaceAll(msg, "\n", " ")
|
||||
msg = strings.ReplaceAll(msg, "\r", " ")
|
||||
if len(msg) > 300 {
|
||||
return msg[:300]
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func (s *Server) writeResult(id json.RawMessage, result any) error {
|
||||
resp := response{JSONRPC: jsonRPCVersion, ID: id, Result: result}
|
||||
return s.writeMessage(resp)
|
||||
}
|
||||
|
||||
func (s *Server) writeError(id json.RawMessage, code int, message string) error {
|
||||
resp := response{
|
||||
JSONRPC: jsonRPCVersion,
|
||||
ID: id,
|
||||
Error: &respError{Code: code, Message: message},
|
||||
}
|
||||
return s.writeMessage(resp)
|
||||
}
|
||||
|
||||
func (s *Server) writeMessage(payload any) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(body))
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
if _, err := s.writer.Write([]byte(header)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.writer.Write(body)
|
||||
return err
|
||||
}
|
||||
|
||||
func readMessage(reader *bufio.Reader) ([]byte, error) {
|
||||
contentLength := -1
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
if line == "" {
|
||||
break
|
||||
}
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
headerName := strings.TrimSpace(parts[0])
|
||||
headerValue := strings.TrimSpace(parts[1])
|
||||
if strings.EqualFold(headerName, "Content-Length") {
|
||||
n, err := strconv.Atoi(headerValue)
|
||||
if err != nil || n < 0 {
|
||||
return nil, fmt.Errorf("invalid Content-Length")
|
||||
}
|
||||
contentLength = n
|
||||
}
|
||||
}
|
||||
if contentLength < 0 {
|
||||
return nil, fmt.Errorf("missing Content-Length")
|
||||
}
|
||||
body := make([]byte, contentLength)
|
||||
if _, err := io.ReadFull(reader, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
56
infra/smartflow-mcp-server/internal/mcp/server_test.go
Normal file
56
infra/smartflow-mcp-server/internal/mcp/server_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/ratelimit"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/tools"
|
||||
)
|
||||
|
||||
type dummyTool struct{}
|
||||
|
||||
func (d dummyTool) Name() string { return "dummy" }
|
||||
func (d dummyTool) Description() string { return "dummy" }
|
||||
func (d dummyTool) InputSchema() map[string]any {
|
||||
return map[string]any{"type": "object"}
|
||||
}
|
||||
func (d dummyTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
}
|
||||
|
||||
func TestHandleInitialize(t *testing.T) {
|
||||
registry, err := tools.NewRegistry(dummyTool{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
out := bytes.NewBuffer(nil)
|
||||
s := NewServer(bytes.NewBuffer(nil), out, registry, nil, ratelimit.New(10, 10), "name", "ver", "2024-11-05", "unknown", time.Second)
|
||||
|
||||
req := request{JSONRPC: jsonRPCVersion, ID: json.RawMessage("1"), Method: "initialize", Params: json.RawMessage(`{"protocolVersion":"2024-11-05"}`)}
|
||||
if err := s.handleRequest(context.Background(), req); err != nil {
|
||||
t.Fatalf("handle initialize error: %v", err)
|
||||
}
|
||||
if !bytes.Contains(out.Bytes(), []byte("protocolVersion")) {
|
||||
t.Fatalf("response missing protocolVersion: %s", out.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadMessage(t *testing.T) {
|
||||
body := []byte(`{"jsonrpc":"2.0","id":1,"method":"ping"}`)
|
||||
frame := fmt.Sprintf("Content-Length: %d\r\n\r\n%s", len(body), string(body))
|
||||
input := bufio.NewReader(bytes.NewBufferString(frame))
|
||||
|
||||
msg, err := readMessage(input)
|
||||
if err != nil {
|
||||
t.Fatalf("read message failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(msg, body) {
|
||||
t.Fatalf("unexpected body: %s", string(msg))
|
||||
}
|
||||
}
|
||||
51
infra/smartflow-mcp-server/internal/ratelimit/limiter.go
Normal file
51
infra/smartflow-mcp-server/internal/ratelimit/limiter.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Limiter struct {
|
||||
mu sync.Mutex
|
||||
rate float64
|
||||
burst float64
|
||||
buckets map[string]*bucket
|
||||
}
|
||||
|
||||
type bucket struct {
|
||||
tokens float64
|
||||
last time.Time
|
||||
}
|
||||
|
||||
func New(rate, burst float64) *Limiter {
|
||||
return &Limiter{
|
||||
rate: rate,
|
||||
burst: burst,
|
||||
buckets: make(map[string]*bucket),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Limiter) Allow(key string) bool {
|
||||
now := time.Now()
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
b, ok := l.buckets[key]
|
||||
if !ok {
|
||||
l.buckets[key] = &bucket{tokens: l.burst - 1, last: now}
|
||||
return true
|
||||
}
|
||||
|
||||
elapsed := now.Sub(b.last).Seconds()
|
||||
b.tokens += elapsed * l.rate
|
||||
if b.tokens > l.burst {
|
||||
b.tokens = l.burst
|
||||
}
|
||||
b.last = now
|
||||
|
||||
if b.tokens < 1 {
|
||||
return false
|
||||
}
|
||||
b.tokens -= 1
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLimiter(t *testing.T) {
|
||||
l := New(2, 2)
|
||||
key := "user:tool"
|
||||
|
||||
if !l.Allow(key) {
|
||||
t.Fatal("first request should pass")
|
||||
}
|
||||
if !l.Allow(key) {
|
||||
t.Fatal("second request should pass")
|
||||
}
|
||||
if l.Allow(key) {
|
||||
t.Fatal("third request should be rate limited")
|
||||
}
|
||||
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
if !l.Allow(key) {
|
||||
t.Fatal("request should pass after token refill")
|
||||
}
|
||||
}
|
||||
27
infra/smartflow-mcp-server/internal/security/redact.go
Normal file
27
infra/smartflow-mcp-server/internal/security/redact.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
singleQuotedString = regexp.MustCompile(`'([^'\\]|\\.)*'`)
|
||||
doubleQuotedString = regexp.MustCompile(`"([^"\\]|\\.)*"`)
|
||||
numericLiteral = regexp.MustCompile(`\b\d+\b`)
|
||||
)
|
||||
|
||||
func RedactSQL(sql string) string {
|
||||
masked := singleQuotedString.ReplaceAllString(sql, "'***'")
|
||||
masked = doubleQuotedString.ReplaceAllString(masked, `"***"`)
|
||||
masked = numericLiteral.ReplaceAllString(masked, "?")
|
||||
return strings.TrimSpace(masked)
|
||||
}
|
||||
|
||||
func RedactKey(key string) string {
|
||||
key = strings.TrimSpace(key)
|
||||
if len(key) <= 4 {
|
||||
return "****"
|
||||
}
|
||||
return key[:2] + "***" + key[len(key)-2:]
|
||||
}
|
||||
193
infra/smartflow-mcp-server/internal/security/sql_validator.go
Normal file
193
infra/smartflow-mcp-server/internal/security/sql_validator.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
commentPattern = regexp.MustCompile(`(?s)/\*.*?\*/|--|#`)
|
||||
forbiddenWords = regexp.MustCompile(`(?i)\b(INSERT|UPDATE|DELETE|ALTER|DROP|TRUNCATE|CREATE|REPLACE|RENAME|GRANT|REVOKE|MERGE|CALL|EXEC|LOCK|UNLOCK|LOAD|OUTFILE|INFILE|HANDLER|SET|USE)\b`)
|
||||
fromJoinRef = regexp.MustCompile(`(?i)\b(?:FROM|JOIN)\s+([` + "`" + `"\w\.]+)`)
|
||||
describeRef = regexp.MustCompile(`(?i)\b(?:DESCRIBE|DESC)\s+([` + "`" + `"\w\.]+)`)
|
||||
showFromRef = regexp.MustCompile(`(?i)\bSHOW\b[\w\s]*\b(?:FROM|IN)\s+([` + "`" + `"\w\.]+)`)
|
||||
)
|
||||
|
||||
type SQLValidator struct {
|
||||
defaultDatabase string
|
||||
enforceWhitelist bool
|
||||
allowedDatabases map[string]struct{}
|
||||
allowedTables map[string]struct{}
|
||||
allowedTableOnly map[string]struct{}
|
||||
}
|
||||
|
||||
type sqlObjectRef struct {
|
||||
database string
|
||||
table string
|
||||
}
|
||||
|
||||
func NewSQLValidator(defaultDatabase string, enforceWhitelist bool, allowedDatabases []string, allowedTables []string) *SQLValidator {
|
||||
v := &SQLValidator{
|
||||
defaultDatabase: strings.ToLower(strings.TrimSpace(defaultDatabase)),
|
||||
enforceWhitelist: enforceWhitelist,
|
||||
allowedDatabases: make(map[string]struct{}),
|
||||
allowedTables: make(map[string]struct{}),
|
||||
allowedTableOnly: make(map[string]struct{}),
|
||||
}
|
||||
for _, db := range allowedDatabases {
|
||||
db = strings.ToLower(strings.TrimSpace(db))
|
||||
if db != "" {
|
||||
v.allowedDatabases[db] = struct{}{}
|
||||
}
|
||||
}
|
||||
for _, tbl := range allowedTables {
|
||||
tbl = strings.ToLower(strings.TrimSpace(strings.Trim(tbl, "`\"")))
|
||||
if tbl == "" {
|
||||
continue
|
||||
}
|
||||
v.allowedTables[tbl] = struct{}{}
|
||||
if dot := strings.LastIndex(tbl, "."); dot >= 0 && dot < len(tbl)-1 {
|
||||
v.allowedTableOnly[tbl[dot+1:]] = struct{}{}
|
||||
} else {
|
||||
v.allowedTableOnly[tbl] = struct{}{}
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *SQLValidator) ValidateReadOnlySQL(sql string) error {
|
||||
trimmed := strings.TrimSpace(sql)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("sql is required")
|
||||
}
|
||||
if strings.Contains(trimmed, ";") {
|
||||
return fmt.Errorf("semicolon is not allowed")
|
||||
}
|
||||
if commentPattern.MatchString(trimmed) {
|
||||
return fmt.Errorf("sql comments are not allowed")
|
||||
}
|
||||
|
||||
first := strings.ToUpper(firstToken(trimmed))
|
||||
switch first {
|
||||
case "SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN":
|
||||
default:
|
||||
return fmt.Errorf("only SELECT/SHOW/DESCRIBE/EXPLAIN are allowed")
|
||||
}
|
||||
|
||||
withoutLiterals := removeStringLiterals(trimmed)
|
||||
if forbiddenWords.MatchString(strings.ToUpper(withoutLiterals)) {
|
||||
return fmt.Errorf("dangerous sql keyword detected")
|
||||
}
|
||||
|
||||
refs := extractObjectRefs(withoutLiterals)
|
||||
if err := v.validateWhitelist(refs); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *SQLValidator) validateWhitelist(refs []sqlObjectRef) error {
|
||||
hasDBAllowlist := len(v.allowedDatabases) > 0
|
||||
hasTableAllowlist := len(v.allowedTables) > 0
|
||||
|
||||
if !v.enforceWhitelist && !hasDBAllowlist && !hasTableAllowlist {
|
||||
return nil
|
||||
}
|
||||
if len(refs) == 0 {
|
||||
if v.enforceWhitelist {
|
||||
return fmt.Errorf("sql does not contain explicit table reference under whitelist mode")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, ref := range refs {
|
||||
db := ref.database
|
||||
tbl := ref.table
|
||||
if db == "" {
|
||||
db = v.defaultDatabase
|
||||
}
|
||||
|
||||
if hasDBAllowlist {
|
||||
if db == "" {
|
||||
return fmt.Errorf("database is not explicit and no default database set")
|
||||
}
|
||||
if _, ok := v.allowedDatabases[db]; !ok {
|
||||
return fmt.Errorf("database %s is not in whitelist", db)
|
||||
}
|
||||
}
|
||||
|
||||
if hasTableAllowlist {
|
||||
full := tbl
|
||||
if db != "" {
|
||||
full = db + "." + tbl
|
||||
}
|
||||
if _, ok := v.allowedTables[full]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := v.allowedTableOnly[tbl]; ok {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("table %s is not in whitelist", full)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractObjectRefs(sql string) []sqlObjectRef {
|
||||
refs := make([]sqlObjectRef, 0)
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
for _, re := range []*regexp.Regexp{fromJoinRef, describeRef, showFromRef} {
|
||||
matches := re.FindAllStringSubmatch(sql, -1)
|
||||
for _, m := range matches {
|
||||
if len(m) < 2 {
|
||||
continue
|
||||
}
|
||||
ref := normalizeRef(m[1])
|
||||
if ref.table == "" {
|
||||
continue
|
||||
}
|
||||
key := ref.database + "." + ref.table
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
}
|
||||
return refs
|
||||
}
|
||||
|
||||
func normalizeRef(raw string) sqlObjectRef {
|
||||
clean := strings.ToLower(strings.TrimSpace(raw))
|
||||
clean = strings.Trim(clean, "`\"'(),")
|
||||
clean = strings.TrimSpace(clean)
|
||||
if clean == "" {
|
||||
return sqlObjectRef{}
|
||||
}
|
||||
parts := strings.Split(clean, ".")
|
||||
if len(parts) == 1 {
|
||||
return sqlObjectRef{table: parts[0]}
|
||||
}
|
||||
return sqlObjectRef{database: parts[0], table: parts[len(parts)-1]}
|
||||
}
|
||||
|
||||
func firstToken(sql string) string {
|
||||
for i, r := range sql {
|
||||
if r == ' ' || r == '\n' || r == '\t' || r == '\r' {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
return sql[:i]
|
||||
}
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
func removeStringLiterals(sql string) string {
|
||||
masked := singleQuotedString.ReplaceAllString(sql, "''")
|
||||
masked = doubleQuotedString.ReplaceAllString(masked, `""`)
|
||||
return masked
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package security
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestValidateReadOnlySQL(t *testing.T) {
|
||||
validator := NewSQLValidator("smartflow", true, []string{"smartflow"}, []string{"smartflow.users", "smartflow.tasks"})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sql string
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "allow select", sql: "SELECT id, name FROM users WHERE id = 1", wantErr: false},
|
||||
{name: "allow explain", sql: "EXPLAIN SELECT * FROM tasks", wantErr: false},
|
||||
{name: "reject insert", sql: "INSERT INTO users(name) VALUES('x')", wantErr: true},
|
||||
{name: "reject multi statement", sql: "SELECT * FROM users; SELECT * FROM tasks", wantErr: true},
|
||||
{name: "reject comment", sql: "SELECT * FROM users -- bypass", wantErr: true},
|
||||
{name: "reject not whitelisted table", sql: "SELECT * FROM orders", wantErr: true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := validator.ValidateReadOnlySQL(tc.sql)
|
||||
if tc.wantErr && err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
if !tc.wantErr && err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedact(t *testing.T) {
|
||||
masked := RedactSQL("SELECT * FROM users WHERE token='abc123' AND id=42")
|
||||
if masked == "" || masked == "SELECT * FROM users WHERE token='abc123' AND id=42" {
|
||||
t.Fatalf("redaction not applied: %s", masked)
|
||||
}
|
||||
|
||||
key := RedactKey("very-sensitive-key")
|
||||
if key == "very-sensitive-key" {
|
||||
t.Fatalf("key not redacted")
|
||||
}
|
||||
}
|
||||
198
infra/smartflow-mcp-server/internal/store/mysql.go
Normal file
198
infra/smartflow-mcp-server/internal/store/mysql.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cfgpkg "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
type MySQLClient struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
type QueryColumn struct {
|
||||
Name string `json:"name"`
|
||||
DatabaseType string `json:"databaseType"`
|
||||
Nullable *bool `json:"nullable,omitempty"`
|
||||
ScanType string `json:"scanType,omitempty"`
|
||||
}
|
||||
|
||||
type QueryResult struct {
|
||||
Columns []QueryColumn `json:"columns"`
|
||||
Rows []map[string]any `json:"rows"`
|
||||
RowCount int `json:"rowCount"`
|
||||
Truncated bool `json:"truncated"`
|
||||
DurationMs int64 `json:"durationMs"`
|
||||
}
|
||||
|
||||
func NewMySQLClient(ctx context.Context, cfg cfgpkg.MySQLConfig) (*MySQLClient, error) {
|
||||
dsn := mysql.Config{
|
||||
User: cfg.User,
|
||||
Passwd: cfg.Password,
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
Net: "tcp",
|
||||
DBName: cfg.Database,
|
||||
AllowNativePasswords: true,
|
||||
ParseTime: true,
|
||||
}
|
||||
|
||||
applyMySQLParams(&dsn, cfg.Params)
|
||||
|
||||
db, err := sql.Open("mysql", dsn.FormatDSN())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open mysql: %w", err)
|
||||
}
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
db.SetMaxOpenConns(5)
|
||||
db.SetMaxIdleConns(5)
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := db.PingContext(pingCtx); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("ping mysql: %w", err)
|
||||
}
|
||||
|
||||
return &MySQLClient{db: db}, nil
|
||||
}
|
||||
|
||||
func applyMySQLParams(cfg *mysql.Config, raw string) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return
|
||||
}
|
||||
values, err := url.ParseQuery(raw)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if cfg.Params == nil {
|
||||
cfg.Params = make(map[string]string)
|
||||
}
|
||||
for key, valueList := range values {
|
||||
if len(valueList) == 0 {
|
||||
continue
|
||||
}
|
||||
value := valueList[0]
|
||||
switch strings.ToLower(key) {
|
||||
case "parsetime":
|
||||
if parsed, err := strconv.ParseBool(value); err == nil {
|
||||
cfg.ParseTime = parsed
|
||||
}
|
||||
case "loc":
|
||||
if loc, err := time.LoadLocation(value); err == nil {
|
||||
cfg.Loc = loc
|
||||
}
|
||||
case "collation":
|
||||
cfg.Collation = value
|
||||
default:
|
||||
cfg.Params[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MySQLClient) Close() error {
|
||||
if c == nil || c.db == nil {
|
||||
return nil
|
||||
}
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func (c *MySQLClient) QueryReadOnly(ctx context.Context, query string, args []any, maxRows int) (QueryResult, error) {
|
||||
start := time.Now()
|
||||
rows, err := c.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return QueryResult{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columnTypes, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return QueryResult{}, err
|
||||
}
|
||||
|
||||
columns := make([]QueryColumn, 0, len(columnTypes))
|
||||
columnNames := make([]string, 0, len(columnTypes))
|
||||
for _, ct := range columnTypes {
|
||||
var nullablePtr *bool
|
||||
if nullable, ok := ct.Nullable(); ok {
|
||||
n := nullable
|
||||
nullablePtr = &n
|
||||
}
|
||||
scanType := ""
|
||||
if st := ct.ScanType(); st != nil {
|
||||
scanType = st.String()
|
||||
}
|
||||
columns = append(columns, QueryColumn{
|
||||
Name: ct.Name(),
|
||||
DatabaseType: ct.DatabaseTypeName(),
|
||||
Nullable: nullablePtr,
|
||||
ScanType: scanType,
|
||||
})
|
||||
columnNames = append(columnNames, ct.Name())
|
||||
}
|
||||
|
||||
resultRows := make([]map[string]any, 0)
|
||||
truncated := false
|
||||
for rows.Next() {
|
||||
if len(resultRows) >= maxRows {
|
||||
truncated = true
|
||||
break
|
||||
}
|
||||
scanned, err := scanRow(rows, len(columnNames))
|
||||
if err != nil {
|
||||
return QueryResult{}, err
|
||||
}
|
||||
rowMap := make(map[string]any, len(columnNames))
|
||||
for i, name := range columnNames {
|
||||
rowMap[name] = normalizeValue(scanned[i])
|
||||
}
|
||||
resultRows = append(resultRows, rowMap)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return QueryResult{}, err
|
||||
}
|
||||
|
||||
return QueryResult{
|
||||
Columns: columns,
|
||||
Rows: resultRows,
|
||||
RowCount: len(resultRows),
|
||||
Truncated: truncated,
|
||||
DurationMs: time.Since(start).Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func scanRow(rows *sql.Rows, size int) ([]any, error) {
|
||||
dest := make([]any, size)
|
||||
holders := make([]any, size)
|
||||
for i := range dest {
|
||||
holders[i] = &dest[i]
|
||||
}
|
||||
if err := rows.Scan(holders...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
func normalizeValue(v any) any {
|
||||
switch val := v.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case []byte:
|
||||
return string(val)
|
||||
case time.Time:
|
||||
return val.Format(time.RFC3339Nano)
|
||||
default:
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
}
|
||||
189
infra/smartflow-mcp-server/internal/store/redis.go
Normal file
189
infra/smartflow-mcp-server/internal/store/redis.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cfgpkg "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
type RedisClient struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
type RedisGetResult struct {
|
||||
Exists bool `json:"exists"`
|
||||
Key string `json:"key"`
|
||||
Type string `json:"type"`
|
||||
Value any `json:"value,omitempty"`
|
||||
Truncated bool `json:"truncated"`
|
||||
DurationMs int64 `json:"durationMs"`
|
||||
}
|
||||
|
||||
type RedisScanResult struct {
|
||||
Pattern string `json:"pattern"`
|
||||
Keys []string `json:"keys"`
|
||||
Returned int `json:"returned"`
|
||||
NextCursor uint64 `json:"nextCursor"`
|
||||
Truncated bool `json:"truncated"`
|
||||
DurationMs int64 `json:"durationMs"`
|
||||
}
|
||||
|
||||
func NewRedisClient(ctx context.Context, cfg cfgpkg.RedisConfig) (*RedisClient, error) {
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
})
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := client.Ping(pingCtx).Err(); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, fmt.Errorf("ping redis: %w", err)
|
||||
}
|
||||
return &RedisClient{client: client}, nil
|
||||
}
|
||||
|
||||
func (c *RedisClient) Close() error {
|
||||
if c == nil || c.client == nil {
|
||||
return nil
|
||||
}
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
func (c *RedisClient) GetWithType(ctx context.Context, key string, maxItems int, maxStringBytes int) (RedisGetResult, error) {
|
||||
start := time.Now()
|
||||
t, err := c.client.Type(ctx, key).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
if t == "none" {
|
||||
return RedisGetResult{
|
||||
Exists: false,
|
||||
Key: key,
|
||||
Type: "none",
|
||||
Truncated: false,
|
||||
DurationMs: time.Since(start).Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
result := RedisGetResult{Exists: true, Key: key, Type: t}
|
||||
switch t {
|
||||
case "string":
|
||||
v, err := c.client.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
if len(v) > maxStringBytes {
|
||||
result.Value = v[:maxStringBytes]
|
||||
result.Truncated = true
|
||||
} else {
|
||||
result.Value = v
|
||||
}
|
||||
case "list":
|
||||
vals, err := c.client.LRange(ctx, key, 0, int64(maxItems-1)).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
result.Value = vals
|
||||
if length, _ := c.client.LLen(ctx, key).Result(); length > int64(maxItems) {
|
||||
result.Truncated = true
|
||||
}
|
||||
case "set":
|
||||
vals, err := c.client.SMembers(ctx, key).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
if len(vals) > maxItems {
|
||||
result.Value = vals[:maxItems]
|
||||
result.Truncated = true
|
||||
} else {
|
||||
result.Value = vals
|
||||
}
|
||||
case "zset":
|
||||
vals, err := c.client.ZRangeWithScores(ctx, key, 0, int64(maxItems-1)).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
resultRows := make([]map[string]any, 0, len(vals))
|
||||
for _, item := range vals {
|
||||
resultRows = append(resultRows, map[string]any{"member": item.Member, "score": item.Score})
|
||||
}
|
||||
result.Value = resultRows
|
||||
if length, _ := c.client.ZCard(ctx, key).Result(); length > int64(maxItems) {
|
||||
result.Truncated = true
|
||||
}
|
||||
case "hash":
|
||||
vals, err := c.client.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
if len(vals) <= maxItems {
|
||||
result.Value = vals
|
||||
} else {
|
||||
trimmed := make(map[string]string, maxItems)
|
||||
count := 0
|
||||
for k, v := range vals {
|
||||
trimmed[k] = v
|
||||
count++
|
||||
if count >= maxItems {
|
||||
break
|
||||
}
|
||||
}
|
||||
result.Value = trimmed
|
||||
result.Truncated = true
|
||||
}
|
||||
default:
|
||||
raw, err := c.client.Dump(ctx, key).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
result.Value = strings.ToUpper(fmt.Sprintf("UNSUPPORTED_TYPE_%s_DUMP_SIZE_%d", t, len(raw)))
|
||||
}
|
||||
result.DurationMs = time.Since(start).Milliseconds()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *RedisClient) ScanKeys(ctx context.Context, pattern string, count int64, maxKeys int) (RedisScanResult, error) {
|
||||
start := time.Now()
|
||||
if pattern == "" {
|
||||
pattern = "*"
|
||||
}
|
||||
if count <= 0 {
|
||||
count = 20
|
||||
}
|
||||
|
||||
keys := make([]string, 0, maxKeys)
|
||||
var cursor uint64
|
||||
truncated := false
|
||||
for {
|
||||
batch, nextCursor, err := c.client.Scan(ctx, cursor, pattern, count).Result()
|
||||
if err != nil {
|
||||
return RedisScanResult{}, err
|
||||
}
|
||||
for _, key := range batch {
|
||||
if len(keys) >= maxKeys {
|
||||
truncated = true
|
||||
break
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
cursor = nextCursor
|
||||
if truncated || cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return RedisScanResult{
|
||||
Pattern: pattern,
|
||||
Keys: keys,
|
||||
Returned: len(keys),
|
||||
NextCursor: cursor,
|
||||
Truncated: truncated,
|
||||
DurationMs: time.Since(start).Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store"
|
||||
)
|
||||
|
||||
func TestIntegrationMySQLReadOnlyTool(t *testing.T) {
|
||||
if os.Getenv("MCP_IT_RUN") != "1" {
|
||||
t.Skip("set MCP_IT_RUN=1 to run integration tests")
|
||||
}
|
||||
|
||||
port := 3306
|
||||
if p := os.Getenv("MYSQL_PORT"); p != "" {
|
||||
if n, err := strconv.Atoi(p); err == nil {
|
||||
port = n
|
||||
}
|
||||
}
|
||||
|
||||
mysqlCfg := config.MySQLConfig{
|
||||
Host: os.Getenv("MYSQL_HOST"),
|
||||
Port: port,
|
||||
User: os.Getenv("MYSQL_USER"),
|
||||
Password: os.Getenv("MYSQL_PASSWORD"),
|
||||
Database: os.Getenv("MYSQL_DATABASE"),
|
||||
Params: "charset=utf8mb4&parseTime=true&loc=Local",
|
||||
}
|
||||
if mysqlCfg.Host == "" || mysqlCfg.User == "" || mysqlCfg.Database == "" {
|
||||
t.Skip("missing MYSQL_* env for integration test")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
client, err := store.NewMySQLClient(ctx, mysqlCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("mysql not available: %v", err)
|
||||
}
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
validator := security.NewSQLValidator(mysqlCfg.Database, false, nil, nil)
|
||||
tool := NewMySQLReadOnlyTool(client, validator, 10)
|
||||
res, err := tool.Execute(ctx, map[string]any{"sql": "SELECT 1 AS ok"})
|
||||
if err != nil {
|
||||
t.Fatalf("tool execute failed: %v", err)
|
||||
}
|
||||
if res["rowCount"].(int) < 1 {
|
||||
t.Fatalf("expected rowCount >= 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationRedisTools(t *testing.T) {
|
||||
if os.Getenv("MCP_IT_RUN") != "1" {
|
||||
t.Skip("set MCP_IT_RUN=1 to run integration tests")
|
||||
}
|
||||
|
||||
db := 0
|
||||
if p := os.Getenv("REDIS_DB"); p != "" {
|
||||
if n, err := strconv.Atoi(p); err == nil {
|
||||
db = n
|
||||
}
|
||||
}
|
||||
redisCfg := config.RedisConfig{
|
||||
Addr: os.Getenv("REDIS_ADDR"),
|
||||
Password: os.Getenv("REDIS_PASSWORD"),
|
||||
DB: db,
|
||||
}
|
||||
if redisCfg.Addr == "" {
|
||||
t.Skip("REDIS_ADDR is empty")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
client, err := store.NewRedisClient(ctx, redisCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("redis not available: %v", err)
|
||||
}
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
getTool := NewRedisGetTool(client, 10, 128)
|
||||
if _, err := getTool.Execute(ctx, map[string]any{"key": "__integration_missing_key__"}); err != nil {
|
||||
t.Fatalf("redis_get failed: %v", err)
|
||||
}
|
||||
|
||||
scanTool := NewRedisScanTool(client, 10, 10)
|
||||
if _, err := scanTool.Execute(ctx, map[string]any{"pattern": "*", "count": float64(5)}); err != nil {
|
||||
t.Fatalf("redis_scan failed: %v", err)
|
||||
}
|
||||
}
|
||||
95
infra/smartflow-mcp-server/internal/tools/mysql_readonly.go
Normal file
95
infra/smartflow-mcp-server/internal/tools/mysql_readonly.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store"
|
||||
)
|
||||
|
||||
type MySQLReadOnlyTool struct {
|
||||
client *store.MySQLClient
|
||||
validator *security.SQLValidator
|
||||
maxRows int
|
||||
}
|
||||
|
||||
func NewMySQLReadOnlyTool(client *store.MySQLClient, validator *security.SQLValidator, maxRows int) *MySQLReadOnlyTool {
|
||||
return &MySQLReadOnlyTool{client: client, validator: validator, maxRows: maxRows}
|
||||
}
|
||||
|
||||
func (t *MySQLReadOnlyTool) Name() string {
|
||||
return "mysql_query_readonly"
|
||||
}
|
||||
|
||||
func (t *MySQLReadOnlyTool) Description() string {
|
||||
return "Execute read-only SQL on MySQL. Only SELECT/SHOW/DESCRIBE/EXPLAIN are allowed."
|
||||
}
|
||||
|
||||
func (t *MySQLReadOnlyTool) InputSchema() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"sql": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Read-only SQL statement",
|
||||
},
|
||||
"params": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Optional bind parameters",
|
||||
"items": map[string]any{
|
||||
"type": []string{"string", "number", "boolean", "null"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": []string{"sql"},
|
||||
"additionalProperties": false,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MySQLReadOnlyTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
|
||||
rawSQL, ok := args["sql"].(string)
|
||||
if !ok || rawSQL == "" {
|
||||
return nil, fmt.Errorf("sql must be a non-empty string")
|
||||
}
|
||||
if err := t.validator.ValidateReadOnlySQL(rawSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params, err := normalizeParams(args["params"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := t.client.QueryReadOnly(ctx, rawSQL, params, t.maxRows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]any{
|
||||
"columns": res.Columns,
|
||||
"rows": res.Rows,
|
||||
"rowCount": res.RowCount,
|
||||
"truncated": res.Truncated,
|
||||
"durationMs": res.DurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeParams(raw any) ([]any, error) {
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
arr, ok := raw.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("params must be an array")
|
||||
}
|
||||
out := make([]any, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
switch v := item.(type) {
|
||||
case string, float64, bool, nil:
|
||||
out = append(out, v)
|
||||
default:
|
||||
return nil, fmt.Errorf("params contains unsupported type")
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
130
infra/smartflow-mcp-server/internal/tools/redis_tools.go
Normal file
130
infra/smartflow-mcp-server/internal/tools/redis_tools.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store"
|
||||
)
|
||||
|
||||
type RedisGetTool struct {
|
||||
client *store.RedisClient
|
||||
valueMaxItems int
|
||||
maxStringBytes int
|
||||
}
|
||||
|
||||
func NewRedisGetTool(client *store.RedisClient, valueMaxItems int, maxStringBytes int) *RedisGetTool {
|
||||
return &RedisGetTool{client: client, valueMaxItems: valueMaxItems, maxStringBytes: maxStringBytes}
|
||||
}
|
||||
|
||||
func (t *RedisGetTool) Name() string {
|
||||
return "redis_get"
|
||||
}
|
||||
|
||||
func (t *RedisGetTool) Description() string {
|
||||
return "Get a Redis key by name and return its type and value."
|
||||
}
|
||||
|
||||
func (t *RedisGetTool) InputSchema() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"key": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Redis key",
|
||||
},
|
||||
},
|
||||
"required": []string{"key"},
|
||||
"additionalProperties": false,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *RedisGetTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
|
||||
key, ok := args["key"].(string)
|
||||
if !ok || key == "" {
|
||||
return nil, fmt.Errorf("key must be a non-empty string")
|
||||
}
|
||||
res, err := t.client.GetWithType(ctx, key, t.valueMaxItems, t.maxStringBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]any{
|
||||
"exists": res.Exists,
|
||||
"key": res.Key,
|
||||
"type": res.Type,
|
||||
"value": res.Value,
|
||||
"truncated": res.Truncated,
|
||||
"durationMs": res.DurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type RedisScanTool struct {
|
||||
client *store.RedisClient
|
||||
maxKeys int
|
||||
maxScanCount int
|
||||
}
|
||||
|
||||
func NewRedisScanTool(client *store.RedisClient, maxKeys int, maxScanCount int) *RedisScanTool {
|
||||
return &RedisScanTool{client: client, maxKeys: maxKeys, maxScanCount: maxScanCount}
|
||||
}
|
||||
|
||||
func (t *RedisScanTool) Name() string {
|
||||
return "redis_scan"
|
||||
}
|
||||
|
||||
func (t *RedisScanTool) Description() string {
|
||||
return "Scan Redis keys by pattern with capped result size."
|
||||
}
|
||||
|
||||
func (t *RedisScanTool) InputSchema() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"pattern": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Pattern, for example user:*",
|
||||
},
|
||||
"count": map[string]any{
|
||||
"type": "number",
|
||||
"description": "Optional scan count hint",
|
||||
},
|
||||
},
|
||||
"required": []string{"pattern"},
|
||||
"additionalProperties": false,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *RedisScanTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
|
||||
pattern, ok := args["pattern"].(string)
|
||||
if !ok || pattern == "" {
|
||||
return nil, fmt.Errorf("pattern must be a non-empty string")
|
||||
}
|
||||
|
||||
count := int64(20)
|
||||
if rawCount, ok := args["count"]; ok {
|
||||
number, ok := rawCount.(float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("count must be a number")
|
||||
}
|
||||
if number <= 0 {
|
||||
return nil, fmt.Errorf("count must be > 0")
|
||||
}
|
||||
count = int64(number)
|
||||
}
|
||||
if count > int64(t.maxScanCount) {
|
||||
count = int64(t.maxScanCount)
|
||||
}
|
||||
|
||||
res, err := t.client.ScanKeys(ctx, pattern, count, t.maxKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]any{
|
||||
"pattern": res.Pattern,
|
||||
"keys": res.Keys,
|
||||
"returned": res.Returned,
|
||||
"nextCursor": res.NextCursor,
|
||||
"truncated": res.Truncated,
|
||||
"durationMs": res.DurationMs,
|
||||
}, nil
|
||||
}
|
||||
53
infra/smartflow-mcp-server/internal/tools/registry.go
Normal file
53
infra/smartflow-mcp-server/internal/tools/registry.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
)
|
||||
|
||||
type Tool interface {
|
||||
Name() string
|
||||
Description() string
|
||||
InputSchema() map[string]any
|
||||
Execute(ctx context.Context, args map[string]any) (map[string]any, error)
|
||||
}
|
||||
|
||||
type Registry struct {
|
||||
tools map[string]Tool
|
||||
}
|
||||
|
||||
func NewRegistry(toolList ...Tool) (*Registry, error) {
|
||||
r := &Registry{tools: make(map[string]Tool, len(toolList))}
|
||||
for _, t := range toolList {
|
||||
name := t.Name()
|
||||
if _, exists := r.tools[name]; exists {
|
||||
return nil, fmt.Errorf("duplicated tool name: %s", name)
|
||||
}
|
||||
r.tools[name] = t
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *Registry) Find(name string) (Tool, bool) {
|
||||
t, ok := r.tools[name]
|
||||
return t, ok
|
||||
}
|
||||
|
||||
func (r *Registry) List() []map[string]any {
|
||||
out := make([]map[string]any, 0, len(r.tools))
|
||||
names := make([]string, 0, len(r.tools))
|
||||
for name := range r.tools {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
for _, name := range names {
|
||||
t := r.tools[name]
|
||||
out = append(out, map[string]any{
|
||||
"name": t.Name(),
|
||||
"description": t.Description(),
|
||||
"inputSchema": t.InputSchema(),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
Reference in New Issue
Block a user