Version: 0.4.4.dev.260307

feat: 🚀 增强会话管理与缓存机制

* 会话 ID 空值兜底,若 `conversation_id` 为空时自动生成 UUID
* 在响应头写入 `X-Conversation-ID`,供前端使用,保持同一会话状态

perf:  会话状态缓存优化

* 当缓存未命中但 DB 已确认/创建会话后,调用 `SetConversationStatus` 回写 Redis
* 缓存写回失败时记录日志,不中断聊天主流程,确保业务流畅性

fix: 🐛 修复历史消息顺序问题与编译错误

* 修复历史消息顺序问题,保证返回的 N 条历史消息按时间正序喂给模型

  * 通过反转 `created_at desc` 查询结果的切片,确保模型输入顺序正确
* 修复 `fmt.Errorf` 参数不匹配问题,修正编译错误
* 整理 `agent-cache.go` 为标准 UTF-8 编码,避免 Go 编译报错 `invalid UTF-8 encoding`

feat: 🛠️ 独立构建 MCP 服务器

* 使用 `Codex` 构建独立于后端的 MCP 服务器,简化与 Codex 的协作
* 通过该服务器方便 Codex 直接测试和查看 Redis 与 MySQL 中的数据
This commit is contained in:
LoveLosita
2026-03-07 15:25:40 +08:00
parent 204e78d1fe
commit 26c350f378
27 changed files with 2274 additions and 17 deletions

View File

