重写task调度器之简单实现

This commit is contained in:
2025-09-16 23:31:36 +08:00
parent 3271f820d4
commit 23343a8558
2 changed files with 144 additions and 410 deletions

View File

@@ -1,223 +1,188 @@
package task package task
import ( import (
"container/heap"
"context" "context"
"errors"
"sync" "sync"
"time" "time"
"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs" "git.huangwc.com/pig/pig-farm-controller/internal/infra/models"
"git.huangwc.com/pig/pig-farm-controller/internal/infra/repository"
"gorm.io/gorm"
) )
// Task 代表一个任务接口 // Logger 定义了调度器期望的日志接口,方便替换为项目中的日志组件
// 所有任务都需要实现此接口 type Logger interface {
type Task interface { Printf(format string, v ...interface{})
// Execute 执行任务
Execute() error
// GetID 获取任务ID
GetID() string
// GetPriority 获取任务优先级
GetPriority() int
// IsDone 检查任务是否已完成
IsDone() bool
// GetDescription 获取任务说明
GetDescription() string
} }
// taskItem 任务队列中的元素 // ProgressTracker 在内存中跟踪正在运行的计划的完成进度
type taskItem struct { type ProgressTracker struct {
task Task mu sync.Mutex
priority int totalTasks map[uint]int // key: planExecutionLogID, value: total tasks
index int completedTasks map[uint]int // key: planExecutionLogID, value: completed tasks
} }
// Queue 代表任务队列 func NewProgressTracker() *ProgressTracker {
type Queue struct { return &ProgressTracker{
// queue 任务队列(按优先级排序) totalTasks: make(map[uint]int),
queue *priorityQueue completedTasks: make(map[uint]int),
// mutex 互斥锁
mutex sync.Mutex
// logger 日志记录器
logger *logs.Logger
}
// NewQueue 创建并返回一个新的任务队列实例。
func NewQueue(logger *logs.Logger) *Queue {
pq := make(priorityQueue, 0)
heap.Init(&pq)
return &Queue{
queue: &pq,
logger: logger,
} }
} }
// AddTask 向队列中添加任务 // StartTracking 开始跟踪一个新的计划执行
func (q *Queue) AddTask(task Task) { func (t *ProgressTracker) StartTracking(planLogID uint, total int) {
q.mutex.Lock() t.mu.Lock()
defer q.mutex.Unlock() defer t.mu.Unlock()
t.totalTasks[planLogID] = total
item := &taskItem{ t.completedTasks[planLogID] = 0
task: task,
priority: task.GetPriority(),
}
heap.Push(q.queue, item)
q.logger.Infow("任务已添加到队列", "任务ID", task.GetID(), "任务描述", task.GetDescription())
} }
// GetNextTask 获取下一个要执行的任务(优先级最高的任务) // Increment 将指定计划的完成计数加一
func (q *Queue) GetNextTask() Task { func (t *ProgressTracker) Increment(planLogID uint) {
q.mutex.Lock() t.mu.Lock()
defer q.mutex.Unlock() defer t.mu.Unlock()
t.completedTasks[planLogID]++
if q.queue.Len() == 0 {
return nil
}
item := heap.Pop(q.queue).(*taskItem)
q.logger.Infow("从队列中获取任务", "任务ID", item.task.GetID(), "任务描述", item.task.GetDescription())
return item.task
} }
// GetTaskCount 获取队列中的任务数量 // IsComplete 检查指定计划是否已完成所有任务
func (q *Queue) GetTaskCount() int { func (t *ProgressTracker) IsComplete(planLogID uint) bool {
q.mutex.Lock() t.mu.Lock()
defer q.mutex.Unlock() defer t.mu.Unlock()
return t.completedTasks[planLogID] >= t.totalTasks[planLogID]
return q.queue.Len()
} }
// priorityQueue 实现优先级队列 // StopTracking 停止跟踪一个计划,清理内存
type priorityQueue []*taskItem func (t *ProgressTracker) StopTracking(planLogID uint) {
t.mu.Lock()
func (pq priorityQueue) Len() int { return len(pq) } defer t.mu.Unlock()
delete(t.totalTasks, planLogID)
func (pq priorityQueue) Less(i, j int) bool { delete(t.completedTasks, planLogID)
return pq[i].priority < pq[j].priority
} }
func (pq priorityQueue) Swap(i, j int) { // Scheduler 是核心的、持久化的任务调度器
pq[i], pq[j] = pq[j], pq[i] type Scheduler struct {
pq[i].index = i logger Logger
pq[j].index = j pollingInterval time.Duration
workers int
pendingTaskRepo repository.PendingTaskRepository
progressTracker *ProgressTracker
taskChannel chan *models.TaskExecutionLog // 用于向 workers 派发任务的缓冲 channel
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
} }
func (pq *priorityQueue) Push(x interface{}) { // NewScheduler 创建一个新的调度器实例
n := len(*pq) func NewScheduler(pendingTaskRepo repository.PendingTaskRepository, logger Logger, interval time.Duration, numWorkers int) *Scheduler {
item := x.(*taskItem)
item.index = n
*pq = append(*pq, item)
}
func (pq *priorityQueue) Pop() interface{} {
old := *pq
n := len(old)
item := old[n-1]
old[n-1] = nil // 避免内存泄漏
item.index = -1 // 无效索引
*pq = old[0 : n-1]
return item
}
// Executor 代表任务执行器
type Executor struct {
// queue 任务队列
queue *Queue
// workers 工作协程数量
workers int
// ctx 执行上下文
ctx context.Context
// cancel 取消函数
cancel context.CancelFunc
// wg 等待组
wg sync.WaitGroup
// logger 日志记录器
logger *logs.Logger
}
// NewExecutor 创建并返回一个新的任务执行器实例。
func NewExecutor(workers int, logger *logs.Logger) *Executor {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &Scheduler{
return &Executor{ pendingTaskRepo: pendingTaskRepo,
queue: NewQueue(logger), // 将 logger 传递给 Queue logger: logger,
workers: workers, pollingInterval: interval,
ctx: ctx, workers: numWorkers,
cancel: cancel, progressTracker: NewProgressTracker(),
logger: logger, taskChannel: make(chan *models.TaskExecutionLog, numWorkers), // 缓冲大小与 worker 数量一致
ctx: ctx,
cancel: cancel,
} }
} }
// Start 启动任务执行器 // Start 启动调度器,包括主轮询循环和所有工作协程
func (e *Executor) Start() { func (s *Scheduler) Start() {
e.logger.Infow("正在启动任务执行器", "工作协程数", e.workers) s.logger.Printf("任务调度器正在启动,工作协程数: %d...", s.workers)
// 启动工作协程 // 启动工作协程
for i := 0; i < e.workers; i++ { s.wg.Add(s.workers)
e.wg.Add(1) for i := 0; i < s.workers; i++ {
go e.worker(i) go s.worker(i)
} }
e.logger.Info("任务执行器启动成功") // 启动主轮询循环
s.wg.Add(1)
go s.run()
s.logger.Printf("任务调度器已成功启动")
} }
// Stop 停止任务执行器 // Stop 优雅地停止调度器和所有工作协程
func (e *Executor) Stop() { func (s *Scheduler) Stop() {
e.logger.Info("正在停止任务执行器") s.logger.Printf("正在停止任务调度器...")
s.cancel() // 发出取消信号
// 取消上下文 s.wg.Wait() // 等待所有协程完成
e.cancel() s.logger.Printf("任务调度器已安全停止")
// 等待所有工作协程结束
e.wg.Wait()
e.logger.Info("任务执行器已停止")
} }
// SubmitTask 提交任务到执行器 // run 是主轮询循环,负责从数据库认领任务并派发
func (e *Executor) SubmitTask(task Task) { func (s *Scheduler) run() {
e.queue.AddTask(task) defer s.wg.Done()
e.logger.Infow("任务已提交", "任务ID", task.GetID(), "任务描述", task.GetDescription()) ticker := time.NewTicker(s.pollingInterval)
} defer ticker.Stop()
// worker 工作协程
func (e *Executor) worker(id int) {
defer e.wg.Done()
e.logger.Infow("工作协程已启动", "工作协程ID", id)
for { for {
select { select {
case <-e.ctx.Done(): case <-s.ctx.Done():
e.logger.Infow("工作协程已停止", "工作协程ID", id) close(s.taskChannel) // 关闭 channel让 workers 退出循环
return return
default: case <-ticker.C:
// 获取下一个任务 s.claimAndDispatch()
task := e.queue.GetNextTask()
if task != nil {
e.logger.Infow("工作协程正在执行任务", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription())
// 执行任务
if err := task.Execute(); err != nil {
e.logger.Errorw("任务执行失败", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription(), "错误", err)
} else {
e.logger.Infow("任务执行成功", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription())
}
} else {
// 没有任务时短暂休眠
time.Sleep(100 * time.Millisecond)
}
} }
} }
} }
// claimAndDispatch 认领一个任务并将其发送到派发通道
func (s *Scheduler) claimAndDispatch() {
claimedLog, err := s.pendingTaskRepo.ClaimNextDueTask()
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
s.logger.Printf("认领任务时发生错误: %v", err)
}
return
}
// 将认领到的任务发送到派发通道
// 如果所有 worker 都在忙,这里会阻塞,从而实现背压,防止队列无限增长
select {
case s.taskChannel <- claimedLog:
s.logger.Printf("成功认领并派发任务, 日志ID: %d, 任务ID: %d", claimedLog.ID, claimedLog.TaskID)
case <-s.ctx.Done():
// 如果在等待派发时调度器被停止,需要处理这个未派发的任务
// 简单的处理方式是忽略它,让清理器进程后续来处理这个 'running' 状态的任务
s.logger.Printf("在派发任务时调度器被停止, 日志ID: %d", claimedLog.ID)
}
}
// worker 是工作协程的实现
func (s *Scheduler) worker(id int) {
defer s.wg.Done()
s.logger.Printf("工作协程 #%d 已启动", id)
for claimedLog := range s.taskChannel {
s.processTask(id, claimedLog)
}
s.logger.Printf("工作协程 #%d 已停止", id)
}
// processTask 包含了处理单个任务的完整逻辑
func (s *Scheduler) processTask(workerID int, claimedLog *models.TaskExecutionLog) {
s.logger.Printf("工作协程 #%d 正在处理任务, 日志ID: %d, 任务ID: %d, 任务名称: %s",
workerID, claimedLog.ID, claimedLog.TaskID, claimedLog.Task.Name)
// 在这里,我们将根据 claimedLog.TaskID 或未来的 Task.Kind 来分发给不同的处理器
// 现在,我们只做一个模拟执行
time.Sleep(2 * time.Second) // 模拟任务执行耗时
// 任务执行完毕后,更新日志和进度
s.logger.Printf("工作协程 #%d 已完成任务, 日志ID: %d", workerID, claimedLog.ID)
// ----------------------------------------------------
// 未来的逻辑将在这里展开:
//
// 1. 调用 handler.Handle(claimedLog)
// 2. 根据 handler 返回的 error 更新日志为 'completed' 或 'failed'
// execLogRepo.UpdateTaskExecutionLog(...)
// 3. 如果成功,则 s.progressTracker.Increment(claimedLog.PlanExecutionLogID)
// 4. 检查 s.progressTracker.IsComplete(...),如果完成则执行计划收尾工作
//
// ----------------------------------------------------
}

