重写task调度器之简单实现
This commit is contained in:
@@ -1,223 +1,188 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs"
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/models"
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/repository"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Task 代表一个任务接口
|
||||
// 所有任务都需要实现此接口
|
||||
type Task interface {
|
||||
// Execute 执行任务
|
||||
Execute() error
|
||||
|
||||
// GetID 获取任务ID
|
||||
GetID() string
|
||||
|
||||
// GetPriority 获取任务优先级
|
||||
GetPriority() int
|
||||
|
||||
// IsDone 检查任务是否已完成
|
||||
IsDone() bool
|
||||
|
||||
// GetDescription 获取任务说明
|
||||
GetDescription() string
|
||||
// Logger 定义了调度器期望的日志接口,方便替换为项目中的日志组件
|
||||
type Logger interface {
|
||||
Printf(format string, v ...interface{})
|
||||
}
|
||||
|
||||
// taskItem 任务队列中的元素
|
||||
type taskItem struct {
|
||||
task Task
|
||||
priority int
|
||||
index int
|
||||
// ProgressTracker 在内存中跟踪正在运行的计划的完成进度
|
||||
type ProgressTracker struct {
|
||||
mu sync.Mutex
|
||||
totalTasks map[uint]int // key: planExecutionLogID, value: total tasks
|
||||
completedTasks map[uint]int // key: planExecutionLogID, value: completed tasks
|
||||
}
|
||||
|
||||
// Queue 代表任务队列
|
||||
type Queue struct {
|
||||
// queue 任务队列(按优先级排序)
|
||||
queue *priorityQueue
|
||||
|
||||
// mutex 互斥锁
|
||||
mutex sync.Mutex
|
||||
|
||||
// logger 日志记录器
|
||||
logger *logs.Logger
|
||||
}
|
||||
|
||||
// NewQueue 创建并返回一个新的任务队列实例。
|
||||
func NewQueue(logger *logs.Logger) *Queue {
|
||||
pq := make(priorityQueue, 0)
|
||||
heap.Init(&pq)
|
||||
|
||||
return &Queue{
|
||||
queue: &pq,
|
||||
logger: logger,
|
||||
func NewProgressTracker() *ProgressTracker {
|
||||
return &ProgressTracker{
|
||||
totalTasks: make(map[uint]int),
|
||||
completedTasks: make(map[uint]int),
|
||||
}
|
||||
}
|
||||
|
||||
// AddTask 向队列中添加任务
|
||||
func (q *Queue) AddTask(task Task) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
item := &taskItem{
|
||||
task: task,
|
||||
priority: task.GetPriority(),
|
||||
}
|
||||
heap.Push(q.queue, item)
|
||||
q.logger.Infow("任务已添加到队列", "任务ID", task.GetID(), "任务描述", task.GetDescription())
|
||||
// StartTracking 开始跟踪一个新的计划执行
|
||||
func (t *ProgressTracker) StartTracking(planLogID uint, total int) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.totalTasks[planLogID] = total
|
||||
t.completedTasks[planLogID] = 0
|
||||
}
|
||||
|
||||
// GetNextTask 获取下一个要执行的任务(优先级最高的任务)
|
||||
func (q *Queue) GetNextTask() Task {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
if q.queue.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
item := heap.Pop(q.queue).(*taskItem)
|
||||
q.logger.Infow("从队列中获取任务", "任务ID", item.task.GetID(), "任务描述", item.task.GetDescription())
|
||||
return item.task
|
||||
// Increment 将指定计划的完成计数加一
|
||||
func (t *ProgressTracker) Increment(planLogID uint) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.completedTasks[planLogID]++
|
||||
}
|
||||
|
||||
// GetTaskCount 获取队列中的任务数量
|
||||
func (q *Queue) GetTaskCount() int {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
return q.queue.Len()
|
||||
// IsComplete 检查指定计划是否已完成所有任务
|
||||
func (t *ProgressTracker) IsComplete(planLogID uint) bool {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.completedTasks[planLogID] >= t.totalTasks[planLogID]
|
||||
}
|
||||
|
||||
// priorityQueue 实现优先级队列
|
||||
type priorityQueue []*taskItem
|
||||
|
||||
func (pq priorityQueue) Len() int { return len(pq) }
|
||||
|
||||
func (pq priorityQueue) Less(i, j int) bool {
|
||||
return pq[i].priority < pq[j].priority
|
||||
// StopTracking 停止跟踪一个计划,清理内存
|
||||
func (t *ProgressTracker) StopTracking(planLogID uint) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
delete(t.totalTasks, planLogID)
|
||||
delete(t.completedTasks, planLogID)
|
||||
}
|
||||
|
||||
func (pq priorityQueue) Swap(i, j int) {
|
||||
pq[i], pq[j] = pq[j], pq[i]
|
||||
pq[i].index = i
|
||||
pq[j].index = j
|
||||
// Scheduler 是核心的、持久化的任务调度器
|
||||
type Scheduler struct {
|
||||
logger Logger
|
||||
pollingInterval time.Duration
|
||||
workers int
|
||||
pendingTaskRepo repository.PendingTaskRepository
|
||||
progressTracker *ProgressTracker
|
||||
|
||||
taskChannel chan *models.TaskExecutionLog // 用于向 workers 派发任务的缓冲 channel
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (pq *priorityQueue) Push(x interface{}) {
|
||||
n := len(*pq)
|
||||
item := x.(*taskItem)
|
||||
item.index = n
|
||||
*pq = append(*pq, item)
|
||||
}
|
||||
|
||||
func (pq *priorityQueue) Pop() interface{} {
|
||||
old := *pq
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
old[n-1] = nil // 避免内存泄漏
|
||||
item.index = -1 // 无效索引
|
||||
*pq = old[0 : n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
// Executor 代表任务执行器
|
||||
type Executor struct {
|
||||
// queue 任务队列
|
||||
queue *Queue
|
||||
|
||||
// workers 工作协程数量
|
||||
workers int
|
||||
|
||||
// ctx 执行上下文
|
||||
ctx context.Context
|
||||
|
||||
// cancel 取消函数
|
||||
cancel context.CancelFunc
|
||||
|
||||
// wg 等待组
|
||||
wg sync.WaitGroup
|
||||
|
||||
// logger 日志记录器
|
||||
logger *logs.Logger
|
||||
}
|
||||
|
||||
// NewExecutor 创建并返回一个新的任务执行器实例。
|
||||
func NewExecutor(workers int, logger *logs.Logger) *Executor {
|
||||
// NewScheduler 创建一个新的调度器实例
|
||||
func NewScheduler(pendingTaskRepo repository.PendingTaskRepository, logger Logger, interval time.Duration, numWorkers int) *Scheduler {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &Executor{
|
||||
queue: NewQueue(logger), // 将 logger 传递给 Queue
|
||||
workers: workers,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
return &Scheduler{
|
||||
pendingTaskRepo: pendingTaskRepo,
|
||||
logger: logger,
|
||||
pollingInterval: interval,
|
||||
workers: numWorkers,
|
||||
progressTracker: NewProgressTracker(),
|
||||
taskChannel: make(chan *models.TaskExecutionLog, numWorkers), // 缓冲大小与 worker 数量一致
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动任务执行器
|
||||
func (e *Executor) Start() {
|
||||
e.logger.Infow("正在启动任务执行器", "工作协程数", e.workers)
|
||||
// Start 启动调度器,包括主轮询循环和所有工作协程
|
||||
func (s *Scheduler) Start() {
|
||||
s.logger.Printf("任务调度器正在启动,工作协程数: %d...", s.workers)
|
||||
|
||||
// 启动工作协程
|
||||
for i := 0; i < e.workers; i++ {
|
||||
e.wg.Add(1)
|
||||
go e.worker(i)
|
||||
// 启动工作协程池
|
||||
s.wg.Add(s.workers)
|
||||
for i := 0; i < s.workers; i++ {
|
||||
go s.worker(i)
|
||||
}
|
||||
|
||||
e.logger.Info("任务执行器启动成功")
|
||||
// 启动主轮询循环
|
||||
s.wg.Add(1)
|
||||
go s.run()
|
||||
|
||||
s.logger.Printf("任务调度器已成功启动")
|
||||
}
|
||||
|
||||
// Stop 停止任务执行器
|
||||
func (e *Executor) Stop() {
|
||||
e.logger.Info("正在停止任务执行器")
|
||||
|
||||
// 取消上下文
|
||||
e.cancel()
|
||||
|
||||
// 等待所有工作协程结束
|
||||
e.wg.Wait()
|
||||
|
||||
e.logger.Info("任务执行器已停止")
|
||||
// Stop 优雅地停止调度器和所有工作协程
|
||||
func (s *Scheduler) Stop() {
|
||||
s.logger.Printf("正在停止任务调度器...")
|
||||
s.cancel() // 发出取消信号
|
||||
s.wg.Wait() // 等待所有协程完成
|
||||
s.logger.Printf("任务调度器已安全停止")
|
||||
}
|
||||
|
||||
// SubmitTask 提交任务到执行器
|
||||
func (e *Executor) SubmitTask(task Task) {
|
||||
e.queue.AddTask(task)
|
||||
e.logger.Infow("任务已提交", "任务ID", task.GetID(), "任务描述", task.GetDescription())
|
||||
}
|
||||
|
||||
// worker 工作协程
|
||||
func (e *Executor) worker(id int) {
|
||||
defer e.wg.Done()
|
||||
|
||||
e.logger.Infow("工作协程已启动", "工作协程ID", id)
|
||||
// run 是主轮询循环,负责从数据库认领任务并派发
|
||||
func (s *Scheduler) run() {
|
||||
defer s.wg.Done()
|
||||
ticker := time.NewTicker(s.pollingInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-e.ctx.Done():
|
||||
e.logger.Infow("工作协程已停止", "工作协程ID", id)
|
||||
case <-s.ctx.Done():
|
||||
close(s.taskChannel) // 关闭 channel,让 workers 退出循环
|
||||
return
|
||||
default:
|
||||
// 获取下一个任务
|
||||
task := e.queue.GetNextTask()
|
||||
if task != nil {
|
||||
e.logger.Infow("工作协程正在执行任务", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription())
|
||||
|
||||
// 执行任务
|
||||
if err := task.Execute(); err != nil {
|
||||
e.logger.Errorw("任务执行失败", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription(), "错误", err)
|
||||
} else {
|
||||
e.logger.Infow("任务执行成功", "工作协程ID", id, "任务ID", task.GetID(), "任务描述", task.GetDescription())
|
||||
}
|
||||
} else {
|
||||
// 没有任务时短暂休眠
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
case <-ticker.C:
|
||||
s.claimAndDispatch()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// claimAndDispatch 认领一个任务并将其发送到派发通道
|
||||
func (s *Scheduler) claimAndDispatch() {
|
||||
claimedLog, err := s.pendingTaskRepo.ClaimNextDueTask()
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
s.logger.Printf("认领任务时发生错误: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 将认领到的任务发送到派发通道
|
||||
// 如果所有 worker 都在忙,这里会阻塞,从而实现背压,防止队列无限增长
|
||||
select {
|
||||
case s.taskChannel <- claimedLog:
|
||||
s.logger.Printf("成功认领并派发任务, 日志ID: %d, 任务ID: %d", claimedLog.ID, claimedLog.TaskID)
|
||||
case <-s.ctx.Done():
|
||||
// 如果在等待派发时调度器被停止,需要处理这个未派发的任务
|
||||
// 简单的处理方式是忽略它,让清理器进程后续来处理这个 'running' 状态的任务
|
||||
s.logger.Printf("在派发任务时调度器被停止, 日志ID: %d", claimedLog.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// worker 是工作协程的实现
|
||||
func (s *Scheduler) worker(id int) {
|
||||
defer s.wg.Done()
|
||||
s.logger.Printf("工作协程 #%d 已启动", id)
|
||||
for claimedLog := range s.taskChannel {
|
||||
s.processTask(id, claimedLog)
|
||||
}
|
||||
s.logger.Printf("工作协程 #%d 已停止", id)
|
||||
}
|
||||
|
||||
// processTask 包含了处理单个任务的完整逻辑
|
||||
func (s *Scheduler) processTask(workerID int, claimedLog *models.TaskExecutionLog) {
|
||||
s.logger.Printf("工作协程 #%d 正在处理任务, 日志ID: %d, 任务ID: %d, 任务名称: %s",
|
||||
workerID, claimedLog.ID, claimedLog.TaskID, claimedLog.Task.Name)
|
||||
|
||||
// 在这里,我们将根据 claimedLog.TaskID 或未来的 Task.Kind 来分发给不同的处理器
|
||||
// 现在,我们只做一个模拟执行
|
||||
time.Sleep(2 * time.Second) // 模拟任务执行耗时
|
||||
|
||||
// 任务执行完毕后,更新日志和进度
|
||||
s.logger.Printf("工作协程 #%d 已完成任务, 日志ID: %d", workerID, claimedLog.ID)
|
||||
|
||||
// ----------------------------------------------------
|
||||
// 未来的逻辑将在这里展开:
|
||||
//
|
||||
// 1. 调用 handler.Handle(claimedLog)
|
||||
// 2. 根据 handler 返回的 error 更新日志为 'completed' 或 'failed'
|
||||
// execLogRepo.UpdateTaskExecutionLog(...)
|
||||
// 3. 如果成功,则 s.progressTracker.Increment(claimedLog.PlanExecutionLogID)
|
||||
// 4. 检查 s.progressTracker.IsComplete(...),如果完成则执行计划收尾工作
|
||||
//
|
||||
// ----------------------------------------------------
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user