feat: 🚀 增强会话管理与缓存机制 * 会话 ID 空值兜底,若 `conversation_id` 为空时自动生成 UUID * 在响应头写入 `X-Conversation-ID`,供前端使用,保持同一会话状态 perf: ⚡ 会话状态缓存优化 * 当缓存未命中但 DB 已确认/创建会话后,调用 `SetConversationStatus` 回写 Redis * 缓存写回失败时记录日志,不中断聊天主流程,确保业务流畅性 fix: 🐛 修复历史消息顺序问题与编译错误 * 修复历史消息顺序问题,保证返回的 N 条历史消息按时间正序喂给模型 * 通过反转 `created_at desc` 查询结果的切片,确保模型输入顺序正确 * 修复 `fmt.Errorf` 参数不匹配问题,修正编译错误 * 整理 `agent-cache.go` 为标准 UTF-8 编码,避免 Go 编译报错 `invalid UTF-8 encoding` feat: 🛠️ 独立构建 MCP 服务器 * 使用 `Codex` 构建独立于后端的 MCP 服务器,简化与 Codex 的协作 * 通过该服务器方便 Codex 直接测试和查看 Redis 与 MySQL 中的数据
194 lines
4.9 KiB
Go
194 lines
4.9 KiB
Go
package security
|
|
|
|
import (
|
|
"fmt"
|
|
"regexp"
|
|
"strings"
|
|
)
|
|
|
|
var (
|
|
commentPattern = regexp.MustCompile(`(?s)/\*.*?\*/|--|#`)
|
|
forbiddenWords = regexp.MustCompile(`(?i)\b(INSERT|UPDATE|DELETE|ALTER|DROP|TRUNCATE|CREATE|REPLACE|RENAME|GRANT|REVOKE|MERGE|CALL|EXEC|LOCK|UNLOCK|LOAD|OUTFILE|INFILE|HANDLER|SET|USE)\b`)
|
|
fromJoinRef = regexp.MustCompile(`(?i)\b(?:FROM|JOIN)\s+([` + "`" + `"\w\.]+)`)
|
|
describeRef = regexp.MustCompile(`(?i)\b(?:DESCRIBE|DESC)\s+([` + "`" + `"\w\.]+)`)
|
|
showFromRef = regexp.MustCompile(`(?i)\bSHOW\b[\w\s]*\b(?:FROM|IN)\s+([` + "`" + `"\w\.]+)`)
|
|
)
|
|
|
|
type SQLValidator struct {
|
|
defaultDatabase string
|
|
enforceWhitelist bool
|
|
allowedDatabases map[string]struct{}
|
|
allowedTables map[string]struct{}
|
|
allowedTableOnly map[string]struct{}
|
|
}
|
|
|
|
type sqlObjectRef struct {
|
|
database string
|
|
table string
|
|
}
|
|
|
|
func NewSQLValidator(defaultDatabase string, enforceWhitelist bool, allowedDatabases []string, allowedTables []string) *SQLValidator {
|
|
v := &SQLValidator{
|
|
defaultDatabase: strings.ToLower(strings.TrimSpace(defaultDatabase)),
|
|
enforceWhitelist: enforceWhitelist,
|
|
allowedDatabases: make(map[string]struct{}),
|
|
allowedTables: make(map[string]struct{}),
|
|
allowedTableOnly: make(map[string]struct{}),
|
|
}
|
|
for _, db := range allowedDatabases {
|
|
db = strings.ToLower(strings.TrimSpace(db))
|
|
if db != "" {
|
|
v.allowedDatabases[db] = struct{}{}
|
|
}
|
|
}
|
|
for _, tbl := range allowedTables {
|
|
tbl = strings.ToLower(strings.TrimSpace(strings.Trim(tbl, "`\"")))
|
|
if tbl == "" {
|
|
continue
|
|
}
|
|
v.allowedTables[tbl] = struct{}{}
|
|
if dot := strings.LastIndex(tbl, "."); dot >= 0 && dot < len(tbl)-1 {
|
|
v.allowedTableOnly[tbl[dot+1:]] = struct{}{}
|
|
} else {
|
|
v.allowedTableOnly[tbl] = struct{}{}
|
|
}
|
|
}
|
|
return v
|
|
}
|
|
|
|
func (v *SQLValidator) ValidateReadOnlySQL(sql string) error {
|
|
trimmed := strings.TrimSpace(sql)
|
|
if trimmed == "" {
|
|
return fmt.Errorf("sql is required")
|
|
}
|
|
if strings.Contains(trimmed, ";") {
|
|
return fmt.Errorf("semicolon is not allowed")
|
|
}
|
|
if commentPattern.MatchString(trimmed) {
|
|
return fmt.Errorf("sql comments are not allowed")
|
|
}
|
|
|
|
first := strings.ToUpper(firstToken(trimmed))
|
|
switch first {
|
|
case "SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN":
|
|
default:
|
|
return fmt.Errorf("only SELECT/SHOW/DESCRIBE/EXPLAIN are allowed")
|
|
}
|
|
|
|
withoutLiterals := removeStringLiterals(trimmed)
|
|
if forbiddenWords.MatchString(strings.ToUpper(withoutLiterals)) {
|
|
return fmt.Errorf("dangerous sql keyword detected")
|
|
}
|
|
|
|
refs := extractObjectRefs(withoutLiterals)
|
|
if err := v.validateWhitelist(refs); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (v *SQLValidator) validateWhitelist(refs []sqlObjectRef) error {
|
|
hasDBAllowlist := len(v.allowedDatabases) > 0
|
|
hasTableAllowlist := len(v.allowedTables) > 0
|
|
|
|
if !v.enforceWhitelist && !hasDBAllowlist && !hasTableAllowlist {
|
|
return nil
|
|
}
|
|
if len(refs) == 0 {
|
|
if v.enforceWhitelist {
|
|
return fmt.Errorf("sql does not contain explicit table reference under whitelist mode")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
for _, ref := range refs {
|
|
db := ref.database
|
|
tbl := ref.table
|
|
if db == "" {
|
|
db = v.defaultDatabase
|
|
}
|
|
|
|
if hasDBAllowlist {
|
|
if db == "" {
|
|
return fmt.Errorf("database is not explicit and no default database set")
|
|
}
|
|
if _, ok := v.allowedDatabases[db]; !ok {
|
|
return fmt.Errorf("database %s is not in whitelist", db)
|
|
}
|
|
}
|
|
|
|
if hasTableAllowlist {
|
|
full := tbl
|
|
if db != "" {
|
|
full = db + "." + tbl
|
|
}
|
|
if _, ok := v.allowedTables[full]; ok {
|
|
continue
|
|
}
|
|
if _, ok := v.allowedTableOnly[tbl]; ok {
|
|
continue
|
|
}
|
|
return fmt.Errorf("table %s is not in whitelist", full)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func extractObjectRefs(sql string) []sqlObjectRef {
|
|
refs := make([]sqlObjectRef, 0)
|
|
seen := make(map[string]struct{})
|
|
|
|
for _, re := range []*regexp.Regexp{fromJoinRef, describeRef, showFromRef} {
|
|
matches := re.FindAllStringSubmatch(sql, -1)
|
|
for _, m := range matches {
|
|
if len(m) < 2 {
|
|
continue
|
|
}
|
|
ref := normalizeRef(m[1])
|
|
if ref.table == "" {
|
|
continue
|
|
}
|
|
key := ref.database + "." + ref.table
|
|
if _, ok := seen[key]; ok {
|
|
continue
|
|
}
|
|
seen[key] = struct{}{}
|
|
refs = append(refs, ref)
|
|
}
|
|
}
|
|
return refs
|
|
}
|
|
|
|
func normalizeRef(raw string) sqlObjectRef {
|
|
clean := strings.ToLower(strings.TrimSpace(raw))
|
|
clean = strings.Trim(clean, "`\"'(),")
|
|
clean = strings.TrimSpace(clean)
|
|
if clean == "" {
|
|
return sqlObjectRef{}
|
|
}
|
|
parts := strings.Split(clean, ".")
|
|
if len(parts) == 1 {
|
|
return sqlObjectRef{table: parts[0]}
|
|
}
|
|
return sqlObjectRef{database: parts[0], table: parts[len(parts)-1]}
|
|
}
|
|
|
|
func firstToken(sql string) string {
|
|
for i, r := range sql {
|
|
if r == ' ' || r == '\n' || r == '\t' || r == '\r' {
|
|
if i == 0 {
|
|
continue
|
|
}
|
|
return sql[:i]
|
|
}
|
|
}
|
|
return sql
|
|
}
|
|
|
|
func removeStringLiterals(sql string) string {
|
|
masked := singleQuotedString.ReplaceAllString(sql, "''")
|
|
masked = doubleQuotedString.ReplaceAllString(masked, `""`)
|
|
return masked
|
|
}
|