View File

@@ -2,8 +2,6 @@
package task_test package task_test
import ( import (
"errors"
"fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
@@ -11,8 +9,6 @@ import (
"git.huangwc.com/pig/pig-farm-controller/internal/infra/config" "git.huangwc.com/pig/pig-farm-controller/internal/infra/config"
"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs" "git.huangwc.com/pig/pig-farm-controller/internal/infra/logs"
"git.huangwc.com/pig/pig-farm-controller/internal/infra/task"
"github.com/stretchr/testify/assert"
) )
// testLogger 是一个用于所有测试用例的静默 logger 实例。 // testLogger 是一个用于所有测试用例的静默 logger 实例。
@@ -79,230 +75,3 @@ func waitForWaitGroup(t *testing.T, wg *sync.WaitGroup, timeout time.Duration) {
t.Fatal("等待任务完成超时") t.Fatal("等待任务完成超时")
} }
} }
// --- 任务队列测试 (无需更改) ---
func TestNewQueue(t *testing.T) {
tq := task.NewQueue(testLogger)
assert.NotNil(t, tq, "新创建的任务队列不应为 nil")
assert.Equal(t, 0, tq.GetTaskCount(), "新创建的任务队列应为空")
}
func TestQueue_AddTask(t *testing.T) {
tq := task.NewQueue(testLogger)
mockTask := &MockTask{id: "task1", priority: 1}
tq.AddTask(mockTask)
assert.Equal(t, 1, tq.GetTaskCount(), "添加任务后,队列中的任务数应为 1")
}
// ... (其他任务队列测试保持不变)
func TestQueue_GetNextTask(t *testing.T) {
t.Run("从空队列获取任务", func(t *testing.T) {
tq := task.NewQueue(testLogger)
nextTask := tq.GetNextTask()
assert.Nil(t, nextTask, "从空队列中获取任务应返回 nil")
})
t.Run("按优先级获取任务", func(t *testing.T) {
tq := task.NewQueue(testLogger)
task1 := &MockTask{id: "task1", priority: 10}
task2 := &MockTask{id: "task2", priority: 1} // 优先级更高
task3 := &MockTask{id: "task3", priority: 5}
tq.AddTask(task1)
tq.AddTask(task2)
tq.AddTask(task3)
assert.Equal(t, 3, tq.GetTaskCount(), "添加三个任务后,队列中的任务数应为 3")
nextTask := tq.GetNextTask()
assert.NotNil(t, nextTask)
assert.Equal(t, "task2", nextTask.GetID(), "应首先获取优先级最高的任务 (task2)")
nextTask = tq.GetNextTask()
assert.NotNil(t, nextTask)
assert.Equal(t, "task3", nextTask.GetID(), "应获取下一个优先级最高的任务 (task3)")
nextTask = tq.GetNextTask()
assert.NotNil(t, nextTask)
assert.Equal(t, "task1", nextTask.GetID(), "应最后获取优先级最低的任务 (task1)")
assert.Equal(t, 0, tq.GetTaskCount(), "获取所有任务后,队列应为空")
})
}
func TestQueue_Concurrency(t *testing.T) {
tq := task.NewQueue(testLogger)
var wg sync.WaitGroup
taskCount := 100
wg.Add(taskCount)
for i := 0; i < taskCount; i++ {
go func(i int) {
defer wg.Done()
tq.AddTask(&MockTask{id: fmt.Sprintf("task-%d", i), priority: i})
}(i)
}
wg.Wait()
assert.Equal(t, taskCount, tq.GetTaskCount(), "并发添加任务后,队列中的任务数应为 %d", taskCount)
wg.Add(taskCount)
for i := 0; i < taskCount; i++ {
go func() {
defer wg.Done()
task := tq.GetNextTask()
assert.NotNil(t, task)
}()
}
wg.Wait()
assert.Equal(t, 0, tq.GetTaskCount(), "并发获取所有任务后,队列应为空")
}
// --- 执行器测试 (为可靠性重构) ---
func TestNewExecutor(t *testing.T) {
executor := task.NewExecutor(5, testLogger)
assert.NotNil(t, executor, "新创建的执行器不应为 nil")
}
func TestExecutor_StartStop(t *testing.T) {
executor := task.NewExecutor(2, testLogger)
executor.Start()
// 确保立即停止不会导致死锁或竞争条件。
executor.Stop()
}
// TestExecutor_SubmitAndExecuteTask 测试提交并执行单个任务 (已重构,更可靠)
func TestExecutor_SubmitAndExecuteTask(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
executor := task.NewExecutor(1, testLogger)
mockTask := &MockTask{
id: "task1",
priority: 1,
execute: func() error {
wg.Done() // 任务完成时通知 WaitGroup
return nil
},
}
executor.Start()
executor.SubmitTask(mockTask)
// 等待任务完成,设置一个合理的超时时间
waitForWaitGroup(t, &wg, 2*time.Second)
executor.Stop()
assert.Equal(t, int32(1), mockTask.ExecutedCount(), "任务应该已被执行")
}
// TestExecutor_ExecuteMultipleTasks 测试执行多个任务 (已重构,更可靠)
func TestExecutor_ExecuteMultipleTasks(t *testing.T) {
taskCount := 10
var wg sync.WaitGroup
wg.Add(taskCount)
executor := task.NewExecutor(3, testLogger)
mockTasks := make([]*MockTask, taskCount)
for i := 0; i < taskCount; i++ {
mockTasks[i] = &MockTask{
id: fmt.Sprintf("task-%d", i),
priority: i,
execute: func() error {
wg.Done() // 每个任务完成时都通知 WaitGroup
return nil
},
}
}
executor.Start()
for _, task := range mockTasks {
executor.SubmitTask(task)
}
// 等待所有任务完成
waitForWaitGroup(t, &wg, 2*time.Second)
executor.Stop()
var totalExecuted int32
for _, task := range mockTasks {
totalExecuted += task.ExecutedCount()
}
assert.Equal(t, int32(taskCount), totalExecuted, "所有提交的任务都应该被执行")
}
// TestExecutor_TaskExecutionError 测试任务执行失败的场景 (已重构,更可靠)
func TestExecutor_TaskExecutionError(t *testing.T) {
var wg sync.WaitGroup
wg.Add(2) // 我们期望两个任务都被执行
executor := task.NewExecutor(1, testLogger)
errorTask := &MockTask{
id: "errorTask",
priority: 1,
execute: func() error {
wg.Done()
return errors.New("执行失败")
},
}
successTask := &MockTask{
id: "successTask",
priority: 2, // 后执行
execute: func() error {
wg.Done()
return nil
},
}
executor.Start()
executor.SubmitTask(errorTask)
executor.SubmitTask(successTask)
waitForWaitGroup(t, &wg, 2*time.Second)
executor.Stop()
assert.Equal(t, int32(1), errorTask.ExecutedCount(), "失败的任务应该被执行一次")
assert.Equal(t, int32(1), successTask.ExecutedCount(), "成功的任务也应该被执行")
}
// TestExecutor_StopWithPendingTasks 测试停止执行器时仍有待处理任务 (已重构,更可靠)
func TestExecutor_StopWithPendingTasks(t *testing.T) {
executor := task.NewExecutor(1, testLogger)
task1Started := make(chan struct{})
task1 := &MockTask{
id: "task1",
priority: 1,
execute: func() error {
close(task1Started) // 发送信号,通知测试 task1 已开始执行
time.Sleep(200 * time.Millisecond) // 模拟耗时操作
return nil
},
}
task2 := &MockTask{id: "task2", priority: 2}
executor.Start()
executor.SubmitTask(task1)
executor.SubmitTask(task2)
// 等待 task1 开始执行的信号,而不是依赖不确定的 sleep
select {
case <-task1Started:
// task1 已开始,可以安全地停止执行器了
case <-time.After(1 * time.Second):
t.Fatal("等待 task1 启动超时")
}
executor.Stop()
assert.Equal(t, int32(1), task1.ExecutedCount(), "task1 应该在停止前开始执行")
assert.Equal(t, int32(0), task2.ExecutedCount(), "task2 不应该被执行,因为执行器已停止")
}