From 23343a8558cd5e98e1738c3bc0c549690cbc866a Mon Sep 17 00:00:00 2001 From: huang <1724659546@qq.com> Date: Tue, 16 Sep 2025 23:31:36 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E5=86=99task=E8=B0=83=E5=BA=A6?= =?UTF-8?q?=E5=99=A8=E4=B9=8B=E7=AE=80=E5=8D=95=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/infra/task/task.go | 323 ++++++++++++++----------------- internal/infra/task/task_test.go | 231 ---------------------- 2 files changed, 144 insertions(+), 410 deletions(-) diff --git a/internal/infra/task/task.go b/internal/infra/task/task.go index 9ae418a..cc6605c 100644 --- a/internal/infra/task/task.go +++ b/internal/infra/task/task.go @@ -1,223 +1,188 @@ package task import ( - "container/heap" "context" + "errors" "sync" "time" - "git.huangwc.com/pig/pig-farm-controller/internal/infra/logs" + "git.huangwc.com/pig/pig-farm-controller/internal/infra/models" + "git.huangwc.com/pig/pig-farm-controller/internal/infra/repository" + "gorm.io/gorm" ) -// Task 代表一个任务接口 -// 所有任务都需要实现此接口 -type Task interface { - // Execute 执行任务 - Execute() error - - // GetID 获取任务ID - GetID() string - - // GetPriority 获取任务优先级 - GetPriority() int - - // IsDone 检查任务是否已完成 - IsDone() bool - - // GetDescription 获取任务说明 - GetDescription() string +// Logger 定义了调度器期望的日志接口,方便替换为项目中的日志组件 +type Logger interface { + Printf(format string, v ...interface{}) } -// taskItem 任务队列中的元素 -type taskItem struct { - task Task - priority int - index int +// ProgressTracker 在内存中跟踪正在运行的计划的完成进度 +type ProgressTracker struct { + mu sync.Mutex + totalTasks map[uint]int // key: planExecutionLogID, value: total tasks + completedTasks map[uint]int // key: planExecutionLogID, value: completed tasks } -// Queue 代表任务队列 -type Queue struct { - // queue 任务队列(按优先级排序) - queue *priorityQueue - - // mutex 互斥锁 - mutex sync.Mutex - - // logger 日志记录器 - logger *logs.Logger -} - -// NewQueue 创建并返回一个新的任务队列实例。 -func NewQueue(logger *logs.Logger) *Queue { - pq := make(priorityQueue, 0) - heap.Init(&pq) - - return &Queue{ - queue: &pq, - logger: logger, +func NewProgressTracker() *ProgressTracker { + return &ProgressTracker{ + totalTasks: make(map[uint]int), + completedTasks: make(map[uint]int), } } -// AddTask 向队列中添加任务 -func (q *Queue) AddTask(task Task) { - q.mutex.Lock() - defer q.mutex.Unlock() - - item := &taskItem{ - task: task, - priority: task.GetPriority(), - } - heap.Push(q.queue, item) - q.logger.Infow("任务已添加到队列", "任务ID", task.GetID(), "任务描述", task.GetDescription()) +// StartTracking 开始跟踪一个新的计划执行 +func (t *ProgressTracker) StartTracking(planLogID uint, total int) { + t.mu.Lock() + defer t.mu.Unlock() + t.totalTasks[planLogID] = total + t.completedTasks[planLogID] = 0 } -// GetNextTask 获取下一个要执行的任务(优先级最高的任务) -func (q *Queue) GetNextTask() Task { - q.mutex.Lock() - defer q.mutex.Unlock() - - if q.queue.Len() == 0 { - return nil - } - - item := heap.Pop(q.queue).(*taskItem) - q.logger.Infow("从队列中获取任务", "任务ID", item.task.GetID(), "任务描述", item.task.GetDescription()) - return item.task +// Increment 将指定计划的完成计数加一 +func (t *ProgressTracker) Increment(planLogID uint) { + t.mu.Lock() + defer t.mu.Unlock() + t.completedTasks[planLogID]++ } -// GetTaskCount 获取队列中的任务数量 -func (q *Queue) GetTaskCount() int { - q.mutex.Lock() - defer q.mutex.Unlock() - - return q.queue.Len() +// IsComplete 检查指定计划是否已完成所有任务 +func (t *ProgressTracker) IsComplete(planLogID uint) bool { + t.mu.Lock() + defer t.mu.Unlock() + return t.completedTasks[planLogID] >= t.totalTasks[planLogID] } -// priorityQueue 实现优先级队列 -type priorityQueue []*taskItem - -func (pq priorityQueue) Len() int { return len(pq) } - -func (pq priorityQueue) Less(i, j int) bool { - return pq[i].priority < pq[j].priority +// StopTracking 停止跟踪一个计划,清理内存 +func (t *ProgressTracker) StopTracking(planLogID uint) { + t.mu.Lock() + defer t.mu.Unlock() + delete(t.totalTasks, planLogID) + delete(t.completedTasks, planLogID) } -func (pq priorityQueue) Swap(i, j int) { - pq[i], pq[j] = pq[j], pq[i] - pq[i].index = i - pq[j].index = j +// Scheduler 是核心的、持久化的任务调度器 +type Scheduler struct { + logger Logger + pollingInterval time.Duration + workers int + pendingTaskRepo repository.PendingTaskRepository + progressTracker *ProgressTracker + + taskChannel chan *models.TaskExecutionLog // 用于向 workers 派发任务的缓冲 channel + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc } -func (pq *priorityQueue) Push(x interface{}) { - n := len(*pq) - item := x.(*taskItem) - item.index = n - *pq = append(*pq, item) -} - -func (pq *priorityQueue) Pop() interface{} { - old := *pq - n := len(old) - item := old[n-1] - old[n-1] = nil // 避免内存泄漏 - item.index = -1 // 无效索引 - *pq = old[0 : n-1] - return item -} - -// Executor 代表任务执行器 -type Executor struct { - // queue 任务队列 - queue *Queue - - // workers 工作协程数量 - workers int - - // ctx 执行上下文 - ctx context.Context - - // cancel 取消函数 - cancel context.CancelFunc - - // wg 等待组 - wg sync.WaitGroup - - // logger 日志记录器 - logger *logs.Logger -} - -// NewExecutor 创建并返回一个新的任务执行器实例。 -func NewExecutor(workers int, logger *logs.Logger) *Executor { +// NewScheduler 创建一个新的调度器实例 +func NewScheduler(pendingTaskRepo repository.PendingTaskRepository, logger Logger, interval time.Duration, numWorkers int) *Scheduler { ctx, cancel := context.WithCancel(context.Background()) - - return &Executor{ - queue: NewQueue(logger), // 将 logger 传递给 Queue - workers: workers, - ctx: ctx, - cancel: cancel, - logger: logger, + return &Scheduler{ + pendingTaskRepo: pendingTaskRepo, + logger: logger, + pollingInterval: interval, + workers: numWorkers, + progressTracker: NewProgressTracker(), + taskChannel: make(chan *models.TaskExecutionLog, numWorkers), // 缓冲大小与 worker 数量一致 + ctx: ctx, + cancel: cancel, } } -// Start 启动任务执行器 -func (e *Executor) Start() { - e.logger.Infow("正在启动任务执行器", "工作协程数", e.workers) +// Start 启动调度器,包括主轮询循环和所有工作协程 +func (s *Scheduler) Start() { + s.logger.Printf("任务调度器正在启动,工作协程数: %d...", s.workers) - // 启动工作协程 - for i := 0; i < e.workers; i++ { - e.wg.Add(1) - go e.worker(i) + // 启动工作协程池 + s.wg.Add(s.workers) + for i := 0; i < s.workers; i++ { + go s.worker(i) } - e.logger.Info("任务执行器启动成功") + // 启动主轮询循环 + s.wg.Add(1) + go s.run() + + s.logger.Printf("任务调度器已成功启动") } -// Stop 停止任务执行器 -func (e *Executor) Stop() { - e.logger.Info("正在停止任务执行器") - - // 取消上下文 - e.cancel() - - // 等待所有工作协程结束 - e.wg.Wait() - - e.logger.Info("任务执行器已停止") +// Stop 优雅地停止调度器和所有工作协程 +func (s *Scheduler) Stop() { + s.logger.Printf("正在停止任务调度器...") + s.cancel() // 发出取消信号 + s.wg.Wait() // 等待所有协程完成 + s.logger.Printf("任务调度器已安全停止") } -// SubmitTask 提交任务到执行器 -func (e *Executor) SubmitTask(task Task) { - e.queue.AddTask(task) - e.logger.Infow("任务已提交", "任务ID", task.GetID(), "任务描述", task.GetDescription()) -} - -// worker 工作协程 -func (e *Executor) worker(id int) { - defer e.wg.Done() - - e.logger.Infow("工作协程已启动", "工作协程ID", id) +// run 是主轮询循环,负责从数据库认领任务并派发 +func (s *Scheduler) run() { + defer s.wg.Done() + ticker := time.NewTicker(s.pollingInterval) + defer ticker.Stop() for { select { - case <-e.ctx.Done(): - e.logger.Infow("工作协程已停止", "工作协程ID", id) + case <-s.ctx.Done(): + close(s.taskChannel) // 关闭 channel,让 workers 退出循环 return - default: - // 获取下一个任务 - task := e.queue.GetNextTask() - if task != nil { - e.logger.Infow("工作协程正在执行任务", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription()) - - // 执行任务 - if err := task.Execute(); err != nil { - e.logger.Errorw("任务执行失败", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription(), "错误", err) - } else { - e.logger.Infow("任务执行成功", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription()) - } - } else { - // 没有任务时短暂休眠 - time.Sleep(100 * time.Millisecond) - } + case <-ticker.C: + s.claimAndDispatch() } } } + +// claimAndDispatch 认领一个任务并将其发送到派发通道 +func (s *Scheduler) claimAndDispatch() { + claimedLog, err := s.pendingTaskRepo.ClaimNextDueTask() + if err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + s.logger.Printf("认领任务时发生错误: %v", err) + } + return + } + + // 将认领到的任务发送到派发通道 + // 如果所有 worker 都在忙,这里会阻塞,从而实现背压,防止队列无限增长 + select { + case s.taskChannel <- claimedLog: + s.logger.Printf("成功认领并派发任务, 日志ID: %d, 任务ID: %d", claimedLog.ID, claimedLog.TaskID) + case <-s.ctx.Done(): + // 如果在等待派发时调度器被停止,需要处理这个未派发的任务 + // 简单的处理方式是忽略它,让清理器进程后续来处理这个 'running' 状态的任务 + s.logger.Printf("在派发任务时调度器被停止, 日志ID: %d", claimedLog.ID) + } +} + +// worker 是工作协程的实现 +func (s *Scheduler) worker(id int) { + defer s.wg.Done() + s.logger.Printf("工作协程 #%d 已启动", id) + for claimedLog := range s.taskChannel { + s.processTask(id, claimedLog) + } + s.logger.Printf("工作协程 #%d 已停止", id) +} + +// processTask 包含了处理单个任务的完整逻辑 +func (s *Scheduler) processTask(workerID int, claimedLog *models.TaskExecutionLog) { + s.logger.Printf("工作协程 #%d 正在处理任务, 日志ID: %d, 任务ID: %d, 任务名称: %s", + workerID, claimedLog.ID, claimedLog.TaskID, claimedLog.Task.Name) + + // 在这里,我们将根据 claimedLog.TaskID 或未来的 Task.Kind 来分发给不同的处理器 + // 现在,我们只做一个模拟执行 + time.Sleep(2 * time.Second) // 模拟任务执行耗时 + + // 任务执行完毕后,更新日志和进度 + s.logger.Printf("工作协程 #%d 已完成任务, 日志ID: %d", workerID, claimedLog.ID) + + // ---------------------------------------------------- + // 未来的逻辑将在这里展开: + // + // 1. 调用 handler.Handle(claimedLog) + // 2. 根据 handler 返回的 error 更新日志为 'completed' 或 'failed' + // execLogRepo.UpdateTaskExecutionLog(...) + // 3. 如果成功,则 s.progressTracker.Increment(claimedLog.PlanExecutionLogID) + // 4. 检查 s.progressTracker.IsComplete(...),如果完成则执行计划收尾工作 + // + // ---------------------------------------------------- +} diff --git a/internal/infra/task/task_test.go b/internal/infra/task/task_test.go index 52fb78d..2cc19e1 100644 --- a/internal/infra/task/task_test.go +++ b/internal/infra/task/task_test.go @@ -2,8 +2,6 @@ package task_test import ( - "errors" - "fmt" "sync" "sync/atomic" "testing" @@ -11,8 +9,6 @@ import ( "git.huangwc.com/pig/pig-farm-controller/internal/infra/config" "git.huangwc.com/pig/pig-farm-controller/internal/infra/logs" - "git.huangwc.com/pig/pig-farm-controller/internal/infra/task" - "github.com/stretchr/testify/assert" ) // testLogger 是一个用于所有测试用例的静默 logger 实例。 @@ -79,230 +75,3 @@ func waitForWaitGroup(t *testing.T, wg *sync.WaitGroup, timeout time.Duration) { t.Fatal("等待任务完成超时") } } - -// --- 任务队列测试 (无需更改) --- - -func TestNewQueue(t *testing.T) { - tq := task.NewQueue(testLogger) - assert.NotNil(t, tq, "新创建的任务队列不应为 nil") - assert.Equal(t, 0, tq.GetTaskCount(), "新创建的任务队列应为空") -} - -func TestQueue_AddTask(t *testing.T) { - tq := task.NewQueue(testLogger) - mockTask := &MockTask{id: "task1", priority: 1} - - tq.AddTask(mockTask) - assert.Equal(t, 1, tq.GetTaskCount(), "添加任务后,队列中的任务数应为 1") -} - -// ... (其他任务队列测试保持不变) -func TestQueue_GetNextTask(t *testing.T) { - t.Run("从空队列获取任务", func(t *testing.T) { - tq := task.NewQueue(testLogger) - nextTask := tq.GetNextTask() - assert.Nil(t, nextTask, "从空队列中获取任务应返回 nil") - }) - - t.Run("按优先级获取任务", func(t *testing.T) { - tq := task.NewQueue(testLogger) - task1 := &MockTask{id: "task1", priority: 10} - task2 := &MockTask{id: "task2", priority: 1} // 优先级更高 - task3 := &MockTask{id: "task3", priority: 5} - - tq.AddTask(task1) - tq.AddTask(task2) - tq.AddTask(task3) - - assert.Equal(t, 3, tq.GetTaskCount(), "添加三个任务后,队列中的任务数应为 3") - - nextTask := tq.GetNextTask() - assert.NotNil(t, nextTask) - assert.Equal(t, "task2", nextTask.GetID(), "应首先获取优先级最高的任务 (task2)") - - nextTask = tq.GetNextTask() - assert.NotNil(t, nextTask) - assert.Equal(t, "task3", nextTask.GetID(), "应获取下一个优先级最高的任务 (task3)") - - nextTask = tq.GetNextTask() - assert.NotNil(t, nextTask) - assert.Equal(t, "task1", nextTask.GetID(), "应最后获取优先级最低的任务 (task1)") - - assert.Equal(t, 0, tq.GetTaskCount(), "获取所有任务后,队列应为空") - }) -} - -func TestQueue_Concurrency(t *testing.T) { - tq := task.NewQueue(testLogger) - var wg sync.WaitGroup - taskCount := 100 - - wg.Add(taskCount) - for i := 0; i < taskCount; i++ { - go func(i int) { - defer wg.Done() - tq.AddTask(&MockTask{id: fmt.Sprintf("task-%d", i), priority: i}) - }(i) - } - wg.Wait() - - assert.Equal(t, taskCount, tq.GetTaskCount(), "并发添加任务后,队列中的任务数应为 %d", taskCount) - - wg.Add(taskCount) - for i := 0; i < taskCount; i++ { - go func() { - defer wg.Done() - task := tq.GetNextTask() - assert.NotNil(t, task) - }() - } - wg.Wait() - - assert.Equal(t, 0, tq.GetTaskCount(), "并发获取所有任务后,队列应为空") -} - -// --- 执行器测试 (为可靠性重构) --- - -func TestNewExecutor(t *testing.T) { - executor := task.NewExecutor(5, testLogger) - assert.NotNil(t, executor, "新创建的执行器不应为 nil") -} - -func TestExecutor_StartStop(t *testing.T) { - executor := task.NewExecutor(2, testLogger) - executor.Start() - // 确保立即停止不会导致死锁或竞争条件。 - executor.Stop() -} - -// TestExecutor_SubmitAndExecuteTask 测试提交并执行单个任务 (已重构,更可靠) -func TestExecutor_SubmitAndExecuteTask(t *testing.T) { - var wg sync.WaitGroup - wg.Add(1) - - executor := task.NewExecutor(1, testLogger) - mockTask := &MockTask{ - id: "task1", - priority: 1, - execute: func() error { - wg.Done() // 任务完成时通知 WaitGroup - return nil - }, - } - - executor.Start() - executor.SubmitTask(mockTask) - - // 等待任务完成,设置一个合理的超时时间 - waitForWaitGroup(t, &wg, 2*time.Second) - - executor.Stop() - - assert.Equal(t, int32(1), mockTask.ExecutedCount(), "任务应该已被执行") -} - -// TestExecutor_ExecuteMultipleTasks 测试执行多个任务 (已重构,更可靠) -func TestExecutor_ExecuteMultipleTasks(t *testing.T) { - taskCount := 10 - var wg sync.WaitGroup - wg.Add(taskCount) - - executor := task.NewExecutor(3, testLogger) - mockTasks := make([]*MockTask, taskCount) - for i := 0; i < taskCount; i++ { - mockTasks[i] = &MockTask{ - id: fmt.Sprintf("task-%d", i), - priority: i, - execute: func() error { - wg.Done() // 每个任务完成时都通知 WaitGroup - return nil - }, - } - } - - executor.Start() - for _, task := range mockTasks { - executor.SubmitTask(task) - } - - // 等待所有任务完成 - waitForWaitGroup(t, &wg, 2*time.Second) - - executor.Stop() - - var totalExecuted int32 - for _, task := range mockTasks { - totalExecuted += task.ExecutedCount() - } - - assert.Equal(t, int32(taskCount), totalExecuted, "所有提交的任务都应该被执行") -} - -// TestExecutor_TaskExecutionError 测试任务执行失败的场景 (已重构,更可靠) -func TestExecutor_TaskExecutionError(t *testing.T) { - var wg sync.WaitGroup - wg.Add(2) // 我们期望两个任务都被执行 - - executor := task.NewExecutor(1, testLogger) - errorTask := &MockTask{ - id: "errorTask", - priority: 1, - execute: func() error { - wg.Done() - return errors.New("执行失败") - }, - } - - successTask := &MockTask{ - id: "successTask", - priority: 2, // 后执行 - execute: func() error { - wg.Done() - return nil - }, - } - - executor.Start() - executor.SubmitTask(errorTask) - executor.SubmitTask(successTask) - - waitForWaitGroup(t, &wg, 2*time.Second) - executor.Stop() - - assert.Equal(t, int32(1), errorTask.ExecutedCount(), "失败的任务应该被执行一次") - assert.Equal(t, int32(1), successTask.ExecutedCount(), "成功的任务也应该被执行") -} - -// TestExecutor_StopWithPendingTasks 测试停止执行器时仍有待处理任务 (已重构,更可靠) -func TestExecutor_StopWithPendingTasks(t *testing.T) { - executor := task.NewExecutor(1, testLogger) - task1Started := make(chan struct{}) - - task1 := &MockTask{ - id: "task1", - priority: 1, - execute: func() error { - close(task1Started) // 发送信号,通知测试 task1 已开始执行 - time.Sleep(200 * time.Millisecond) // 模拟耗时操作 - return nil - }, - } - task2 := &MockTask{id: "task2", priority: 2} - - executor.Start() - executor.SubmitTask(task1) - executor.SubmitTask(task2) - - // 等待 task1 开始执行的信号,而不是依赖不确定的 sleep - select { - case <-task1Started: - // task1 已开始,可以安全地停止执行器了 - case <-time.After(1 * time.Second): - t.Fatal("等待 task1 启动超时") - } - - executor.Stop() - - assert.Equal(t, int32(1), task1.ExecutedCount(), "task1 应该在停止前开始执行") - assert.Equal(t, int32(0), task2.ExecutedCount(), "task2 不应该被执行,因为执行器已停止") -}