Files
pig-farm-controller/internal/infra/task/task.go

186 lines
5.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package task
import (
"context"
"errors"
"sync"
"time"
"git.huangwc.com/pig/pig-farm-controller/internal/infra/models"
"git.huangwc.com/pig/pig-farm-controller/internal/infra/repository"
"github.com/panjf2000/ants/v2"
"gorm.io/gorm"
)
// Logger 定义了调度器期望的日志接口,方便替换为项目中的日志组件
type Logger interface {
Printf(format string, v ...interface{})
}
// ProgressTracker 在内存中跟踪正在运行的计划的完成进度
type ProgressTracker struct {
mu sync.Mutex
totalTasks map[uint]int // key: planExecutionLogID, value: total tasks
completedTasks map[uint]int // key: planExecutionLogID, value: completed tasks
}
func NewProgressTracker() *ProgressTracker {
return &ProgressTracker{
totalTasks: make(map[uint]int),
completedTasks: make(map[uint]int),
}
}
// Scheduler 是核心的、持久化的任务调度器
type Scheduler struct {
logger Logger
pollingInterval time.Duration
workers int
pendingTaskRepo repository.PendingTaskRepository
progressTracker *ProgressTracker
pool *ants.Pool // 使用 ants 协程池来管理并发
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
// NewScheduler 创建一个新的调度器实例
func NewScheduler(pendingTaskRepo repository.PendingTaskRepository, logger Logger, interval time.Duration, numWorkers int) *Scheduler {
ctx, cancel := context.WithCancel(context.Background())
return &Scheduler{
pendingTaskRepo: pendingTaskRepo,
logger: logger,
pollingInterval: interval,
workers: numWorkers,
progressTracker: NewProgressTracker(),
ctx: ctx,
cancel: cancel,
}
}
// Start 启动调度器,包括初始化协程池和启动主轮询循环
func (s *Scheduler) Start() {
s.logger.Printf("任务调度器正在启动,工作协程数: %d...", s.workers)
// 初始化 ants 协程池
pool, err := ants.NewPool(s.workers, ants.WithPanicHandler(func(err interface{}) {
s.logger.Printf("[严重] 任务执行时发生 panic: %v", err)
}))
if err != nil {
panic("初始化协程池失败: " + err.Error())
}
s.pool = pool
// 启动主轮询循环
s.wg.Add(1)
go s.run()
s.logger.Printf("任务调度器已成功启动")
}
// Stop 优雅地停止调度器
func (s *Scheduler) Stop() {
s.logger.Printf("正在停止任务调度器...")
s.cancel() // 1. 发出取消信号,停止主循环
s.wg.Wait() // 2. 等待主循环完成
s.pool.Release() // 3. 释放 ants 池 (等待所有已提交的任务执行完毕)
s.logger.Printf("任务调度器已安全停止")
}
// run 是主轮询循环,负责从数据库认领任务并提交到协程池
func (s *Scheduler) run() {
defer s.wg.Done()
ticker := time.NewTicker(s.pollingInterval)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
s.claimAndSubmit()
}
}
}
// claimAndSubmit 认领一个任务并将其提交到 ants 协程池
func (s *Scheduler) claimAndSubmit() {
// ants 池的 Running() 数量可以用来提前判断是否繁忙,但这只是一个快照,
// 真正的阻塞和背压由 Submit() 方法保证。
if s.pool.Running() >= s.workers {
// 可选:如果所有 worker 都在忙,可以跳过本次数据库查询,以减轻数据库压力
return
}
claimedLog, err := s.pendingTaskRepo.ClaimNextDueTask()
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
s.logger.Printf("认领任务时发生错误: %v", err)
}
return
}
// 将任务处理逻辑作为一个函数提交给 ants 池。
// 如果池已满Submit 方法会阻塞,直到有协程空闲出来,这自然地实现了背压。
err = s.pool.Submit(func() {
s.processTask(claimedLog)
})
if err != nil {
// 如果在调度器停止期间提交任务,可能会发生此错误
s.logger.Printf("向协程池提交任务失败: %v", err)
// 可以在这里添加逻辑,将任务状态恢复为 pending
}
}
// processTask 包含了处理单个任务的完整逻辑
func (s *Scheduler) processTask(claimedLog *models.TaskExecutionLog) {
s.logger.Printf("开始处理任务, 日志ID: %d, 任务ID: %d, 任务名称: %s",
claimedLog.ID, claimedLog.TaskID, claimedLog.Task.Name)
// 在这里,我们将根据 claimedLog.TaskID 或未来的 Task.Kind 来分发给不同的处理器
// 现在,我们只做一个模拟执行
time.Sleep(2 * time.Second) // 模拟任务执行耗时
// 任务执行完毕后,更新日志和进度
s.logger.Printf("完成任务, 日志ID: %d", claimedLog.ID)
// ----------------------------------------------------
// 未来的逻辑将在这里展开:
//
// 1. 调用 handler.Handle(claimedLog)
// 2. 根据 handler 返回的 error 更新日志为 'completed' 或 'failed'
// execLogRepo.UpdateTaskExecutionLog(...)
// 3. 如果成功,则 s.progressTracker.Increment(claimedLog.PlanExecutionLogID)
// 4. 检查 s.progressTracker.IsComplete(...),如果完成则执行计划收尾工作
//
// ----------------------------------------------------
}
// ProgressTracker 的方法实现
func (t *ProgressTracker) StartTracking(planLogID uint, total int) {
t.mu.Lock()
defer t.mu.Unlock()
t.totalTasks[planLogID] = total
t.completedTasks[planLogID] = 0
}
func (t *ProgressTracker) Increment(planLogID uint) {
t.mu.Lock()
defer t.mu.Unlock()
t.completedTasks[planLogID]++
}
func (t *ProgressTracker) IsComplete(planLogID uint) bool {
t.mu.Lock()
defer t.mu.Unlock()
return t.completedTasks[planLogID] >= t.totalTasks[planLogID]
}
func (t *ProgressTracker) StopTracking(planLogID uint) {
t.mu.Lock()
defer t.mu.Unlock()
delete(t.totalTasks, planLogID)
delete(t.completedTasks, planLogID)
}