Version: 0.4.4.dev.260307
feat: 🚀 增强会话管理与缓存机制 * 会话 ID 空值兜底,若 `conversation_id` 为空时自动生成 UUID * 在响应头写入 `X-Conversation-ID`,供前端使用,保持同一会话状态 perf: ⚡ 会话状态缓存优化 * 当缓存未命中但 DB 已确认/创建会话后,调用 `SetConversationStatus` 回写 Redis * 缓存写回失败时记录日志,不中断聊天主流程,确保业务流畅性 fix: 🐛 修复历史消息顺序问题与编译错误 * 修复历史消息顺序问题,保证返回的 N 条历史消息按时间正序喂给模型 * 通过反转 `created_at desc` 查询结果的切片,确保模型输入顺序正确 * 修复 `fmt.Errorf` 参数不匹配问题,修正编译错误 * 整理 `agent-cache.go` 为标准 UTF-8 编码,避免 Go 编译报错 `invalid UTF-8 encoding` feat: 🛠️ 独立构建 MCP 服务器 * 使用 `Codex` 构建独立于后端的 MCP 服务器,简化与 Codex 的协作 * 通过该服务器方便 Codex 直接测试和查看 Redis 与 MySQL 中的数据
This commit is contained in:
394
infra/smartflow-mcp-server/internal/mcp/server.go
Normal file
394
infra/smartflow-mcp-server/internal/mcp/server.go
Normal file
@@ -0,0 +1,394 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/audit"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/ratelimit"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
|
||||
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
jsonRPCVersion = "2.0"
|
||||
|
||||
errCodeParseError = -32700
|
||||
errCodeInvalidRequest = -32600
|
||||
errCodeMethodNotFound = -32601
|
||||
errCodeInvalidParams = -32602
|
||||
errCodeInternalError = -32603
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
reader *bufio.Reader
|
||||
writer io.Writer
|
||||
writeMu sync.Mutex
|
||||
registry *tools.Registry
|
||||
auditLogger *audit.Logger
|
||||
limiter *ratelimit.Limiter
|
||||
serverName string
|
||||
serverVersion string
|
||||
protocolVersion string
|
||||
defaultCaller string
|
||||
toolTimeout time.Duration
|
||||
}
|
||||
|
||||
type request struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id,omitempty"`
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type response struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id,omitempty"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error *respError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type respError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type initializeParams struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
}
|
||||
|
||||
type toolCallParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
Meta map[string]any `json:"_meta,omitempty"`
|
||||
}
|
||||
|
||||
func NewServer(
|
||||
in io.Reader,
|
||||
out io.Writer,
|
||||
registry *tools.Registry,
|
||||
auditLogger *audit.Logger,
|
||||
limiter *ratelimit.Limiter,
|
||||
serverName string,
|
||||
serverVersion string,
|
||||
protocolVersion string,
|
||||
defaultCaller string,
|
||||
toolTimeout time.Duration,
|
||||
) *Server {
|
||||
return &Server{
|
||||
reader: bufio.NewReader(in),
|
||||
writer: out,
|
||||
registry: registry,
|
||||
auditLogger: auditLogger,
|
||||
limiter: limiter,
|
||||
serverName: serverName,
|
||||
serverVersion: serverVersion,
|
||||
protocolVersion: protocolVersion,
|
||||
defaultCaller: defaultCaller,
|
||||
toolTimeout: toolTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Serve(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
body, err := readMessage(s.reader)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
log.Printf("read message failed: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
var req request
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
_ = s.writeError(nil, errCodeParseError, "invalid json")
|
||||
continue
|
||||
}
|
||||
if req.Method == "" || req.JSONRPC != jsonRPCVersion {
|
||||
_ = s.writeError(req.ID, errCodeInvalidRequest, "invalid json-rpc request")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.handleRequest(ctx, req); err != nil {
|
||||
log.Printf("handle request failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(ctx context.Context, req request) error {
|
||||
switch req.Method {
|
||||
case "initialize":
|
||||
return s.handleInitialize(req)
|
||||
case "notifications/initialized":
|
||||
return nil
|
||||
case "tools/list":
|
||||
if len(req.ID) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.writeResult(req.ID, map[string]any{"tools": s.registry.List()})
|
||||
case "tools/call":
|
||||
if len(req.ID) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.handleToolCall(ctx, req)
|
||||
case "ping":
|
||||
if len(req.ID) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.writeResult(req.ID, map[string]any{})
|
||||
default:
|
||||
if len(req.ID) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.writeError(req.ID, errCodeMethodNotFound, "method not found")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleInitialize(req request) error {
|
||||
if len(req.ID) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var params initializeParams
|
||||
if len(req.Params) > 0 {
|
||||
_ = json.Unmarshal(req.Params, ¶ms)
|
||||
}
|
||||
_ = params
|
||||
|
||||
return s.writeResult(req.ID, map[string]any{
|
||||
"protocolVersion": s.protocolVersion,
|
||||
"capabilities": map[string]any{
|
||||
"tools": map[string]any{"listChanged": false},
|
||||
},
|
||||
"serverInfo": map[string]any{
|
||||
"name": s.serverName,
|
||||
"version": s.serverVersion,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleToolCall(ctx context.Context, req request) error {
|
||||
var params toolCallParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
return s.writeError(req.ID, errCodeInvalidParams, "invalid tool call params")
|
||||
}
|
||||
if params.Name == "" {
|
||||
return s.writeError(req.ID, errCodeInvalidParams, "tool name is required")
|
||||
}
|
||||
if params.Arguments == nil {
|
||||
params.Arguments = map[string]any{}
|
||||
}
|
||||
|
||||
caller := extractCaller(params, s.defaultCaller)
|
||||
rateKey := fmt.Sprintf("%s:%s", caller, params.Name)
|
||||
if !s.limiter.Allow(rateKey) {
|
||||
result := buildToolErrorResult("rate limit exceeded")
|
||||
s.auditLogger.Log(audit.Record{
|
||||
Tool: params.Name,
|
||||
Caller: caller,
|
||||
Success: false,
|
||||
DurationMs: 0,
|
||||
Meta: sanitizeAuditInput(params.Name, params.Arguments),
|
||||
Error: "rate limit exceeded",
|
||||
})
|
||||
return s.writeResult(req.ID, result)
|
||||
}
|
||||
|
||||
tool, ok := s.registry.Find(params.Name)
|
||||
if !ok {
|
||||
result := buildToolErrorResult("tool not found")
|
||||
s.auditLogger.Log(audit.Record{
|
||||
Tool: params.Name,
|
||||
Caller: caller,
|
||||
Success: false,
|
||||
DurationMs: 0,
|
||||
Meta: sanitizeAuditInput(params.Name, params.Arguments),
|
||||
Error: "tool not found",
|
||||
})
|
||||
return s.writeResult(req.ID, result)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
toolCtx, cancel := context.WithTimeout(ctx, s.toolTimeout)
|
||||
defer cancel()
|
||||
|
||||
output, err := tool.Execute(toolCtx, params.Arguments)
|
||||
duration := time.Since(start).Milliseconds()
|
||||
if err != nil {
|
||||
errMsg := sanitizeError(err)
|
||||
s.auditLogger.Log(audit.Record{
|
||||
Tool: params.Name,
|
||||
Caller: caller,
|
||||
Success: false,
|
||||
DurationMs: duration,
|
||||
Meta: sanitizeAuditInput(params.Name, params.Arguments),
|
||||
Error: errMsg,
|
||||
})
|
||||
return s.writeResult(req.ID, buildToolErrorResult(errMsg))
|
||||
}
|
||||
|
||||
s.auditLogger.Log(audit.Record{
|
||||
Tool: params.Name,
|
||||
Caller: caller,
|
||||
Success: true,
|
||||
DurationMs: duration,
|
||||
Meta: sanitizeAuditInput(params.Name, params.Arguments),
|
||||
})
|
||||
return s.writeResult(req.ID, buildToolSuccessResult(output))
|
||||
}
|
||||
|
||||
func buildToolSuccessResult(output map[string]any) map[string]any {
|
||||
asJSON, _ := json.Marshal(output)
|
||||
return map[string]any{
|
||||
"content": []map[string]any{
|
||||
{"type": "text", "text": string(asJSON)},
|
||||
},
|
||||
"structuredContent": output,
|
||||
}
|
||||
}
|
||||
|
||||
func buildToolErrorResult(message string) map[string]any {
|
||||
return map[string]any{
|
||||
"content": []map[string]any{
|
||||
{"type": "text", "text": message},
|
||||
},
|
||||
"structuredContent": map[string]any{"error": message},
|
||||
"isError": true,
|
||||
}
|
||||
}
|
||||
|
||||
func extractCaller(params toolCallParams, fallback string) string {
|
||||
if params.Meta != nil {
|
||||
if v, ok := params.Meta["caller"]; ok {
|
||||
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
if v, ok := params.Arguments["caller"]; ok {
|
||||
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func sanitizeAuditInput(toolName string, args map[string]any) map[string]any {
|
||||
meta := map[string]any{}
|
||||
switch toolName {
|
||||
case "mysql_query_readonly":
|
||||
sql, _ := args["sql"].(string)
|
||||
meta["sql"] = security.RedactSQL(sql)
|
||||
if params, ok := args["params"].([]any); ok {
|
||||
meta["paramCount"] = len(params)
|
||||
}
|
||||
case "redis_get":
|
||||
if key, ok := args["key"].(string); ok {
|
||||
meta["key"] = security.RedactKey(key)
|
||||
}
|
||||
case "redis_scan":
|
||||
if pattern, ok := args["pattern"].(string); ok {
|
||||
if len(pattern) > 40 {
|
||||
meta["pattern"] = pattern[:40]
|
||||
} else {
|
||||
meta["pattern"] = pattern
|
||||
}
|
||||
}
|
||||
if count, ok := args["count"]; ok {
|
||||
meta["count"] = count
|
||||
}
|
||||
default:
|
||||
meta["args"] = "masked"
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
func sanitizeError(err error) string {
|
||||
msg := err.Error()
|
||||
msg = strings.ReplaceAll(msg, "\n", " ")
|
||||
msg = strings.ReplaceAll(msg, "\r", " ")
|
||||
if len(msg) > 300 {
|
||||
return msg[:300]
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func (s *Server) writeResult(id json.RawMessage, result any) error {
|
||||
resp := response{JSONRPC: jsonRPCVersion, ID: id, Result: result}
|
||||
return s.writeMessage(resp)
|
||||
}
|
||||
|
||||
func (s *Server) writeError(id json.RawMessage, code int, message string) error {
|
||||
resp := response{
|
||||
JSONRPC: jsonRPCVersion,
|
||||
ID: id,
|
||||
Error: &respError{Code: code, Message: message},
|
||||
}
|
||||
return s.writeMessage(resp)
|
||||
}
|
||||
|
||||
func (s *Server) writeMessage(payload any) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(body))
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
if _, err := s.writer.Write([]byte(header)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.writer.Write(body)
|
||||
return err
|
||||
}
|
||||
|
||||
func readMessage(reader *bufio.Reader) ([]byte, error) {
|
||||
contentLength := -1
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
if line == "" {
|
||||
break
|
||||
}
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
headerName := strings.TrimSpace(parts[0])
|
||||
headerValue := strings.TrimSpace(parts[1])
|
||||
if strings.EqualFold(headerName, "Content-Length") {
|
||||
n, err := strconv.Atoi(headerValue)
|
||||
if err != nil || n < 0 {
|
||||
return nil, fmt.Errorf("invalid Content-Length")
|
||||
}
|
||||
contentLength = n
|
||||
}
|
||||
}
|
||||
if contentLength < 0 {
|
||||
return nil, fmt.Errorf("missing Content-Length")
|
||||
}
|
||||
body := make([]byte, contentLength)
|
||||
if _, err := io.ReadFull(reader, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
Reference in New Issue
Block a user