Files
LoveLosita 26c350f378 Version: 0.4.4.dev.260307
feat: 🚀 增强会话管理与缓存机制

* 会话 ID 空值兜底,若 `conversation_id` 为空时自动生成 UUID
* 在响应头写入 `X-Conversation-ID`,供前端使用,保持同一会话状态

perf:  会话状态缓存优化

* 当缓存未命中但 DB 已确认/创建会话后,调用 `SetConversationStatus` 回写 Redis
* 缓存写回失败时记录日志,不中断聊天主流程,确保业务流畅性

fix: 🐛 修复历史消息顺序问题与编译错误

* 修复历史消息顺序问题,保证返回的 N 条历史消息按时间正序喂给模型

  * 通过反转 `created_at desc` 查询结果的切片,确保模型输入顺序正确
* 修复 `fmt.Errorf` 参数不匹配问题,修正编译错误
* 整理 `agent-cache.go` 为标准 UTF-8 编码,避免 Go 编译报错 `invalid UTF-8 encoding`

feat: 🛠️ 独立构建 MCP 服务器

* 使用 `Codex` 构建独立于后端的 MCP 服务器,简化与 Codex 的协作
* 通过该服务器方便 Codex 直接测试和查看 Redis 与 MySQL 中的数据
2026-03-07 15:25:40 +08:00

395 lines
9.4 KiB
Go