@@ -3,11 +3,13 @@ package api
import (
"io"
"net/http"
"strings"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/respond"
"github.com/LoveLosita/smartflow/backend/service"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type AgentHandler struct {
@@ -33,9 +35,18 @@ func (api *AgentHandler) ChatAgent(c *gin.Context) {
c.JSON(http.StatusBadRequest, respond.WrongParamType)
return
}
// 兼容:如果前端没传会话 ID后端兜底创建一个
conversationID := strings.TrimSpace(req.ConversationID)
if conversationID == "" {
conversationID = uuid.NewString()
}
// 把最终生效的会话 ID 回传给前端,方便后续继续同一会话
c.Writer.Header().Set("X-Conversation-ID", conversationID)
userID := c.GetInt("user_id") // 从上下文中获取用户 ID
// 3. 调用 Service 层的聊天方法,获取输出通道和错误通道
outChan, errChan := api.svc.AgentChat(c.Request.Context(), req.Message, req.Thinking, userID, req.ConversationID)
outChan, errChan := api.svc.AgentChat(c.Request.Context(), req.Message, req.Thinking, userID, conversationID)
// 4. 循环转发消息/错误
c.Stream(func(w io.Writer) bool {
select {

View File

@@ -95,7 +95,7 @@ func (m *AgentCache) BackfillHistory(ctx context.Context, sessionID string, mess
for i, msg := range messages {
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshal failed at index %d: %w", err)
return fmt.Errorf("marshal failed at index %d: %w", i, err)
}
values[i] = data
}

View File

@@ -44,10 +44,18 @@ func (a *AgentDAO) CreateNewChat(userID int, chatID string) (int64, error) {
func (a *AgentDAO) GetUserChatHistories(ctx context.Context, userID, limit int, chatID string) ([]model.ChatHistory, error) {
var histories []model.ChatHistory
err := a.db.WithContext(ctx).Where("user_id = ? AND chat_id = ?", userID, chatID).Order("created_at desc").Limit(limit).Find(&histories).Error
err := a.db.WithContext(ctx).
Where("user_id = ? AND chat_id = ?", userID, chatID).
Order("created_at desc").
Limit(limit).
Find(&histories).Error
if err != nil {
return nil, err
}
// 保留“最近 N 条”的前提下,反转为时间正序,便于模型消费
for i, j := 0, len(histories)-1; i < j; i, j = i+1, j-1 {
histories[i], histories[j] = histories[j], histories[i]
}
return histories, nil
}

View File

@@ -3,12 +3,14 @@ package service
import (
"context"
"log"
"strings"
"github.com/LoveLosita/smartflow/backend/agent"
"github.com/LoveLosita/smartflow/backend/conv"
"github.com/LoveLosita/smartflow/backend/dao"
"github.com/LoveLosita/smartflow/backend/inits"
"github.com/cloudwego/eino/schema"
"github.com/google/uuid"
)
type AgentService struct {
@@ -25,10 +27,21 @@ func NewAgentService(aiHub *inits.AIHub, repo *dao.AgentDAO, agentRedis *dao.Age
}
}
func normalizeConversationID(chatID string) string {
trimmed := strings.TrimSpace(chatID)
if trimmed == "" {
return uuid.NewString()
}
return trimmed
}
func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThinking bool, userID int, chatID string) (<-chan string, <-chan error) {
//1. 创建一个输出通道
outChan := make(chan string, 5)
errChan := make(chan error, 1)
//补充:会话 ID 兜底,避免上层漏传
chatID = normalizeConversationID(chatID)
//2. 先确保这个会话存在(如果不存在就创建一个新的)
//先看看缓存里面有没有这个会话
result, err := s.agentCache.GetConversationStatus(ctx, chatID)
@@ -49,20 +62,23 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin
}
if !innerResult {
//如果会话不存在,先创建一个新的会话
_, err := s.repo.CreateNewChat(userID, chatID)
if err != nil {
if _, err = s.repo.CreateNewChat(userID, chatID); err != nil {
errChan <- err
close(outChan)
close(errChan)
return outChan, errChan
}
}
//补充:把“会话存在”状态回写缓存,后续请求可直接命中
if err = s.agentCache.SetConversationStatus(ctx, chatID); err != nil {
//缓存回写失败不影响主流程
log.Printf("failed to set conversation status cache for %s: %v", chatID, err)
}
}
//能走到这里,要么缓存里有这个会话,要么数据库里有这个会话了
//4. 提取出历史消息,构建上下文
//先尝试从缓存里拿历史消息
var chatHistory []*schema.Message
chatHistory, err = s.agentCache.GetHistory(ctx, chatID)
chatHistory, err := s.agentCache.GetHistory(ctx, chatID)
if err != nil {
errChan <- err
close(outChan)
@@ -82,8 +98,7 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin
//再转换成 Eino 的消息格式
chatHistory = conv.ToEinoMessages(histories)
//把历史消息放到缓存里,方便下次直接拿
err = s.agentCache.BackfillHistory(ctx, chatID, chatHistory)
if err != nil {
if err = s.agentCache.BackfillHistory(ctx, chatID, chatHistory); err != nil {
errChan <- err
close(outChan)
close(errChan)
@@ -93,16 +108,18 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin
//3. 将用户消息异步落缓存和库
go func() {
//这里先不管落库成功与否了,毕竟不想因为落库失败而影响用户的聊天体验
_ = s.agentCache.PushMessage(ctx, chatID, &schema.Message{
Role: "user",
bg := context.Background()
_ = s.agentCache.PushMessage(bg, chatID, &schema.Message{
Role: schema.User,
Content: userMessage,
})
_ = s.repo.SaveChatHistory(ctx, userID, chatID, "user", userMessage)
_ = s.repo.SaveChatHistory(bg, userID, chatID, "user", userMessage)
}()
//5. 启动一个 goroutine 来处理聊天逻辑
go func() {
defer close(outChan) // 确保在函数结束时关闭通道
defer close(errChan)
//3. 调用 StreamChat 函数进行流式聊天
fullText, err := agent.StreamChat(ctx, s.AIHub.Worker, userMessage, ifThinking, chatHistory, outChan)
if err != nil {
@@ -111,13 +128,13 @@ func (s *AgentService) AgentChat(ctx context.Context, userMessage string, ifThin
}
//4. 将 AI 的回复异步落缓存和库
go func() {
_ = s.agentCache.PushMessage(ctx, chatID, &schema.Message{
Role: "assistant",
bg := context.Background()
_ = s.agentCache.PushMessage(bg, chatID, &schema.Message{
Role: schema.Assistant,
Content: fullText,
})
err = s.repo.SaveChatHistory(context.Background(), userID, chatID, "assistant", fullText)
if err != nil {
log.Printf("Failed to save chat history to database: %v", err)
if saveErr := s.repo.SaveChatHistory(bg, userID, chatID, "assistant", fullText); saveErr != nil {
log.Printf("failed to save chat history to database: %v", saveErr)
return
}
}()

View File

@@ -0,0 +1,10 @@
root = true
[*]
charset = utf-8
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true
[*.md]
trim_trailing_whitespace = false

View File

@@ -0,0 +1,46 @@
# =========================
# MCP server metadata
# =========================
MCP_SERVER_NAME=smartflow-mcp-server
MCP_SERVER_VERSION=0.1.0
MCP_PROTOCOL_VERSION=2024-11-05
MCP_DEFAULT_CALLER=codex
# =========================
# Governance & safety
# =========================
MCP_TOOL_TIMEOUT_MS=5000
MCP_RATE_LIMIT_RPS=5
MCP_RATE_LIMIT_BURST=10
MCP_MAX_RESULT_ROWS=500
MCP_ENFORCE_WHITELIST=false
MCP_AUDIT_LOG_PATH=logs/audit.log
# Redis scan/value caps
MCP_REDIS_SCAN_MAX_KEYS=200
MCP_REDIS_SCAN_MAX_COUNT=200
MCP_REDIS_VALUE_MAX_ITEMS=100
MCP_REDIS_MAX_STRING_BYTES=4096
# =========================
# MySQL
# =========================
MYSQL_HOST=127.0.0.1
MYSQL_PORT=3306
MYSQL_USER=readonly_user
MYSQL_PASSWORD=replace_me
MYSQL_DATABASE=smartflow
MYSQL_PARAMS=charset=utf8mb4&parseTime=true&loc=Local
# Comma-separated whitelist (optional)
# Example: MYSQL_ALLOWED_DATABASES=smartflow,analytics
MYSQL_ALLOWED_DATABASES=smartflow
# Example: MYSQL_ALLOWED_TABLES=smartflow.users,smartflow.tasks
MYSQL_ALLOWED_TABLES=smartflow.users,smartflow.tasks
# =========================
# Redis
# =========================
REDIS_ADDR=127.0.0.1:6379
REDIS_PASSWORD=
REDIS_DB=0

2
infra/smartflow-mcp-server/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
.env
logs/*.log

View File

@@ -0,0 +1,213 @@
# smartflow-mcp-server (MVP)
用于让 Codex 通过 MCPstdio只读访问 MySQL 与 Redis面向接口联调与测试。
## 1. 功能范围(第一阶段)
只实现 3 个只读工具:
1. `mysql_query_readonly`
2. `redis_get`
3. `redis_scan`
未实现任何写操作工具。
## 2. 目录结构
```text
infra/smartflow-mcp-server
├─ cmd/server/main.go
├─ internal
│ ├─ audit
│ ├─ config
│ ├─ envutil
│ ├─ mcp
│ ├─ ratelimit
│ ├─ security
│ ├─ store
│ └─ tools
├─ .env.example
├─ go.mod
└─ README.md
```
## 3. 快速启动
```bash
go mod tidy
go test ./...
go run ./cmd/server
```
服务采用 stdio MCP 协议,不会启动 HTTP 端口。
## 4. 配置说明(全部来自环境变量)
复制并编辑:
```bash
cp .env.example .env
```
关键变量:
- `MYSQL_HOST` / `MYSQL_PORT` / `MYSQL_USER` / `MYSQL_PASSWORD` / `MYSQL_DATABASE`
- `REDIS_ADDR` / `REDIS_PASSWORD` / `REDIS_DB`
- `MYSQL_ALLOWED_DATABASES`:逗号分隔
- `MYSQL_ALLOWED_TABLES`:逗号分隔,支持 `db.table``table`
- `MCP_ENFORCE_WHITELIST``true` 时无明确表引用会拒绝执行
- `MCP_TOOL_TIMEOUT_MS`:单次工具调用超时
- `MCP_RATE_LIMIT_RPS` + `MCP_RATE_LIMIT_BURST`:基础令牌桶限流
- `MCP_MAX_RESULT_ROWS`MySQL 最大返回行数
- `MCP_REDIS_SCAN_MAX_KEYS``redis_scan` 最大返回 key 数
- `MCP_AUDIT_LOG_PATH`:审计日志路径
## 5. 工具说明
### 5.1 `mysql_query_readonly`
输入:
```json
{
"sql": "SELECT id, name FROM users WHERE id = ?",
"params": [1]
}
```
安全限制:
- 仅允许 `SELECT` / `SHOW` / `DESCRIBE` / `EXPLAIN`
- 禁止分号 `;`(多语句)
- 禁止注释 `--` / `#` / `/* */`
- 禁止 DDL/DML 关键字(`INSERT`/`UPDATE`/`DELETE`/`ALTER`/`DROP`/`TRUNCATE` 等)
- 支持库/表白名单校验
输出(结构化):
- `columns`
- `rows`
- `rowCount`
- `truncated`
- `durationMs`
### 5.2 `redis_get`
输入:
```json
{
"key": "user:1001"
}
```
输出:
- `exists`
- `key`
- `type`
- `value`
- `truncated`
- `durationMs`
### 5.3 `redis_scan`
输入:
```json
{
"pattern": "user:*",
"count": 50
}
```
输出:
- `pattern`
- `keys`
- `returned`
- `nextCursor`
- `truncated`
- `durationMs`
## 6. 审计日志
每次工具调用会记录JSON 行格式):
- 时间
- 工具名
- 调用方caller
- 是否成功
- 耗时
- 脱敏后的输入摘要
- 错误信息(截断)
敏感字段处理:
- SQL 字符串字面量与数字会脱敏
- Redis key 仅保留前后少量字符
## 7. Codex MCP 配置示例stdio
可按客户端配置格式接入,示例:
```json
{
"mcpServers": {
"smartflow-db-readonly": {
"command": "go",
"args": ["run", "./cmd/server"],
"cwd": "E:/SmartFlow-Agent/infra/smartflow-mcp-server",
"env": {
"MYSQL_HOST": "127.0.0.1",
"MYSQL_PORT": "3306",
"MYSQL_USER": "readonly_user",
"MYSQL_PASSWORD": "replace_me",
"MYSQL_DATABASE": "smartflow",
"MYSQL_ALLOWED_DATABASES": "smartflow",
"MYSQL_ALLOWED_TABLES": "smartflow.users,smartflow.tasks",
"REDIS_ADDR": "127.0.0.1:6379",
"REDIS_DB": "0",
"MCP_TOOL_TIMEOUT_MS": "5000",
"MCP_RATE_LIMIT_RPS": "5",
"MCP_RATE_LIMIT_BURST": "10",
"MCP_MAX_RESULT_ROWS": "500",
"MCP_REDIS_SCAN_MAX_KEYS": "200",
"MCP_AUDIT_LOG_PATH": "logs/audit.log"
}
}
}
}
```
## 8. 安全限制生效示例
- SQL 多语句:`SELECT 1; SELECT 2` -> 被拒绝semicolon is not allowed
- SQL 注释绕过:`SELECT * FROM users --x` -> 被拒绝sql comments are not allowed
- 写操作:`DELETE FROM users` -> 被拒绝dangerous sql keyword detected
- 白名单外表:`SELECT * FROM admin.secret` -> 被拒绝table not in whitelist
- Redis 大范围扫描:`redis_scan` 返回数量受 `MCP_REDIS_SCAN_MAX_KEYS` 限制
## 9. 风险说明MVP 已知边界)
1. SQL 校验采用关键字与模式匹配,不是完整 SQL AST 解析,建议二阶段引入 AST 级校验。
2. `SHOW DATABASES` 等无显式表引用语句在非严格模式下可执行;生产建议开启 `MCP_ENFORCE_WHITELIST=true`
3. Redis 复杂类型返回做了截断保护,但仍建议在生产环境设置更小上限。
## 10. 常见问题FAQ
### Q1: 启动时报 `MYSQL_USER and MYSQL_DATABASE are required`
检查环境变量是否正确加载,建议先确认 `.env` 存在于 `infra/smartflow-mcp-server`
### Q2: 为什么调用工具报限流
默认启用了令牌桶限流,调大 `MCP_RATE_LIMIT_RPS``MCP_RATE_LIMIT_BURST` 即可。
### Q3: 为什么 `redis_scan` 返回不全
是预期行为,结果数被 `MCP_REDIS_SCAN_MAX_KEYS` 限制,避免全量扫描拖垮 Redis。
### Q4: 审计日志在哪里
默认在 `logs/audit.log`,可用 `MCP_AUDIT_LOG_PATH` 自定义。

View File

@@ -0,0 +1,90 @@
package main
import (
"context"
"log"
"os"
"os/signal"
"syscall"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/audit"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/envutil"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/mcp"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/ratelimit"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/tools"
)
func main() {
if err := envutil.LoadDotEnv(".env"); err != nil {
log.Fatalf("load .env failed: %v", err)
}
cfg, err := config.LoadFromEnv()
if err != nil {
log.Fatalf("load config failed: %v", err)
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
auditLogger, err := audit.New(cfg.AuditLogPath)
if err != nil {
log.Fatalf("init audit logger failed: %v", err)
}
defer func() {
_ = auditLogger.Close()
}()
mysqlClient, err := store.NewMySQLClient(ctx, cfg.MySQL)
if err != nil {
log.Fatalf("init mysql failed: %v", err)
}
defer func() {
_ = mysqlClient.Close()
}()
redisClient, err := store.NewRedisClient(ctx, cfg.Redis)
if err != nil {
log.Fatalf("init redis failed: %v", err)
}
defer func() {
_ = redisClient.Close()
}()
sqlValidator := security.NewSQLValidator(
cfg.MySQL.Database,
cfg.EnforceWhitelist,
cfg.MySQL.AllowedDatabases,
cfg.MySQL.AllowedTables,
)
registry, err := tools.NewRegistry(
tools.NewMySQLReadOnlyTool(mysqlClient, sqlValidator, cfg.MaxResultRows),
tools.NewRedisGetTool(redisClient, cfg.RedisValueMaxItems, cfg.RedisMaxStringBytes),
tools.NewRedisScanTool(redisClient, cfg.RedisScanMaxKeys, cfg.RedisScanMaxCount),
)
if err != nil {
log.Fatalf("init tool registry failed: %v", err)
}
limiter := ratelimit.New(cfg.RateLimitRPS, cfg.RateLimitBurst)
server := mcp.NewServer(
os.Stdin,
os.Stdout,
registry,
auditLogger,
limiter,
cfg.ServerName,
cfg.ServerVersion,
cfg.ProtocolVersion,
cfg.DefaultCaller,
cfg.ToolTimeout,
)
if err := server.Serve(ctx); err != nil {
log.Fatalf("mcp server exited with error: %v", err)
}
}

View File

@@ -0,0 +1,14 @@
module github.com/LoveLosita/smartflow/infra/smartflow-mcp-server
go 1.23.4
require (
github.com/go-redis/redis/v8 v8.11.5
github.com/go-sql-driver/mysql v1.8.1
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
)

View File

@@ -0,0 +1,28 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=

View File

@@ -0,0 +1,60 @@
package audit
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
)
type Logger struct {
mu sync.Mutex
file *os.File
}
type Record struct {
Timestamp time.Time `json:"timestamp"`
Tool string `json:"tool"`
Caller string `json:"caller"`
Success bool `json:"success"`
DurationMs int64 `json:"duration_ms"`
Meta map[string]any `json:"meta,omitempty"`
Error string `json:"error,omitempty"`
}
func New(path string) (*Logger, error) {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("create audit dir: %w", err)
}
f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil {
return nil, fmt.Errorf("open audit log file: %w", err)
}
return &Logger{file: f}, nil
}
func (l *Logger) Close() error {
if l == nil || l.file == nil {
return nil
}
return l.file.Close()
}
func (l *Logger) Log(record Record) {
if l == nil || l.file == nil {
return
}
if record.Timestamp.IsZero() {
record.Timestamp = time.Now()
}
body, err := json.Marshal(record)
if err != nil {
return
}
l.mu.Lock()
defer l.mu.Unlock()
_, _ = l.file.Write(append(body, '\n'))
}

View File

@@ -0,0 +1,163 @@
package config
import (
"fmt"
"os"
"strconv"
"strings"
"time"
)
type Config struct {
ServerName string
ServerVersion string
ProtocolVersion string
DefaultCaller string
ToolTimeout time.Duration
RateLimitRPS float64
RateLimitBurst float64
MaxResultRows int
AuditLogPath string
EnforceWhitelist bool
RedisScanMaxKeys int
RedisScanMaxCount int
RedisValueMaxItems int
RedisMaxStringBytes int
MySQL MySQLConfig
Redis RedisConfig
}
type MySQLConfig struct {
Host string
Port int
User string
Password string
Database string
Params string
AllowedDatabases []string
AllowedTables []string
}
type RedisConfig struct {
Addr string
Password string
DB int
}
func LoadFromEnv() (Config, error) {
cfg := Config{
ServerName: getEnv("MCP_SERVER_NAME", "smartflow-mcp-server"),
ServerVersion: getEnv("MCP_SERVER_VERSION", "0.1.0"),
ProtocolVersion: getEnv("MCP_PROTOCOL_VERSION", "2024-11-05"),
DefaultCaller: getEnv("MCP_DEFAULT_CALLER", "unknown"),
ToolTimeout: getEnvDurationMS("MCP_TOOL_TIMEOUT_MS", 5000),
RateLimitRPS: getEnvFloat("MCP_RATE_LIMIT_RPS", 5),
RateLimitBurst: getEnvFloat("MCP_RATE_LIMIT_BURST", 10),
MaxResultRows: getEnvInt("MCP_MAX_RESULT_ROWS", 500),
AuditLogPath: getEnv("MCP_AUDIT_LOG_PATH", "logs/audit.log"),
EnforceWhitelist: getEnvBool("MCP_ENFORCE_WHITELIST", false),
RedisScanMaxKeys: getEnvInt("MCP_REDIS_SCAN_MAX_KEYS", 200),
RedisScanMaxCount: getEnvInt("MCP_REDIS_SCAN_MAX_COUNT", 200),
RedisValueMaxItems: getEnvInt("MCP_REDIS_VALUE_MAX_ITEMS", 100),
RedisMaxStringBytes: getEnvInt("MCP_REDIS_MAX_STRING_BYTES", 4096),
MySQL: MySQLConfig{
Host: getEnv("MYSQL_HOST", "127.0.0.1"),
Port: getEnvInt("MYSQL_PORT", 3306),
User: getEnv("MYSQL_USER", ""),
Password: getEnv("MYSQL_PASSWORD", ""),
Database: getEnv("MYSQL_DATABASE", ""),
Params: getEnv("MYSQL_PARAMS", "charset=utf8mb4&parseTime=true&loc=Local"),
AllowedDatabases: splitCommaList(getEnv("MYSQL_ALLOWED_DATABASES", "")),
AllowedTables: splitCommaList(getEnv("MYSQL_ALLOWED_TABLES", "")),
},
Redis: RedisConfig{
Addr: getEnv("REDIS_ADDR", "127.0.0.1:6379"),
Password: getEnv("REDIS_PASSWORD", ""),
DB: getEnvInt("REDIS_DB", 0),
},
}
if cfg.MySQL.User == "" || cfg.MySQL.Database == "" {
return Config{}, fmt.Errorf("MYSQL_USER and MYSQL_DATABASE are required")
}
if cfg.Redis.Addr == "" {
return Config{}, fmt.Errorf("REDIS_ADDR is required")
}
if cfg.MaxResultRows <= 0 {
return Config{}, fmt.Errorf("MCP_MAX_RESULT_ROWS must be > 0")
}
if cfg.RedisScanMaxKeys <= 0 {
return Config{}, fmt.Errorf("MCP_REDIS_SCAN_MAX_KEYS must be > 0")
}
if cfg.RedisScanMaxCount <= 0 {
return Config{}, fmt.Errorf("MCP_REDIS_SCAN_MAX_COUNT must be > 0")
}
if cfg.ToolTimeout <= 0 {
return Config{}, fmt.Errorf("MCP_TOOL_TIMEOUT_MS must be > 0")
}
if cfg.RateLimitRPS <= 0 || cfg.RateLimitBurst <= 0 {
return Config{}, fmt.Errorf("MCP_RATE_LIMIT_RPS and MCP_RATE_LIMIT_BURST must be > 0")
}
return cfg, nil
}
func getEnv(key string, defaultValue string) string {
if v, ok := os.LookupEnv(key); ok {
return strings.TrimSpace(v)
}
return defaultValue
}
func getEnvInt(key string, defaultValue int) int {
v := getEnv(key, "")
if v == "" {
return defaultValue
}
n, err := strconv.Atoi(v)
if err != nil {
return defaultValue
}
return n
}
func getEnvFloat(key string, defaultValue float64) float64 {
v := getEnv(key, "")
if v == "" {
return defaultValue
}
n, err := strconv.ParseFloat(v, 64)
if err != nil {
return defaultValue
}
return n
}
func getEnvBool(key string, defaultValue bool) bool {
v := strings.ToLower(getEnv(key, ""))
if v == "" {
return defaultValue
}
return v == "1" || v == "true" || v == "yes" || v == "on"
}
func getEnvDurationMS(key string, defaultValueMs int) time.Duration {
ms := getEnvInt(key, defaultValueMs)
return time.Duration(ms) * time.Millisecond
}
func splitCommaList(raw string) []string {
if strings.TrimSpace(raw) == "" {
return nil
}
parts := strings.Split(raw, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
trimmed := strings.TrimSpace(p)
if trimmed != "" {
out = append(out, strings.ToLower(trimmed))
}
}
return out
}

View File

@@ -0,0 +1,44 @@
package envutil
import (
"bufio"
"fmt"
"os"
"strings"
)
func LoadDotEnv(path string) error {
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("open .env: %w", err)
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
i := strings.Index(line, "=")
if i <= 0 {
continue
}
key := strings.TrimSpace(line[:i])
value := strings.TrimSpace(line[i+1:])
value = strings.Trim(value, "\"")
if key == "" {
continue
}
if _, exists := os.LookupEnv(key); !exists {
_ = os.Setenv(key, value)
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("scan .env: %w", err)
}
return nil
}

View File

@@ -0,0 +1,394 @@
package mcp
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"strconv"
"strings"
"sync"
"time"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/audit"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/ratelimit"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/tools"
)
const (
jsonRPCVersion = "2.0"
errCodeParseError = -32700
errCodeInvalidRequest = -32600
errCodeMethodNotFound = -32601
errCodeInvalidParams = -32602
errCodeInternalError = -32603
)
type Server struct {
reader *bufio.Reader
writer io.Writer
writeMu sync.Mutex
registry *tools.Registry
auditLogger *audit.Logger
limiter *ratelimit.Limiter
serverName string
serverVersion string
protocolVersion string
defaultCaller string
toolTimeout time.Duration
}
type request struct {
JSONRPC string `json:"jsonrpc"`
ID json.RawMessage `json:"id,omitempty"`
Method string `json:"method"`
Params json.RawMessage `json:"params,omitempty"`
}
type response struct {
JSONRPC string `json:"jsonrpc"`
ID json.RawMessage `json:"id,omitempty"`
Result any `json:"result,omitempty"`
Error *respError `json:"error,omitempty"`
}
type respError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type initializeParams struct {
ProtocolVersion string `json:"protocolVersion"`
}
type toolCallParams struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
Meta map[string]any `json:"_meta,omitempty"`
}
func NewServer(
in io.Reader,
out io.Writer,
registry *tools.Registry,
auditLogger *audit.Logger,
limiter *ratelimit.Limiter,
serverName string,
serverVersion string,
protocolVersion string,
defaultCaller string,
toolTimeout time.Duration,
) *Server {
return &Server{
reader: bufio.NewReader(in),
writer: out,
registry: registry,
auditLogger: auditLogger,
limiter: limiter,
serverName: serverName,
serverVersion: serverVersion,
protocolVersion: protocolVersion,
defaultCaller: defaultCaller,
toolTimeout: toolTimeout,
}
}
func (s *Server) Serve(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
default:
}
body, err := readMessage(s.reader)
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
log.Printf("read message failed: %v", err)
continue
}
var req request
if err := json.Unmarshal(body, &req); err != nil {
_ = s.writeError(nil, errCodeParseError, "invalid json")
continue
}
if req.Method == "" || req.JSONRPC != jsonRPCVersion {
_ = s.writeError(req.ID, errCodeInvalidRequest, "invalid json-rpc request")
continue
}
if err := s.handleRequest(ctx, req); err != nil {
log.Printf("handle request failed: %v", err)
}
}
}
func (s *Server) handleRequest(ctx context.Context, req request) error {
switch req.Method {
case "initialize":
return s.handleInitialize(req)
case "notifications/initialized":
return nil
case "tools/list":
if len(req.ID) == 0 {
return nil
}
return s.writeResult(req.ID, map[string]any{"tools": s.registry.List()})
case "tools/call":
if len(req.ID) == 0 {
return nil
}
return s.handleToolCall(ctx, req)
case "ping":
if len(req.ID) == 0 {
return nil
}
return s.writeResult(req.ID, map[string]any{})
default:
if len(req.ID) == 0 {
return nil
}
return s.writeError(req.ID, errCodeMethodNotFound, "method not found")
}
}
func (s *Server) handleInitialize(req request) error {
if len(req.ID) == 0 {
return nil
}
var params initializeParams
if len(req.Params) > 0 {
_ = json.Unmarshal(req.Params, &params)
}
_ = params
return s.writeResult(req.ID, map[string]any{
"protocolVersion": s.protocolVersion,
"capabilities": map[string]any{
"tools": map[string]any{"listChanged": false},
},
"serverInfo": map[string]any{
"name": s.serverName,
"version": s.serverVersion,
},
})
}
func (s *Server) handleToolCall(ctx context.Context, req request) error {
var params toolCallParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return s.writeError(req.ID, errCodeInvalidParams, "invalid tool call params")
}
if params.Name == "" {
return s.writeError(req.ID, errCodeInvalidParams, "tool name is required")
}
if params.Arguments == nil {
params.Arguments = map[string]any{}
}
caller := extractCaller(params, s.defaultCaller)
rateKey := fmt.Sprintf("%s:%s", caller, params.Name)
if !s.limiter.Allow(rateKey) {
result := buildToolErrorResult("rate limit exceeded")
s.auditLogger.Log(audit.Record{
Tool: params.Name,
Caller: caller,
Success: false,
DurationMs: 0,
Meta: sanitizeAuditInput(params.Name, params.Arguments),
Error: "rate limit exceeded",
})
return s.writeResult(req.ID, result)
}
tool, ok := s.registry.Find(params.Name)
if !ok {
result := buildToolErrorResult("tool not found")
s.auditLogger.Log(audit.Record{
Tool: params.Name,
Caller: caller,
Success: false,
DurationMs: 0,
Meta: sanitizeAuditInput(params.Name, params.Arguments),
Error: "tool not found",
})
return s.writeResult(req.ID, result)
}
start := time.Now()
toolCtx, cancel := context.WithTimeout(ctx, s.toolTimeout)
defer cancel()
output, err := tool.Execute(toolCtx, params.Arguments)
duration := time.Since(start).Milliseconds()
if err != nil {
errMsg := sanitizeError(err)
s.auditLogger.Log(audit.Record{
Tool: params.Name,
Caller: caller,
Success: false,
DurationMs: duration,
Meta: sanitizeAuditInput(params.Name, params.Arguments),
Error: errMsg,
})
return s.writeResult(req.ID, buildToolErrorResult(errMsg))
}
s.auditLogger.Log(audit.Record{
Tool: params.Name,
Caller: caller,
Success: true,
DurationMs: duration,
Meta: sanitizeAuditInput(params.Name, params.Arguments),
})
return s.writeResult(req.ID, buildToolSuccessResult(output))
}
func buildToolSuccessResult(output map[string]any) map[string]any {
asJSON, _ := json.Marshal(output)
return map[string]any{
"content": []map[string]any{
{"type": "text", "text": string(asJSON)},
},
"structuredContent": output,
}
}
func buildToolErrorResult(message string) map[string]any {
return map[string]any{
"content": []map[string]any{
{"type": "text", "text": message},
},
"structuredContent": map[string]any{"error": message},
"isError": true,
}
}
func extractCaller(params toolCallParams, fallback string) string {
if params.Meta != nil {
if v, ok := params.Meta["caller"]; ok {
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
return strings.TrimSpace(s)
}
}
}
if v, ok := params.Arguments["caller"]; ok {
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
return strings.TrimSpace(s)
}
}
return fallback
}
func sanitizeAuditInput(toolName string, args map[string]any) map[string]any {
meta := map[string]any{}
switch toolName {
case "mysql_query_readonly":
sql, _ := args["sql"].(string)
meta["sql"] = security.RedactSQL(sql)
if params, ok := args["params"].([]any); ok {
meta["paramCount"] = len(params)
}
case "redis_get":
if key, ok := args["key"].(string); ok {
meta["key"] = security.RedactKey(key)
}
case "redis_scan":
if pattern, ok := args["pattern"].(string); ok {
if len(pattern) > 40 {
meta["pattern"] = pattern[:40]
} else {
meta["pattern"] = pattern
}
}
if count, ok := args["count"]; ok {
meta["count"] = count
}
default:
meta["args"] = "masked"
}
return meta
}
func sanitizeError(err error) string {
msg := err.Error()
msg = strings.ReplaceAll(msg, "\n", " ")
msg = strings.ReplaceAll(msg, "\r", " ")
if len(msg) > 300 {
return msg[:300]
}
return msg
}
func (s *Server) writeResult(id json.RawMessage, result any) error {
resp := response{JSONRPC: jsonRPCVersion, ID: id, Result: result}
return s.writeMessage(resp)
}
func (s *Server) writeError(id json.RawMessage, code int, message string) error {
resp := response{
JSONRPC: jsonRPCVersion,
ID: id,
Error: &respError{Code: code, Message: message},
}
return s.writeMessage(resp)
}
func (s *Server) writeMessage(payload any) error {
body, err := json.Marshal(payload)
if err != nil {
return err
}
header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(body))
s.writeMu.Lock()
defer s.writeMu.Unlock()
if _, err := s.writer.Write([]byte(header)); err != nil {
return err
}
_, err = s.writer.Write(body)
return err
}
func readMessage(reader *bufio.Reader) ([]byte, error) {
contentLength := -1
for {
line, err := reader.ReadString('\n')
if err != nil {
return nil, err
}
line = strings.TrimRight(line, "\r\n")
if line == "" {
break
}
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
headerName := strings.TrimSpace(parts[0])
headerValue := strings.TrimSpace(parts[1])
if strings.EqualFold(headerName, "Content-Length") {
n, err := strconv.Atoi(headerValue)
if err != nil || n < 0 {
return nil, fmt.Errorf("invalid Content-Length")
}
contentLength = n
}
}
if contentLength < 0 {
return nil, fmt.Errorf("missing Content-Length")
}
body := make([]byte, contentLength)
if _, err := io.ReadFull(reader, body); err != nil {
return nil, err
}
return body, nil
}

View File

@@ -0,0 +1,56 @@
package mcp
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"testing"
"time"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/ratelimit"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/tools"
)
type dummyTool struct{}
func (d dummyTool) Name() string { return "dummy" }
func (d dummyTool) Description() string { return "dummy" }
func (d dummyTool) InputSchema() map[string]any {
return map[string]any{"type": "object"}
}
func (d dummyTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
return map[string]any{"ok": true}, nil
}
func TestHandleInitialize(t *testing.T) {
registry, err := tools.NewRegistry(dummyTool{})
if err != nil {
t.Fatal(err)
}
out := bytes.NewBuffer(nil)
s := NewServer(bytes.NewBuffer(nil), out, registry, nil, ratelimit.New(10, 10), "name", "ver", "2024-11-05", "unknown", time.Second)
req := request{JSONRPC: jsonRPCVersion, ID: json.RawMessage("1"), Method: "initialize", Params: json.RawMessage(`{"protocolVersion":"2024-11-05"}`)}
if err := s.handleRequest(context.Background(), req); err != nil {
t.Fatalf("handle initialize error: %v", err)
}
if !bytes.Contains(out.Bytes(), []byte("protocolVersion")) {
t.Fatalf("response missing protocolVersion: %s", out.String())
}
}
func TestReadMessage(t *testing.T) {
body := []byte(`{"jsonrpc":"2.0","id":1,"method":"ping"}`)
frame := fmt.Sprintf("Content-Length: %d\r\n\r\n%s", len(body), string(body))
input := bufio.NewReader(bytes.NewBufferString(frame))
msg, err := readMessage(input)
if err != nil {
t.Fatalf("read message failed: %v", err)
}
if !bytes.Equal(msg, body) {
t.Fatalf("unexpected body: %s", string(msg))
}
}

View File

@@ -0,0 +1,51 @@
package ratelimit
import (
"sync"
"time"
)
type Limiter struct {
mu sync.Mutex
rate float64
burst float64
buckets map[string]*bucket
}
type bucket struct {
tokens float64
last time.Time
}
func New(rate, burst float64) *Limiter {
return &Limiter{
rate: rate,
burst: burst,
buckets: make(map[string]*bucket),
}
}
func (l *Limiter) Allow(key string) bool {
now := time.Now()
l.mu.Lock()
defer l.mu.Unlock()
b, ok := l.buckets[key]
if !ok {
l.buckets[key] = &bucket{tokens: l.burst - 1, last: now}
return true
}
elapsed := now.Sub(b.last).Seconds()
b.tokens += elapsed * l.rate
if b.tokens > l.burst {
b.tokens = l.burst
}
b.last = now
if b.tokens < 1 {
return false
}
b.tokens -= 1
return true
}

View File

@@ -0,0 +1,26 @@
package ratelimit
import (
"testing"
"time"
)
func TestLimiter(t *testing.T) {
l := New(2, 2)
key := "user:tool"
if !l.Allow(key) {
t.Fatal("first request should pass")
}
if !l.Allow(key) {
t.Fatal("second request should pass")
}
if l.Allow(key) {
t.Fatal("third request should be rate limited")
}
time.Sleep(600 * time.Millisecond)
if !l.Allow(key) {
t.Fatal("request should pass after token refill")
}
}

View File

@@ -0,0 +1,27 @@
package security
import (
"regexp"
"strings"
)
var (
singleQuotedString = regexp.MustCompile(`'([^'\\]|\\.)*'`)
doubleQuotedString = regexp.MustCompile(`"([^"\\]|\\.)*"`)
numericLiteral = regexp.MustCompile(`\b\d+\b`)
)
func RedactSQL(sql string) string {
masked := singleQuotedString.ReplaceAllString(sql, "'***'")
masked = doubleQuotedString.ReplaceAllString(masked, `"***"`)
masked = numericLiteral.ReplaceAllString(masked, "?")
return strings.TrimSpace(masked)
}
func RedactKey(key string) string {
key = strings.TrimSpace(key)
if len(key) <= 4 {
return "****"
}
return key[:2] + "***" + key[len(key)-2:]
}

View File

@@ -0,0 +1,193 @@
package security
import (
"fmt"
"regexp"
"strings"
)
var (
commentPattern = regexp.MustCompile(`(?s)/\*.*?\*/|--|#`)
forbiddenWords = regexp.MustCompile(`(?i)\b(INSERT|UPDATE|DELETE|ALTER|DROP|TRUNCATE|CREATE|REPLACE|RENAME|GRANT|REVOKE|MERGE|CALL|EXEC|LOCK|UNLOCK|LOAD|OUTFILE|INFILE|HANDLER|SET|USE)\b`)
fromJoinRef = regexp.MustCompile(`(?i)\b(?:FROM|JOIN)\s+([` + "`" + `"\w\.]+)`)
describeRef = regexp.MustCompile(`(?i)\b(?:DESCRIBE|DESC)\s+([` + "`" + `"\w\.]+)`)
showFromRef = regexp.MustCompile(`(?i)\bSHOW\b[\w\s]*\b(?:FROM|IN)\s+([` + "`" + `"\w\.]+)`)
)
type SQLValidator struct {
defaultDatabase string
enforceWhitelist bool
allowedDatabases map[string]struct{}
allowedTables map[string]struct{}
allowedTableOnly map[string]struct{}
}
type sqlObjectRef struct {
database string
table string
}
func NewSQLValidator(defaultDatabase string, enforceWhitelist bool, allowedDatabases []string, allowedTables []string) *SQLValidator {
v := &SQLValidator{
defaultDatabase: strings.ToLower(strings.TrimSpace(defaultDatabase)),
enforceWhitelist: enforceWhitelist,
allowedDatabases: make(map[string]struct{}),
allowedTables: make(map[string]struct{}),
allowedTableOnly: make(map[string]struct{}),
}
for _, db := range allowedDatabases {
db = strings.ToLower(strings.TrimSpace(db))
if db != "" {
v.allowedDatabases[db] = struct{}{}
}
}
for _, tbl := range allowedTables {
tbl = strings.ToLower(strings.TrimSpace(strings.Trim(tbl, "`\"")))
if tbl == "" {
continue
}
v.allowedTables[tbl] = struct{}{}
if dot := strings.LastIndex(tbl, "."); dot >= 0 && dot < len(tbl)-1 {
v.allowedTableOnly[tbl[dot+1:]] = struct{}{}
} else {
v.allowedTableOnly[tbl] = struct{}{}
}
}
return v
}
func (v *SQLValidator) ValidateReadOnlySQL(sql string) error {
trimmed := strings.TrimSpace(sql)
if trimmed == "" {
return fmt.Errorf("sql is required")
}
if strings.Contains(trimmed, ";") {
return fmt.Errorf("semicolon is not allowed")
}
if commentPattern.MatchString(trimmed) {
return fmt.Errorf("sql comments are not allowed")
}
first := strings.ToUpper(firstToken(trimmed))
switch first {
case "SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN":
default:
return fmt.Errorf("only SELECT/SHOW/DESCRIBE/EXPLAIN are allowed")
}
withoutLiterals := removeStringLiterals(trimmed)
if forbiddenWords.MatchString(strings.ToUpper(withoutLiterals)) {
return fmt.Errorf("dangerous sql keyword detected")
}
refs := extractObjectRefs(withoutLiterals)
if err := v.validateWhitelist(refs); err != nil {
return err
}
return nil
}
func (v *SQLValidator) validateWhitelist(refs []sqlObjectRef) error {
hasDBAllowlist := len(v.allowedDatabases) > 0
hasTableAllowlist := len(v.allowedTables) > 0
if !v.enforceWhitelist && !hasDBAllowlist && !hasTableAllowlist {
return nil
}
if len(refs) == 0 {
if v.enforceWhitelist {
return fmt.Errorf("sql does not contain explicit table reference under whitelist mode")
}
return nil
}
for _, ref := range refs {
db := ref.database
tbl := ref.table
if db == "" {
db = v.defaultDatabase
}
if hasDBAllowlist {
if db == "" {
return fmt.Errorf("database is not explicit and no default database set")
}
if _, ok := v.allowedDatabases[db]; !ok {
return fmt.Errorf("database %s is not in whitelist", db)
}
}
if hasTableAllowlist {
full := tbl
if db != "" {
full = db + "." + tbl
}
if _, ok := v.allowedTables[full]; ok {
continue
}
if _, ok := v.allowedTableOnly[tbl]; ok {
continue
}
return fmt.Errorf("table %s is not in whitelist", full)
}
}
return nil
}
func extractObjectRefs(sql string) []sqlObjectRef {
refs := make([]sqlObjectRef, 0)
seen := make(map[string]struct{})
for _, re := range []*regexp.Regexp{fromJoinRef, describeRef, showFromRef} {
matches := re.FindAllStringSubmatch(sql, -1)
for _, m := range matches {
if len(m) < 2 {
continue
}
ref := normalizeRef(m[1])
if ref.table == "" {
continue
}
key := ref.database + "." + ref.table
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
refs = append(refs, ref)
}
}
return refs
}
func normalizeRef(raw string) sqlObjectRef {
clean := strings.ToLower(strings.TrimSpace(raw))
clean = strings.Trim(clean, "`\"'(),")
clean = strings.TrimSpace(clean)
if clean == "" {
return sqlObjectRef{}
}
parts := strings.Split(clean, ".")
if len(parts) == 1 {
return sqlObjectRef{table: parts[0]}
}
return sqlObjectRef{database: parts[0], table: parts[len(parts)-1]}
}
func firstToken(sql string) string {
for i, r := range sql {
if r == ' ' || r == '\n' || r == '\t' || r == '\r' {
if i == 0 {
continue
}
return sql[:i]
}
}
return sql
}
func removeStringLiterals(sql string) string {
masked := singleQuotedString.ReplaceAllString(sql, "''")
masked = doubleQuotedString.ReplaceAllString(masked, `""`)
return masked
}

View File

@@ -0,0 +1,44 @@
package security
import "testing"
func TestValidateReadOnlySQL(t *testing.T) {
validator := NewSQLValidator("smartflow", true, []string{"smartflow"}, []string{"smartflow.users", "smartflow.tasks"})
tests := []struct {
name string
sql string
wantErr bool
}{
{name: "allow select", sql: "SELECT id, name FROM users WHERE id = 1", wantErr: false},
{name: "allow explain", sql: "EXPLAIN SELECT * FROM tasks", wantErr: false},
{name: "reject insert", sql: "INSERT INTO users(name) VALUES('x')", wantErr: true},
{name: "reject multi statement", sql: "SELECT * FROM users; SELECT * FROM tasks", wantErr: true},
{name: "reject comment", sql: "SELECT * FROM users -- bypass", wantErr: true},
{name: "reject not whitelisted table", sql: "SELECT * FROM orders", wantErr: true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := validator.ValidateReadOnlySQL(tc.sql)
if tc.wantErr && err == nil {
t.Fatalf("expected error, got nil")
}
if !tc.wantErr && err != nil {
t.Fatalf("expected no error, got %v", err)
}
})
}
}
func TestRedact(t *testing.T) {
masked := RedactSQL("SELECT * FROM users WHERE token='abc123' AND id=42")
if masked == "" || masked == "SELECT * FROM users WHERE token='abc123' AND id=42" {
t.Fatalf("redaction not applied: %s", masked)
}
key := RedactKey("very-sensitive-key")
if key == "very-sensitive-key" {
t.Fatalf("key not redacted")
}
}

View File

@@ -0,0 +1,198 @@
package store
import (
"context"
"database/sql"
"fmt"
"net/url"
"reflect"
"strconv"
"strings"
"time"
cfgpkg "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
"github.com/go-sql-driver/mysql"
)
type MySQLClient struct {
db *sql.DB
}
type QueryColumn struct {
Name string `json:"name"`
DatabaseType string `json:"databaseType"`
Nullable *bool `json:"nullable,omitempty"`
ScanType string `json:"scanType,omitempty"`
}
type QueryResult struct {
Columns []QueryColumn `json:"columns"`
Rows []map[string]any `json:"rows"`
RowCount int `json:"rowCount"`
Truncated bool `json:"truncated"`
DurationMs int64 `json:"durationMs"`
}
func NewMySQLClient(ctx context.Context, cfg cfgpkg.MySQLConfig) (*MySQLClient, error) {
dsn := mysql.Config{
User: cfg.User,
Passwd: cfg.Password,
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Net: "tcp",
DBName: cfg.Database,
AllowNativePasswords: true,
ParseTime: true,
}
applyMySQLParams(&dsn, cfg.Params)
db, err := sql.Open("mysql", dsn.FormatDSN())
if err != nil {
return nil, fmt.Errorf("open mysql: %w", err)
}
db.SetConnMaxLifetime(5 * time.Minute)
db.SetMaxOpenConns(5)
db.SetMaxIdleConns(5)
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
if err := db.PingContext(pingCtx); err != nil {
_ = db.Close()
return nil, fmt.Errorf("ping mysql: %w", err)
}
return &MySQLClient{db: db}, nil
}
func applyMySQLParams(cfg *mysql.Config, raw string) {
if strings.TrimSpace(raw) == "" {
return
}
values, err := url.ParseQuery(raw)
if err != nil {
return
}
if cfg.Params == nil {
cfg.Params = make(map[string]string)
}
for key, valueList := range values {
if len(valueList) == 0 {
continue
}
value := valueList[0]
switch strings.ToLower(key) {
case "parsetime":
if parsed, err := strconv.ParseBool(value); err == nil {
cfg.ParseTime = parsed
}
case "loc":
if loc, err := time.LoadLocation(value); err == nil {
cfg.Loc = loc
}
case "collation":
cfg.Collation = value
default:
cfg.Params[key] = value
}
}
}
func (c *MySQLClient) Close() error {
if c == nil || c.db == nil {
return nil
}
return c.db.Close()
}
func (c *MySQLClient) QueryReadOnly(ctx context.Context, query string, args []any, maxRows int) (QueryResult, error) {
start := time.Now()
rows, err := c.db.QueryContext(ctx, query, args...)
if err != nil {
return QueryResult{}, err
}
defer rows.Close()
columnTypes, err := rows.ColumnTypes()
if err != nil {
return QueryResult{}, err
}
columns := make([]QueryColumn, 0, len(columnTypes))
columnNames := make([]string, 0, len(columnTypes))
for _, ct := range columnTypes {
var nullablePtr *bool
if nullable, ok := ct.Nullable(); ok {
n := nullable
nullablePtr = &n
}
scanType := ""
if st := ct.ScanType(); st != nil {
scanType = st.String()
}
columns = append(columns, QueryColumn{
Name: ct.Name(),
DatabaseType: ct.DatabaseTypeName(),
Nullable: nullablePtr,
ScanType: scanType,
})
columnNames = append(columnNames, ct.Name())
}
resultRows := make([]map[string]any, 0)
truncated := false
for rows.Next() {
if len(resultRows) >= maxRows {
truncated = true
break
}
scanned, err := scanRow(rows, len(columnNames))
if err != nil {
return QueryResult{}, err
}
rowMap := make(map[string]any, len(columnNames))
for i, name := range columnNames {
rowMap[name] = normalizeValue(scanned[i])
}
resultRows = append(resultRows, rowMap)
}
if err := rows.Err(); err != nil {
return QueryResult{}, err
}
return QueryResult{
Columns: columns,
Rows: resultRows,
RowCount: len(resultRows),
Truncated: truncated,
DurationMs: time.Since(start).Milliseconds(),
}, nil
}
func scanRow(rows *sql.Rows, size int) ([]any, error) {
dest := make([]any, size)
holders := make([]any, size)
for i := range dest {
holders[i] = &dest[i]
}
if err := rows.Scan(holders...); err != nil {
return nil, err
}
return dest, nil
}
func normalizeValue(v any) any {
switch val := v.(type) {
case nil:
return nil
case []byte:
return string(val)
case time.Time:
return val.Format(time.RFC3339Nano)
default:
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr && rv.IsNil() {
return nil
}
return v
}
}

View File

@@ -0,0 +1,189 @@
package store
import (
"context"
"fmt"
"strings"
"time"
cfgpkg "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
"github.com/go-redis/redis/v8"
)
type RedisClient struct {
client *redis.Client
}
type RedisGetResult struct {
Exists bool `json:"exists"`
Key string `json:"key"`
Type string `json:"type"`
Value any `json:"value,omitempty"`
Truncated bool `json:"truncated"`
DurationMs int64 `json:"durationMs"`
}
type RedisScanResult struct {
Pattern string `json:"pattern"`
Keys []string `json:"keys"`
Returned int `json:"returned"`
NextCursor uint64 `json:"nextCursor"`
Truncated bool `json:"truncated"`
DurationMs int64 `json:"durationMs"`
}
func NewRedisClient(ctx context.Context, cfg cfgpkg.RedisConfig) (*RedisClient, error) {
client := redis.NewClient(&redis.Options{
Addr: cfg.Addr,
Password: cfg.Password,
DB: cfg.DB,
})
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
if err := client.Ping(pingCtx).Err(); err != nil {
_ = client.Close()
return nil, fmt.Errorf("ping redis: %w", err)
}
return &RedisClient{client: client}, nil
}
func (c *RedisClient) Close() error {
if c == nil || c.client == nil {
return nil
}
return c.client.Close()
}
func (c *RedisClient) GetWithType(ctx context.Context, key string, maxItems int, maxStringBytes int) (RedisGetResult, error) {
start := time.Now()
t, err := c.client.Type(ctx, key).Result()
if err != nil {
return RedisGetResult{}, err
}
if t == "none" {
return RedisGetResult{
Exists: false,
Key: key,
Type: "none",
Truncated: false,
DurationMs: time.Since(start).Milliseconds(),
}, nil
}
result := RedisGetResult{Exists: true, Key: key, Type: t}
switch t {
case "string":
v, err := c.client.Get(ctx, key).Result()
if err != nil {
return RedisGetResult{}, err
}
if len(v) > maxStringBytes {
result.Value = v[:maxStringBytes]
result.Truncated = true
} else {
result.Value = v
}
case "list":
vals, err := c.client.LRange(ctx, key, 0, int64(maxItems-1)).Result()
if err != nil {
return RedisGetResult{}, err
}
result.Value = vals
if length, _ := c.client.LLen(ctx, key).Result(); length > int64(maxItems) {
result.Truncated = true
}
case "set":
vals, err := c.client.SMembers(ctx, key).Result()
if err != nil {
return RedisGetResult{}, err
}
if len(vals) > maxItems {
result.Value = vals[:maxItems]
result.Truncated = true
} else {
result.Value = vals
}
case "zset":
vals, err := c.client.ZRangeWithScores(ctx, key, 0, int64(maxItems-1)).Result()
if err != nil {
return RedisGetResult{}, err
}
resultRows := make([]map[string]any, 0, len(vals))
for _, item := range vals {
resultRows = append(resultRows, map[string]any{"member": item.Member, "score": item.Score})
}
result.Value = resultRows
if length, _ := c.client.ZCard(ctx, key).Result(); length > int64(maxItems) {
result.Truncated = true
}
case "hash":
vals, err := c.client.HGetAll(ctx, key).Result()
if err != nil {
return RedisGetResult{}, err
}
if len(vals) <= maxItems {
result.Value = vals
} else {
trimmed := make(map[string]string, maxItems)
count := 0
for k, v := range vals {
trimmed[k] = v
count++
if count >= maxItems {
break
}
}
result.Value = trimmed
result.Truncated = true
}
default:
raw, err := c.client.Dump(ctx, key).Result()
if err != nil {
return RedisGetResult{}, err
}
result.Value = strings.ToUpper(fmt.Sprintf("UNSUPPORTED_TYPE_%s_DUMP_SIZE_%d", t, len(raw)))
}
result.DurationMs = time.Since(start).Milliseconds()
return result, nil
}
func (c *RedisClient) ScanKeys(ctx context.Context, pattern string, count int64, maxKeys int) (RedisScanResult, error) {
start := time.Now()
if pattern == "" {
pattern = "*"
}
if count <= 0 {
count = 20
}
keys := make([]string, 0, maxKeys)
var cursor uint64
truncated := false
for {
batch, nextCursor, err := c.client.Scan(ctx, cursor, pattern, count).Result()
if err != nil {
return RedisScanResult{}, err
}
for _, key := range batch {
if len(keys) >= maxKeys {
truncated = true
break
}
keys = append(keys, key)
}
cursor = nextCursor
if truncated || cursor == 0 {
break
}
}
return RedisScanResult{
Pattern: pattern,
Keys: keys,
Returned: len(keys),
NextCursor: cursor,
Truncated: truncated,
DurationMs: time.Since(start).Milliseconds(),
}, nil
}

View File

@@ -0,0 +1,95 @@
package tools
import (
"context"
"os"
"strconv"
"testing"
"time"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store"
)
func TestIntegrationMySQLReadOnlyTool(t *testing.T) {
if os.Getenv("MCP_IT_RUN") != "1" {
t.Skip("set MCP_IT_RUN=1 to run integration tests")
}
port := 3306
if p := os.Getenv("MYSQL_PORT"); p != "" {
if n, err := strconv.Atoi(p); err == nil {
port = n
}
}
mysqlCfg := config.MySQLConfig{
Host: os.Getenv("MYSQL_HOST"),
Port: port,
User: os.Getenv("MYSQL_USER"),
Password: os.Getenv("MYSQL_PASSWORD"),
Database: os.Getenv("MYSQL_DATABASE"),
Params: "charset=utf8mb4&parseTime=true&loc=Local",
}
if mysqlCfg.Host == "" || mysqlCfg.User == "" || mysqlCfg.Database == "" {
t.Skip("missing MYSQL_* env for integration test")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
client, err := store.NewMySQLClient(ctx, mysqlCfg)
if err != nil {
t.Fatalf("mysql not available: %v", err)
}
defer func() { _ = client.Close() }()
validator := security.NewSQLValidator(mysqlCfg.Database, false, nil, nil)
tool := NewMySQLReadOnlyTool(client, validator, 10)
res, err := tool.Execute(ctx, map[string]any{"sql": "SELECT 1 AS ok"})
if err != nil {
t.Fatalf("tool execute failed: %v", err)
}
if res["rowCount"].(int) < 1 {
t.Fatalf("expected rowCount >= 1")
}
}
func TestIntegrationRedisTools(t *testing.T) {
if os.Getenv("MCP_IT_RUN") != "1" {
t.Skip("set MCP_IT_RUN=1 to run integration tests")
}
db := 0
if p := os.Getenv("REDIS_DB"); p != "" {
if n, err := strconv.Atoi(p); err == nil {
db = n
}
}
redisCfg := config.RedisConfig{
Addr: os.Getenv("REDIS_ADDR"),
Password: os.Getenv("REDIS_PASSWORD"),
DB: db,
}
if redisCfg.Addr == "" {
t.Skip("REDIS_ADDR is empty")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
client, err := store.NewRedisClient(ctx, redisCfg)
if err != nil {
t.Fatalf("redis not available: %v", err)
}
defer func() { _ = client.Close() }()
getTool := NewRedisGetTool(client, 10, 128)
if _, err := getTool.Execute(ctx, map[string]any{"key": "__integration_missing_key__"}); err != nil {
t.Fatalf("redis_get failed: %v", err)
}
scanTool := NewRedisScanTool(client, 10, 10)
if _, err := scanTool.Execute(ctx, map[string]any{"pattern": "*", "count": float64(5)}); err != nil {
t.Fatalf("redis_scan failed: %v", err)
}
}

View File

@@ -0,0 +1,95 @@
package tools
import (
"context"
"fmt"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store"
)
type MySQLReadOnlyTool struct {
client *store.MySQLClient
validator *security.SQLValidator
maxRows int
}
func NewMySQLReadOnlyTool(client *store.MySQLClient, validator *security.SQLValidator, maxRows int) *MySQLReadOnlyTool {
return &MySQLReadOnlyTool{client: client, validator: validator, maxRows: maxRows}
}
func (t *MySQLReadOnlyTool) Name() string {
return "mysql_query_readonly"
}
func (t *MySQLReadOnlyTool) Description() string {
return "Execute read-only SQL on MySQL. Only SELECT/SHOW/DESCRIBE/EXPLAIN are allowed."
}
func (t *MySQLReadOnlyTool) InputSchema() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"sql": map[string]any{
"type": "string",
"description": "Read-only SQL statement",
},
"params": map[string]any{
"type": "array",
"description": "Optional bind parameters",
"items": map[string]any{
"type": []string{"string", "number", "boolean", "null"},
},
},
},
"required": []string{"sql"},
"additionalProperties": false,
}
}
func (t *MySQLReadOnlyTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
rawSQL, ok := args["sql"].(string)
if !ok || rawSQL == "" {
return nil, fmt.Errorf("sql must be a non-empty string")
}
if err := t.validator.ValidateReadOnlySQL(rawSQL); err != nil {
return nil, err
}
params, err := normalizeParams(args["params"])
if err != nil {
return nil, err
}
res, err := t.client.QueryReadOnly(ctx, rawSQL, params, t.maxRows)
if err != nil {
return nil, err
}
return map[string]any{
"columns": res.Columns,
"rows": res.Rows,
"rowCount": res.RowCount,
"truncated": res.Truncated,
"durationMs": res.DurationMs,
}, nil
}
func normalizeParams(raw any) ([]any, error) {
if raw == nil {
return nil, nil
}
arr, ok := raw.([]any)
if !ok {
return nil, fmt.Errorf("params must be an array")
}
out := make([]any, 0, len(arr))
for _, item := range arr {
switch v := item.(type) {
case string, float64, bool, nil:
out = append(out, v)
default:
return nil, fmt.Errorf("params contains unsupported type")
}
}
return out, nil
}

View File

@@ -0,0 +1,130 @@
package tools
import (
"context"
"fmt"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store"
)
type RedisGetTool struct {
client *store.RedisClient
valueMaxItems int
maxStringBytes int
}
func NewRedisGetTool(client *store.RedisClient, valueMaxItems int, maxStringBytes int) *RedisGetTool {
return &RedisGetTool{client: client, valueMaxItems: valueMaxItems, maxStringBytes: maxStringBytes}
}
func (t *RedisGetTool) Name() string {
return "redis_get"
}
func (t *RedisGetTool) Description() string {
return "Get a Redis key by name and return its type and value."
}
func (t *RedisGetTool) InputSchema() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"key": map[string]any{
"type": "string",
"description": "Redis key",
},
},
"required": []string{"key"},
"additionalProperties": false,
}
}
func (t *RedisGetTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
key, ok := args["key"].(string)
if !ok || key == "" {
return nil, fmt.Errorf("key must be a non-empty string")
}
res, err := t.client.GetWithType(ctx, key, t.valueMaxItems, t.maxStringBytes)
if err != nil {
return nil, err
}
return map[string]any{
"exists": res.Exists,
"key": res.Key,
"type": res.Type,
"value": res.Value,
"truncated": res.Truncated,
"durationMs": res.DurationMs,
}, nil
}
type RedisScanTool struct {
client *store.RedisClient
maxKeys int
maxScanCount int
}
func NewRedisScanTool(client *store.RedisClient, maxKeys int, maxScanCount int) *RedisScanTool {
return &RedisScanTool{client: client, maxKeys: maxKeys, maxScanCount: maxScanCount}
}
func (t *RedisScanTool) Name() string {
return "redis_scan"
}
func (t *RedisScanTool) Description() string {
return "Scan Redis keys by pattern with capped result size."
}
func (t *RedisScanTool) InputSchema() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"pattern": map[string]any{
"type": "string",
"description": "Pattern, for example user:*",
},
"count": map[string]any{
"type": "number",
"description": "Optional scan count hint",
},
},
"required": []string{"pattern"},
"additionalProperties": false,
}
}
func (t *RedisScanTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
pattern, ok := args["pattern"].(string)
if !ok || pattern == "" {
return nil, fmt.Errorf("pattern must be a non-empty string")
}
count := int64(20)
if rawCount, ok := args["count"]; ok {
number, ok := rawCount.(float64)
if !ok {
return nil, fmt.Errorf("count must be a number")
}
if number <= 0 {
return nil, fmt.Errorf("count must be > 0")
}
count = int64(number)
}
if count > int64(t.maxScanCount) {
count = int64(t.maxScanCount)
}
res, err := t.client.ScanKeys(ctx, pattern, count, t.maxKeys)
if err != nil {
return nil, err
}
return map[string]any{
"pattern": res.Pattern,
"keys": res.Keys,
"returned": res.Returned,
"nextCursor": res.NextCursor,
"truncated": res.Truncated,
"durationMs": res.DurationMs,
}, nil
}

View File

@@ -0,0 +1,53 @@
package tools
import (
"context"
"fmt"
"sort"
)
type Tool interface {
Name() string
Description() string
InputSchema() map[string]any
Execute(ctx context.Context, args map[string]any) (map[string]any, error)
}
type Registry struct {
tools map[string]Tool
}
func NewRegistry(toolList ...Tool) (*Registry, error) {
r := &Registry{tools: make(map[string]Tool, len(toolList))}
for _, t := range toolList {
name := t.Name()
if _, exists := r.tools[name]; exists {
return nil, fmt.Errorf("duplicated tool name: %s", name)
}
r.tools[name] = t
}
return r, nil
}
func (r *Registry) Find(name string) (Tool, bool) {
t, ok := r.tools[name]
return t, ok
}
func (r *Registry) List() []map[string]any {
out := make([]map[string]any, 0, len(r.tools))
names := make([]string, 0, len(r.tools))
for name := range r.tools {
names = append(names, name)
}
sort.Strings(names)
for _, name := range names {
t := r.tools[name]
out = append(out, map[string]any{
"name": t.Name(),
"description": t.Description(),
"inputSchema": t.InputSchema(),
})
}
return out
}