重写task调度器之简单实现
This commit is contained in:
@@ -1,223 +1,188 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs"
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/models"
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/repository"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Task 代表一个任务接口
|
||||
// 所有任务都需要实现此接口
|
||||
type Task interface {
|
||||
// Execute 执行任务
|
||||
Execute() error
|
||||
|
||||
// GetID 获取任务ID
|
||||
GetID() string
|
||||
|
||||
// GetPriority 获取任务优先级
|
||||
GetPriority() int
|
||||
|
||||
// IsDone 检查任务是否已完成
|
||||
IsDone() bool
|
||||
|
||||
// GetDescription 获取任务说明
|
||||
GetDescription() string
|
||||
// Logger 定义了调度器期望的日志接口,方便替换为项目中的日志组件
|
||||
type Logger interface {
|
||||
Printf(format string, v ...interface{})
|
||||
}
|
||||
|
||||
// taskItem 任务队列中的元素
|
||||
type taskItem struct {
|
||||
task Task
|
||||
priority int
|
||||
index int
|
||||
// ProgressTracker 在内存中跟踪正在运行的计划的完成进度
|
||||
type ProgressTracker struct {
|
||||
mu sync.Mutex
|
||||
totalTasks map[uint]int // key: planExecutionLogID, value: total tasks
|
||||
completedTasks map[uint]int // key: planExecutionLogID, value: completed tasks
|
||||
}
|
||||
|
||||
// Queue 代表任务队列
|
||||
type Queue struct {
|
||||
// queue 任务队列(按优先级排序)
|
||||
queue *priorityQueue
|
||||
|
||||
// mutex 互斥锁
|
||||
mutex sync.Mutex
|
||||
|
||||
// logger 日志记录器
|
||||
logger *logs.Logger
|
||||
}
|
||||
|
||||
// NewQueue 创建并返回一个新的任务队列实例。
|
||||
func NewQueue(logger *logs.Logger) *Queue {
|
||||
pq := make(priorityQueue, 0)
|
||||
heap.Init(&pq)
|
||||
|
||||
return &Queue{
|
||||
queue: &pq,
|
||||
logger: logger,
|
||||
func NewProgressTracker() *ProgressTracker {
|
||||
return &ProgressTracker{
|
||||
totalTasks: make(map[uint]int),
|
||||
completedTasks: make(map[uint]int),
|
||||
}
|
||||
}
|
||||
|
||||
// AddTask 向队列中添加任务
|
||||
func (q *Queue) AddTask(task Task) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
item := &taskItem{
|
||||
task: task,
|
||||
priority: task.GetPriority(),
|
||||
}
|
||||
heap.Push(q.queue, item)
|
||||
q.logger.Infow("任务已添加到队列", "任务ID", task.GetID(), "任务描述", task.GetDescription())
|
||||
// StartTracking 开始跟踪一个新的计划执行
|
||||
func (t *ProgressTracker) StartTracking(planLogID uint, total int) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.totalTasks[planLogID] = total
|
||||
t.completedTasks[planLogID] = 0
|
||||
}
|
||||
|
||||
// GetNextTask 获取下一个要执行的任务(优先级最高的任务)
|
||||
func (q *Queue) GetNextTask() Task {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
if q.queue.Len() == 0 {
|
||||
return nil
|
||||
// Increment 将指定计划的完成计数加一
|
||||
func (t *ProgressTracker) Increment(planLogID uint) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.completedTasks[planLogID]++
|
||||
}
|
||||
|
||||
item := heap.Pop(q.queue).(*taskItem)
|
||||
q.logger.Infow("从队列中获取任务", "任务ID", item.task.GetID(), "任务描述", item.task.GetDescription())
|
||||
return item.task
|
||||
// IsComplete 检查指定计划是否已完成所有任务
|
||||
func (t *ProgressTracker) IsComplete(planLogID uint) bool {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.completedTasks[planLogID] >= t.totalTasks[planLogID]
|
||||
}
|
||||
|
||||
// GetTaskCount 获取队列中的任务数量
|
||||
func (q *Queue) GetTaskCount() int {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
return q.queue.Len()
|
||||
// StopTracking 停止跟踪一个计划,清理内存
|
||||
func (t *ProgressTracker) StopTracking(planLogID uint) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
delete(t.totalTasks, planLogID)
|
||||
delete(t.completedTasks, planLogID)
|
||||
}
|
||||
|
||||
// priorityQueue 实现优先级队列
|
||||
type priorityQueue []*taskItem
|
||||
|
||||
func (pq priorityQueue) Len() int { return len(pq) }
|
||||
|
||||
func (pq priorityQueue) Less(i, j int) bool {
|
||||
return pq[i].priority < pq[j].priority
|
||||
}
|
||||
|
||||
func (pq priorityQueue) Swap(i, j int) {
|
||||
pq[i], pq[j] = pq[j], pq[i]
|
||||
pq[i].index = i
|
||||
pq[j].index = j
|
||||
}
|
||||
|
||||
func (pq *priorityQueue) Push(x interface{}) {
|
||||
n := len(*pq)
|
||||
item := x.(*taskItem)
|
||||
item.index = n
|
||||
*pq = append(*pq, item)
|
||||
}
|
||||
|
||||
func (pq *priorityQueue) Pop() interface{} {
|
||||
old := *pq
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
old[n-1] = nil // 避免内存泄漏
|
||||
item.index = -1 // 无效索引
|
||||
*pq = old[0 : n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
// Executor 代表任务执行器
|
||||
type Executor struct {
|
||||
// queue 任务队列
|
||||
queue *Queue
|
||||
|
||||
// workers 工作协程数量
|
||||
// Scheduler 是核心的、持久化的任务调度器
|
||||
type Scheduler struct {
|
||||
logger Logger
|
||||
pollingInterval time.Duration
|
||||
workers int
|
||||
pendingTaskRepo repository.PendingTaskRepository
|
||||
progressTracker *ProgressTracker
|
||||
|
||||
// ctx 执行上下文
|
||||
ctx context.Context
|
||||
|
||||
// cancel 取消函数
|
||||
cancel context.CancelFunc
|
||||
|
||||
// wg 等待组
|
||||
taskChannel chan *models.TaskExecutionLog // 用于向 workers 派发任务的缓冲 channel
|
||||
wg sync.WaitGroup
|
||||
|
||||
// logger 日志记录器
|
||||
logger *logs.Logger
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewExecutor 创建并返回一个新的任务执行器实例。
|
||||
func NewExecutor(workers int, logger *logs.Logger) *Executor {
|
||||
// NewScheduler 创建一个新的调度器实例
|
||||
func NewScheduler(pendingTaskRepo repository.PendingTaskRepository, logger Logger, interval time.Duration, numWorkers int) *Scheduler {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &Executor{
|
||||
queue: NewQueue(logger), // 将 logger 传递给 Queue
|
||||
workers: workers,
|
||||
return &Scheduler{
|
||||
pendingTaskRepo: pendingTaskRepo,
|
||||
logger: logger,
|
||||
pollingInterval: interval,
|
||||
workers: numWorkers,
|
||||
progressTracker: NewProgressTracker(),
|
||||
taskChannel: make(chan *models.TaskExecutionLog, numWorkers), // 缓冲大小与 worker 数量一致
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动任务执行器
|
||||
func (e *Executor) Start() {
|
||||
e.logger.Infow("正在启动任务执行器", "工作协程数", e.workers)
|
||||
// Start 启动调度器,包括主轮询循环和所有工作协程
|
||||
func (s *Scheduler) Start() {
|
||||
s.logger.Printf("任务调度器正在启动,工作协程数: %d...", s.workers)
|
||||
|
||||
// 启动工作协程
|
||||
for i := 0; i < e.workers; i++ {
|
||||
e.wg.Add(1)
|
||||
go e.worker(i)
|
||||
// 启动工作协程池
|
||||
s.wg.Add(s.workers)
|
||||
for i := 0; i < s.workers; i++ {
|
||||
go s.worker(i)
|
||||
}
|
||||
|
||||
e.logger.Info("任务执行器启动成功")
|
||||
// 启动主轮询循环
|
||||
s.wg.Add(1)
|
||||
go s.run()
|
||||
|
||||
s.logger.Printf("任务调度器已成功启动")
|
||||
}
|
||||
|
||||
// Stop 停止任务执行器
|
||||
func (e *Executor) Stop() {
|
||||
e.logger.Info("正在停止任务执行器")
|
||||
|
||||
// 取消上下文
|
||||
e.cancel()
|
||||
|
||||
// 等待所有工作协程结束
|
||||
e.wg.Wait()
|
||||
|
||||
e.logger.Info("任务执行器已停止")
|
||||
// Stop 优雅地停止调度器和所有工作协程
|
||||
func (s *Scheduler) Stop() {
|
||||
s.logger.Printf("正在停止任务调度器...")
|
||||
s.cancel() // 发出取消信号
|
||||
s.wg.Wait() // 等待所有协程完成
|
||||
s.logger.Printf("任务调度器已安全停止")
|
||||
}
|
||||
|
||||
// SubmitTask 提交任务到执行器
|
||||
func (e *Executor) SubmitTask(task Task) {
|
||||
e.queue.AddTask(task)
|
||||
e.logger.Infow("任务已提交", "任务ID", task.GetID(), "任务描述", task.GetDescription())
|
||||
}
|
||||
|
||||
// worker 工作协程
|
||||
func (e *Executor) worker(id int) {
|
||||
defer e.wg.Done()
|
||||
|
||||
e.logger.Infow("工作协程已启动", "工作协程ID", id)
|
||||
// run 是主轮询循环,负责从数据库认领任务并派发
|
||||
func (s *Scheduler) run() {
|
||||
defer s.wg.Done()
|
||||
ticker := time.NewTicker(s.pollingInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-e.ctx.Done():
|
||||
e.logger.Infow("工作协程已停止", "工作协程ID", id)
|
||||
case <-s.ctx.Done():
|
||||
close(s.taskChannel) // 关闭 channel,让 workers 退出循环
|
||||
return
|
||||
default:
|
||||
// 获取下一个任务
|
||||
task := e.queue.GetNextTask()
|
||||
if task != nil {
|
||||
e.logger.Infow("工作协程正在执行任务", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription())
|
||||
case <-ticker.C:
|
||||
s.claimAndDispatch()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 执行任务
|
||||
if err := task.Execute(); err != nil {
|
||||
e.logger.Errorw("任务执行失败", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription(), "错误", err)
|
||||
} else {
|
||||
e.logger.Infow("任务执行成功", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription())
|
||||
// claimAndDispatch 认领一个任务并将其发送到派发通道
|
||||
func (s *Scheduler) claimAndDispatch() {
|
||||
claimedLog, err := s.pendingTaskRepo.ClaimNextDueTask()
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
s.logger.Printf("认领任务时发生错误: %v", err)
|
||||
}
|
||||
} else {
|
||||
// 没有任务时短暂休眠
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return
|
||||
}
|
||||
|
||||
// 将认领到的任务发送到派发通道
|
||||
// 如果所有 worker 都在忙,这里会阻塞,从而实现背压,防止队列无限增长
|
||||
select {
|
||||
case s.taskChannel <- claimedLog:
|
||||
s.logger.Printf("成功认领并派发任务, 日志ID: %d, 任务ID: %d", claimedLog.ID, claimedLog.TaskID)
|
||||
case <-s.ctx.Done():
|
||||
// 如果在等待派发时调度器被停止,需要处理这个未派发的任务
|
||||
// 简单的处理方式是忽略它,让清理器进程后续来处理这个 'running' 状态的任务
|
||||
s.logger.Printf("在派发任务时调度器被停止, 日志ID: %d", claimedLog.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// worker 是工作协程的实现
|
||||
func (s *Scheduler) worker(id int) {
|
||||
defer s.wg.Done()
|
||||
s.logger.Printf("工作协程 #%d 已启动", id)
|
||||
for claimedLog := range s.taskChannel {
|
||||
s.processTask(id, claimedLog)
|
||||
}
|
||||
s.logger.Printf("工作协程 #%d 已停止", id)
|
||||
}
|
||||
|
||||
// processTask 包含了处理单个任务的完整逻辑
|
||||
func (s *Scheduler) processTask(workerID int, claimedLog *models.TaskExecutionLog) {
|
||||
s.logger.Printf("工作协程 #%d 正在处理任务, 日志ID: %d, 任务ID: %d, 任务名称: %s",
|
||||
workerID, claimedLog.ID, claimedLog.TaskID, claimedLog.Task.Name)
|
||||
|
||||
// 在这里,我们将根据 claimedLog.TaskID 或未来的 Task.Kind 来分发给不同的处理器
|
||||
// 现在,我们只做一个模拟执行
|
||||
time.Sleep(2 * time.Second) // 模拟任务执行耗时
|
||||
|
||||
// 任务执行完毕后,更新日志和进度
|
||||
s.logger.Printf("工作协程 #%d 已完成任务, 日志ID: %d", workerID, claimedLog.ID)
|
||||
|
||||
// ----------------------------------------------------
|
||||
// 未来的逻辑将在这里展开:
|
||||
//
|
||||
// 1. 调用 handler.Handle(claimedLog)
|
||||
// 2. 根据 handler 返回的 error 更新日志为 'completed' 或 'failed'
|
||||
// execLogRepo.UpdateTaskExecutionLog(...)
|
||||
// 3. 如果成功,则 s.progressTracker.Increment(claimedLog.PlanExecutionLogID)
|
||||
// 4. 检查 s.progressTracker.IsComplete(...),如果完成则执行计划收尾工作
|
||||
//
|
||||
// ----------------------------------------------------
|
||||
}
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
package task_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -11,8 +9,6 @@ import (
|
||||
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/config"
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs"
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/task"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// testLogger 是一个用于所有测试用例的静默 logger 实例。
|
||||
@@ -79,230 +75,3 @@ func waitForWaitGroup(t *testing.T, wg *sync.WaitGroup, timeout time.Duration) {
|
||||
t.Fatal("等待任务完成超时")
|
||||
}
|
||||
}
|
||||
|
||||
// --- 任务队列测试 (无需更改) ---
|
||||
|
||||
func TestNewQueue(t *testing.T) {
|
||||
tq := task.NewQueue(testLogger)
|
||||
assert.NotNil(t, tq, "新创建的任务队列不应为 nil")
|
||||
assert.Equal(t, 0, tq.GetTaskCount(), "新创建的任务队列应为空")
|
||||
}
|
||||
|
||||
func TestQueue_AddTask(t *testing.T) {
|
||||
tq := task.NewQueue(testLogger)
|
||||
mockTask := &MockTask{id: "task1", priority: 1}
|
||||
|
||||
tq.AddTask(mockTask)
|
||||
assert.Equal(t, 1, tq.GetTaskCount(), "添加任务后,队列中的任务数应为 1")
|
||||
}
|
||||
|
||||
// ... (其他任务队列测试保持不变)
|
||||
func TestQueue_GetNextTask(t *testing.T) {
|
||||
t.Run("从空队列获取任务", func(t *testing.T) {
|
||||
tq := task.NewQueue(testLogger)
|
||||
nextTask := tq.GetNextTask()
|
||||
assert.Nil(t, nextTask, "从空队列中获取任务应返回 nil")
|
||||
})
|
||||
|
||||
t.Run("按优先级获取任务", func(t *testing.T) {
|
||||
tq := task.NewQueue(testLogger)
|
||||
task1 := &MockTask{id: "task1", priority: 10}
|
||||
task2 := &MockTask{id: "task2", priority: 1} // 优先级更高
|
||||
task3 := &MockTask{id: "task3", priority: 5}
|
||||
|
||||
tq.AddTask(task1)
|
||||
tq.AddTask(task2)
|
||||
tq.AddTask(task3)
|
||||
|
||||
assert.Equal(t, 3, tq.GetTaskCount(), "添加三个任务后,队列中的任务数应为 3")
|
||||
|
||||
nextTask := tq.GetNextTask()
|
||||
assert.NotNil(t, nextTask)
|
||||
assert.Equal(t, "task2", nextTask.GetID(), "应首先获取优先级最高的任务 (task2)")
|
||||
|
||||
nextTask = tq.GetNextTask()
|
||||
assert.NotNil(t, nextTask)
|
||||
assert.Equal(t, "task3", nextTask.GetID(), "应获取下一个优先级最高的任务 (task3)")
|
||||
|
||||
nextTask = tq.GetNextTask()
|
||||
assert.NotNil(t, nextTask)
|
||||
assert.Equal(t, "task1", nextTask.GetID(), "应最后获取优先级最低的任务 (task1)")
|
||||
|
||||
assert.Equal(t, 0, tq.GetTaskCount(), "获取所有任务后,队列应为空")
|
||||
})
|
||||
}
|
||||
|
||||
func TestQueue_Concurrency(t *testing.T) {
|
||||
tq := task.NewQueue(testLogger)
|
||||
var wg sync.WaitGroup
|
||||
taskCount := 100
|
||||
|
||||
wg.Add(taskCount)
|
||||
for i := 0; i < taskCount; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
tq.AddTask(&MockTask{id: fmt.Sprintf("task-%d", i), priority: i})
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, taskCount, tq.GetTaskCount(), "并发添加任务后,队列中的任务数应为 %d", taskCount)
|
||||
|
||||
wg.Add(taskCount)
|
||||
for i := 0; i < taskCount; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
task := tq.GetNextTask()
|
||||
assert.NotNil(t, task)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, 0, tq.GetTaskCount(), "并发获取所有任务后,队列应为空")
|
||||
}
|
||||
|
||||
// --- 执行器测试 (为可靠性重构) ---
|
||||
|
||||
func TestNewExecutor(t *testing.T) {
|
||||
executor := task.NewExecutor(5, testLogger)
|
||||
assert.NotNil(t, executor, "新创建的执行器不应为 nil")
|
||||
}
|
||||
|
||||
func TestExecutor_StartStop(t *testing.T) {
|
||||
executor := task.NewExecutor(2, testLogger)
|
||||
executor.Start()
|
||||
// 确保立即停止不会导致死锁或竞争条件。
|
||||
executor.Stop()
|
||||
}
|
||||
|
||||
// TestExecutor_SubmitAndExecuteTask 测试提交并执行单个任务 (已重构,更可靠)
|
||||
func TestExecutor_SubmitAndExecuteTask(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
executor := task.NewExecutor(1, testLogger)
|
||||
mockTask := &MockTask{
|
||||
id: "task1",
|
||||
priority: 1,
|
||||
execute: func() error {
|
||||
wg.Done() // 任务完成时通知 WaitGroup
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
executor.Start()
|
||||
executor.SubmitTask(mockTask)
|
||||
|
||||
// 等待任务完成,设置一个合理的超时时间
|
||||
waitForWaitGroup(t, &wg, 2*time.Second)
|
||||
|
||||
executor.Stop()
|
||||
|
||||
assert.Equal(t, int32(1), mockTask.ExecutedCount(), "任务应该已被执行")
|
||||
}
|
||||
|
||||
// TestExecutor_ExecuteMultipleTasks 测试执行多个任务 (已重构,更可靠)
|
||||
func TestExecutor_ExecuteMultipleTasks(t *testing.T) {
|
||||
taskCount := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(taskCount)
|
||||
|
||||
executor := task.NewExecutor(3, testLogger)
|
||||
mockTasks := make([]*MockTask, taskCount)
|
||||
for i := 0; i < taskCount; i++ {
|
||||
mockTasks[i] = &MockTask{
|
||||
id: fmt.Sprintf("task-%d", i),
|
||||
priority: i,
|
||||
execute: func() error {
|
||||
wg.Done() // 每个任务完成时都通知 WaitGroup
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
executor.Start()
|
||||
for _, task := range mockTasks {
|
||||
executor.SubmitTask(task)
|
||||
}
|
||||
|
||||
// 等待所有任务完成
|
||||
waitForWaitGroup(t, &wg, 2*time.Second)
|
||||
|
||||
executor.Stop()
|
||||
|
||||
var totalExecuted int32
|
||||
for _, task := range mockTasks {
|
||||
totalExecuted += task.ExecutedCount()
|
||||
}
|
||||
|
||||
assert.Equal(t, int32(taskCount), totalExecuted, "所有提交的任务都应该被执行")
|
||||
}
|
||||
|
||||
// TestExecutor_TaskExecutionError 测试任务执行失败的场景 (已重构,更可靠)
|
||||
func TestExecutor_TaskExecutionError(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2) // 我们期望两个任务都被执行
|
||||
|
||||
executor := task.NewExecutor(1, testLogger)
|
||||
errorTask := &MockTask{
|
||||
id: "errorTask",
|
||||
priority: 1,
|
||||
execute: func() error {
|
||||
wg.Done()
|
||||
return errors.New("执行失败")
|
||||
},
|
||||
}
|
||||
|
||||
successTask := &MockTask{
|
||||
id: "successTask",
|
||||
priority: 2, // 后执行
|
||||
execute: func() error {
|
||||
wg.Done()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
executor.Start()
|
||||
executor.SubmitTask(errorTask)
|
||||
executor.SubmitTask(successTask)
|
||||
|
||||
waitForWaitGroup(t, &wg, 2*time.Second)
|
||||
executor.Stop()
|
||||
|
||||
assert.Equal(t, int32(1), errorTask.ExecutedCount(), "失败的任务应该被执行一次")
|
||||
assert.Equal(t, int32(1), successTask.ExecutedCount(), "成功的任务也应该被执行")
|
||||
}
|
||||
|
||||
// TestExecutor_StopWithPendingTasks 测试停止执行器时仍有待处理任务 (已重构,更可靠)
|
||||
func TestExecutor_StopWithPendingTasks(t *testing.T) {
|
||||
executor := task.NewExecutor(1, testLogger)
|
||||
task1Started := make(chan struct{})
|
||||
|
||||
task1 := &MockTask{
|
||||
id: "task1",
|
||||
priority: 1,
|
||||
execute: func() error {
|
||||
close(task1Started) // 发送信号,通知测试 task1 已开始执行
|
||||
time.Sleep(200 * time.Millisecond) // 模拟耗时操作
|
||||
return nil
|
||||
},
|
||||
}
|
||||
task2 := &MockTask{id: "task2", priority: 2}
|
||||
|
||||
executor.Start()
|
||||
executor.SubmitTask(task1)
|
||||
executor.SubmitTask(task2)
|
||||
|
||||
// 等待 task1 开始执行的信号,而不是依赖不确定的 sleep
|
||||
select {
|
||||
case <-task1Started:
|
||||
// task1 已开始,可以安全地停止执行器了
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("等待 task1 启动超时")
|
||||
}
|
||||
|
||||
executor.Stop()
|
||||
|
||||
assert.Equal(t, int32(1), task1.ExecutedCount(), "task1 应该在停止前开始执行")
|
||||
assert.Equal(t, int32(0), task2.ExecutedCount(), "task2 不应该被执行,因为执行器已停止")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user