Files
smartmate/backend/dao/task-class.go
LoveLosita e5a4114202 Version: 0.2.5.dev.260211
feat: 🗑️ 新增删除任务类接口并实现级联删除

- 通过 task_class 与 task_item 两张表建立级联关系 🔗
- 删除 task_class 时自动删除关联的 task_item
- 保证数据一致性,避免产生孤立数据 
2026-02-11 18:44:13 +08:00

234 lines
6.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dao
import (
"context"
"errors"
"github.com/LoveLosita/smartflow/backend/model"
"github.com/LoveLosita/smartflow/backend/respond"
"gorm.io/gorm"
)
type TaskClassDAO struct {
// 这是一个口袋,用来装数据库连接实例
db *gorm.DB
}
// NewTaskClassDAO 创建TaskClassDAO实例
// NewTaskClassDAO 接收一个 *gorm.DB并把它塞进结构体的口袋里
func NewTaskClassDAO(db *gorm.DB) *TaskClassDAO {
return &TaskClassDAO{
db: db,
}
}
func (dao *TaskClassDAO) WithTx(tx *gorm.DB) *TaskClassDAO {
return &TaskClassDAO{
db: tx,
}
}
// AddOrUpdateTaskClass 为指定用户添加/更新任务类(防越权:更新时限定 user_id
func (dao *TaskClassDAO) AddOrUpdateTaskClass(userID int, taskClass *model.TaskClass) (int, error) {
// 不信任入参里的 UserID强制使用当前登录用户
taskClass.UserID = &userID
// 新增ID == 0 直接插入
if taskClass.ID == 0 {
if err := dao.db.Create(taskClass).Error; err != nil {
return 0, err
}
return taskClass.ID, nil
}
// 更新:必须同时匹配 id + user_id否则不会更新任何行避免覆盖他人数据
tx := dao.db.Model(&model.TaskClass{}).
Where("id = ? AND user_id = ?", taskClass.ID, userID).
Updates(taskClass)
if tx.Error != nil {
return 0, tx.Error
}
if tx.RowsAffected == 0 {
// 未匹配到记录:要么不存在,要么不属于该用户
return 0, respond.UserTaskClassForbidden
}
return taskClass.ID, nil
}
func (dao *TaskClassDAO) AddOrUpdateTaskClassItems(userID int, items []model.TaskClassItem) error {
if len(items) == 0 {
return nil
}
// 1) 校验这些 items 关联的 task_classcategory_id都属于当前用户
categoryIDSet := make(map[int]struct{}, len(items))
var categoryIDs []int
for _, it := range items {
if *it.CategoryID == 0 {
return gorm.ErrRecordNotFound
}
if _, ok := categoryIDSet[*it.CategoryID]; !ok {
categoryIDSet[*it.CategoryID] = struct{}{}
categoryIDs = append(categoryIDs, *it.CategoryID)
}
}
var count int64
if err := dao.db.Model(&model.TaskClass{}).
Where("id IN ? AND user_id = ?", categoryIDs, userID).
Count(&count).Error; err != nil {
return err
}
if count != int64(len(categoryIDs)) {
return respond.UserTaskClassForbidden
}
// 2) 新增与更新分开处理:新增不受影响;更新时限定 category_id防越权
var toCreate []model.TaskClassItem
for _, it := range items {
if it.ID == 0 {
toCreate = append(toCreate, it)
continue
}
tx := dao.db.Model(&model.TaskClassItem{}).
Where("id = ? AND category_id IN ?", it.ID, categoryIDs).
Updates(map[string]any{
"category_id": it.CategoryID,
})
if tx.Error != nil {
return tx.Error
}
if tx.RowsAffected == 0 {
return respond.UserTaskClassForbidden
}
}
if len(toCreate) > 0 {
if err := dao.db.Create(&toCreate).Error; err != nil {
return err
}
}
return nil
}
// Transaction 在一个事务中执行传入的函数,供 service 层复用(自动提交/回滚)
// 规则fn 返回 nil -> commitfn 返回 error 或 panic -> rollback
func (dao *TaskClassDAO) Transaction(fn func(txDAO *TaskClassDAO) error) error {
return dao.db.Transaction(func(tx *gorm.DB) error {
return fn(NewTaskClassDAO(tx))
})
}
func (dao *TaskClassDAO) GetUserTaskClasses(userID int) ([]model.TaskClass, error) {
var taskClasses []model.TaskClass
err := dao.db.Where("user_id = ?", userID).Find(&taskClasses).Error
if err != nil {
return nil, err
}
return taskClasses, nil
}
// GetCompleteTaskClassByID 带着 ID 和 UserID 去取,防越权
func (dao *TaskClassDAO) GetCompleteTaskClassByID(ctx context.Context, id int, userID int) (*model.TaskClass, error) {
var taskClass model.TaskClass
// 1. 使用 Preload("Items") 自动执行两条 SQL 并组装
// SQL A: SELECT * FROM task_classes WHERE id = ? AND user_id = ?
// SQL B: SELECT * FROM task_class_items WHERE category_id = (SQL A 的 ID)
err := dao.db.WithContext(ctx).
Preload("Items").
Where("id = ? AND user_id = ?", id, userID).
First(&taskClass).Error
if err != nil {
return nil, err
}
return &taskClass, nil
}
func (dao *TaskClassDAO) GetTaskClassItemByID(ctx context.Context, id int) (*model.TaskClassItem, error) {
var item model.TaskClassItem
err := dao.db.WithContext(ctx).
Where("id = ?", id).
First(&item).Error
if err != nil {
return nil, err
}
return &item, nil
}
func (dao *TaskClassDAO) GetTaskClassIDByTaskItemID(ctx context.Context, itemID int) (int, error) {
var item model.TaskClassItem
res := dao.db.WithContext(ctx).
Select("category_id").
Where("id = ?", itemID).
First(&item)
if res.Error != nil {
if errors.Is(res.Error, gorm.ErrRecordNotFound) {
return 0, respond.TaskClassItemNotFound
}
return 0, res.Error
}
return *item.CategoryID, nil
}
func (dao *TaskClassDAO) GetTaskClassUserIDByID(ctx context.Context, taskClassID int) (int, error) {
var taskClass model.TaskClass
err := dao.db.WithContext(ctx).
Select("user_id").
Where("id = ?", taskClassID).
First(&taskClass).Error
if err != nil {
return 0, err
}
return *taskClass.UserID, nil
}
func (dao *TaskClassDAO) UpdateTaskClassItemEmbeddedTime(ctx context.Context, taskID int, embeddedTime *model.TargetTime) error {
err := dao.db.WithContext(ctx).
Model(&model.TaskClassItem{}).
Where("id = ?", taskID).
Update("embedded_time", embeddedTime).Error
return err
}
func (dao *TaskClassDAO) DeleteTaskClassItemEmbeddedTime(ctx context.Context, taskID int) error {
err := dao.db.WithContext(ctx).
Model(&model.TaskClassItem{}).
Where("id = ?", taskID).
Update("embedded_time", nil).Error
return err
}
func (dao *TaskClassDAO) IfTaskClassItemArranged(ctx context.Context, taskID int) (bool, error) {
var item model.TaskClassItem
err := dao.db.WithContext(ctx).
Select("embedded_time").
Where("id = ?", taskID).
First(&item).Error
if err != nil {
return false, err
}
return item.EmbeddedTime != nil, nil
}
func (dao *TaskClassDAO) DeleteTaskClassItemByID(ctx context.Context, id int) error {
err := dao.db.WithContext(ctx).
Where("id = ?", id).
Delete(&model.TaskClassItem{}).Error
return err
}
func (dao *TaskClassDAO) DeleteTaskClassByID(ctx context.Context, id int) error {
res := dao.db.WithContext(ctx).
Where("id = ?", id).
Delete(&model.TaskClass{})
if res.Error != nil {
return res.Error
}
if res.RowsAffected == 0 {
return respond.WrongTaskClassID
}
return nil
}