package mcp
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"strconv"
"strings"
"sync"
"time"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/audit"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/ratelimit"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/security"
"github.com/LoveLosita/smartflow/infra/smartflow-mcp-server/internal/tools"
)
const (
jsonRPCVersion = "2.0"
errCodeParseError = -32700
errCodeInvalidRequest = -32600
errCodeMethodNotFound = -32601
errCodeInvalidParams = -32602
errCodeInternalError = -32603
)
type Server struct {
reader *bufio.Reader
writer io.Writer
writeMu sync.Mutex
registry *tools.Registry
auditLogger *audit.Logger
limiter *ratelimit.Limiter
serverName string
serverVersion string
protocolVersion string
defaultCaller string
toolTimeout time.Duration
}
type request struct {
JSONRPC string `json:"jsonrpc"`
ID json.RawMessage `json:"id,omitempty"`
Method string `json:"method"`
Params json.RawMessage `json:"params,omitempty"`
}
type response struct {
JSONRPC string `json:"jsonrpc"`
ID json.RawMessage `json:"id,omitempty"`
Result any `json:"result,omitempty"`
Error *respError `json:"error,omitempty"`
}
type respError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type initializeParams struct {
ProtocolVersion string `json:"protocolVersion"`
}
type toolCallParams struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
Meta map[string]any `json:"_meta,omitempty"`
}
func NewServer(
in io.Reader,
out io.Writer,
registry *tools.Registry,
auditLogger *audit.Logger,
limiter *ratelimit.Limiter,
serverName string,
serverVersion string,
protocolVersion string,
defaultCaller string,
toolTimeout time.Duration,
) *Server {
return &Server{
reader: bufio.NewReader(in),
writer: out,
registry: registry,
auditLogger: auditLogger,
limiter: limiter,
serverName: serverName,
serverVersion: serverVersion,
protocolVersion: protocolVersion,
defaultCaller: defaultCaller,
toolTimeout: toolTimeout,
}
}
func (s *Server) Serve(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
default:
}
body, err := readMessage(s.reader)
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
log.Printf("read message failed: %v", err)
continue
}
var req request
if err := json.Unmarshal(body, &req); err != nil {
_ = s.writeError(nil, errCodeParseError, "invalid json")
continue
}
if req.Method == "" || req.JSONRPC != jsonRPCVersion {
_ = s.writeError(req.ID, errCodeInvalidRequest, "invalid json-rpc request")
continue
}
if err := s.handleRequest(ctx, req); err != nil {
log.Printf("handle request failed: %v", err)
}
}
}
func (s *Server) handleRequest(ctx context.Context, req request) error {
switch req.Method {
case "initialize":
return s.handleInitialize(req)
case "notifications/initialized":
return nil
case "tools/list":
if len(req.ID) == 0 {
return nil
}
return s.writeResult(req.ID, map[string]any{"tools": s.registry.List()})
case "tools/call":
if len(req.ID) == 0 {
return nil
}
return s.handleToolCall(ctx, req)
case "ping":
if len(req.ID) == 0 {
return nil
}
return s.writeResult(req.ID, map[string]any{})
default:
if len(req.ID) == 0 {
return nil
}
return s.writeError(req.ID, errCodeMethodNotFound, "method not found")
}
}
func (s *Server) handleInitialize(req request) error {
if len(req.ID) == 0 {
return nil
}
var params initializeParams
if len(req.Params) > 0 {
_ = json.Unmarshal(req.Params, &params)
}
_ = params
return s.writeResult(req.ID, map[string]any{
"protocolVersion": s.protocolVersion,
"capabilities": map[string]any{
"tools": map[string]any{"listChanged": false},
},
"serverInfo": map[string]any{
"name": s.serverName,
"version": s.serverVersion,
},
})
}
func (s *Server) handleToolCall(ctx context.Context, req request) error {
var params toolCallParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return s.writeError(req.ID, errCodeInvalidParams, "invalid tool call params")
}
if params.Name == "" {
return s.writeError(req.ID, errCodeInvalidParams, "tool name is required")
}
if params.Arguments == nil {
params.Arguments = map[string]any{}
}
caller := extractCaller(params, s.defaultCaller)
rateKey := fmt.Sprintf("%s:%s", caller, params.Name)
if !s.limiter.Allow(rateKey) {
result := buildToolErrorResult("rate limit exceeded")
s.auditLogger.Log(audit.Record{
Tool: params.Name,
Caller: caller,
Success: false,
DurationMs: 0,
Meta: sanitizeAuditInput(params.Name, params.Arguments),
Error: "rate limit exceeded",
})
return s.writeResult(req.ID, result)
}
tool, ok := s.registry.Find(params.Name)
if !ok {
result := buildToolErrorResult("tool not found")
s.auditLogger.Log(audit.Record{
Tool: params.Name,
Caller: caller,
Success: false,
DurationMs: 0,
Meta: sanitizeAuditInput(params.Name, params.Arguments),
Error: "tool not found",
})
return s.writeResult(req.ID, result)
}
start := time.Now()
toolCtx, cancel := context.WithTimeout(ctx, s.toolTimeout)
defer cancel()
output, err := tool.Execute(toolCtx, params.Arguments)
duration := time.Since(start).Milliseconds()
if err != nil {
errMsg := sanitizeError(err)
s.auditLogger.Log(audit.Record{
Tool: params.Name,
Caller: caller,
Success: false,
DurationMs: duration,
Meta: sanitizeAuditInput(params.Name, params.Arguments),
Error: errMsg,
})
return s.writeResult(req.ID, buildToolErrorResult(errMsg))
}
s.auditLogger.Log(audit.Record{
Tool: params.Name,
Caller: caller,
Success: true,
DurationMs: duration,
Meta: sanitizeAuditInput(params.Name, params.Arguments),
})
return s.writeResult(req.ID, buildToolSuccessResult(output))
}
func buildToolSuccessResult(output map[string]any) map[string]any {
asJSON, _ := json.Marshal(output)
return map[string]any{
"content": []map[string]any{
{"type": "text", "text": string(asJSON)},
},
"structuredContent": output,
}
}
func buildToolErrorResult(message string) map[string]any {
return map[string]any{
"content": []map[string]any{
{"type": "text", "text": message},
},
"structuredContent": map[string]any{"error": message},
"isError": true,
}
}
func extractCaller(params toolCallParams, fallback string) string {
if params.Meta != nil {
if v, ok := params.Meta["caller"]; ok {
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
return strings.TrimSpace(s)
}
}
}
if v, ok := params.Arguments["caller"]; ok {
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
return strings.TrimSpace(s)
}
}
return fallback
}
func sanitizeAuditInput(toolName string, args map[string]any) map[string]any {
meta := map[string]any{}
switch toolName {
case "mysql_query_readonly":
sql, _ := args["sql"].(string)
meta["sql"] = security.RedactSQL(sql)
if params, ok := args["params"].([]any); ok {
meta["paramCount"] = len(params)
}
case "redis_get":
if key, ok := args["key"].(string); ok {
meta["key"] = security.RedactKey(key)
}
case "redis_scan":
if pattern, ok := args["pattern"].(string); ok {
if len(pattern) > 40 {
meta["pattern"] = pattern[:40]
} else {
meta["pattern"] = pattern
}
}
if count, ok := args["count"]; ok {
meta["count"] = count
}
default:
meta["args"] = "masked"
}
return meta
}
func sanitizeError(err error) string {
msg := err.Error()
msg = strings.ReplaceAll(msg, "\n", " ")
msg = strings.ReplaceAll(msg, "\r", " ")
if len(msg) > 300 {
return msg[:300]
}
return msg
}
func (s *Server) writeResult(id json.RawMessage, result any) error {
resp := response{JSONRPC: jsonRPCVersion, ID: id, Result: result}
return s.writeMessage(resp)
}
func (s *Server) writeError(id json.RawMessage, code int, message string) error {
resp := response{
JSONRPC: jsonRPCVersion,
ID: id,
Error: &respError{Code: code, Message: message},
}
return s.writeMessage(resp)
}
func (s *Server) writeMessage(payload any) error {
body, err := json.Marshal(payload)
if err != nil {
return err
}
header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(body))
s.writeMu.Lock()
defer s.writeMu.Unlock()
if _, err := s.writer.Write([]byte(header)); err != nil {
return err
}
_, err = s.writer.Write(body)
return err
}
func readMessage(reader *bufio.Reader) ([]byte, error) {
contentLength := -1
for {
line, err := reader.ReadString('\n')
if err != nil {
return nil, err
}
line = strings.TrimRight(line, "\r\n")
if line == "" {
break
}
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
headerName := strings.TrimSpace(parts[0])
headerValue := strings.TrimSpace(parts[1])
if strings.EqualFold(headerName, "Content-Length") {
n, err := strconv.Atoi(headerValue)
if err != nil || n < 0 {
return nil, fmt.Errorf("invalid Content-Length")
}
contentLength = n
}
}
if contentLength < 0 {
return nil, fmt.Errorf("missing Content-Length")
}
body := make([]byte, contentLength)
if _, err := io.ReadFull(reader, body); err != nil {
return nil, err
}
return body, nil
}