后端: 1. Credit 价格规则补齐利润率与实际计费单价语义:新增 `profit_rate_bps` 与 `charge_*_price_micros` 展示字段,下沉共享价格推导 helper,tokenstore rpc/client/proto/model/default rule 全链路同步,LLM usage 扣费统一改按加价后的 charge 单价换算。 2. task-class 更新链路修正全量覆盖与归属校验:`runtime/conv` 保留 item id,DAO 更新前显式校验 task-class 与 item 归属,改用显式字段 map 落库 nil/空切片/零值,避免 `RowsAffected=0` 误判越权,同时补齐任务项可编辑字段更新。 3. GormCache task-class 失效补空 user_id 保护:更新语句缺少模型上下文时直接跳过失效,避免缓存插件因空指针影响主事务。 前端: 4. 课表中心补齐任务类编辑能力:新增 `updateTaskClass` API,创建弹窗支持编辑态回填与 item id 提交,日程页支持先拉详情再编辑并在保存后刷新任务类详情与列表。 5. 计划广场详情补点赞交互与奖励提示:详情页新增点赞/取消点赞按钮、奖励反馈文案与计数展示,论坛类型补 `reward_hint`,评论区与帖子作者头像统一接入兜底头像工具。 6. 品牌与展示细节收口:侧边栏与 favicon 切到项目 logo,首页标题改为 `SmartMate`,主面板缩放上限微调,论坛列表头像显示与整体品牌观感同步统一。
456 lines
14 KiB
Go
456 lines
14 KiB
Go
package dao
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"time"
|
||
|
||
"github.com/LoveLosita/smartflow/backend/services/runtime/model"
|
||
"github.com/LoveLosita/smartflow/backend/shared/respond"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type TaskClassDAO struct {
|
||
// 这是一个口袋,用来装数据库连接实例
|
||
db *gorm.DB
|
||
}
|
||
|
||
// NewTaskClassDAO 创建TaskClassDAO实例
|
||
// NewTaskClassDAO 接收一个 *gorm.DB,并把它塞进结构体的口袋里
|
||
func NewTaskClassDAO(db *gorm.DB) *TaskClassDAO {
|
||
return &TaskClassDAO{
|
||
db: db,
|
||
}
|
||
}
|
||
|
||
func (dao *TaskClassDAO) WithTx(tx *gorm.DB) *TaskClassDAO {
|
||
return &TaskClassDAO{
|
||
db: tx,
|
||
}
|
||
}
|
||
|
||
// AddOrUpdateTaskClass 为指定用户添加/更新任务类(防越权:更新时限定 user_id)
|
||
func (dao *TaskClassDAO) AddOrUpdateTaskClass(userID int, taskClass *model.TaskClass) (int, error) {
|
||
// 不信任入参里的 UserID,强制使用当前登录用户
|
||
taskClass.UserID = &userID
|
||
|
||
// 新增:ID == 0 直接插入
|
||
if taskClass.ID == 0 {
|
||
if err := dao.db.Create(taskClass).Error; err != nil {
|
||
return 0, err
|
||
}
|
||
return taskClass.ID, nil
|
||
}
|
||
|
||
// 1. 先显式校验任务类归属,避免“更新值与库内完全相同”时 RowsAffected=0 被误判成越权。
|
||
// 2. 这里只负责校验 id/user_id 是否匹配,不负责判断具体字段有没有变化。
|
||
// 3. 若确实不存在或不属于当前用户,统一返回 UserTaskClassForbidden。
|
||
if err := dao.ensureTaskClassOwned(userID, taskClass.ID); err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
// 1. 更新语义是“前端提交什么,就以什么覆盖数据库当前值”。
|
||
// 2. 因此这里不能再直接用 struct Updates,否则 nil/空切片 等零值字段会被 GORM 跳过,刷新后看起来像“没更新”。
|
||
// 3. 统一改成显式字段映射,保证可选字段清空、排除数组清空都能真正落库。
|
||
tx := dao.db.Model(&model.TaskClass{UserID: &userID}).
|
||
Where("id = ? AND user_id = ?", taskClass.ID, userID).
|
||
Updates(buildTaskClassUpdateMap(taskClass))
|
||
if tx.Error != nil {
|
||
return 0, tx.Error
|
||
}
|
||
return taskClass.ID, nil
|
||
}
|
||
|
||
func (dao *TaskClassDAO) AddOrUpdateTaskClassItems(userID int, items []model.TaskClassItem) error {
|
||
if len(items) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 1) 校验这些 items 关联的 task_class(category_id)都属于当前用户
|
||
categoryIDSet := make(map[int]struct{}, len(items))
|
||
var categoryIDs []int
|
||
for _, it := range items {
|
||
if *it.CategoryID == 0 {
|
||
return gorm.ErrRecordNotFound
|
||
}
|
||
if _, ok := categoryIDSet[*it.CategoryID]; !ok {
|
||
categoryIDSet[*it.CategoryID] = struct{}{}
|
||
categoryIDs = append(categoryIDs, *it.CategoryID)
|
||
}
|
||
}
|
||
|
||
var count int64
|
||
if err := dao.db.Model(&model.TaskClass{}).
|
||
Where("id IN ? AND user_id = ?", categoryIDs, userID).
|
||
Count(&count).Error; err != nil {
|
||
return err
|
||
}
|
||
if count != int64(len(categoryIDs)) {
|
||
return respond.UserTaskClassForbidden
|
||
}
|
||
|
||
// 2. 收集本次要更新的已有 item,先做一次归属校验。
|
||
// 2.1 这里单独校验是为了避免“只更新成原值”时 RowsAffected=0 被误判为越权。
|
||
// 2.2 校验通过后,后面的 UPDATE 只关心数据库错误,不再用 RowsAffected 判权限。
|
||
// 2.3 若请求里混入了不属于这些 task_class 的 item_id,统一返回 UserTaskClassForbidden。
|
||
existingItemIDs := make([]int, 0, len(items))
|
||
existingItemIDSet := make(map[int]struct{}, len(items))
|
||
for _, it := range items {
|
||
if it.ID == 0 {
|
||
continue
|
||
}
|
||
if _, exists := existingItemIDSet[it.ID]; exists {
|
||
continue
|
||
}
|
||
existingItemIDSet[it.ID] = struct{}{}
|
||
existingItemIDs = append(existingItemIDs, it.ID)
|
||
}
|
||
if err := dao.ensureTaskClassItemsOwnedByCategories(existingItemIDs, categoryIDs); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 3) 新增与更新分开处理:新增直接插入;已有 item 按现有契约更新可编辑字段。
|
||
var toCreate []model.TaskClassItem
|
||
for _, it := range items {
|
||
if it.ID == 0 {
|
||
toCreate = append(toCreate, it)
|
||
continue
|
||
}
|
||
|
||
tx := dao.db.Model(&model.TaskClassItem{}).
|
||
Where("id = ? AND category_id IN ?", it.ID, categoryIDs).
|
||
Updates(map[string]any{
|
||
"category_id": it.CategoryID,
|
||
"order": it.Order,
|
||
"content": it.Content,
|
||
"embedded_time": it.EmbeddedTime,
|
||
})
|
||
if tx.Error != nil {
|
||
return tx.Error
|
||
}
|
||
}
|
||
|
||
if len(toCreate) > 0 {
|
||
if err := dao.db.Create(&toCreate).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ensureTaskClassOwned 只负责校验 task_class 是否属于当前用户。
|
||
func (dao *TaskClassDAO) ensureTaskClassOwned(userID int, taskClassID int) error {
|
||
var count int64
|
||
if err := dao.db.Model(&model.TaskClass{}).
|
||
Where("id = ? AND user_id = ?", taskClassID, userID).
|
||
Count(&count).Error; err != nil {
|
||
return err
|
||
}
|
||
if count == 0 {
|
||
return respond.UserTaskClassForbidden
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ensureTaskClassItemsOwnedByCategories 只负责校验一批 item 是否都挂在允许的 task_class 下。
|
||
func (dao *TaskClassDAO) ensureTaskClassItemsOwnedByCategories(itemIDs []int, categoryIDs []int) error {
|
||
if len(itemIDs) == 0 {
|
||
return nil
|
||
}
|
||
|
||
var count int64
|
||
if err := dao.db.Model(&model.TaskClassItem{}).
|
||
Where("id IN ? AND category_id IN ?", itemIDs, categoryIDs).
|
||
Count(&count).Error; err != nil {
|
||
return err
|
||
}
|
||
if count != int64(len(itemIDs)) {
|
||
return respond.UserTaskClassForbidden
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// buildTaskClassUpdateMap 负责把“全量更新”请求转换成显式列更新。
|
||
func buildTaskClassUpdateMap(taskClass *model.TaskClass) map[string]any {
|
||
return map[string]any{
|
||
"name": nullableStringValue(taskClass.Name),
|
||
"mode": nullableStringValue(taskClass.Mode),
|
||
"start_date": nullableTimeValue(taskClass.StartDate),
|
||
"end_date": nullableTimeValue(taskClass.EndDate),
|
||
"subject_type": nullableStringValue(taskClass.SubjectType),
|
||
"difficulty_level": nullableStringValue(taskClass.DifficultyLevel),
|
||
"cognitive_intensity": nullableStringValue(taskClass.CognitiveIntensity),
|
||
"total_slots": nullableIntValue(taskClass.TotalSlots),
|
||
"allow_filler_course": nullableBoolValue(taskClass.AllowFillerCourse),
|
||
"strategy": nullableStringValue(taskClass.Strategy),
|
||
"excluded_slots": taskClass.ExcludedSlots,
|
||
"excluded_days_of_week": taskClass.ExcludedDaysOfWeek,
|
||
}
|
||
}
|
||
|
||
func nullableStringValue(value *string) any {
|
||
if value == nil {
|
||
return nil
|
||
}
|
||
return *value
|
||
}
|
||
|
||
func nullableIntValue(value *int) any {
|
||
if value == nil {
|
||
return nil
|
||
}
|
||
return *value
|
||
}
|
||
|
||
func nullableBoolValue(value *bool) any {
|
||
if value == nil {
|
||
return nil
|
||
}
|
||
return *value
|
||
}
|
||
|
||
func nullableTimeValue(value *time.Time) any {
|
||
if value == nil {
|
||
return nil
|
||
}
|
||
return *value
|
||
}
|
||
|
||
// Transaction 在一个事务中执行传入的函数,供 service 层复用(自动提交/回滚)
|
||
// 规则:fn 返回 nil -> commit;fn 返回 error 或 panic -> rollback
|
||
func (dao *TaskClassDAO) Transaction(fn func(txDAO *TaskClassDAO) error) error {
|
||
return dao.db.Transaction(func(tx *gorm.DB) error {
|
||
return fn(NewTaskClassDAO(tx))
|
||
})
|
||
}
|
||
|
||
func (dao *TaskClassDAO) GetUserTaskClasses(userID int) ([]model.TaskClass, error) {
|
||
var taskClasses []model.TaskClass
|
||
err := dao.db.Where("user_id = ?", userID).Find(&taskClasses).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return taskClasses, nil
|
||
}
|
||
|
||
// GetCompleteTaskClassByID 带着 ID 和 UserID 去取,防越权
|
||
func (dao *TaskClassDAO) GetCompleteTaskClassByID(ctx context.Context, id int, userID int) (*model.TaskClass, error) {
|
||
var taskClass model.TaskClass
|
||
|
||
// 1. 使用 Preload("Items") 自动执行两条 SQL 并组装
|
||
// SQL A: SELECT * FROM task_classes WHERE id = ? AND user_id = ?
|
||
// SQL B: SELECT * FROM task_class_items WHERE category_id = (SQL A 的 ID)
|
||
err := dao.db.WithContext(ctx).
|
||
Preload("Items").
|
||
Where("id = ? AND user_id = ?", id, userID).
|
||
First(&taskClass).Error
|
||
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &taskClass, nil
|
||
}
|
||
|
||
// GetCompleteTaskClassesByIDs 批量获取“完整任务类”(含 Items)。
|
||
//
|
||
// 职责边界:
|
||
// 1. 负责按 user_id + ids 过滤,保证数据归属安全;
|
||
// 2. 负责预加载 Items,供智能粗排直接使用;
|
||
// 3. 不负责排序策略,返回结果顺序由 service 层决定;
|
||
// 4. 若存在任一 id 不存在或不属于该用户,返回 WrongTaskClassID。
|
||
func (dao *TaskClassDAO) GetCompleteTaskClassesByIDs(ctx context.Context, userID int, ids []int) ([]model.TaskClass, error) {
|
||
if len(ids) == 0 {
|
||
return []model.TaskClass{}, nil
|
||
}
|
||
|
||
// 1. 先做去重与合法值过滤,避免无效 ID 放大数据库压力。
|
||
uniqueIDs := make([]int, 0, len(ids))
|
||
seen := make(map[int]struct{}, len(ids))
|
||
for _, id := range ids {
|
||
if id <= 0 {
|
||
continue
|
||
}
|
||
if _, exists := seen[id]; exists {
|
||
continue
|
||
}
|
||
seen[id] = struct{}{}
|
||
uniqueIDs = append(uniqueIDs, id)
|
||
}
|
||
if len(uniqueIDs) == 0 {
|
||
return nil, respond.WrongTaskClassID
|
||
}
|
||
|
||
// 2. 批量查询并预加载任务项。
|
||
var taskClasses []model.TaskClass
|
||
err := dao.db.WithContext(ctx).
|
||
Preload("Items").
|
||
Where("user_id = ? AND id IN ?", userID, uniqueIDs).
|
||
Find(&taskClasses).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 3. 数量校验:少一条都视为“存在非法/越权 ID”,统一按业务错误返回。
|
||
if len(taskClasses) != len(uniqueIDs) {
|
||
return nil, respond.WrongTaskClassID
|
||
}
|
||
return taskClasses, nil
|
||
}
|
||
|
||
func (dao *TaskClassDAO) GetTaskClassItemByID(ctx context.Context, id int) (*model.TaskClassItem, error) {
|
||
var item model.TaskClassItem
|
||
err := dao.db.WithContext(ctx).
|
||
Where("id = ?", id).
|
||
First(&item).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &item, nil
|
||
}
|
||
|
||
func (dao *TaskClassDAO) GetTaskClassIDByTaskItemID(ctx context.Context, itemID int) (int, error) {
|
||
var item model.TaskClassItem
|
||
res := dao.db.WithContext(ctx).
|
||
Select("category_id").
|
||
Where("id = ?", itemID).
|
||
First(&item)
|
||
if res.Error != nil {
|
||
if errors.Is(res.Error, gorm.ErrRecordNotFound) {
|
||
return 0, respond.TaskClassItemNotFound
|
||
}
|
||
return 0, res.Error
|
||
}
|
||
return *item.CategoryID, nil
|
||
}
|
||
|
||
func (dao *TaskClassDAO) GetTaskClassUserIDByID(ctx context.Context, taskClassID int) (int, error) {
|
||
var taskClass model.TaskClass
|
||
err := dao.db.WithContext(ctx).
|
||
Select("user_id").
|
||
Where("id = ?", taskClassID).
|
||
First(&taskClass).Error
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return *taskClass.UserID, nil
|
||
}
|
||
|
||
func (dao *TaskClassDAO) UpdateTaskClassItemEmbeddedTime(ctx context.Context, taskID int, embeddedTime *model.TargetTime) error {
|
||
err := dao.db.WithContext(ctx).
|
||
Model(&model.TaskClassItem{}).
|
||
Where("id = ?", taskID).
|
||
Update("embedded_time", embeddedTime).Error
|
||
return err
|
||
}
|
||
|
||
func (dao *TaskClassDAO) DeleteTaskClassItemEmbeddedTime(ctx context.Context, taskID int) error {
|
||
err := dao.db.WithContext(ctx).
|
||
Model(&model.TaskClassItem{}).
|
||
Where("id = ?", taskID).
|
||
Update("embedded_time", nil).Error
|
||
return err
|
||
}
|
||
|
||
func (dao *TaskClassDAO) IfTaskClassItemArranged(ctx context.Context, taskID int) (bool, error) {
|
||
var item model.TaskClassItem
|
||
err := dao.db.WithContext(ctx).
|
||
Select("embedded_time").
|
||
Where("id = ?", taskID).
|
||
First(&item).Error
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return item.EmbeddedTime != nil, nil
|
||
}
|
||
|
||
func (dao *TaskClassDAO) BatchCheckIfTaskClassItemsArranged(ctx context.Context, itemIDs []int) (bool, error) {
|
||
if len(itemIDs) == 0 {
|
||
return false, nil
|
||
}
|
||
var count int64
|
||
err := dao.db.WithContext(ctx).
|
||
Model(&model.TaskClassItem{}).
|
||
Where("id IN ? AND embedded_time IS NOT NULL", itemIDs).
|
||
Count(&count).Error
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return count > 0, nil
|
||
}
|
||
|
||
func (dao *TaskClassDAO) DeleteTaskClassItemByID(ctx context.Context, id int) error {
|
||
err := dao.db.WithContext(ctx).
|
||
Where("id = ?", id).
|
||
Delete(&model.TaskClassItem{}).Error
|
||
return err
|
||
}
|
||
|
||
func (dao *TaskClassDAO) DeleteTaskClassByID(ctx context.Context, id int, userID int) error {
|
||
// 1. 删除时显式把 user_id 挂到 Model 上,供 GORM 缓存失效插件读取。
|
||
// 2. 业务层已经完成归属校验,这里仍带上 user_id 条件,避免极端并发下误删其它用户数据。
|
||
// 3. 若仍存在 task_items 外键依赖,GORM 会返回数据库错误并回滚,本函数不吞掉该错误。
|
||
res := dao.db.WithContext(ctx).
|
||
Model(&model.TaskClass{UserID: &userID}).
|
||
Where("id = ? AND user_id = ?", id, userID).
|
||
Delete(&model.TaskClass{})
|
||
if res.Error != nil {
|
||
return res.Error
|
||
}
|
||
if res.RowsAffected == 0 {
|
||
return respond.WrongTaskClassID
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (dao *TaskClassDAO) BatchUpdateTaskClassItemEmbeddedTime(ctx context.Context, itemIDs []int, updates []*model.TargetTime) error {
|
||
if len(itemIDs) == 0 {
|
||
return nil
|
||
}
|
||
if len(itemIDs) != len(updates) {
|
||
return errors.New("itemIDs length mismatch updates length")
|
||
}
|
||
|
||
// 单条 SQL 批量更新:UPDATE ... SET embedded_time = CASE id WHEN ? THEN ? ... END WHERE id IN (?)
|
||
caseSQL := "CASE id"
|
||
args := make([]any, 0, len(itemIDs)*2)
|
||
for i, id := range itemIDs {
|
||
caseSQL += " WHEN ? THEN ?"
|
||
args = append(args, id, updates[i])
|
||
}
|
||
caseSQL += " END"
|
||
|
||
res := dao.db.WithContext(ctx).
|
||
Model(&model.TaskClassItem{}).
|
||
Where("id IN ?", itemIDs).
|
||
Update("embedded_time", gorm.Expr(caseSQL, args...))
|
||
|
||
return res.Error
|
||
}
|
||
|
||
func (dao *TaskClassDAO) ValidateTaskItemIDsBelongToTaskClass(ctx context.Context, taskClassID int, itemIDs []int) (bool, error) {
|
||
if len(itemIDs) == 0 {
|
||
return true, nil
|
||
}
|
||
|
||
var count int64
|
||
err := dao.db.WithContext(ctx).
|
||
Model(&model.TaskClassItem{}).
|
||
Where("id IN ? AND category_id = ?", itemIDs, taskClassID).
|
||
Count(&count).Error
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return count == int64(len(itemIDs)), nil
|
||
}
|
||
|
||
func (dao *TaskClassDAO) GetTaskClassItemsByIDs(ctx context.Context, itemIDs []int) ([]model.TaskClassItem, error) {
|
||
var items []model.TaskClassItem
|
||
err := dao.db.WithContext(ctx).
|
||
Where("id IN ?", itemIDs).
|
||
Find(&items).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return items, nil
|
||
}
|