From 1bcbd41bec404af2b90a9f3c5b35b3b3e2019b5d Mon Sep 17 00:00:00 2001 From: LoveLosita <2810873701@qq.com> Date: Wed, 4 Feb 2026 22:08:58 +0800 Subject: [PATCH] =?UTF-8?q?Version:0.0.7.dev.260204=20feat:=20=E2=9C=85=20?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=E8=8E=B7=E5=8F=96=E5=AE=8C=E6=95=B4=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E7=B1=BB=E4=B8=8E=E4=BF=AE=E6=94=B9=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E7=B1=BB=E6=8E=A5=E5=8F=A3=E5=B9=B6=E5=AE=8C=E6=88=90=E6=B5=8B?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增获取完整任务类接口 📋 - 实现创建任务类的逆向数据转换函数 🔄 - 工程量较大,涉及完整数据结构还原 🏗️ - 新增修改任务类接口 ✏️ - 调整 service 层 AddOrUpdateTaskClass 函数逻辑 - 复用创建任务类的大部分实现,并通过 method 区分创建/更新操作 ♻️ - 更新 dao 层操作逻辑 🗄️ - 增加防止越权修改其它用户任务类的机制 🔒 - 两个接口代码量巨大,但均已测试通过 🧪💪 --- backend/api/task-class.go | 62 +++++++++++++++++++- backend/conv/task-class.go | 72 +++++++++++++++++++++++ backend/dao/task-class.go | 105 +++++++++++++++++++++++++++++++--- backend/model/task-class.go | 9 +-- backend/respond/respond.go | 10 ++++ backend/routers/routers.go | 2 + backend/service/task-class.go | 25 ++++++-- 7 files changed, 269 insertions(+), 16 deletions(-) diff --git a/backend/api/task-class.go b/backend/api/task-class.go index 65bd264..99b21fa 100644 --- a/backend/api/task-class.go +++ b/backend/api/task-class.go @@ -4,12 +4,14 @@ import ( "context" "errors" "net/http" + "strconv" "time" "github.com/LoveLosita/smartflow/backend/model" "github.com/LoveLosita/smartflow/backend/respond" "github.com/LoveLosita/smartflow/backend/service" "github.com/gin-gonic/gin" + "gorm.io/gorm" ) type TaskClassHandler struct { @@ -23,6 +25,11 @@ func NewTaskClassHandler(svc *service.TaskClassService) *TaskClassHandler { } } +const ( + create = 0 + update = 1 +) + func (api *TaskClassHandler) UserAddTaskClass(c *gin.Context) { var req model.UserAddTaskClassRequest err := c.ShouldBindJSON(&req) @@ -34,7 +41,7 @@ func (api *TaskClassHandler) UserAddTaskClass(c *gin.Context) { // 创建一个带 1 秒超时的上下文 ctx, cancel := context.WithTimeout(c.Request.Context(), 1*time.Second) defer cancel() // 记得释放资源 - err = api.svc.AddTaskClass(ctx, &req, userIDInterface) + err = api.svc.AddOrUpdateTaskClass(ctx, &req, userIDInterface, create, 0) if err != nil { if errors.Is(err, respond.WrongParamType) { c.JSON(http.StatusBadRequest, respond.WrongParamType) @@ -57,3 +64,56 @@ func (api *TaskClassHandler) UserGetTaskClassInfos(c *gin.Context) { } c.JSON(http.StatusOK, respond.RespWithData(respond.Ok, resp)) } + +func (api *TaskClassHandler) UserGetCompleteTaskClass(c *gin.Context) { + taskClassID := c.Query("task_class_id") + //将taskClassID转换为int + intTaskClassID, err := strconv.Atoi(taskClassID) + if err != nil { + c.JSON(http.StatusBadRequest, respond.WrongParamType) + return + } + if taskClassID == "" { + c.JSON(http.StatusBadRequest, respond.MissingParam) + return + } + userIDInterface := c.GetInt("user_id") + // 创建一个带 1 秒超时的上下文 + ctx, cancel := context.WithTimeout(c.Request.Context(), 1*time.Second) + defer cancel() // 记得释放资源 + resp, err := api.svc.GetUserCompleteTaskClass(ctx, userIDInterface, intTaskClassID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + c.JSON(http.StatusNotFound, respond.UserTaskClassNotFound) + return + } + c.JSON(http.StatusInternalServerError, respond.InternalError(err)) + return + } + c.JSON(http.StatusOK, respond.RespWithData(respond.Ok, resp)) +} + +func (api *TaskClassHandler) UserUpdateTaskClass(c *gin.Context) { + var req model.UserAddTaskClassRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.JSON(http.StatusBadRequest, respond.WrongParamType) + return + } + taskClassID := c.Query("task_class_id") + //将taskClassID转换为int + intTaskClassID, err := strconv.Atoi(taskClassID) + userIDInterface := c.GetInt("user_id") + // 创建一个带 1 秒超时的上下文 + ctx, cancel := context.WithTimeout(c.Request.Context(), 1*time.Second) + defer cancel() // 记得释放资源 + err = api.svc.AddOrUpdateTaskClass(ctx, &req, userIDInterface, update, intTaskClassID) + if err != nil { + if errors.Is(err, respond.WrongParamType) || errors.Is(err, respond.UserTaskClassForbidden) { + c.JSON(http.StatusBadRequest, err) + return + } + c.JSON(http.StatusInternalServerError, respond.InternalError(err)) + } + c.JSON(http.StatusOK, respond.Ok) +} diff --git a/backend/conv/task-class.go b/backend/conv/task-class.go index 111f528..7c3a3a0 100644 --- a/backend/conv/task-class.go +++ b/backend/conv/task-class.go @@ -1,6 +1,8 @@ package conv import ( + "encoding/json" + "errors" "time" "github.com/LoveLosita/smartflow/backend/model" @@ -94,3 +96,73 @@ func TaskClassModelToResponse(taskClasses []model.TaskClass) *model.UserGetTaskC } return &resp } + +func ProcessUserGetCompleteTaskClassRequest(taskClass *model.TaskClass) (*model.UserAddTaskClassRequest, error) { + if taskClass == nil { + return nil, errors.New("源数据对象不可为空") + } + // 1. 映射基础信息 (处理指针解引用) + req := &model.UserAddTaskClassRequest{ + Name: safeStr(taskClass.Name), + Mode: safeStr(taskClass.Mode), + StartDate: formatTime(taskClass.StartDate), + EndDate: formatTime(taskClass.EndDate), + } + // 2. 映射配置信息 (Config Section) + req.Config = model.UserAddTaskClassConfig{ + TotalSlots: safeInt(taskClass.TotalSlots), + AllowFillerCourse: safeBool(taskClass.AllowFillerCourse), + Strategy: safeStr(taskClass.Strategy), + } + // 3. 处理 ExcludedSlots JSON 字符串 -> []int + if taskClass.ExcludedSlots != nil && *taskClass.ExcludedSlots != "" { + var excluded []int + // 直接使用标准反序列化,比手动处理 rune 字符要健壮得多 + if err := json.Unmarshal([]byte(*taskClass.ExcludedSlots), &excluded); err == nil { + req.Config.ExcludedSlots = excluded + } + } + // 4. 映射子项信息 (Items Section) + // 此时 items 已经通过 Preload 加载到了 taskClass.Items 中 + req.Items = make([]model.UserAddTaskClassItemRequest, 0, len(taskClass.Items)) + for _, item := range taskClass.Items { + itemReq := model.UserAddTaskClassItemRequest{ + Order: safeInt(item.Order), + Content: safeStr(item.Content), + EmbeddedTime: item.EmbeddedTime, // 结构体指针直接复用 + } + req.Items = append(req.Items, itemReq) + } + return req, nil +} + +// --- 🛡️ 辅助工具函数:保持代码清爽并防止 Panic --- + +func safeStr(s *string) string { + if s == nil { + return "" + } + return *s +} + +func safeInt(i *int) int { + if i == nil { + return 0 + } + return *i +} + +func safeBool(b *bool) bool { + if b == nil { + return true + } + return *b +} + +func formatTime(t *time.Time) string { + if t == nil { + return "" + } + // 务必使用 2006-01-02 格式以匹配前端校验 + return t.Format("2006-01-02") +} diff --git a/backend/dao/task-class.go b/backend/dao/task-class.go index 0ac587f..6122d45 100644 --- a/backend/dao/task-class.go +++ b/backend/dao/task-class.go @@ -1,7 +1,10 @@ package dao import ( + "context" + "github.com/LoveLosita/smartflow/backend/model" + "github.com/LoveLosita/smartflow/backend/respond" "gorm.io/gorm" ) @@ -18,17 +21,87 @@ func NewTaskClassDAO(db *gorm.DB) *TaskClassDAO { } } -// AddTaskClass 为指定用户添加任务类 -func (dao *TaskClassDAO) AddTaskClass(taskClass *model.TaskClass) (int, error) { - err := dao.db.Create(taskClass).Error - if err != nil { - return 0, err +// AddOrUpdateTaskClass 为指定用户添加/更新任务类(防越权:更新时限定 user_id) +func (dao *TaskClassDAO) AddOrUpdateTaskClass(userID int, taskClass *model.TaskClass) (int, error) { + // 不信任入参里的 UserID,强制使用当前登录用户 + taskClass.UserID = &userID + + // 新增:ID == 0 直接插入 + if taskClass.ID == 0 { + if err := dao.db.Create(taskClass).Error; err != nil { + return 0, err + } + return taskClass.ID, nil + } + // 更新:必须同时匹配 id + user_id,否则不会更新任何行(避免覆盖他人数据) + tx := dao.db.Model(&model.TaskClass{}). + Where("id = ? AND user_id = ?", taskClass.ID, userID). + Updates(taskClass) + if tx.Error != nil { + return 0, tx.Error + } + if tx.RowsAffected == 0 { + // 未匹配到记录:要么不存在,要么不属于该用户 + return 0, respond.UserTaskClassForbidden } return taskClass.ID, nil } -func (dao *TaskClassDAO) AddTaskClassItems(items []model.TaskClassItem) error { - return dao.db.Create(&items).Error +func (dao *TaskClassDAO) AddOrUpdateTaskClassItems(userID int, items []model.TaskClassItem) error { + if len(items) == 0 { + return nil + } + + // 1) 校验这些 items 关联的 task_class(category_id)都属于当前用户 + categoryIDSet := make(map[int]struct{}, len(items)) + var categoryIDs []int + for _, it := range items { + if *it.CategoryID == 0 { + return gorm.ErrRecordNotFound + } + if _, ok := categoryIDSet[*it.CategoryID]; !ok { + categoryIDSet[*it.CategoryID] = struct{}{} + categoryIDs = append(categoryIDs, *it.CategoryID) + } + } + + var count int64 + if err := dao.db.Model(&model.TaskClass{}). + Where("id IN ? AND user_id = ?", categoryIDs, userID). + Count(&count).Error; err != nil { + return err + } + if count != int64(len(categoryIDs)) { + return respond.UserTaskClassForbidden + } + + // 2) 新增与更新分开处理:新增不受影响;更新时限定 category_id(防越权) + var toCreate []model.TaskClassItem + for _, it := range items { + if it.ID == 0 { + toCreate = append(toCreate, it) + continue + } + + tx := dao.db.Model(&model.TaskClassItem{}). + Where("id = ? AND category_id IN ?", it.ID, categoryIDs). + Updates(map[string]any{ + "category_id": it.CategoryID, + }) + if tx.Error != nil { + return tx.Error + } + if tx.RowsAffected == 0 { + return respond.UserTaskClassForbidden + } + } + + if len(toCreate) > 0 { + if err := dao.db.Create(&toCreate).Error; err != nil { + return err + } + } + return nil } // Transaction 在一个事务中执行传入的函数,供 service 层复用(自动提交/回滚) @@ -47,3 +120,21 @@ func (dao *TaskClassDAO) GetUserTaskClasses(userID int) ([]model.TaskClass, erro } return taskClasses, nil } + +// GetCompleteTaskClassByID 带着 ID 和 UserID 去取,防越权 +func (dao *TaskClassDAO) GetCompleteTaskClassByID(ctx context.Context, id int, userID int) (*model.TaskClass, error) { + var taskClass model.TaskClass + + // 1. 使用 Preload("Items") 自动执行两条 SQL 并组装 + // SQL A: SELECT * FROM task_classes WHERE id = ? AND user_id = ? + // SQL B: SELECT * FROM task_class_items WHERE category_id = (SQL A 的 ID) + err := dao.db.WithContext(ctx). + Preload("Items"). + Where("id = ? AND user_id = ?", id, userID). + First(&taskClass).Error + + if err != nil { + return nil, err + } + return &taskClass, nil +} diff --git a/backend/model/task-class.go b/backend/model/task-class.go index 1bd6efc..80cbde3 100644 --- a/backend/model/task-class.go +++ b/backend/model/task-class.go @@ -17,10 +17,11 @@ type TaskClass struct { StartDate *time.Time `gorm:"column:start_date"` EndDate *time.Time `gorm:"column:end_date"` //section 3 - TotalSlots *int `gorm:"column:total_slots;comment:分配的总节数"` - AllowFillerCourse *bool `gorm:"column:allow_filler_course;default:true"` - Strategy *string `gorm:"column:strategy;type:enum('steady','rapid')"` - ExcludedSlots *string `gorm:"column:excluded_slots;type:json;comment:不想要的时段切片"` + TotalSlots *int `gorm:"column:total_slots;comment:分配的总节数"` + AllowFillerCourse *bool `gorm:"column:allow_filler_course;default:true"` + Strategy *string `gorm:"column:strategy;type:enum('steady','rapid')"` + ExcludedSlots *string `gorm:"column:excluded_slots;type:json;comment:不想要的时段切片"` + Items []TaskClassItem `gorm:"foreignKey:CategoryID;references:ID"` // 一对多关联:一个 TaskClass 有多个 TaskClassItem } // TableName 设定 TaskClass 的表名为 task_classes diff --git a/backend/respond/respond.go b/backend/respond/respond.go index a212e14..9104a4f 100644 --- a/backend/respond/respond.go +++ b/backend/respond/respond.go @@ -138,4 +138,14 @@ var ( //请求相关的响应 Status: "40019", Info: "wrong course info", } + + UserTaskClassNotFound = Response{ //用户任务类未找到 + Status: "40020", + Info: "user task class not found", + } + + UserTaskClassForbidden = Response{ //用户任务类禁止访问 + Status: "40021", + Info: "user task class forbidden", + } ) diff --git a/backend/routers/routers.go b/backend/routers/routers.go index a5216b2..2c74740 100644 --- a/backend/routers/routers.go +++ b/backend/routers/routers.go @@ -65,6 +65,8 @@ func RegisterRouters(handlers *api.ApiHandlers, cache *dao.CacheDAO) *gin.Engine taskClassGroup.Use(middleware.JWTTokenAuth(cache)) taskClassGroup.POST("/add", handlers.TaskClassHandler.UserAddTaskClass) taskClassGroup.GET("/list", handlers.TaskClassHandler.UserGetTaskClassInfos) + taskClassGroup.GET("/get", handlers.TaskClassHandler.UserGetCompleteTaskClass) + taskClassGroup.PUT("/update", handlers.TaskClassHandler.UserUpdateTaskClass) } } // 初始化Gin引擎 diff --git a/backend/service/task-class.go b/backend/service/task-class.go index c67b855..6884d4e 100644 --- a/backend/service/task-class.go +++ b/backend/service/task-class.go @@ -24,16 +24,19 @@ func NewTaskClassService(taskClassRepo *dao.TaskClassDAO, cacheRepo *dao.CacheDA } } -// AddTaskClass 为指定用户添加任务类 -func (sv *TaskClassService) AddTaskClass(ctx context.Context, req *model.UserAddTaskClassRequest, userID int) error { +// AddOrUpdateTaskClass 为指定用户添加任务类 +func (sv *TaskClassService) AddOrUpdateTaskClass(ctx context.Context, req *model.UserAddTaskClassRequest, userID int, method int, targetTaskClassID int) error { // 1) 先写数据库(事务内) if err := sv.taskClassRepo.Transaction(func(txDAO *dao.TaskClassDAO) error { taskClass, items, err := conv.ProcessUserAddTaskClassRequest(req, userID) if err != nil { return err } + if method == 1 { // 更新操作 + taskClass.ID = targetTaskClassID + } - taskClassID, err := txDAO.AddTaskClass(taskClass) + taskClassID, err := txDAO.AddOrUpdateTaskClass(userID, taskClass) if err != nil { return err } @@ -41,7 +44,7 @@ func (sv *TaskClassService) AddTaskClass(ctx context.Context, req *model.UserAdd for i := range items { items[i].CategoryID = &taskClassID } - if err := txDAO.AddTaskClassItems(items); err != nil { + if err := txDAO.AddOrUpdateTaskClassItems(userID, items); err != nil { return err } return nil @@ -79,3 +82,17 @@ func (sv *TaskClassService) GetUserTaskClassInfos(ctx context.Context, userID in } return resp, nil } + +func (sv *TaskClassService) GetUserCompleteTaskClass(ctx context.Context, userID int, taskClassID int) (*model.UserAddTaskClassRequest, error) { + //1.查询数据库 + taskClass, err := sv.taskClassRepo.GetCompleteTaskClassByID(ctx, taskClassID, userID) + if err != nil { + return nil, err + } + //2.转换为响应结构体 + resp, err := conv.ProcessUserGetCompleteTaskClassRequest(taskClass) + if err != nil { + return nil, err + } + return resp, nil +}