重写task调度器之简单实现
This commit is contained in:
		| @@ -1,223 +1,188 @@ | ||||
| package task | ||||
|  | ||||
| import ( | ||||
| 	"container/heap" | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs" | ||||
| 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/models" | ||||
| 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/repository" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // Task 代表一个任务接口 | ||||
| // 所有任务都需要实现此接口 | ||||
| type Task interface { | ||||
| 	// Execute 执行任务 | ||||
| 	Execute() error | ||||
|  | ||||
| 	// GetID 获取任务ID | ||||
| 	GetID() string | ||||
|  | ||||
| 	// GetPriority 获取任务优先级 | ||||
| 	GetPriority() int | ||||
|  | ||||
| 	// IsDone 检查任务是否已完成 | ||||
| 	IsDone() bool | ||||
|  | ||||
| 	// GetDescription 获取任务说明 | ||||
| 	GetDescription() string | ||||
| // Logger 定义了调度器期望的日志接口,方便替换为项目中的日志组件 | ||||
| type Logger interface { | ||||
| 	Printf(format string, v ...interface{}) | ||||
| } | ||||
|  | ||||
| // taskItem 任务队列中的元素 | ||||
| type taskItem struct { | ||||
| 	task     Task | ||||
| 	priority int | ||||
| 	index    int | ||||
| // ProgressTracker 在内存中跟踪正在运行的计划的完成进度 | ||||
| type ProgressTracker struct { | ||||
| 	mu             sync.Mutex | ||||
| 	totalTasks     map[uint]int // key: planExecutionLogID, value: total tasks | ||||
| 	completedTasks map[uint]int // key: planExecutionLogID, value: completed tasks | ||||
| } | ||||
|  | ||||
| // Queue 代表任务队列 | ||||
| type Queue struct { | ||||
| 	// queue 任务队列(按优先级排序) | ||||
| 	queue *priorityQueue | ||||
|  | ||||
| 	// mutex 互斥锁 | ||||
| 	mutex sync.Mutex | ||||
|  | ||||
| 	// logger 日志记录器 | ||||
| 	logger *logs.Logger | ||||
| } | ||||
|  | ||||
| // NewQueue 创建并返回一个新的任务队列实例。 | ||||
| func NewQueue(logger *logs.Logger) *Queue { | ||||
| 	pq := make(priorityQueue, 0) | ||||
| 	heap.Init(&pq) | ||||
|  | ||||
| 	return &Queue{ | ||||
| 		queue:  &pq, | ||||
| 		logger: logger, | ||||
| func NewProgressTracker() *ProgressTracker { | ||||
| 	return &ProgressTracker{ | ||||
| 		totalTasks:     make(map[uint]int), | ||||
| 		completedTasks: make(map[uint]int), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // AddTask 向队列中添加任务 | ||||
| func (q *Queue) AddTask(task Task) { | ||||
| 	q.mutex.Lock() | ||||
| 	defer q.mutex.Unlock() | ||||
|  | ||||
| 	item := &taskItem{ | ||||
| 		task:     task, | ||||
| 		priority: task.GetPriority(), | ||||
| 	} | ||||
| 	heap.Push(q.queue, item) | ||||
| 	q.logger.Infow("任务已添加到队列", "任务ID", task.GetID(), "任务描述", task.GetDescription()) | ||||
| // StartTracking 开始跟踪一个新的计划执行 | ||||
| func (t *ProgressTracker) StartTracking(planLogID uint, total int) { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	t.totalTasks[planLogID] = total | ||||
| 	t.completedTasks[planLogID] = 0 | ||||
| } | ||||
|  | ||||
| // GetNextTask 获取下一个要执行的任务(优先级最高的任务) | ||||
| func (q *Queue) GetNextTask() Task { | ||||
| 	q.mutex.Lock() | ||||
| 	defer q.mutex.Unlock() | ||||
|  | ||||
| 	if q.queue.Len() == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	item := heap.Pop(q.queue).(*taskItem) | ||||
| 	q.logger.Infow("从队列中获取任务", "任务ID", item.task.GetID(), "任务描述", item.task.GetDescription()) | ||||
| 	return item.task | ||||
| // Increment 将指定计划的完成计数加一 | ||||
| func (t *ProgressTracker) Increment(planLogID uint) { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	t.completedTasks[planLogID]++ | ||||
| } | ||||
|  | ||||
| // GetTaskCount 获取队列中的任务数量 | ||||
| func (q *Queue) GetTaskCount() int { | ||||
| 	q.mutex.Lock() | ||||
| 	defer q.mutex.Unlock() | ||||
|  | ||||
| 	return q.queue.Len() | ||||
| // IsComplete 检查指定计划是否已完成所有任务 | ||||
| func (t *ProgressTracker) IsComplete(planLogID uint) bool { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	return t.completedTasks[planLogID] >= t.totalTasks[planLogID] | ||||
| } | ||||
|  | ||||
| // priorityQueue 实现优先级队列 | ||||
| type priorityQueue []*taskItem | ||||
|  | ||||
| func (pq priorityQueue) Len() int { return len(pq) } | ||||
|  | ||||
| func (pq priorityQueue) Less(i, j int) bool { | ||||
| 	return pq[i].priority < pq[j].priority | ||||
| // StopTracking 停止跟踪一个计划,清理内存 | ||||
| func (t *ProgressTracker) StopTracking(planLogID uint) { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	delete(t.totalTasks, planLogID) | ||||
| 	delete(t.completedTasks, planLogID) | ||||
| } | ||||
|  | ||||
| func (pq priorityQueue) Swap(i, j int) { | ||||
| 	pq[i], pq[j] = pq[j], pq[i] | ||||
| 	pq[i].index = i | ||||
| 	pq[j].index = j | ||||
| // Scheduler 是核心的、持久化的任务调度器 | ||||
| type Scheduler struct { | ||||
| 	logger          Logger | ||||
| 	pollingInterval time.Duration | ||||
| 	workers         int | ||||
| 	pendingTaskRepo repository.PendingTaskRepository | ||||
| 	progressTracker *ProgressTracker | ||||
|  | ||||
| 	taskChannel chan *models.TaskExecutionLog // 用于向 workers 派发任务的缓冲 channel | ||||
| 	wg          sync.WaitGroup | ||||
| 	ctx         context.Context | ||||
| 	cancel      context.CancelFunc | ||||
| } | ||||
|  | ||||
| func (pq *priorityQueue) Push(x interface{}) { | ||||
| 	n := len(*pq) | ||||
| 	item := x.(*taskItem) | ||||
| 	item.index = n | ||||
| 	*pq = append(*pq, item) | ||||
| } | ||||
|  | ||||
| func (pq *priorityQueue) Pop() interface{} { | ||||
| 	old := *pq | ||||
| 	n := len(old) | ||||
| 	item := old[n-1] | ||||
| 	old[n-1] = nil  // 避免内存泄漏 | ||||
| 	item.index = -1 // 无效索引 | ||||
| 	*pq = old[0 : n-1] | ||||
| 	return item | ||||
| } | ||||
|  | ||||
| // Executor 代表任务执行器 | ||||
| type Executor struct { | ||||
| 	// queue 任务队列 | ||||
| 	queue *Queue | ||||
|  | ||||
| 	// workers 工作协程数量 | ||||
| 	workers int | ||||
|  | ||||
| 	// ctx 执行上下文 | ||||
| 	ctx context.Context | ||||
|  | ||||
| 	// cancel 取消函数 | ||||
| 	cancel context.CancelFunc | ||||
|  | ||||
| 	// wg 等待组 | ||||
| 	wg sync.WaitGroup | ||||
|  | ||||
| 	// logger 日志记录器 | ||||
| 	logger *logs.Logger | ||||
| } | ||||
|  | ||||
| // NewExecutor 创建并返回一个新的任务执行器实例。 | ||||
| func NewExecutor(workers int, logger *logs.Logger) *Executor { | ||||
| // NewScheduler 创建一个新的调度器实例 | ||||
| func NewScheduler(pendingTaskRepo repository.PendingTaskRepository, logger Logger, interval time.Duration, numWorkers int) *Scheduler { | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
|  | ||||
| 	return &Executor{ | ||||
| 		queue:   NewQueue(logger), // 将 logger 传递给 Queue | ||||
| 		workers: workers, | ||||
| 		ctx:     ctx, | ||||
| 		cancel:  cancel, | ||||
| 		logger:  logger, | ||||
| 	return &Scheduler{ | ||||
| 		pendingTaskRepo: pendingTaskRepo, | ||||
| 		logger:          logger, | ||||
| 		pollingInterval: interval, | ||||
| 		workers:         numWorkers, | ||||
| 		progressTracker: NewProgressTracker(), | ||||
| 		taskChannel:     make(chan *models.TaskExecutionLog, numWorkers), // 缓冲大小与 worker 数量一致 | ||||
| 		ctx:             ctx, | ||||
| 		cancel:          cancel, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Start 启动任务执行器 | ||||
| func (e *Executor) Start() { | ||||
| 	e.logger.Infow("正在启动任务执行器", "工作协程数", e.workers) | ||||
| // Start 启动调度器,包括主轮询循环和所有工作协程 | ||||
| func (s *Scheduler) Start() { | ||||
| 	s.logger.Printf("任务调度器正在启动,工作协程数: %d...", s.workers) | ||||
|  | ||||
| 	// 启动工作协程 | ||||
| 	for i := 0; i < e.workers; i++ { | ||||
| 		e.wg.Add(1) | ||||
| 		go e.worker(i) | ||||
| 	// 启动工作协程池 | ||||
| 	s.wg.Add(s.workers) | ||||
| 	for i := 0; i < s.workers; i++ { | ||||
| 		go s.worker(i) | ||||
| 	} | ||||
|  | ||||
| 	e.logger.Info("任务执行器启动成功") | ||||
| 	// 启动主轮询循环 | ||||
| 	s.wg.Add(1) | ||||
| 	go s.run() | ||||
|  | ||||
| 	s.logger.Printf("任务调度器已成功启动") | ||||
| } | ||||
|  | ||||
| // Stop 停止任务执行器 | ||||
| func (e *Executor) Stop() { | ||||
| 	e.logger.Info("正在停止任务执行器") | ||||
|  | ||||
| 	// 取消上下文 | ||||
| 	e.cancel() | ||||
|  | ||||
| 	// 等待所有工作协程结束 | ||||
| 	e.wg.Wait() | ||||
|  | ||||
| 	e.logger.Info("任务执行器已停止") | ||||
| // Stop 优雅地停止调度器和所有工作协程 | ||||
| func (s *Scheduler) Stop() { | ||||
| 	s.logger.Printf("正在停止任务调度器...") | ||||
| 	s.cancel()  // 发出取消信号 | ||||
| 	s.wg.Wait() // 等待所有协程完成 | ||||
| 	s.logger.Printf("任务调度器已安全停止") | ||||
| } | ||||
|  | ||||
| // SubmitTask 提交任务到执行器 | ||||
| func (e *Executor) SubmitTask(task Task) { | ||||
| 	e.queue.AddTask(task) | ||||
| 	e.logger.Infow("任务已提交", "任务ID", task.GetID(), "任务描述", task.GetDescription()) | ||||
| } | ||||
|  | ||||
| // worker 工作协程 | ||||
| func (e *Executor) worker(id int) { | ||||
| 	defer e.wg.Done() | ||||
|  | ||||
| 	e.logger.Infow("工作协程已启动", "工作协程ID", id) | ||||
| // run 是主轮询循环,负责从数据库认领任务并派发 | ||||
| func (s *Scheduler) run() { | ||||
| 	defer s.wg.Done() | ||||
| 	ticker := time.NewTicker(s.pollingInterval) | ||||
| 	defer ticker.Stop() | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-e.ctx.Done(): | ||||
| 			e.logger.Infow("工作协程已停止", "工作协程ID", id) | ||||
| 		case <-s.ctx.Done(): | ||||
| 			close(s.taskChannel) // 关闭 channel,让 workers 退出循环 | ||||
| 			return | ||||
| 		default: | ||||
| 			// 获取下一个任务 | ||||
| 			task := e.queue.GetNextTask() | ||||
| 			if task != nil { | ||||
| 				e.logger.Infow("工作协程正在执行任务", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription()) | ||||
|  | ||||
| 				// 执行任务 | ||||
| 				if err := task.Execute(); err != nil { | ||||
| 					e.logger.Errorw("任务执行失败", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription(), "错误", err) | ||||
| 				} else { | ||||
| 					e.logger.Infow("任务执行成功", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription()) | ||||
| 				} | ||||
| 			} else { | ||||
| 				// 没有任务时短暂休眠 | ||||
| 				time.Sleep(100 * time.Millisecond) | ||||
| 			} | ||||
| 		case <-ticker.C: | ||||
| 			s.claimAndDispatch() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // claimAndDispatch 认领一个任务并将其发送到派发通道 | ||||
| func (s *Scheduler) claimAndDispatch() { | ||||
| 	claimedLog, err := s.pendingTaskRepo.ClaimNextDueTask() | ||||
| 	if err != nil { | ||||
| 		if !errors.Is(err, gorm.ErrRecordNotFound) { | ||||
| 			s.logger.Printf("认领任务时发生错误: %v", err) | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 将认领到的任务发送到派发通道 | ||||
| 	// 如果所有 worker 都在忙,这里会阻塞,从而实现背压,防止队列无限增长 | ||||
| 	select { | ||||
| 	case s.taskChannel <- claimedLog: | ||||
| 		s.logger.Printf("成功认领并派发任务, 日志ID: %d, 任务ID: %d", claimedLog.ID, claimedLog.TaskID) | ||||
| 	case <-s.ctx.Done(): | ||||
| 		// 如果在等待派发时调度器被停止,需要处理这个未派发的任务 | ||||
| 		// 简单的处理方式是忽略它,让清理器进程后续来处理这个 'running' 状态的任务 | ||||
| 		s.logger.Printf("在派发任务时调度器被停止, 日志ID: %d", claimedLog.ID) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // worker 是工作协程的实现 | ||||
| func (s *Scheduler) worker(id int) { | ||||
| 	defer s.wg.Done() | ||||
| 	s.logger.Printf("工作协程 #%d 已启动", id) | ||||
| 	for claimedLog := range s.taskChannel { | ||||
| 		s.processTask(id, claimedLog) | ||||
| 	} | ||||
| 	s.logger.Printf("工作协程 #%d 已停止", id) | ||||
| } | ||||
|  | ||||
| // processTask 包含了处理单个任务的完整逻辑 | ||||
| func (s *Scheduler) processTask(workerID int, claimedLog *models.TaskExecutionLog) { | ||||
| 	s.logger.Printf("工作协程 #%d 正在处理任务, 日志ID: %d, 任务ID: %d, 任务名称: %s", | ||||
| 		workerID, claimedLog.ID, claimedLog.TaskID, claimedLog.Task.Name) | ||||
|  | ||||
| 	// 在这里,我们将根据 claimedLog.TaskID 或未来的 Task.Kind 来分发给不同的处理器 | ||||
| 	// 现在,我们只做一个模拟执行 | ||||
| 	time.Sleep(2 * time.Second) // 模拟任务执行耗时 | ||||
|  | ||||
| 	// 任务执行完毕后,更新日志和进度 | ||||
| 	s.logger.Printf("工作协程 #%d 已完成任务, 日志ID: %d", workerID, claimedLog.ID) | ||||
|  | ||||
| 	// ---------------------------------------------------- | ||||
| 	// 未来的逻辑将在这里展开: | ||||
| 	// | ||||
| 	// 1. 调用 handler.Handle(claimedLog) | ||||
| 	// 2. 根据 handler 返回的 error 更新日志为 'completed' 或 'failed' | ||||
| 	//    execLogRepo.UpdateTaskExecutionLog(...) | ||||
| 	// 3. 如果成功,则 s.progressTracker.Increment(claimedLog.PlanExecutionLogID) | ||||
| 	// 4. 检查 s.progressTracker.IsComplete(...),如果完成则执行计划收尾工作 | ||||
| 	// | ||||
| 	// ---------------------------------------------------- | ||||
| } | ||||
|   | ||||
| @@ -2,8 +2,6 @@ | ||||
| package task_test | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| @@ -11,8 +9,6 @@ import ( | ||||
|  | ||||
| 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/config" | ||||
| 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs" | ||||
| 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/task" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| // testLogger 是一个用于所有测试用例的静默 logger 实例。 | ||||
| @@ -79,230 +75,3 @@ func waitForWaitGroup(t *testing.T, wg *sync.WaitGroup, timeout time.Duration) { | ||||
| 		t.Fatal("等待任务完成超时") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // --- 任务队列测试 (无需更改) --- | ||||
|  | ||||
| func TestNewQueue(t *testing.T) { | ||||
| 	tq := task.NewQueue(testLogger) | ||||
| 	assert.NotNil(t, tq, "新创建的任务队列不应为 nil") | ||||
| 	assert.Equal(t, 0, tq.GetTaskCount(), "新创建的任务队列应为空") | ||||
| } | ||||
|  | ||||
| func TestQueue_AddTask(t *testing.T) { | ||||
| 	tq := task.NewQueue(testLogger) | ||||
| 	mockTask := &MockTask{id: "task1", priority: 1} | ||||
|  | ||||
| 	tq.AddTask(mockTask) | ||||
| 	assert.Equal(t, 1, tq.GetTaskCount(), "添加任务后,队列中的任务数应为 1") | ||||
| } | ||||
|  | ||||
| // ... (其他任务队列测试保持不变) | ||||
| func TestQueue_GetNextTask(t *testing.T) { | ||||
| 	t.Run("从空队列获取任务", func(t *testing.T) { | ||||
| 		tq := task.NewQueue(testLogger) | ||||
| 		nextTask := tq.GetNextTask() | ||||
| 		assert.Nil(t, nextTask, "从空队列中获取任务应返回 nil") | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("按优先级获取任务", func(t *testing.T) { | ||||
| 		tq := task.NewQueue(testLogger) | ||||
| 		task1 := &MockTask{id: "task1", priority: 10} | ||||
| 		task2 := &MockTask{id: "task2", priority: 1} // 优先级更高 | ||||
| 		task3 := &MockTask{id: "task3", priority: 5} | ||||
|  | ||||
| 		tq.AddTask(task1) | ||||
| 		tq.AddTask(task2) | ||||
| 		tq.AddTask(task3) | ||||
|  | ||||
| 		assert.Equal(t, 3, tq.GetTaskCount(), "添加三个任务后,队列中的任务数应为 3") | ||||
|  | ||||
| 		nextTask := tq.GetNextTask() | ||||
| 		assert.NotNil(t, nextTask) | ||||
| 		assert.Equal(t, "task2", nextTask.GetID(), "应首先获取优先级最高的任务 (task2)") | ||||
|  | ||||
| 		nextTask = tq.GetNextTask() | ||||
| 		assert.NotNil(t, nextTask) | ||||
| 		assert.Equal(t, "task3", nextTask.GetID(), "应获取下一个优先级最高的任务 (task3)") | ||||
|  | ||||
| 		nextTask = tq.GetNextTask() | ||||
| 		assert.NotNil(t, nextTask) | ||||
| 		assert.Equal(t, "task1", nextTask.GetID(), "应最后获取优先级最低的任务 (task1)") | ||||
|  | ||||
| 		assert.Equal(t, 0, tq.GetTaskCount(), "获取所有任务后,队列应为空") | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestQueue_Concurrency(t *testing.T) { | ||||
| 	tq := task.NewQueue(testLogger) | ||||
| 	var wg sync.WaitGroup | ||||
| 	taskCount := 100 | ||||
|  | ||||
| 	wg.Add(taskCount) | ||||
| 	for i := 0; i < taskCount; i++ { | ||||
| 		go func(i int) { | ||||
| 			defer wg.Done() | ||||
| 			tq.AddTask(&MockTask{id: fmt.Sprintf("task-%d", i), priority: i}) | ||||
| 		}(i) | ||||
| 	} | ||||
| 	wg.Wait() | ||||
|  | ||||
| 	assert.Equal(t, taskCount, tq.GetTaskCount(), "并发添加任务后,队列中的任务数应为 %d", taskCount) | ||||
|  | ||||
| 	wg.Add(taskCount) | ||||
| 	for i := 0; i < taskCount; i++ { | ||||
| 		go func() { | ||||
| 			defer wg.Done() | ||||
| 			task := tq.GetNextTask() | ||||
| 			assert.NotNil(t, task) | ||||
| 		}() | ||||
| 	} | ||||
| 	wg.Wait() | ||||
|  | ||||
| 	assert.Equal(t, 0, tq.GetTaskCount(), "并发获取所有任务后,队列应为空") | ||||
| } | ||||
|  | ||||
| // --- 执行器测试 (为可靠性重构) --- | ||||
|  | ||||
| func TestNewExecutor(t *testing.T) { | ||||
| 	executor := task.NewExecutor(5, testLogger) | ||||
| 	assert.NotNil(t, executor, "新创建的执行器不应为 nil") | ||||
| } | ||||
|  | ||||
| func TestExecutor_StartStop(t *testing.T) { | ||||
| 	executor := task.NewExecutor(2, testLogger) | ||||
| 	executor.Start() | ||||
| 	// 确保立即停止不会导致死锁或竞争条件。 | ||||
| 	executor.Stop() | ||||
| } | ||||
|  | ||||
| // TestExecutor_SubmitAndExecuteTask 测试提交并执行单个任务 (已重构,更可靠) | ||||
| func TestExecutor_SubmitAndExecuteTask(t *testing.T) { | ||||
| 	var wg sync.WaitGroup | ||||
| 	wg.Add(1) | ||||
|  | ||||
| 	executor := task.NewExecutor(1, testLogger) | ||||
| 	mockTask := &MockTask{ | ||||
| 		id:       "task1", | ||||
| 		priority: 1, | ||||
| 		execute: func() error { | ||||
| 			wg.Done() // 任务完成时通知 WaitGroup | ||||
| 			return nil | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	executor.Start() | ||||
| 	executor.SubmitTask(mockTask) | ||||
|  | ||||
| 	// 等待任务完成,设置一个合理的超时时间 | ||||
| 	waitForWaitGroup(t, &wg, 2*time.Second) | ||||
|  | ||||
| 	executor.Stop() | ||||
|  | ||||
| 	assert.Equal(t, int32(1), mockTask.ExecutedCount(), "任务应该已被执行") | ||||
| } | ||||
|  | ||||
| // TestExecutor_ExecuteMultipleTasks 测试执行多个任务 (已重构,更可靠) | ||||
| func TestExecutor_ExecuteMultipleTasks(t *testing.T) { | ||||
| 	taskCount := 10 | ||||
| 	var wg sync.WaitGroup | ||||
| 	wg.Add(taskCount) | ||||
|  | ||||
| 	executor := task.NewExecutor(3, testLogger) | ||||
| 	mockTasks := make([]*MockTask, taskCount) | ||||
| 	for i := 0; i < taskCount; i++ { | ||||
| 		mockTasks[i] = &MockTask{ | ||||
| 			id:       fmt.Sprintf("task-%d", i), | ||||
| 			priority: i, | ||||
| 			execute: func() error { | ||||
| 				wg.Done() // 每个任务完成时都通知 WaitGroup | ||||
| 				return nil | ||||
| 			}, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	executor.Start() | ||||
| 	for _, task := range mockTasks { | ||||
| 		executor.SubmitTask(task) | ||||
| 	} | ||||
|  | ||||
| 	// 等待所有任务完成 | ||||
| 	waitForWaitGroup(t, &wg, 2*time.Second) | ||||
|  | ||||
| 	executor.Stop() | ||||
|  | ||||
| 	var totalExecuted int32 | ||||
| 	for _, task := range mockTasks { | ||||
| 		totalExecuted += task.ExecutedCount() | ||||
| 	} | ||||
|  | ||||
| 	assert.Equal(t, int32(taskCount), totalExecuted, "所有提交的任务都应该被执行") | ||||
| } | ||||
|  | ||||
| // TestExecutor_TaskExecutionError 测试任务执行失败的场景 (已重构,更可靠) | ||||
| func TestExecutor_TaskExecutionError(t *testing.T) { | ||||
| 	var wg sync.WaitGroup | ||||
| 	wg.Add(2) // 我们期望两个任务都被执行 | ||||
|  | ||||
| 	executor := task.NewExecutor(1, testLogger) | ||||
| 	errorTask := &MockTask{ | ||||
| 		id:       "errorTask", | ||||
| 		priority: 1, | ||||
| 		execute: func() error { | ||||
| 			wg.Done() | ||||
| 			return errors.New("执行失败") | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	successTask := &MockTask{ | ||||
| 		id:       "successTask", | ||||
| 		priority: 2, // 后执行 | ||||
| 		execute: func() error { | ||||
| 			wg.Done() | ||||
| 			return nil | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	executor.Start() | ||||
| 	executor.SubmitTask(errorTask) | ||||
| 	executor.SubmitTask(successTask) | ||||
|  | ||||
| 	waitForWaitGroup(t, &wg, 2*time.Second) | ||||
| 	executor.Stop() | ||||
|  | ||||
| 	assert.Equal(t, int32(1), errorTask.ExecutedCount(), "失败的任务应该被执行一次") | ||||
| 	assert.Equal(t, int32(1), successTask.ExecutedCount(), "成功的任务也应该被执行") | ||||
| } | ||||
|  | ||||
| // TestExecutor_StopWithPendingTasks 测试停止执行器时仍有待处理任务 (已重构,更可靠) | ||||
| func TestExecutor_StopWithPendingTasks(t *testing.T) { | ||||
| 	executor := task.NewExecutor(1, testLogger) | ||||
| 	task1Started := make(chan struct{}) | ||||
|  | ||||
| 	task1 := &MockTask{ | ||||
| 		id:       "task1", | ||||
| 		priority: 1, | ||||
| 		execute: func() error { | ||||
| 			close(task1Started)                // 发送信号,通知测试 task1 已开始执行 | ||||
| 			time.Sleep(200 * time.Millisecond) // 模拟耗时操作 | ||||
| 			return nil | ||||
| 		}, | ||||
| 	} | ||||
| 	task2 := &MockTask{id: "task2", priority: 2} | ||||
|  | ||||
| 	executor.Start() | ||||
| 	executor.SubmitTask(task1) | ||||
| 	executor.SubmitTask(task2) | ||||
|  | ||||
| 	// 等待 task1 开始执行的信号,而不是依赖不确定的 sleep | ||||
| 	select { | ||||
| 	case <-task1Started: | ||||
| 		// task1 已开始,可以安全地停止执行器了 | ||||
| 	case <-time.After(1 * time.Second): | ||||
| 		t.Fatal("等待 task1 启动超时") | ||||
| 	} | ||||
|  | ||||
| 	executor.Stop() | ||||
|  | ||||
| 	assert.Equal(t, int32(1), task1.ExecutedCount(), "task1 应该在停止前开始执行") | ||||
| 	assert.Equal(t, int32(0), task2.ExecutedCount(), "task2 不应该被执行,因为执行器已停止") | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user