Files
pig-farm-controller/internal/infra/task/task.go

189 lines
5.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package task
import (
"context"
"errors"
"sync"
"time"
"git.huangwc.com/pig/pig-farm-controller/internal/infra/models"
"git.huangwc.com/pig/pig-farm-controller/internal/infra/repository"
"gorm.io/gorm"
)
// Logger 定义了调度器期望的日志接口,方便替换为项目中的日志组件
type Logger interface {
Printf(format string, v ...interface{})
}
// ProgressTracker 在内存中跟踪正在运行的计划的完成进度
type ProgressTracker struct {
mu sync.Mutex
totalTasks map[uint]int // key: planExecutionLogID, value: total tasks
completedTasks map[uint]int // key: planExecutionLogID, value: completed tasks
}
func NewProgressTracker() *ProgressTracker {
return &ProgressTracker{
totalTasks: make(map[uint]int),
completedTasks: make(map[uint]int),
}
}
// StartTracking 开始跟踪一个新的计划执行
func (t *ProgressTracker) StartTracking(planLogID uint, total int) {
t.mu.Lock()
defer t.mu.Unlock()
t.totalTasks[planLogID] = total
t.completedTasks[planLogID] = 0
}
// Increment 将指定计划的完成计数加一
func (t *ProgressTracker) Increment(planLogID uint) {
t.mu.Lock()
defer t.mu.Unlock()
t.completedTasks[planLogID]++
}
// IsComplete 检查指定计划是否已完成所有任务
func (t *ProgressTracker) IsComplete(planLogID uint) bool {
t.mu.Lock()
defer t.mu.Unlock()
return t.completedTasks[planLogID] >= t.totalTasks[planLogID]
}
// StopTracking 停止跟踪一个计划,清理内存
func (t *ProgressTracker) StopTracking(planLogID uint) {
t.mu.Lock()
defer t.mu.Unlock()
delete(t.totalTasks, planLogID)
delete(t.completedTasks, planLogID)
}
// Scheduler 是核心的、持久化的任务调度器
type Scheduler struct {
logger Logger
pollingInterval time.Duration
workers int
pendingTaskRepo repository.PendingTaskRepository
progressTracker *ProgressTracker
taskChannel chan *models.TaskExecutionLog // 用于向 workers 派发任务的缓冲 channel
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
// NewScheduler 创建一个新的调度器实例
func NewScheduler(pendingTaskRepo repository.PendingTaskRepository, logger Logger, interval time.Duration, numWorkers int) *Scheduler {
ctx, cancel := context.WithCancel(context.Background())
return &Scheduler{
pendingTaskRepo: pendingTaskRepo,
logger: logger,
pollingInterval: interval,
workers: numWorkers,
progressTracker: NewProgressTracker(),
taskChannel: make(chan *models.TaskExecutionLog, numWorkers), // 缓冲大小与 worker 数量一致
ctx: ctx,
cancel: cancel,
}
}
// Start 启动调度器,包括主轮询循环和所有工作协程
func (s *Scheduler) Start() {
s.logger.Printf("任务调度器正在启动,工作协程数: %d...", s.workers)
// 启动工作协程池
s.wg.Add(s.workers)
for i := 0; i < s.workers; i++ {
go s.worker(i)
}
// 启动主轮询循环
s.wg.Add(1)
go s.run()
s.logger.Printf("任务调度器已成功启动")
}
// Stop 优雅地停止调度器和所有工作协程
func (s *Scheduler) Stop() {
s.logger.Printf("正在停止任务调度器...")
s.cancel() // 发出取消信号
s.wg.Wait() // 等待所有协程完成
s.logger.Printf("任务调度器已安全停止")
}
// run 是主轮询循环,负责从数据库认领任务并派发
func (s *Scheduler) run() {
defer s.wg.Done()
ticker := time.NewTicker(s.pollingInterval)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
close(s.taskChannel) // 关闭 channel让 workers 退出循环
return
case <-ticker.C:
s.claimAndDispatch()
}
}
}
// claimAndDispatch 认领一个任务并将其发送到派发通道
func (s *Scheduler) claimAndDispatch() {
claimedLog, err := s.pendingTaskRepo.ClaimNextDueTask()
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
s.logger.Printf("认领任务时发生错误: %v", err)
}
return
}
// 将认领到的任务发送到派发通道
// 如果所有 worker 都在忙,这里会阻塞,从而实现背压,防止队列无限增长
select {
case s.taskChannel <- claimedLog:
s.logger.Printf("成功认领并派发任务, 日志ID: %d, 任务ID: %d", claimedLog.ID, claimedLog.TaskID)
case <-s.ctx.Done():
// 如果在等待派发时调度器被停止,需要处理这个未派发的任务
// 简单的处理方式是忽略它,让清理器进程后续来处理这个 'running' 状态的任务
s.logger.Printf("在派发任务时调度器被停止, 日志ID: %d", claimedLog.ID)
}
}
// worker 是工作协程的实现
func (s *Scheduler) worker(id int) {
defer s.wg.Done()
s.logger.Printf("工作协程 #%d 已启动", id)
for claimedLog := range s.taskChannel {
s.processTask(id, claimedLog)
}
s.logger.Printf("工作协程 #%d 已停止", id)
}
// processTask 包含了处理单个任务的完整逻辑
func (s *Scheduler) processTask(workerID int, claimedLog *models.TaskExecutionLog) {
s.logger.Printf("工作协程 #%d 正在处理任务, 日志ID: %d, 任务ID: %d, 任务名称: %s",
workerID, claimedLog.ID, claimedLog.TaskID, claimedLog.Task.Name)
// 在这里,我们将根据 claimedLog.TaskID 或未来的 Task.Kind 来分发给不同的处理器
// 现在,我们只做一个模拟执行
time.Sleep(2 * time.Second) // 模拟任务执行耗时
// 任务执行完毕后,更新日志和进度
s.logger.Printf("工作协程 #%d 已完成任务, 日志ID: %d", workerID, claimedLog.ID)
// ----------------------------------------------------
// 未来的逻辑将在这里展开:
//
// 1. 调用 handler.Handle(claimedLog)
// 2. 根据 handler 返回的 error 更新日志为 'completed' 或 'failed'
// execLogRepo.UpdateTaskExecutionLog(...)
// 3. 如果成功,则 s.progressTracker.Increment(claimedLog.PlanExecutionLogID)
// 4. 检查 s.progressTracker.IsComplete(...),如果完成则执行计划收尾工作
//
// ----------------------------------------------------
}