Files
LoveLosita 26c350f378 Version: 0.4.4.dev.260307
feat: 🚀 增强会话管理与缓存机制

* 会话 ID 空值兜底,若 `conversation_id` 为空时自动生成 UUID
* 在响应头写入 `X-Conversation-ID`,供前端使用,保持同一会话状态

perf:  会话状态缓存优化

* 当缓存未命中但 DB 已确认/创建会话后,调用 `SetConversationStatus` 回写 Redis
* 缓存写回失败时记录日志,不中断聊天主流程,确保业务流畅性

fix: 🐛 修复历史消息顺序问题与编译错误

* 修复历史消息顺序问题,保证返回的 N 条历史消息按时间正序喂给模型

  * 通过反转 `created_at desc` 查询结果的切片,确保模型输入顺序正确
* 修复 `fmt.Errorf` 参数不匹配问题,修正编译错误
* 整理 `agent-cache.go` 为标准 UTF-8 编码,避免 Go 编译报错 `invalid UTF-8 encoding`

feat: 🛠️ 独立构建 MCP 服务器

* 使用 `Codex` 构建独立于后端的 MCP 服务器,简化与 Codex 的协作
* 通过该服务器方便 Codex 直接测试和查看 Redis 与 MySQL 中的数据
2026-03-07 15:25:40 +08:00

194 lines
4.9 KiB
Go

package security
import (
"fmt"
"regexp"
"strings"
)
var (
commentPattern = regexp.MustCompile(`(?s)/\*.*?\*/|--|#`)
forbiddenWords = regexp.MustCompile(`(?i)\b(INSERT|UPDATE|DELETE|ALTER|DROP|TRUNCATE|CREATE|REPLACE|RENAME|GRANT|REVOKE|MERGE|CALL|EXEC|LOCK|UNLOCK|LOAD|OUTFILE|INFILE|HANDLER|SET|USE)\b`)
fromJoinRef = regexp.MustCompile(`(?i)\b(?:FROM|JOIN)\s+([` + "`" + `"\w\.]+)`)
describeRef = regexp.MustCompile(`(?i)\b(?:DESCRIBE|DESC)\s+([` + "`" + `"\w\.]+)`)
showFromRef = regexp.MustCompile(`(?i)\bSHOW\b[\w\s]*\b(?:FROM|IN)\s+([` + "`" + `"\w\.]+)`)
)
type SQLValidator struct {
defaultDatabase string
enforceWhitelist bool
allowedDatabases map[string]struct{}
allowedTables map[string]struct{}
allowedTableOnly map[string]struct{}
}
type sqlObjectRef struct {
database string
table string
}
func NewSQLValidator(defaultDatabase string, enforceWhitelist bool, allowedDatabases []string, allowedTables []string) *SQLValidator {
v := &SQLValidator{
defaultDatabase: strings.ToLower(strings.TrimSpace(defaultDatabase)),
enforceWhitelist: enforceWhitelist,
allowedDatabases: make(map[string]struct{}),
allowedTables: make(map[string]struct{}),
allowedTableOnly: make(map[string]struct{}),
}
for _, db := range allowedDatabases {
db = strings.ToLower(strings.TrimSpace(db))
if db != "" {
v.allowedDatabases[db] = struct{}{}
}
}
for _, tbl := range allowedTables {
tbl = strings.ToLower(strings.TrimSpace(strings.Trim(tbl, "`\"")))
if tbl == "" {
continue
}
v.allowedTables[tbl] = struct{}{}
if dot := strings.LastIndex(tbl, "."); dot >= 0 && dot < len(tbl)-1 {
v.allowedTableOnly[tbl[dot+1:]] = struct{}{}
} else {
v.allowedTableOnly[tbl] = struct{}{}
}
}
return v
}
func (v *SQLValidator) ValidateReadOnlySQL(sql string) error {
trimmed := strings.TrimSpace(sql)
if trimmed == "" {
return fmt.Errorf("sql is required")
}
if strings.Contains(trimmed, ";") {
return fmt.Errorf("semicolon is not allowed")
}
if commentPattern.MatchString(trimmed) {
return fmt.Errorf("sql comments are not allowed")
}
first := strings.ToUpper(firstToken(trimmed))
switch first {
case "SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN":
default:
return fmt.Errorf("only SELECT/SHOW/DESCRIBE/EXPLAIN are allowed")
}
withoutLiterals := removeStringLiterals(trimmed)
if forbiddenWords.MatchString(strings.ToUpper(withoutLiterals)) {
return fmt.Errorf("dangerous sql keyword detected")
}
refs := extractObjectRefs(withoutLiterals)
if err := v.validateWhitelist(refs); err != nil {
return err
}
return nil
}
func (v *SQLValidator) validateWhitelist(refs []sqlObjectRef) error {
hasDBAllowlist := len(v.allowedDatabases) > 0
hasTableAllowlist := len(v.allowedTables) > 0
if !v.enforceWhitelist && !hasDBAllowlist && !hasTableAllowlist {
return nil
}
if len(refs) == 0 {
if v.enforceWhitelist {
return fmt.Errorf("sql does not contain explicit table reference under whitelist mode")
}
return nil
}
for _, ref := range refs {
db := ref.database
tbl := ref.table
if db == "" {
db = v.defaultDatabase
}
if hasDBAllowlist {
if db == "" {
return fmt.Errorf("database is not explicit and no default database set")
}
if _, ok := v.allowedDatabases[db]; !ok {
return fmt.Errorf("database %s is not in whitelist", db)
}
}
if hasTableAllowlist {
full := tbl
if db != "" {
full = db + "." + tbl
}
if _, ok := v.allowedTables[full]; ok {
continue
}
if _, ok := v.allowedTableOnly[tbl]; ok {
continue
}
return fmt.Errorf("table %s is not in whitelist", full)
}
}
return nil
}
func extractObjectRefs(sql string) []sqlObjectRef {
refs := make([]sqlObjectRef, 0)
seen := make(map[string]struct{})
for _, re := range []*regexp.Regexp{fromJoinRef, describeRef, showFromRef} {
matches := re.FindAllStringSubmatch(sql, -1)
for _, m := range matches {
if len(m) < 2 {
continue
}
ref := normalizeRef(m[1])
if ref.table == "" {
continue
}
key := ref.database + "." + ref.table
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
refs = append(refs, ref)
}
}
return refs
}
func normalizeRef(raw string) sqlObjectRef {
clean := strings.ToLower(strings.TrimSpace(raw))
clean = strings.Trim(clean, "`\"'(),")
clean = strings.TrimSpace(clean)
if clean == "" {
return sqlObjectRef{}
}
parts := strings.Split(clean, ".")
if len(parts) == 1 {
return sqlObjectRef{table: parts[0]}
}
return sqlObjectRef{database: parts[0], table: parts[len(parts)-1]}
}
func firstToken(sql string) string {
for i, r := range sql {
if r == ' ' || r == '\n' || r == '\t' || r == '\r' {
if i == 0 {
continue
}
return sql[:i]
}
}
return sql
}
func removeStringLiterals(sql string) string {
masked := singleQuotedString.ReplaceAllString(sql, "''")
masked = doubleQuotedString.ReplaceAllString(masked, `""`)
return masked
}