重构 #4
| @@ -1,223 +1,188 @@ | |||||||
| package task | package task | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"container/heap" |  | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"errors" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs" | 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/models" | ||||||
|  | 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/repository" | ||||||
|  | 	"gorm.io/gorm" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Task 代表一个任务接口 | // Logger 定义了调度器期望的日志接口,方便替换为项目中的日志组件 | ||||||
| // 所有任务都需要实现此接口 | type Logger interface { | ||||||
| type Task interface { | 	Printf(format string, v ...interface{}) | ||||||
| 	// Execute 执行任务 |  | ||||||
| 	Execute() error |  | ||||||
|  |  | ||||||
| 	// GetID 获取任务ID |  | ||||||
| 	GetID() string |  | ||||||
|  |  | ||||||
| 	// GetPriority 获取任务优先级 |  | ||||||
| 	GetPriority() int |  | ||||||
|  |  | ||||||
| 	// IsDone 检查任务是否已完成 |  | ||||||
| 	IsDone() bool |  | ||||||
|  |  | ||||||
| 	// GetDescription 获取任务说明 |  | ||||||
| 	GetDescription() string |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // taskItem 任务队列中的元素 | // ProgressTracker 在内存中跟踪正在运行的计划的完成进度 | ||||||
| type taskItem struct { | type ProgressTracker struct { | ||||||
| 	task     Task | 	mu             sync.Mutex | ||||||
| 	priority int | 	totalTasks     map[uint]int // key: planExecutionLogID, value: total tasks | ||||||
| 	index    int | 	completedTasks map[uint]int // key: planExecutionLogID, value: completed tasks | ||||||
| } | } | ||||||
|  |  | ||||||
| // Queue 代表任务队列 | func NewProgressTracker() *ProgressTracker { | ||||||
| type Queue struct { | 	return &ProgressTracker{ | ||||||
| 	// queue 任务队列(按优先级排序) | 		totalTasks:     make(map[uint]int), | ||||||
| 	queue *priorityQueue | 		completedTasks: make(map[uint]int), | ||||||
|  |  | ||||||
| 	// mutex 互斥锁 |  | ||||||
| 	mutex sync.Mutex |  | ||||||
|  |  | ||||||
| 	// logger 日志记录器 |  | ||||||
| 	logger *logs.Logger |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // NewQueue 创建并返回一个新的任务队列实例。 |  | ||||||
| func NewQueue(logger *logs.Logger) *Queue { |  | ||||||
| 	pq := make(priorityQueue, 0) |  | ||||||
| 	heap.Init(&pq) |  | ||||||
|  |  | ||||||
| 	return &Queue{ |  | ||||||
| 		queue:  &pq, |  | ||||||
| 		logger: logger, |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // AddTask 向队列中添加任务 | // StartTracking 开始跟踪一个新的计划执行 | ||||||
| func (q *Queue) AddTask(task Task) { | func (t *ProgressTracker) StartTracking(planLogID uint, total int) { | ||||||
| 	q.mutex.Lock() | 	t.mu.Lock() | ||||||
| 	defer q.mutex.Unlock() | 	defer t.mu.Unlock() | ||||||
|  | 	t.totalTasks[planLogID] = total | ||||||
| 	item := &taskItem{ | 	t.completedTasks[planLogID] = 0 | ||||||
| 		task:     task, |  | ||||||
| 		priority: task.GetPriority(), |  | ||||||
| 	} |  | ||||||
| 	heap.Push(q.queue, item) |  | ||||||
| 	q.logger.Infow("任务已添加到队列", "任务ID", task.GetID(), "任务描述", task.GetDescription()) |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // GetNextTask 获取下一个要执行的任务(优先级最高的任务) | // Increment 将指定计划的完成计数加一 | ||||||
| func (q *Queue) GetNextTask() Task { | func (t *ProgressTracker) Increment(planLogID uint) { | ||||||
| 	q.mutex.Lock() | 	t.mu.Lock() | ||||||
| 	defer q.mutex.Unlock() | 	defer t.mu.Unlock() | ||||||
|  | 	t.completedTasks[planLogID]++ | ||||||
| 	if q.queue.Len() == 0 { |  | ||||||
| 		return nil |  | ||||||
| } | } | ||||||
|  |  | ||||||
| 	item := heap.Pop(q.queue).(*taskItem) | // IsComplete 检查指定计划是否已完成所有任务 | ||||||
| 	q.logger.Infow("从队列中获取任务", "任务ID", item.task.GetID(), "任务描述", item.task.GetDescription()) | func (t *ProgressTracker) IsComplete(planLogID uint) bool { | ||||||
| 	return item.task | 	t.mu.Lock() | ||||||
|  | 	defer t.mu.Unlock() | ||||||
|  | 	return t.completedTasks[planLogID] >= t.totalTasks[planLogID] | ||||||
| } | } | ||||||
|  |  | ||||||
| // GetTaskCount 获取队列中的任务数量 | // StopTracking 停止跟踪一个计划,清理内存 | ||||||
| func (q *Queue) GetTaskCount() int { | func (t *ProgressTracker) StopTracking(planLogID uint) { | ||||||
| 	q.mutex.Lock() | 	t.mu.Lock() | ||||||
| 	defer q.mutex.Unlock() | 	defer t.mu.Unlock() | ||||||
|  | 	delete(t.totalTasks, planLogID) | ||||||
| 	return q.queue.Len() | 	delete(t.completedTasks, planLogID) | ||||||
| } | } | ||||||
|  |  | ||||||
| // priorityQueue 实现优先级队列 | // Scheduler 是核心的、持久化的任务调度器 | ||||||
| type priorityQueue []*taskItem | type Scheduler struct { | ||||||
|  | 	logger          Logger | ||||||
| func (pq priorityQueue) Len() int { return len(pq) } | 	pollingInterval time.Duration | ||||||
|  |  | ||||||
| func (pq priorityQueue) Less(i, j int) bool { |  | ||||||
| 	return pq[i].priority < pq[j].priority |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (pq priorityQueue) Swap(i, j int) { |  | ||||||
| 	pq[i], pq[j] = pq[j], pq[i] |  | ||||||
| 	pq[i].index = i |  | ||||||
| 	pq[j].index = j |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (pq *priorityQueue) Push(x interface{}) { |  | ||||||
| 	n := len(*pq) |  | ||||||
| 	item := x.(*taskItem) |  | ||||||
| 	item.index = n |  | ||||||
| 	*pq = append(*pq, item) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (pq *priorityQueue) Pop() interface{} { |  | ||||||
| 	old := *pq |  | ||||||
| 	n := len(old) |  | ||||||
| 	item := old[n-1] |  | ||||||
| 	old[n-1] = nil  // 避免内存泄漏 |  | ||||||
| 	item.index = -1 // 无效索引 |  | ||||||
| 	*pq = old[0 : n-1] |  | ||||||
| 	return item |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Executor 代表任务执行器 |  | ||||||
| type Executor struct { |  | ||||||
| 	// queue 任务队列 |  | ||||||
| 	queue *Queue |  | ||||||
|  |  | ||||||
| 	// workers 工作协程数量 |  | ||||||
| 	workers         int | 	workers         int | ||||||
|  | 	pendingTaskRepo repository.PendingTaskRepository | ||||||
|  | 	progressTracker *ProgressTracker | ||||||
|  |  | ||||||
| 	// ctx 执行上下文 | 	taskChannel chan *models.TaskExecutionLog // 用于向 workers 派发任务的缓冲 channel | ||||||
| 	ctx context.Context |  | ||||||
|  |  | ||||||
| 	// cancel 取消函数 |  | ||||||
| 	cancel context.CancelFunc |  | ||||||
|  |  | ||||||
| 	// wg 等待组 |  | ||||||
| 	wg          sync.WaitGroup | 	wg          sync.WaitGroup | ||||||
|  | 	ctx         context.Context | ||||||
| 	// logger 日志记录器 | 	cancel      context.CancelFunc | ||||||
| 	logger *logs.Logger |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // NewExecutor 创建并返回一个新的任务执行器实例。 | // NewScheduler 创建一个新的调度器实例 | ||||||
| func NewExecutor(workers int, logger *logs.Logger) *Executor { | func NewScheduler(pendingTaskRepo repository.PendingTaskRepository, logger Logger, interval time.Duration, numWorkers int) *Scheduler { | ||||||
| 	ctx, cancel := context.WithCancel(context.Background()) | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
|  | 	return &Scheduler{ | ||||||
| 	return &Executor{ | 		pendingTaskRepo: pendingTaskRepo, | ||||||
| 		queue:   NewQueue(logger), // 将 logger 传递给 Queue | 		logger:          logger, | ||||||
| 		workers: workers, | 		pollingInterval: interval, | ||||||
|  | 		workers:         numWorkers, | ||||||
|  | 		progressTracker: NewProgressTracker(), | ||||||
|  | 		taskChannel:     make(chan *models.TaskExecutionLog, numWorkers), // 缓冲大小与 worker 数量一致 | ||||||
| 		ctx:             ctx, | 		ctx:             ctx, | ||||||
| 		cancel:          cancel, | 		cancel:          cancel, | ||||||
| 		logger:  logger, |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // Start 启动任务执行器 | // Start 启动调度器,包括主轮询循环和所有工作协程 | ||||||
| func (e *Executor) Start() { | func (s *Scheduler) Start() { | ||||||
| 	e.logger.Infow("正在启动任务执行器", "工作协程数", e.workers) | 	s.logger.Printf("任务调度器正在启动,工作协程数: %d...", s.workers) | ||||||
|  |  | ||||||
| 	// 启动工作协程 | 	// 启动工作协程池 | ||||||
| 	for i := 0; i < e.workers; i++ { | 	s.wg.Add(s.workers) | ||||||
| 		e.wg.Add(1) | 	for i := 0; i < s.workers; i++ { | ||||||
| 		go e.worker(i) | 		go s.worker(i) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	e.logger.Info("任务执行器启动成功") | 	// 启动主轮询循环 | ||||||
|  | 	s.wg.Add(1) | ||||||
|  | 	go s.run() | ||||||
|  |  | ||||||
|  | 	s.logger.Printf("任务调度器已成功启动") | ||||||
| } | } | ||||||
|  |  | ||||||
| // Stop 停止任务执行器 | // Stop 优雅地停止调度器和所有工作协程 | ||||||
| func (e *Executor) Stop() { | func (s *Scheduler) Stop() { | ||||||
| 	e.logger.Info("正在停止任务执行器") | 	s.logger.Printf("正在停止任务调度器...") | ||||||
|  | 	s.cancel()  // 发出取消信号 | ||||||
| 	// 取消上下文 | 	s.wg.Wait() // 等待所有协程完成 | ||||||
| 	e.cancel() | 	s.logger.Printf("任务调度器已安全停止") | ||||||
|  |  | ||||||
| 	// 等待所有工作协程结束 |  | ||||||
| 	e.wg.Wait() |  | ||||||
|  |  | ||||||
| 	e.logger.Info("任务执行器已停止") |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // SubmitTask 提交任务到执行器 | // run 是主轮询循环,负责从数据库认领任务并派发 | ||||||
| func (e *Executor) SubmitTask(task Task) { | func (s *Scheduler) run() { | ||||||
| 	e.queue.AddTask(task) | 	defer s.wg.Done() | ||||||
| 	e.logger.Infow("任务已提交", "任务ID", task.GetID(), "任务描述", task.GetDescription()) | 	ticker := time.NewTicker(s.pollingInterval) | ||||||
| } | 	defer ticker.Stop() | ||||||
|  |  | ||||||
| // worker 工作协程 |  | ||||||
| func (e *Executor) worker(id int) { |  | ||||||
| 	defer e.wg.Done() |  | ||||||
|  |  | ||||||
| 	e.logger.Infow("工作协程已启动", "工作协程ID", id) |  | ||||||
|  |  | ||||||
| 	for { | 	for { | ||||||
| 		select { | 		select { | ||||||
| 		case <-e.ctx.Done(): | 		case <-s.ctx.Done(): | ||||||
| 			e.logger.Infow("工作协程已停止", "工作协程ID", id) | 			close(s.taskChannel) // 关闭 channel,让 workers 退出循环 | ||||||
| 			return | 			return | ||||||
| 		default: | 		case <-ticker.C: | ||||||
| 			// 获取下一个任务 | 			s.claimAndDispatch() | ||||||
| 			task := e.queue.GetNextTask() | 		} | ||||||
| 			if task != nil { | 	} | ||||||
| 				e.logger.Infow("工作协程正在执行任务", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription()) | } | ||||||
|  |  | ||||||
| 				// 执行任务 | // claimAndDispatch 认领一个任务并将其发送到派发通道 | ||||||
| 				if err := task.Execute(); err != nil { | func (s *Scheduler) claimAndDispatch() { | ||||||
| 					e.logger.Errorw("任务执行失败", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription(), "错误", err) | 	claimedLog, err := s.pendingTaskRepo.ClaimNextDueTask() | ||||||
| 				} else { | 	if err != nil { | ||||||
| 					e.logger.Infow("任务执行成功", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription()) | 		if !errors.Is(err, gorm.ErrRecordNotFound) { | ||||||
|  | 			s.logger.Printf("认领任务时发生错误: %v", err) | ||||||
| 		} | 		} | ||||||
| 			} else { | 		return | ||||||
| 				// 没有任务时短暂休眠 | 	} | ||||||
| 				time.Sleep(100 * time.Millisecond) |  | ||||||
|  | 	// 将认领到的任务发送到派发通道 | ||||||
|  | 	// 如果所有 worker 都在忙,这里会阻塞,从而实现背压,防止队列无限增长 | ||||||
|  | 	select { | ||||||
|  | 	case s.taskChannel <- claimedLog: | ||||||
|  | 		s.logger.Printf("成功认领并派发任务, 日志ID: %d, 任务ID: %d", claimedLog.ID, claimedLog.TaskID) | ||||||
|  | 	case <-s.ctx.Done(): | ||||||
|  | 		// 如果在等待派发时调度器被停止,需要处理这个未派发的任务 | ||||||
|  | 		// 简单的处理方式是忽略它,让清理器进程后续来处理这个 'running' 状态的任务 | ||||||
|  | 		s.logger.Printf("在派发任务时调度器被停止, 日志ID: %d", claimedLog.ID) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // worker 是工作协程的实现 | ||||||
|  | func (s *Scheduler) worker(id int) { | ||||||
|  | 	defer s.wg.Done() | ||||||
|  | 	s.logger.Printf("工作协程 #%d 已启动", id) | ||||||
|  | 	for claimedLog := range s.taskChannel { | ||||||
|  | 		s.processTask(id, claimedLog) | ||||||
| 	} | 	} | ||||||
|  | 	s.logger.Printf("工作协程 #%d 已停止", id) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // processTask 包含了处理单个任务的完整逻辑 | ||||||
|  | func (s *Scheduler) processTask(workerID int, claimedLog *models.TaskExecutionLog) { | ||||||
|  | 	s.logger.Printf("工作协程 #%d 正在处理任务, 日志ID: %d, 任务ID: %d, 任务名称: %s", | ||||||
|  | 		workerID, claimedLog.ID, claimedLog.TaskID, claimedLog.Task.Name) | ||||||
|  |  | ||||||
|  | 	// 在这里,我们将根据 claimedLog.TaskID 或未来的 Task.Kind 来分发给不同的处理器 | ||||||
|  | 	// 现在,我们只做一个模拟执行 | ||||||
|  | 	time.Sleep(2 * time.Second) // 模拟任务执行耗时 | ||||||
|  |  | ||||||
|  | 	// 任务执行完毕后,更新日志和进度 | ||||||
|  | 	s.logger.Printf("工作协程 #%d 已完成任务, 日志ID: %d", workerID, claimedLog.ID) | ||||||
|  |  | ||||||
|  | 	// ---------------------------------------------------- | ||||||
|  | 	// 未来的逻辑将在这里展开: | ||||||
|  | 	// | ||||||
|  | 	// 1. 调用 handler.Handle(claimedLog) | ||||||
|  | 	// 2. 根据 handler 返回的 error 更新日志为 'completed' 或 'failed' | ||||||
|  | 	//    execLogRepo.UpdateTaskExecutionLog(...) | ||||||
|  | 	// 3. 如果成功,则 s.progressTracker.Increment(claimedLog.PlanExecutionLogID) | ||||||
|  | 	// 4. 检查 s.progressTracker.IsComplete(...),如果完成则执行计划收尾工作 | ||||||
|  | 	// | ||||||
|  | 	// ---------------------------------------------------- | ||||||
| } | } | ||||||
|   | |||||||
| @@ -2,8 +2,6 @@ | |||||||
| package task_test | package task_test | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"errors" |  | ||||||
| 	"fmt" |  | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
| 	"testing" | 	"testing" | ||||||
| @@ -11,8 +9,6 @@ import ( | |||||||
|  |  | ||||||
| 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/config" | 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/config" | ||||||
| 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs" | 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs" | ||||||
| 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/task" |  | ||||||
| 	"github.com/stretchr/testify/assert" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // testLogger 是一个用于所有测试用例的静默 logger 实例。 | // testLogger 是一个用于所有测试用例的静默 logger 实例。 | ||||||
| @@ -79,230 +75,3 @@ func waitForWaitGroup(t *testing.T, wg *sync.WaitGroup, timeout time.Duration) { | |||||||
| 		t.Fatal("等待任务完成超时") | 		t.Fatal("等待任务完成超时") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // --- 任务队列测试 (无需更改) --- |  | ||||||
|  |  | ||||||
| func TestNewQueue(t *testing.T) { |  | ||||||
| 	tq := task.NewQueue(testLogger) |  | ||||||
| 	assert.NotNil(t, tq, "新创建的任务队列不应为 nil") |  | ||||||
| 	assert.Equal(t, 0, tq.GetTaskCount(), "新创建的任务队列应为空") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestQueue_AddTask(t *testing.T) { |  | ||||||
| 	tq := task.NewQueue(testLogger) |  | ||||||
| 	mockTask := &MockTask{id: "task1", priority: 1} |  | ||||||
|  |  | ||||||
| 	tq.AddTask(mockTask) |  | ||||||
| 	assert.Equal(t, 1, tq.GetTaskCount(), "添加任务后,队列中的任务数应为 1") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ... (其他任务队列测试保持不变) |  | ||||||
| func TestQueue_GetNextTask(t *testing.T) { |  | ||||||
| 	t.Run("从空队列获取任务", func(t *testing.T) { |  | ||||||
| 		tq := task.NewQueue(testLogger) |  | ||||||
| 		nextTask := tq.GetNextTask() |  | ||||||
| 		assert.Nil(t, nextTask, "从空队列中获取任务应返回 nil") |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	t.Run("按优先级获取任务", func(t *testing.T) { |  | ||||||
| 		tq := task.NewQueue(testLogger) |  | ||||||
| 		task1 := &MockTask{id: "task1", priority: 10} |  | ||||||
| 		task2 := &MockTask{id: "task2", priority: 1} // 优先级更高 |  | ||||||
| 		task3 := &MockTask{id: "task3", priority: 5} |  | ||||||
|  |  | ||||||
| 		tq.AddTask(task1) |  | ||||||
| 		tq.AddTask(task2) |  | ||||||
| 		tq.AddTask(task3) |  | ||||||
|  |  | ||||||
| 		assert.Equal(t, 3, tq.GetTaskCount(), "添加三个任务后,队列中的任务数应为 3") |  | ||||||
|  |  | ||||||
| 		nextTask := tq.GetNextTask() |  | ||||||
| 		assert.NotNil(t, nextTask) |  | ||||||
| 		assert.Equal(t, "task2", nextTask.GetID(), "应首先获取优先级最高的任务 (task2)") |  | ||||||
|  |  | ||||||
| 		nextTask = tq.GetNextTask() |  | ||||||
| 		assert.NotNil(t, nextTask) |  | ||||||
| 		assert.Equal(t, "task3", nextTask.GetID(), "应获取下一个优先级最高的任务 (task3)") |  | ||||||
|  |  | ||||||
| 		nextTask = tq.GetNextTask() |  | ||||||
| 		assert.NotNil(t, nextTask) |  | ||||||
| 		assert.Equal(t, "task1", nextTask.GetID(), "应最后获取优先级最低的任务 (task1)") |  | ||||||
|  |  | ||||||
| 		assert.Equal(t, 0, tq.GetTaskCount(), "获取所有任务后,队列应为空") |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestQueue_Concurrency(t *testing.T) { |  | ||||||
| 	tq := task.NewQueue(testLogger) |  | ||||||
| 	var wg sync.WaitGroup |  | ||||||
| 	taskCount := 100 |  | ||||||
|  |  | ||||||
| 	wg.Add(taskCount) |  | ||||||
| 	for i := 0; i < taskCount; i++ { |  | ||||||
| 		go func(i int) { |  | ||||||
| 			defer wg.Done() |  | ||||||
| 			tq.AddTask(&MockTask{id: fmt.Sprintf("task-%d", i), priority: i}) |  | ||||||
| 		}(i) |  | ||||||
| 	} |  | ||||||
| 	wg.Wait() |  | ||||||
|  |  | ||||||
| 	assert.Equal(t, taskCount, tq.GetTaskCount(), "并发添加任务后,队列中的任务数应为 %d", taskCount) |  | ||||||
|  |  | ||||||
| 	wg.Add(taskCount) |  | ||||||
| 	for i := 0; i < taskCount; i++ { |  | ||||||
| 		go func() { |  | ||||||
| 			defer wg.Done() |  | ||||||
| 			task := tq.GetNextTask() |  | ||||||
| 			assert.NotNil(t, task) |  | ||||||
| 		}() |  | ||||||
| 	} |  | ||||||
| 	wg.Wait() |  | ||||||
|  |  | ||||||
| 	assert.Equal(t, 0, tq.GetTaskCount(), "并发获取所有任务后,队列应为空") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // --- 执行器测试 (为可靠性重构) --- |  | ||||||
|  |  | ||||||
| func TestNewExecutor(t *testing.T) { |  | ||||||
| 	executor := task.NewExecutor(5, testLogger) |  | ||||||
| 	assert.NotNil(t, executor, "新创建的执行器不应为 nil") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestExecutor_StartStop(t *testing.T) { |  | ||||||
| 	executor := task.NewExecutor(2, testLogger) |  | ||||||
| 	executor.Start() |  | ||||||
| 	// 确保立即停止不会导致死锁或竞争条件。 |  | ||||||
| 	executor.Stop() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // TestExecutor_SubmitAndExecuteTask 测试提交并执行单个任务 (已重构,更可靠) |  | ||||||
| func TestExecutor_SubmitAndExecuteTask(t *testing.T) { |  | ||||||
| 	var wg sync.WaitGroup |  | ||||||
| 	wg.Add(1) |  | ||||||
|  |  | ||||||
| 	executor := task.NewExecutor(1, testLogger) |  | ||||||
| 	mockTask := &MockTask{ |  | ||||||
| 		id:       "task1", |  | ||||||
| 		priority: 1, |  | ||||||
| 		execute: func() error { |  | ||||||
| 			wg.Done() // 任务完成时通知 WaitGroup |  | ||||||
| 			return nil |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	executor.Start() |  | ||||||
| 	executor.SubmitTask(mockTask) |  | ||||||
|  |  | ||||||
| 	// 等待任务完成,设置一个合理的超时时间 |  | ||||||
| 	waitForWaitGroup(t, &wg, 2*time.Second) |  | ||||||
|  |  | ||||||
| 	executor.Stop() |  | ||||||
|  |  | ||||||
| 	assert.Equal(t, int32(1), mockTask.ExecutedCount(), "任务应该已被执行") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // TestExecutor_ExecuteMultipleTasks 测试执行多个任务 (已重构,更可靠) |  | ||||||
| func TestExecutor_ExecuteMultipleTasks(t *testing.T) { |  | ||||||
| 	taskCount := 10 |  | ||||||
| 	var wg sync.WaitGroup |  | ||||||
| 	wg.Add(taskCount) |  | ||||||
|  |  | ||||||
| 	executor := task.NewExecutor(3, testLogger) |  | ||||||
| 	mockTasks := make([]*MockTask, taskCount) |  | ||||||
| 	for i := 0; i < taskCount; i++ { |  | ||||||
| 		mockTasks[i] = &MockTask{ |  | ||||||
| 			id:       fmt.Sprintf("task-%d", i), |  | ||||||
| 			priority: i, |  | ||||||
| 			execute: func() error { |  | ||||||
| 				wg.Done() // 每个任务完成时都通知 WaitGroup |  | ||||||
| 				return nil |  | ||||||
| 			}, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	executor.Start() |  | ||||||
| 	for _, task := range mockTasks { |  | ||||||
| 		executor.SubmitTask(task) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// 等待所有任务完成 |  | ||||||
| 	waitForWaitGroup(t, &wg, 2*time.Second) |  | ||||||
|  |  | ||||||
| 	executor.Stop() |  | ||||||
|  |  | ||||||
| 	var totalExecuted int32 |  | ||||||
| 	for _, task := range mockTasks { |  | ||||||
| 		totalExecuted += task.ExecutedCount() |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	assert.Equal(t, int32(taskCount), totalExecuted, "所有提交的任务都应该被执行") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // TestExecutor_TaskExecutionError 测试任务执行失败的场景 (已重构,更可靠) |  | ||||||
| func TestExecutor_TaskExecutionError(t *testing.T) { |  | ||||||
| 	var wg sync.WaitGroup |  | ||||||
| 	wg.Add(2) // 我们期望两个任务都被执行 |  | ||||||
|  |  | ||||||
| 	executor := task.NewExecutor(1, testLogger) |  | ||||||
| 	errorTask := &MockTask{ |  | ||||||
| 		id:       "errorTask", |  | ||||||
| 		priority: 1, |  | ||||||
| 		execute: func() error { |  | ||||||
| 			wg.Done() |  | ||||||
| 			return errors.New("执行失败") |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	successTask := &MockTask{ |  | ||||||
| 		id:       "successTask", |  | ||||||
| 		priority: 2, // 后执行 |  | ||||||
| 		execute: func() error { |  | ||||||
| 			wg.Done() |  | ||||||
| 			return nil |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	executor.Start() |  | ||||||
| 	executor.SubmitTask(errorTask) |  | ||||||
| 	executor.SubmitTask(successTask) |  | ||||||
|  |  | ||||||
| 	waitForWaitGroup(t, &wg, 2*time.Second) |  | ||||||
| 	executor.Stop() |  | ||||||
|  |  | ||||||
| 	assert.Equal(t, int32(1), errorTask.ExecutedCount(), "失败的任务应该被执行一次") |  | ||||||
| 	assert.Equal(t, int32(1), successTask.ExecutedCount(), "成功的任务也应该被执行") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // TestExecutor_StopWithPendingTasks 测试停止执行器时仍有待处理任务 (已重构,更可靠) |  | ||||||
| func TestExecutor_StopWithPendingTasks(t *testing.T) { |  | ||||||
| 	executor := task.NewExecutor(1, testLogger) |  | ||||||
| 	task1Started := make(chan struct{}) |  | ||||||
|  |  | ||||||
| 	task1 := &MockTask{ |  | ||||||
| 		id:       "task1", |  | ||||||
| 		priority: 1, |  | ||||||
| 		execute: func() error { |  | ||||||
| 			close(task1Started)                // 发送信号,通知测试 task1 已开始执行 |  | ||||||
| 			time.Sleep(200 * time.Millisecond) // 模拟耗时操作 |  | ||||||
| 			return nil |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 	task2 := &MockTask{id: "task2", priority: 2} |  | ||||||
|  |  | ||||||
| 	executor.Start() |  | ||||||
| 	executor.SubmitTask(task1) |  | ||||||
| 	executor.SubmitTask(task2) |  | ||||||
|  |  | ||||||
| 	// 等待 task1 开始执行的信号,而不是依赖不确定的 sleep |  | ||||||
| 	select { |  | ||||||
| 	case <-task1Started: |  | ||||||
| 		// task1 已开始,可以安全地停止执行器了 |  | ||||||
| 	case <-time.After(1 * time.Second): |  | ||||||
| 		t.Fatal("等待 task1 启动超时") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	executor.Stop() |  | ||||||
|  |  | ||||||
| 	assert.Equal(t, int32(1), task1.ExecutedCount(), "task1 应该在停止前开始执行") |  | ||||||
| 	assert.Equal(t, int32(0), task2.ExecutedCount(), "task2 不应该被执行,因为执行器已停止") |  | ||||||
| } |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user