Version: 0.4.4.dev.260307
feat: 🚀 增强会话管理与缓存机制 * 会话 ID 空值兜底,若 `conversation_id` 为空时自动生成 UUID * 在响应头写入 `X-Conversation-ID`,供前端使用,保持同一会话状态 perf: ⚡ 会话状态缓存优化 * 当缓存未命中但 DB 已确认/创建会话后,调用 `SetConversationStatus` 回写 Redis * 缓存写回失败时记录日志,不中断聊天主流程,确保业务流畅性 fix: 🐛 修复历史消息顺序问题与编译错误 * 修复历史消息顺序问题,保证返回的 N 条历史消息按时间正序喂给模型 * 通过反转 `created_at desc` 查询结果的切片,确保模型输入顺序正确 * 修复 `fmt.Errorf` 参数不匹配问题,修正编译错误 * 整理 `agent-cache.go` 为标准 UTF-8 编码,避免 Go 编译报错 `invalid UTF-8 encoding` feat: 🛠️ 独立构建 MCP 服务器 * 使用 `Codex` 构建独立于后端的 MCP 服务器,简化与 Codex 的协作 * 通过该服务器方便 Codex 直接测试和查看 Redis 与 MySQL 中的数据
This commit is contained in:
60
infra/smartflow-mcp-server/internal/audit/logger.go
Normal file
60
infra/smartflow-mcp-server/internal/audit/logger.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
mu sync.Mutex
|
||||
file *os.File
|
||||
}
|
||||
|
||||
type Record struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Tool string `json:"tool"`
|
||||
Caller string `json:"caller"`
|
||||
Success bool `json:"success"`
|
||||
DurationMs int64 `json:"duration_ms"`
|
||||
Meta map[string]any `json:"meta,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func New(path string) (*Logger, error) {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create audit dir: %w", err)
|
||||
}
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open audit log file: %w", err)
|
||||
}
|
||||
return &Logger{file: f}, nil
|
||||
}
|
||||
|
||||
func (l *Logger) Close() error {
|
||||
if l == nil || l.file == nil {
|
||||
return nil
|
||||
}
|
||||
return l.file.Close()
|
||||
}
|
||||
|
||||
func (l *Logger) Log(record Record) {
|
||||
if l == nil || l.file == nil {
|
||||
return
|
||||
}
|
||||
if record.Timestamp.IsZero() {
|
||||
record.Timestamp = time.Now()
|
||||
}
|
||||
body, err := json.Marshal(record)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
_, _ = l.file.Write(append(body, '\n'))
|
||||
}
|
||||
163
infra/smartflow-mcp-server/internal/config/config.go
Normal file
163
infra/smartflow-mcp-server/internal/config/config.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ServerName string
|
||||
ServerVersion string
|
||||
ProtocolVersion string
|
||||
DefaultCaller string
|
||||
ToolTimeout time.Duration
|
||||
RateLimitRPS float64
|
||||
RateLimitBurst float64
|
||||
MaxResultRows int
|
||||
AuditLogPath string
|
||||
EnforceWhitelist bool
|
||||
RedisScanMaxKeys int
|
||||
RedisScanMaxCount int
|
||||
RedisValueMaxItems int
|
||||
RedisMaxStringBytes int
|
||||
|
||||
MySQL MySQLConfig
|
||||
Redis RedisConfig
|
||||
}
|
||||
|
||||
type MySQLConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
Database string
|
||||
Params string
|
||||
AllowedDatabases []string
|
||||
AllowedTables []string
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
Addr string
|
||||
Password string
|
||||
DB int
|
||||
}
|
||||
|
||||
func LoadFromEnv() (Config, error) {
|
||||
cfg := Config{
|
||||
ServerName: getEnv("MCP_SERVER_NAME", "smartflow-mcp-server"),
|
||||
ServerVersion: getEnv("MCP_SERVER_VERSION", "0.1.0"),
|
||||
ProtocolVersion: getEnv("MCP_PROTOCOL_VERSION", "2024-11-05"),
|
||||
DefaultCaller: getEnv("MCP_DEFAULT_CALLER", "unknown"),
|
||||
ToolTimeout: getEnvDurationMS("MCP_TOOL_TIMEOUT_MS", 5000),
|
||||
RateLimitRPS: getEnvFloat("MCP_RATE_LIMIT_RPS", 5),
|
||||
RateLimitBurst: getEnvFloat("MCP_RATE_LIMIT_BURST", 10),
|
||||
MaxResultRows: getEnvInt("MCP_MAX_RESULT_ROWS", 500),
|
||||
AuditLogPath: getEnv("MCP_AUDIT_LOG_PATH", "logs/audit.log"),
|
||||
EnforceWhitelist: getEnvBool("MCP_ENFORCE_WHITELIST", false),
|
||||
RedisScanMaxKeys: getEnvInt("MCP_REDIS_SCAN_MAX_KEYS", 200),
|
||||
RedisScanMaxCount: getEnvInt("MCP_REDIS_SCAN_MAX_COUNT", 200),
|
||||
RedisValueMaxItems: getEnvInt("MCP_REDIS_VALUE_MAX_ITEMS", 100),
|
||||
RedisMaxStringBytes: getEnvInt("MCP_REDIS_MAX_STRING_BYTES", 4096),
|
||||
MySQL: MySQLConfig{
|
||||
Host: getEnv("MYSQL_HOST", "127.0.0.1"),
|
||||
Port: getEnvInt("MYSQL_PORT", 3306),
|
||||
User: getEnv("MYSQL_USER", ""),
|
||||
Password: getEnv("MYSQL_PASSWORD", ""),
|
||||
Database: getEnv("MYSQL_DATABASE", ""),
|
||||
Params: getEnv("MYSQL_PARAMS", "charset=utf8mb4&parseTime=true&loc=Local"),
|
||||
AllowedDatabases: splitCommaList(getEnv("MYSQL_ALLOWED_DATABASES", "")),
|
||||
AllowedTables: splitCommaList(getEnv("MYSQL_ALLOWED_TABLES", "")),
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Addr: getEnv("REDIS_ADDR", "127.0.0.1:6379"),
|
||||
Password: getEnv("REDIS_PASSWORD", ""),
|
||||
DB: getEnvInt("REDIS_DB", 0),
|
||||
},
|
||||
}
|
||||
|
||||
if cfg.MySQL.User == "" || cfg.MySQL.Database == "" {
|
||||
return Config{}, fmt.Errorf("MYSQL_USER and MYSQL_DATABASE are required")
|
||||
}
|
||||
if cfg.Redis.Addr == "" {
|
||||
return Config{}, fmt.Errorf("REDIS_ADDR is required")
|
||||
}
|
||||
if cfg.MaxResultRows <= 0 {
|
||||
return Config{}, fmt.Errorf("MCP_MAX_RESULT_ROWS must be > 0")
|
||||
}
|
||||
if cfg.RedisScanMaxKeys <= 0 {
|
||||
return Config{}, fmt.Errorf("MCP_REDIS_SCAN_MAX_KEYS must be > 0")
|
||||
}
|
||||
if cfg.RedisScanMaxCount <= 0 {
|
||||
return Config{}, fmt.Errorf("MCP_REDIS_SCAN_MAX_COUNT must be > 0")
|
||||
}
|
||||
if cfg.ToolTimeout <= 0 {
|
||||
return Config{}, fmt.Errorf("MCP_TOOL_TIMEOUT_MS must be > 0")
|
||||
}
|
||||
if cfg.RateLimitRPS <= 0 || cfg.RateLimitBurst <= 0 {
|
||||
return Config{}, fmt.Errorf("MCP_RATE_LIMIT_RPS and MCP_RATE_LIMIT_BURST must be > 0")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func getEnv(key string, defaultValue string) string {
|
||||
if v, ok := os.LookupEnv(key); ok {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvInt(key string, defaultValue int) int {
|
||||
v := getEnv(key, "")
|
||||
if v == "" {
|
||||
return defaultValue
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func getEnvFloat(key string, defaultValue float64) float64 {
|
||||
v := getEnv(key, "")
|
||||
if v == "" {
|
||||
return defaultValue
|
||||
}
|
||||
n, err := strconv.ParseFloat(v, 64)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func getEnvBool(key string, defaultValue bool) bool {
|
||||
v := strings.ToLower(getEnv(key, ""))
|
||||
if v == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return v == "1" || v == "true" || v == "yes" || v == "on"
|
||||
}
|
||||
|
||||
func getEnvDurationMS(key string, defaultValueMs int) time.Duration {
|
||||
ms := getEnvInt(key, defaultValueMs)
|
||||
return time.Duration(ms) * time.Millisecond
|
||||
}
|
||||
|
||||
func splitCommaList(raw string) []string {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(raw, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
trimmed := strings.TrimSpace(p)
|
||||
if trimmed != "" {
|
||||
out = append(out, strings.ToLower(trimmed))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
44
infra/smartflow-mcp-server/internal/envutil/loader.go
Normal file
44
infra/smartflow-mcp-server/internal/envutil/loader.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package envutil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func LoadDotEnv(path string) error {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("open .env: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
i := strings.Index(line, "=")
|
||||
if i <= 0 {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(line[:i])
|
||||
value := strings.TrimSpace(line[i+1:])
|
||||
value = strings.Trim(value, "\"")
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := os.LookupEnv(key); !exists {
|
||||
_ = os.Setenv(key, value)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("scan .env: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
394
infra/smartflow-mcp-server/internal/mcp/server.go
Normal file
394
infra/smartflow-mcp-server/internal/mcp/server.go
Normal file
@@ -0,0 +1,394 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/audit"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/ratelimit"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
jsonRPCVersion = "2.0"
|
||||
|
||||
errCodeParseError = -32700
|
||||
errCodeInvalidRequest = -32600
|
||||
errCodeMethodNotFound = -32601
|
||||
errCodeInvalidParams = -32602
|
||||
errCodeInternalError = -32603
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
reader *bufio.Reader
|
||||
writer io.Writer
|
||||
writeMu sync.Mutex
|
||||
registry *tools.Registry
|
||||
auditLogger *audit.Logger
|
||||
limiter *ratelimit.Limiter
|
||||
serverName string
|
||||
serverVersion string
|
||||
protocolVersion string
|
||||
defaultCaller string
|
||||
toolTimeout time.Duration
|
||||
}
|
||||
|
||||
type request struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id,omitempty"`
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type response struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id,omitempty"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error *respError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type respError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type initializeParams struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
}
|
||||
|
||||
type toolCallParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
Meta map[string]any `json:"_meta,omitempty"`
|
||||
}
|
||||
|
||||
func NewServer(
|
||||
in io.Reader,
|
||||
out io.Writer,
|
||||
registry *tools.Registry,
|
||||
auditLogger *audit.Logger,
|
||||
limiter *ratelimit.Limiter,
|
||||
serverName string,
|
||||
serverVersion string,
|
||||
protocolVersion string,
|
||||
defaultCaller string,
|
||||
toolTimeout time.Duration,
|
||||
) *Server {
|
||||
return &Server{
|
||||
reader: bufio.NewReader(in),
|
||||
writer: out,
|
||||
registry: registry,
|
||||
auditLogger: auditLogger,
|
||||
limiter: limiter,
|
||||
serverName: serverName,
|
||||
serverVersion: serverVersion,
|
||||
protocolVersion: protocolVersion,
|
||||
defaultCaller: defaultCaller,
|
||||
toolTimeout: toolTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Serve(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
body, err := readMessage(s.reader)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
log.Printf("read message failed: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
var req request
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
_ = s.writeError(nil, errCodeParseError, "invalid json")
|
||||
continue
|
||||
}
|
||||
if req.Method == "" || req.JSONRPC != jsonRPCVersion {
|
||||
_ = s.writeError(req.ID, errCodeInvalidRequest, "invalid json-rpc request")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.handleRequest(ctx, req); err != nil {
|
||||
log.Printf("handle request failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(ctx context.Context, req request) error {
|
||||
switch req.Method {
|
||||
case "initialize":
|
||||
return s.handleInitialize(req)
|
||||
case "notifications/initialized":
|
||||
return nil
|
||||
case "tools/list":
|
||||
if len(req.ID) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.writeResult(req.ID, map[string]any{"tools": s.registry.List()})
|
||||
case "tools/call":
|
||||
if len(req.ID) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.handleToolCall(ctx, req)
|
||||
case "ping":
|
||||
if len(req.ID) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.writeResult(req.ID, map[string]any{})
|
||||
default:
|
||||
if len(req.ID) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.writeError(req.ID, errCodeMethodNotFound, "method not found")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleInitialize(req request) error {
|
||||
if len(req.ID) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var params initializeParams
|
||||
if len(req.Params) > 0 {
|
||||
_ = json.Unmarshal(req.Params, ¶ms)
|
||||
}
|
||||
_ = params
|
||||
|
||||
return s.writeResult(req.ID, map[string]any{
|
||||
"protocolVersion": s.protocolVersion,
|
||||
"capabilities": map[string]any{
|
||||
"tools": map[string]any{"listChanged": false},
|
||||
},
|
||||
"serverInfo": map[string]any{
|
||||
"name": s.serverName,
|
||||
"version": s.serverVersion,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleToolCall(ctx context.Context, req request) error {
|
||||
var params toolCallParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
return s.writeError(req.ID, errCodeInvalidParams, "invalid tool call params")
|
||||
}
|
||||
if params.Name == "" {
|
||||
return s.writeError(req.ID, errCodeInvalidParams, "tool name is required")
|
||||
}
|
||||
if params.Arguments == nil {
|
||||
params.Arguments = map[string]any{}
|
||||
}
|
||||
|
||||
caller := extractCaller(params, s.defaultCaller)
|
||||
rateKey := fmt.Sprintf("%s:%s", caller, params.Name)
|
||||
if !s.limiter.Allow(rateKey) {
|
||||
result := buildToolErrorResult("rate limit exceeded")
|
||||
s.auditLogger.Log(audit.Record{
|
||||
Tool: params.Name,
|
||||
Caller: caller,
|
||||
Success: false,
|
||||
DurationMs: 0,
|
||||
Meta: sanitizeAuditInput(params.Name, params.Arguments),
|
||||
Error: "rate limit exceeded",
|
||||
})
|
||||
return s.writeResult(req.ID, result)
|
||||
}
|
||||
|
||||
tool, ok := s.registry.Find(params.Name)
|
||||
if !ok {
|
||||
result := buildToolErrorResult("tool not found")
|
||||
s.auditLogger.Log(audit.Record{
|
||||
Tool: params.Name,
|
||||
Caller: caller,
|
||||
Success: false,
|
||||
DurationMs: 0,
|
||||
Meta: sanitizeAuditInput(params.Name, params.Arguments),
|
||||
Error: "tool not found",
|
||||
})
|
||||
return s.writeResult(req.ID, result)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
toolCtx, cancel := context.WithTimeout(ctx, s.toolTimeout)
|
||||
defer cancel()
|
||||
|
||||
output, err := tool.Execute(toolCtx, params.Arguments)
|
||||
duration := time.Since(start).Milliseconds()
|
||||
if err != nil {
|
||||
errMsg := sanitizeError(err)
|
||||
s.auditLogger.Log(audit.Record{
|
||||
Tool: params.Name,
|
||||
Caller: caller,
|
||||
Success: false,
|
||||
DurationMs: duration,
|
||||
Meta: sanitizeAuditInput(params.Name, params.Arguments),
|
||||
Error: errMsg,
|
||||
})
|
||||
return s.writeResult(req.ID, buildToolErrorResult(errMsg))
|
||||
}
|
||||
|
||||
s.auditLogger.Log(audit.Record{
|
||||
Tool: params.Name,
|
||||
Caller: caller,
|
||||
Success: true,
|
||||
DurationMs: duration,
|
||||
Meta: sanitizeAuditInput(params.Name, params.Arguments),
|
||||
})
|
||||
return s.writeResult(req.ID, buildToolSuccessResult(output))
|
||||
}
|
||||
|
||||
func buildToolSuccessResult(output map[string]any) map[string]any {
|
||||
asJSON, _ := json.Marshal(output)
|
||||
return map[string]any{
|
||||
"content": []map[string]any{
|
||||
{"type": "text", "text": string(asJSON)},
|
||||
},
|
||||
"structuredContent": output,
|
||||
}
|
||||
}
|
||||
|
||||
func buildToolErrorResult(message string) map[string]any {
|
||||
return map[string]any{
|
||||
"content": []map[string]any{
|
||||
{"type": "text", "text": message},
|
||||
},
|
||||
"structuredContent": map[string]any{"error": message},
|
||||
"isError": true,
|
||||
}
|
||||
}
|
||||
|
||||
func extractCaller(params toolCallParams, fallback string) string {
|
||||
if params.Meta != nil {
|
||||
if v, ok := params.Meta["caller"]; ok {
|
||||
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
if v, ok := params.Arguments["caller"]; ok {
|
||||
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func sanitizeAuditInput(toolName string, args map[string]any) map[string]any {
|
||||
meta := map[string]any{}
|
||||
switch toolName {
|
||||
case "mysql_query_readonly":
|
||||
sql, _ := args["sql"].(string)
|
||||
meta["sql"] = security.RedactSQL(sql)
|
||||
if params, ok := args["params"].([]any); ok {
|
||||
meta["paramCount"] = len(params)
|
||||
}
|
||||
case "redis_get":
|
||||
if key, ok := args["key"].(string); ok {
|
||||
meta["key"] = security.RedactKey(key)
|
||||
}
|
||||
case "redis_scan":
|
||||
if pattern, ok := args["pattern"].(string); ok {
|
||||
if len(pattern) > 40 {
|
||||
meta["pattern"] = pattern[:40]
|
||||
} else {
|
||||
meta["pattern"] = pattern
|
||||
}
|
||||
}
|
||||
if count, ok := args["count"]; ok {
|
||||
meta["count"] = count
|
||||
}
|
||||
default:
|
||||
meta["args"] = "masked"
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
func sanitizeError(err error) string {
|
||||
msg := err.Error()
|
||||
msg = strings.ReplaceAll(msg, "\n", " ")
|
||||
msg = strings.ReplaceAll(msg, "\r", " ")
|
||||
if len(msg) > 300 {
|
||||
return msg[:300]
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func (s *Server) writeResult(id json.RawMessage, result any) error {
|
||||
resp := response{JSONRPC: jsonRPCVersion, ID: id, Result: result}
|
||||
return s.writeMessage(resp)
|
||||
}
|
||||
|
||||
func (s *Server) writeError(id json.RawMessage, code int, message string) error {
|
||||
resp := response{
|
||||
JSONRPC: jsonRPCVersion,
|
||||
ID: id,
|
||||
Error: &respError{Code: code, Message: message},
|
||||
}
|
||||
return s.writeMessage(resp)
|
||||
}
|
||||
|
||||
func (s *Server) writeMessage(payload any) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(body))
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
if _, err := s.writer.Write([]byte(header)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.writer.Write(body)
|
||||
return err
|
||||
}
|
||||
|
||||
func readMessage(reader *bufio.Reader) ([]byte, error) {
|
||||
contentLength := -1
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
if line == "" {
|
||||
break
|
||||
}
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
headerName := strings.TrimSpace(parts[0])
|
||||
headerValue := strings.TrimSpace(parts[1])
|
||||
if strings.EqualFold(headerName, "Content-Length") {
|
||||
n, err := strconv.Atoi(headerValue)
|
||||
if err != nil || n < 0 {
|
||||
return nil, fmt.Errorf("invalid Content-Length")
|
||||
}
|
||||
contentLength = n
|
||||
}
|
||||
}
|
||||
if contentLength < 0 {
|
||||
return nil, fmt.Errorf("missing Content-Length")
|
||||
}
|
||||
body := make([]byte, contentLength)
|
||||
if _, err := io.ReadFull(reader, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
56
infra/smartflow-mcp-server/internal/mcp/server_test.go
Normal file
56
infra/smartflow-mcp-server/internal/mcp/server_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/ratelimit"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/tools"
|
||||
)
|
||||
|
||||
type dummyTool struct{}
|
||||
|
||||
func (d dummyTool) Name() string { return "dummy" }
|
||||
func (d dummyTool) Description() string { return "dummy" }
|
||||
func (d dummyTool) InputSchema() map[string]any {
|
||||
return map[string]any{"type": "object"}
|
||||
}
|
||||
func (d dummyTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
|
||||
return map[string]any{"ok": true}, nil
|
||||
}
|
||||
|
||||
func TestHandleInitialize(t *testing.T) {
|
||||
registry, err := tools.NewRegistry(dummyTool{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
out := bytes.NewBuffer(nil)
|
||||
s := NewServer(bytes.NewBuffer(nil), out, registry, nil, ratelimit.New(10, 10), "name", "ver", "2024-11-05", "unknown", time.Second)
|
||||
|
||||
req := request{JSONRPC: jsonRPCVersion, ID: json.RawMessage("1"), Method: "initialize", Params: json.RawMessage(`{"protocolVersion":"2024-11-05"}`)}
|
||||
if err := s.handleRequest(context.Background(), req); err != nil {
|
||||
t.Fatalf("handle initialize error: %v", err)
|
||||
}
|
||||
if !bytes.Contains(out.Bytes(), []byte("protocolVersion")) {
|
||||
t.Fatalf("response missing protocolVersion: %s", out.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadMessage(t *testing.T) {
|
||||
body := []byte(`{"jsonrpc":"2.0","id":1,"method":"ping"}`)
|
||||
frame := fmt.Sprintf("Content-Length: %d\r\n\r\n%s", len(body), string(body))
|
||||
input := bufio.NewReader(bytes.NewBufferString(frame))
|
||||
|
||||
msg, err := readMessage(input)
|
||||
if err != nil {
|
||||
t.Fatalf("read message failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(msg, body) {
|
||||
t.Fatalf("unexpected body: %s", string(msg))
|
||||
}
|
||||
}
|
||||
51
infra/smartflow-mcp-server/internal/ratelimit/limiter.go
Normal file
51
infra/smartflow-mcp-server/internal/ratelimit/limiter.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Limiter struct {
|
||||
mu sync.Mutex
|
||||
rate float64
|
||||
burst float64
|
||||
buckets map[string]*bucket
|
||||
}
|
||||
|
||||
type bucket struct {
|
||||
tokens float64
|
||||
last time.Time
|
||||
}
|
||||
|
||||
func New(rate, burst float64) *Limiter {
|
||||
return &Limiter{
|
||||
rate: rate,
|
||||
burst: burst,
|
||||
buckets: make(map[string]*bucket),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Limiter) Allow(key string) bool {
|
||||
now := time.Now()
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
b, ok := l.buckets[key]
|
||||
if !ok {
|
||||
l.buckets[key] = &bucket{tokens: l.burst - 1, last: now}
|
||||
return true
|
||||
}
|
||||
|
||||
elapsed := now.Sub(b.last).Seconds()
|
||||
b.tokens += elapsed * l.rate
|
||||
if b.tokens > l.burst {
|
||||
b.tokens = l.burst
|
||||
}
|
||||
b.last = now
|
||||
|
||||
if b.tokens < 1 {
|
||||
return false
|
||||
}
|
||||
b.tokens -= 1
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLimiter(t *testing.T) {
|
||||
l := New(2, 2)
|
||||
key := "user:tool"
|
||||
|
||||
if !l.Allow(key) {
|
||||
t.Fatal("first request should pass")
|
||||
}
|
||||
if !l.Allow(key) {
|
||||
t.Fatal("second request should pass")
|
||||
}
|
||||
if l.Allow(key) {
|
||||
t.Fatal("third request should be rate limited")
|
||||
}
|
||||
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
if !l.Allow(key) {
|
||||
t.Fatal("request should pass after token refill")
|
||||
}
|
||||
}
|
||||
27
infra/smartflow-mcp-server/internal/security/redact.go
Normal file
27
infra/smartflow-mcp-server/internal/security/redact.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
singleQuotedString = regexp.MustCompile(`'([^'\\]|\\.)*'`)
|
||||
doubleQuotedString = regexp.MustCompile(`"([^"\\]|\\.)*"`)
|
||||
numericLiteral = regexp.MustCompile(`\b\d+\b`)
|
||||
)
|
||||
|
||||
func RedactSQL(sql string) string {
|
||||
masked := singleQuotedString.ReplaceAllString(sql, "'***'")
|
||||
masked = doubleQuotedString.ReplaceAllString(masked, `"***"`)
|
||||
masked = numericLiteral.ReplaceAllString(masked, "?")
|
||||
return strings.TrimSpace(masked)
|
||||
}
|
||||
|
||||
func RedactKey(key string) string {
|
||||
key = strings.TrimSpace(key)
|
||||
if len(key) <= 4 {
|
||||
return "****"
|
||||
}
|
||||
return key[:2] + "***" + key[len(key)-2:]
|
||||
}
|
||||
193
infra/smartflow-mcp-server/internal/security/sql_validator.go
Normal file
193
infra/smartflow-mcp-server/internal/security/sql_validator.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
commentPattern = regexp.MustCompile(`(?s)/\*.*?\*/|--|#`)
|
||||
forbiddenWords = regexp.MustCompile(`(?i)\b(INSERT|UPDATE|DELETE|ALTER|DROP|TRUNCATE|CREATE|REPLACE|RENAME|GRANT|REVOKE|MERGE|CALL|EXEC|LOCK|UNLOCK|LOAD|OUTFILE|INFILE|HANDLER|SET|USE)\b`)
|
||||
fromJoinRef = regexp.MustCompile(`(?i)\b(?:FROM|JOIN)\s+([` + "`" + `"\w\.]+)`)
|
||||
describeRef = regexp.MustCompile(`(?i)\b(?:DESCRIBE|DESC)\s+([` + "`" + `"\w\.]+)`)
|
||||
showFromRef = regexp.MustCompile(`(?i)\bSHOW\b[\w\s]*\b(?:FROM|IN)\s+([` + "`" + `"\w\.]+)`)
|
||||
)
|
||||
|
||||
type SQLValidator struct {
|
||||
defaultDatabase string
|
||||
enforceWhitelist bool
|
||||
allowedDatabases map[string]struct{}
|
||||
allowedTables map[string]struct{}
|
||||
allowedTableOnly map[string]struct{}
|
||||
}
|
||||
|
||||
type sqlObjectRef struct {
|
||||
database string
|
||||
table string
|
||||
}
|
||||
|
||||
func NewSQLValidator(defaultDatabase string, enforceWhitelist bool, allowedDatabases []string, allowedTables []string) *SQLValidator {
|
||||
v := &SQLValidator{
|
||||
defaultDatabase: strings.ToLower(strings.TrimSpace(defaultDatabase)),
|
||||
enforceWhitelist: enforceWhitelist,
|
||||
allowedDatabases: make(map[string]struct{}),
|
||||
allowedTables: make(map[string]struct{}),
|
||||
allowedTableOnly: make(map[string]struct{}),
|
||||
}
|
||||
for _, db := range allowedDatabases {
|
||||
db = strings.ToLower(strings.TrimSpace(db))
|
||||
if db != "" {
|
||||
v.allowedDatabases[db] = struct{}{}
|
||||
}
|
||||
}
|
||||
for _, tbl := range allowedTables {
|
||||
tbl = strings.ToLower(strings.TrimSpace(strings.Trim(tbl, "`\"")))
|
||||
if tbl == "" {
|
||||
continue
|
||||
}
|
||||
v.allowedTables[tbl] = struct{}{}
|
||||
if dot := strings.LastIndex(tbl, "."); dot >= 0 && dot < len(tbl)-1 {
|
||||
v.allowedTableOnly[tbl[dot+1:]] = struct{}{}
|
||||
} else {
|
||||
v.allowedTableOnly[tbl] = struct{}{}
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *SQLValidator) ValidateReadOnlySQL(sql string) error {
|
||||
trimmed := strings.TrimSpace(sql)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("sql is required")
|
||||
}
|
||||
if strings.Contains(trimmed, ";") {
|
||||
return fmt.Errorf("semicolon is not allowed")
|
||||
}
|
||||
if commentPattern.MatchString(trimmed) {
|
||||
return fmt.Errorf("sql comments are not allowed")
|
||||
}
|
||||
|
||||
first := strings.ToUpper(firstToken(trimmed))
|
||||
switch first {
|
||||
case "SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN":
|
||||
default:
|
||||
return fmt.Errorf("only SELECT/SHOW/DESCRIBE/EXPLAIN are allowed")
|
||||
}
|
||||
|
||||
withoutLiterals := removeStringLiterals(trimmed)
|
||||
if forbiddenWords.MatchString(strings.ToUpper(withoutLiterals)) {
|
||||
return fmt.Errorf("dangerous sql keyword detected")
|
||||
}
|
||||
|
||||
refs := extractObjectRefs(withoutLiterals)
|
||||
if err := v.validateWhitelist(refs); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *SQLValidator) validateWhitelist(refs []sqlObjectRef) error {
|
||||
hasDBAllowlist := len(v.allowedDatabases) > 0
|
||||
hasTableAllowlist := len(v.allowedTables) > 0
|
||||
|
||||
if !v.enforceWhitelist && !hasDBAllowlist && !hasTableAllowlist {
|
||||
return nil
|
||||
}
|
||||
if len(refs) == 0 {
|
||||
if v.enforceWhitelist {
|
||||
return fmt.Errorf("sql does not contain explicit table reference under whitelist mode")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, ref := range refs {
|
||||
db := ref.database
|
||||
tbl := ref.table
|
||||
if db == "" {
|
||||
db = v.defaultDatabase
|
||||
}
|
||||
|
||||
if hasDBAllowlist {
|
||||
if db == "" {
|
||||
return fmt.Errorf("database is not explicit and no default database set")
|
||||
}
|
||||
if _, ok := v.allowedDatabases[db]; !ok {
|
||||
return fmt.Errorf("database %s is not in whitelist", db)
|
||||
}
|
||||
}
|
||||
|
||||
if hasTableAllowlist {
|
||||
full := tbl
|
||||
if db != "" {
|
||||
full = db + "." + tbl
|
||||
}
|
||||
if _, ok := v.allowedTables[full]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := v.allowedTableOnly[tbl]; ok {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("table %s is not in whitelist", full)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractObjectRefs(sql string) []sqlObjectRef {
|
||||
refs := make([]sqlObjectRef, 0)
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
for _, re := range []*regexp.Regexp{fromJoinRef, describeRef, showFromRef} {
|
||||
matches := re.FindAllStringSubmatch(sql, -1)
|
||||
for _, m := range matches {
|
||||
if len(m) < 2 {
|
||||
continue
|
||||
}
|
||||
ref := normalizeRef(m[1])
|
||||
if ref.table == "" {
|
||||
continue
|
||||
}
|
||||
key := ref.database + "." + ref.table
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
}
|
||||
return refs
|
||||
}
|
||||
|
||||
func normalizeRef(raw string) sqlObjectRef {
|
||||
clean := strings.ToLower(strings.TrimSpace(raw))
|
||||
clean = strings.Trim(clean, "`\"'(),")
|
||||
clean = strings.TrimSpace(clean)
|
||||
if clean == "" {
|
||||
return sqlObjectRef{}
|
||||
}
|
||||
parts := strings.Split(clean, ".")
|
||||
if len(parts) == 1 {
|
||||
return sqlObjectRef{table: parts[0]}
|
||||
}
|
||||
return sqlObjectRef{database: parts[0], table: parts[len(parts)-1]}
|
||||
}
|
||||
|
||||
func firstToken(sql string) string {
|
||||
for i, r := range sql {
|
||||
if r == ' ' || r == '\n' || r == '\t' || r == '\r' {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
return sql[:i]
|
||||
}
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
func removeStringLiterals(sql string) string {
|
||||
masked := singleQuotedString.ReplaceAllString(sql, "''")
|
||||
masked = doubleQuotedString.ReplaceAllString(masked, `""`)
|
||||
return masked
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package security
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestValidateReadOnlySQL(t *testing.T) {
|
||||
validator := NewSQLValidator("smartflow", true, []string{"smartflow"}, []string{"smartflow.users", "smartflow.tasks"})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sql string
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "allow select", sql: "SELECT id, name FROM users WHERE id = 1", wantErr: false},
|
||||
{name: "allow explain", sql: "EXPLAIN SELECT * FROM tasks", wantErr: false},
|
||||
{name: "reject insert", sql: "INSERT INTO users(name) VALUES('x')", wantErr: true},
|
||||
{name: "reject multi statement", sql: "SELECT * FROM users; SELECT * FROM tasks", wantErr: true},
|
||||
{name: "reject comment", sql: "SELECT * FROM users -- bypass", wantErr: true},
|
||||
{name: "reject not whitelisted table", sql: "SELECT * FROM orders", wantErr: true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := validator.ValidateReadOnlySQL(tc.sql)
|
||||
if tc.wantErr && err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
if !tc.wantErr && err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedact(t *testing.T) {
|
||||
masked := RedactSQL("SELECT * FROM users WHERE token='abc123' AND id=42")
|
||||
if masked == "" || masked == "SELECT * FROM users WHERE token='abc123' AND id=42" {
|
||||
t.Fatalf("redaction not applied: %s", masked)
|
||||
}
|
||||
|
||||
key := RedactKey("very-sensitive-key")
|
||||
if key == "very-sensitive-key" {
|
||||
t.Fatalf("key not redacted")
|
||||
}
|
||||
}
|
||||
198
infra/smartflow-mcp-server/internal/store/mysql.go
Normal file
198
infra/smartflow-mcp-server/internal/store/mysql.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cfgpkg "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
type MySQLClient struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
type QueryColumn struct {
|
||||
Name string `json:"name"`
|
||||
DatabaseType string `json:"databaseType"`
|
||||
Nullable *bool `json:"nullable,omitempty"`
|
||||
ScanType string `json:"scanType,omitempty"`
|
||||
}
|
||||
|
||||
type QueryResult struct {
|
||||
Columns []QueryColumn `json:"columns"`
|
||||
Rows []map[string]any `json:"rows"`
|
||||
RowCount int `json:"rowCount"`
|
||||
Truncated bool `json:"truncated"`
|
||||
DurationMs int64 `json:"durationMs"`
|
||||
}
|
||||
|
||||
func NewMySQLClient(ctx context.Context, cfg cfgpkg.MySQLConfig) (*MySQLClient, error) {
|
||||
dsn := mysql.Config{
|
||||
User: cfg.User,
|
||||
Passwd: cfg.Password,
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
Net: "tcp",
|
||||
DBName: cfg.Database,
|
||||
AllowNativePasswords: true,
|
||||
ParseTime: true,
|
||||
}
|
||||
|
||||
applyMySQLParams(&dsn, cfg.Params)
|
||||
|
||||
db, err := sql.Open("mysql", dsn.FormatDSN())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open mysql: %w", err)
|
||||
}
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
db.SetMaxOpenConns(5)
|
||||
db.SetMaxIdleConns(5)
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := db.PingContext(pingCtx); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("ping mysql: %w", err)
|
||||
}
|
||||
|
||||
return &MySQLClient{db: db}, nil
|
||||
}
|
||||
|
||||
func applyMySQLParams(cfg *mysql.Config, raw string) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return
|
||||
}
|
||||
values, err := url.ParseQuery(raw)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if cfg.Params == nil {
|
||||
cfg.Params = make(map[string]string)
|
||||
}
|
||||
for key, valueList := range values {
|
||||
if len(valueList) == 0 {
|
||||
continue
|
||||
}
|
||||
value := valueList[0]
|
||||
switch strings.ToLower(key) {
|
||||
case "parsetime":
|
||||
if parsed, err := strconv.ParseBool(value); err == nil {
|
||||
cfg.ParseTime = parsed
|
||||
}
|
||||
case "loc":
|
||||
if loc, err := time.LoadLocation(value); err == nil {
|
||||
cfg.Loc = loc
|
||||
}
|
||||
case "collation":
|
||||
cfg.Collation = value
|
||||
default:
|
||||
cfg.Params[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MySQLClient) Close() error {
|
||||
if c == nil || c.db == nil {
|
||||
return nil
|
||||
}
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func (c *MySQLClient) QueryReadOnly(ctx context.Context, query string, args []any, maxRows int) (QueryResult, error) {
|
||||
start := time.Now()
|
||||
rows, err := c.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return QueryResult{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columnTypes, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return QueryResult{}, err
|
||||
}
|
||||
|
||||
columns := make([]QueryColumn, 0, len(columnTypes))
|
||||
columnNames := make([]string, 0, len(columnTypes))
|
||||
for _, ct := range columnTypes {
|
||||
var nullablePtr *bool
|
||||
if nullable, ok := ct.Nullable(); ok {
|
||||
n := nullable
|
||||
nullablePtr = &n
|
||||
}
|
||||
scanType := ""
|
||||
if st := ct.ScanType(); st != nil {
|
||||
scanType = st.String()
|
||||
}
|
||||
columns = append(columns, QueryColumn{
|
||||
Name: ct.Name(),
|
||||
DatabaseType: ct.DatabaseTypeName(),
|
||||
Nullable: nullablePtr,
|
||||
ScanType: scanType,
|
||||
})
|
||||
columnNames = append(columnNames, ct.Name())
|
||||
}
|
||||
|
||||
resultRows := make([]map[string]any, 0)
|
||||
truncated := false
|
||||
for rows.Next() {
|
||||
if len(resultRows) >= maxRows {
|
||||
truncated = true
|
||||
break
|
||||
}
|
||||
scanned, err := scanRow(rows, len(columnNames))
|
||||
if err != nil {
|
||||
return QueryResult{}, err
|
||||
}
|
||||
rowMap := make(map[string]any, len(columnNames))
|
||||
for i, name := range columnNames {
|
||||
rowMap[name] = normalizeValue(scanned[i])
|
||||
}
|
||||
resultRows = append(resultRows, rowMap)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return QueryResult{}, err
|
||||
}
|
||||
|
||||
return QueryResult{
|
||||
Columns: columns,
|
||||
Rows: resultRows,
|
||||
RowCount: len(resultRows),
|
||||
Truncated: truncated,
|
||||
DurationMs: time.Since(start).Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func scanRow(rows *sql.Rows, size int) ([]any, error) {
|
||||
dest := make([]any, size)
|
||||
holders := make([]any, size)
|
||||
for i := range dest {
|
||||
holders[i] = &dest[i]
|
||||
}
|
||||
if err := rows.Scan(holders...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
func normalizeValue(v any) any {
|
||||
switch val := v.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case []byte:
|
||||
return string(val)
|
||||
case time.Time:
|
||||
return val.Format(time.RFC3339Nano)
|
||||
default:
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
}
|
||||
189
infra/smartflow-mcp-server/internal/store/redis.go
Normal file
189
infra/smartflow-mcp-server/internal/store/redis.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cfgpkg "github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
type RedisClient struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
type RedisGetResult struct {
|
||||
Exists bool `json:"exists"`
|
||||
Key string `json:"key"`
|
||||
Type string `json:"type"`
|
||||
Value any `json:"value,omitempty"`
|
||||
Truncated bool `json:"truncated"`
|
||||
DurationMs int64 `json:"durationMs"`
|
||||
}
|
||||
|
||||
type RedisScanResult struct {
|
||||
Pattern string `json:"pattern"`
|
||||
Keys []string `json:"keys"`
|
||||
Returned int `json:"returned"`
|
||||
NextCursor uint64 `json:"nextCursor"`
|
||||
Truncated bool `json:"truncated"`
|
||||
DurationMs int64 `json:"durationMs"`
|
||||
}
|
||||
|
||||
func NewRedisClient(ctx context.Context, cfg cfgpkg.RedisConfig) (*RedisClient, error) {
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
})
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := client.Ping(pingCtx).Err(); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, fmt.Errorf("ping redis: %w", err)
|
||||
}
|
||||
return &RedisClient{client: client}, nil
|
||||
}
|
||||
|
||||
func (c *RedisClient) Close() error {
|
||||
if c == nil || c.client == nil {
|
||||
return nil
|
||||
}
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
func (c *RedisClient) GetWithType(ctx context.Context, key string, maxItems int, maxStringBytes int) (RedisGetResult, error) {
|
||||
start := time.Now()
|
||||
t, err := c.client.Type(ctx, key).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
if t == "none" {
|
||||
return RedisGetResult{
|
||||
Exists: false,
|
||||
Key: key,
|
||||
Type: "none",
|
||||
Truncated: false,
|
||||
DurationMs: time.Since(start).Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
result := RedisGetResult{Exists: true, Key: key, Type: t}
|
||||
switch t {
|
||||
case "string":
|
||||
v, err := c.client.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
if len(v) > maxStringBytes {
|
||||
result.Value = v[:maxStringBytes]
|
||||
result.Truncated = true
|
||||
} else {
|
||||
result.Value = v
|
||||
}
|
||||
case "list":
|
||||
vals, err := c.client.LRange(ctx, key, 0, int64(maxItems-1)).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
result.Value = vals
|
||||
if length, _ := c.client.LLen(ctx, key).Result(); length > int64(maxItems) {
|
||||
result.Truncated = true
|
||||
}
|
||||
case "set":
|
||||
vals, err := c.client.SMembers(ctx, key).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
if len(vals) > maxItems {
|
||||
result.Value = vals[:maxItems]
|
||||
result.Truncated = true
|
||||
} else {
|
||||
result.Value = vals
|
||||
}
|
||||
case "zset":
|
||||
vals, err := c.client.ZRangeWithScores(ctx, key, 0, int64(maxItems-1)).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
resultRows := make([]map[string]any, 0, len(vals))
|
||||
for _, item := range vals {
|
||||
resultRows = append(resultRows, map[string]any{"member": item.Member, "score": item.Score})
|
||||
}
|
||||
result.Value = resultRows
|
||||
if length, _ := c.client.ZCard(ctx, key).Result(); length > int64(maxItems) {
|
||||
result.Truncated = true
|
||||
}
|
||||
case "hash":
|
||||
vals, err := c.client.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
if len(vals) <= maxItems {
|
||||
result.Value = vals
|
||||
} else {
|
||||
trimmed := make(map[string]string, maxItems)
|
||||
count := 0
|
||||
for k, v := range vals {
|
||||
trimmed[k] = v
|
||||
count++
|
||||
if count >= maxItems {
|
||||
break
|
||||
}
|
||||
}
|
||||
result.Value = trimmed
|
||||
result.Truncated = true
|
||||
}
|
||||
default:
|
||||
raw, err := c.client.Dump(ctx, key).Result()
|
||||
if err != nil {
|
||||
return RedisGetResult{}, err
|
||||
}
|
||||
result.Value = strings.ToUpper(fmt.Sprintf("UNSUPPORTED_TYPE_%s_DUMP_SIZE_%d", t, len(raw)))
|
||||
}
|
||||
result.DurationMs = time.Since(start).Milliseconds()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *RedisClient) ScanKeys(ctx context.Context, pattern string, count int64, maxKeys int) (RedisScanResult, error) {
|
||||
start := time.Now()
|
||||
if pattern == "" {
|
||||
pattern = "*"
|
||||
}
|
||||
if count <= 0 {
|
||||
count = 20
|
||||
}
|
||||
|
||||
keys := make([]string, 0, maxKeys)
|
||||
var cursor uint64
|
||||
truncated := false
|
||||
for {
|
||||
batch, nextCursor, err := c.client.Scan(ctx, cursor, pattern, count).Result()
|
||||
if err != nil {
|
||||
return RedisScanResult{}, err
|
||||
}
|
||||
for _, key := range batch {
|
||||
if len(keys) >= maxKeys {
|
||||
truncated = true
|
||||
break
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
cursor = nextCursor
|
||||
if truncated || cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return RedisScanResult{
|
||||
Pattern: pattern,
|
||||
Keys: keys,
|
||||
Returned: len(keys),
|
||||
NextCursor: cursor,
|
||||
Truncated: truncated,
|
||||
DurationMs: time.Since(start).Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/config"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store"
|
||||
)
|
||||
|
||||
func TestIntegrationMySQLReadOnlyTool(t *testing.T) {
|
||||
if os.Getenv("MCP_IT_RUN") != "1" {
|
||||
t.Skip("set MCP_IT_RUN=1 to run integration tests")
|
||||
}
|
||||
|
||||
port := 3306
|
||||
if p := os.Getenv("MYSQL_PORT"); p != "" {
|
||||
if n, err := strconv.Atoi(p); err == nil {
|
||||
port = n
|
||||
}
|
||||
}
|
||||
|
||||
mysqlCfg := config.MySQLConfig{
|
||||
Host: os.Getenv("MYSQL_HOST"),
|
||||
Port: port,
|
||||
User: os.Getenv("MYSQL_USER"),
|
||||
Password: os.Getenv("MYSQL_PASSWORD"),
|
||||
Database: os.Getenv("MYSQL_DATABASE"),
|
||||
Params: "charset=utf8mb4&parseTime=true&loc=Local",
|
||||
}
|
||||
if mysqlCfg.Host == "" || mysqlCfg.User == "" || mysqlCfg.Database == "" {
|
||||
t.Skip("missing MYSQL_* env for integration test")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
client, err := store.NewMySQLClient(ctx, mysqlCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("mysql not available: %v", err)
|
||||
}
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
validator := security.NewSQLValidator(mysqlCfg.Database, false, nil, nil)
|
||||
tool := NewMySQLReadOnlyTool(client, validator, 10)
|
||||
res, err := tool.Execute(ctx, map[string]any{"sql": "SELECT 1 AS ok"})
|
||||
if err != nil {
|
||||
t.Fatalf("tool execute failed: %v", err)
|
||||
}
|
||||
if res["rowCount"].(int) < 1 {
|
||||
t.Fatalf("expected rowCount >= 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationRedisTools(t *testing.T) {
|
||||
if os.Getenv("MCP_IT_RUN") != "1" {
|
||||
t.Skip("set MCP_IT_RUN=1 to run integration tests")
|
||||
}
|
||||
|
||||
db := 0
|
||||
if p := os.Getenv("REDIS_DB"); p != "" {
|
||||
if n, err := strconv.Atoi(p); err == nil {
|
||||
db = n
|
||||
}
|
||||
}
|
||||
redisCfg := config.RedisConfig{
|
||||
Addr: os.Getenv("REDIS_ADDR"),
|
||||
Password: os.Getenv("REDIS_PASSWORD"),
|
||||
DB: db,
|
||||
}
|
||||
if redisCfg.Addr == "" {
|
||||
t.Skip("REDIS_ADDR is empty")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
client, err := store.NewRedisClient(ctx, redisCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("redis not available: %v", err)
|
||||
}
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
getTool := NewRedisGetTool(client, 10, 128)
|
||||
if _, err := getTool.Execute(ctx, map[string]any{"key": "__integration_missing_key__"}); err != nil {
|
||||
t.Fatalf("redis_get failed: %v", err)
|
||||
}
|
||||
|
||||
scanTool := NewRedisScanTool(client, 10, 10)
|
||||
if _, err := scanTool.Execute(ctx, map[string]any{"pattern": "*", "count": float64(5)}); err != nil {
|
||||
t.Fatalf("redis_scan failed: %v", err)
|
||||
}
|
||||
}
|
||||
95
infra/smartflow-mcp-server/internal/tools/mysql_readonly.go
Normal file
95
infra/smartflow-mcp-server/internal/tools/mysql_readonly.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store"
|
||||
)
|
||||
|
||||
type MySQLReadOnlyTool struct {
|
||||
client *store.MySQLClient
|
||||
validator *security.SQLValidator
|
||||
maxRows int
|
||||
}
|
||||
|
||||
func NewMySQLReadOnlyTool(client *store.MySQLClient, validator *security.SQLValidator, maxRows int) *MySQLReadOnlyTool {
|
||||
return &MySQLReadOnlyTool{client: client, validator: validator, maxRows: maxRows}
|
||||
}
|
||||
|
||||
func (t *MySQLReadOnlyTool) Name() string {
|
||||
return "mysql_query_readonly"
|
||||
}
|
||||
|
||||
func (t *MySQLReadOnlyTool) Description() string {
|
||||
return "Execute read-only SQL on MySQL. Only SELECT/SHOW/DESCRIBE/EXPLAIN are allowed."
|
||||
}
|
||||
|
||||
func (t *MySQLReadOnlyTool) InputSchema() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"sql": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Read-only SQL statement",
|
||||
},
|
||||
"params": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Optional bind parameters",
|
||||
"items": map[string]any{
|
||||
"type": []string{"string", "number", "boolean", "null"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": []string{"sql"},
|
||||
"additionalProperties": false,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MySQLReadOnlyTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
|
||||
rawSQL, ok := args["sql"].(string)
|
||||
if !ok || rawSQL == "" {
|
||||
return nil, fmt.Errorf("sql must be a non-empty string")
|
||||
}
|
||||
if err := t.validator.ValidateReadOnlySQL(rawSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params, err := normalizeParams(args["params"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := t.client.QueryReadOnly(ctx, rawSQL, params, t.maxRows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]any{
|
||||
"columns": res.Columns,
|
||||
"rows": res.Rows,
|
||||
"rowCount": res.RowCount,
|
||||
"truncated": res.Truncated,
|
||||
"durationMs": res.DurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeParams(raw any) ([]any, error) {
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
arr, ok := raw.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("params must be an array")
|
||||
}
|
||||
out := make([]any, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
switch v := item.(type) {
|
||||
case string, float64, bool, nil:
|
||||
out = append(out, v)
|
||||
default:
|
||||
return nil, fmt.Errorf("params contains unsupported type")
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
130
infra/smartflow-mcp-server/internal/tools/redis_tools.go
Normal file
130
infra/smartflow-mcp-server/internal/tools/redis_tools.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/store"
|
||||
)
|
||||
|
||||
type RedisGetTool struct {
|
||||
client *store.RedisClient
|
||||
valueMaxItems int
|
||||
maxStringBytes int
|
||||
}
|
||||
|
||||
func NewRedisGetTool(client *store.RedisClient, valueMaxItems int, maxStringBytes int) *RedisGetTool {
|
||||
return &RedisGetTool{client: client, valueMaxItems: valueMaxItems, maxStringBytes: maxStringBytes}
|
||||
}
|
||||
|
||||
func (t *RedisGetTool) Name() string {
|
||||
return "redis_get"
|
||||
}
|
||||
|
||||
func (t *RedisGetTool) Description() string {
|
||||
return "Get a Redis key by name and return its type and value."
|
||||
}
|
||||
|
||||
func (t *RedisGetTool) InputSchema() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"key": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Redis key",
|
||||
},
|
||||
},
|
||||
"required": []string{"key"},
|
||||
"additionalProperties": false,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *RedisGetTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
|
||||
key, ok := args["key"].(string)
|
||||
if !ok || key == "" {
|
||||
return nil, fmt.Errorf("key must be a non-empty string")
|
||||
}
|
||||
res, err := t.client.GetWithType(ctx, key, t.valueMaxItems, t.maxStringBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]any{
|
||||
"exists": res.Exists,
|
||||
"key": res.Key,
|
||||
"type": res.Type,
|
||||
"value": res.Value,
|
||||
"truncated": res.Truncated,
|
||||
"durationMs": res.DurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type RedisScanTool struct {
|
||||
client *store.RedisClient
|
||||
maxKeys int
|
||||
maxScanCount int
|
||||
}
|
||||
|
||||
func NewRedisScanTool(client *store.RedisClient, maxKeys int, maxScanCount int) *RedisScanTool {
|
||||
return &RedisScanTool{client: client, maxKeys: maxKeys, maxScanCount: maxScanCount}
|
||||
}
|
||||
|
||||
func (t *RedisScanTool) Name() string {
|
||||
return "redis_scan"
|
||||
}
|
||||
|
||||
func (t *RedisScanTool) Description() string {
|
||||
return "Scan Redis keys by pattern with capped result size."
|
||||
}
|
||||
|
||||
func (t *RedisScanTool) InputSchema() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"pattern": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Pattern, for example user:*",
|
||||
},
|
||||
"count": map[string]any{
|
||||
"type": "number",
|
||||
"description": "Optional scan count hint",
|
||||
},
|
||||
},
|
||||
"required": []string{"pattern"},
|
||||
"additionalProperties": false,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *RedisScanTool) Execute(ctx context.Context, args map[string]any) (map[string]any, error) {
|
||||
pattern, ok := args["pattern"].(string)
|
||||
if !ok || pattern == "" {
|
||||
return nil, fmt.Errorf("pattern must be a non-empty string")
|
||||
}
|
||||
|
||||
count := int64(20)
|
||||
if rawCount, ok := args["count"]; ok {
|
||||
number, ok := rawCount.(float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("count must be a number")
|
||||
}
|
||||
if number <= 0 {
|
||||
return nil, fmt.Errorf("count must be > 0")
|
||||
}
|
||||
count = int64(number)
|
||||
}
|
||||
if count > int64(t.maxScanCount) {
|
||||
count = int64(t.maxScanCount)
|
||||
}
|
||||
|
||||
res, err := t.client.ScanKeys(ctx, pattern, count, t.maxKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]any{
|
||||
"pattern": res.Pattern,
|
||||
"keys": res.Keys,
|
||||
"returned": res.Returned,
|
||||
"nextCursor": res.NextCursor,
|
||||
"truncated": res.Truncated,
|
||||
"durationMs": res.DurationMs,
|
||||
}, nil
|
||||
}
|
||||
53
infra/smartflow-mcp-server/internal/tools/registry.go
Normal file
53
infra/smartflow-mcp-server/internal/tools/registry.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
)
|
||||
|
||||
type Tool interface {
|
||||
Name() string
|
||||
Description() string
|
||||
InputSchema() map[string]any
|
||||
Execute(ctx context.Context, args map[string]any) (map[string]any, error)
|
||||
}
|
||||
|
||||
type Registry struct {
|
||||
tools map[string]Tool
|
||||
}
|
||||
|
||||
func NewRegistry(toolList ...Tool) (*Registry, error) {
|
||||
r := &Registry{tools: make(map[string]Tool, len(toolList))}
|
||||
for _, t := range toolList {
|
||||
name := t.Name()
|
||||
if _, exists := r.tools[name]; exists {
|
||||
return nil, fmt.Errorf("duplicated tool name: %s", name)
|
||||
}
|
||||
r.tools[name] = t
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *Registry) Find(name string) (Tool, bool) {
|
||||
t, ok := r.tools[name]
|
||||
return t, ok
|
||||
}
|
||||
|
||||
func (r *Registry) List() []map[string]any {
|
||||
out := make([]map[string]any, 0, len(r.tools))
|
||||
names := make([]string, 0, len(r.tools))
|
||||
for name := range r.tools {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
for _, name := range names {
|
||||
t := r.tools[name]
|
||||
out = append(out, map[string]any{
|
||||
"name": t.Name(),
|
||||
"description": t.Description(),
|
||||
"inputSchema": t.InputSchema(),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
Reference in New Issue
Block a user