1. 调整目录结构
2. 实现user_controller
This commit is contained in:
142
internal/infra/config/config.go
Normal file
142
internal/infra/config/config.go
Normal file
@@ -0,0 +1,142 @@
|
||||
// Package config 提供配置文件读取和解析功能
|
||||
// 支持YAML格式的配置文件解析
|
||||
// 包含服务器和数据库相关配置
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// Config 代表应用的完整配置结构
|
||||
type Config struct {
|
||||
// App 应用基础配置
|
||||
App AppConfig `yaml:"app"`
|
||||
|
||||
// Server 服务器配置
|
||||
Server ServerConfig `yaml:"server"`
|
||||
|
||||
// Log 日志配置
|
||||
Log LogConfig `yaml:"log"`
|
||||
|
||||
// Database 数据库配置
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
|
||||
// WebSocket WebSocket配置
|
||||
WebSocket WebSocketConfig `yaml:"websocket"`
|
||||
|
||||
// Heartbeat 心跳配置
|
||||
Heartbeat HeartbeatConfig `yaml:"heartbeat"`
|
||||
}
|
||||
|
||||
// AppConfig 代表应用基础配置
|
||||
type AppConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Version string `yaml:"version"`
|
||||
}
|
||||
|
||||
// ServerConfig 代表服务器配置
|
||||
type ServerConfig struct {
|
||||
// Port 服务器监听端口
|
||||
Port int `yaml:"port"`
|
||||
// Mode 服务器运行模式
|
||||
Mode string `yaml:"mode"`
|
||||
}
|
||||
|
||||
// LogConfig 代表日志配置
|
||||
type LogConfig struct {
|
||||
Level string `yaml:"level"`
|
||||
Format string `yaml:"format"`
|
||||
EnableFile bool `yaml:"enable_file"`
|
||||
FilePath string `yaml:"file_path"`
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
Compress bool `yaml:"compress"`
|
||||
}
|
||||
|
||||
// DatabaseConfig 代表数据库配置
|
||||
type DatabaseConfig struct {
|
||||
// Host 数据库主机地址
|
||||
Host string `yaml:"host"`
|
||||
|
||||
// Port 数据库端口
|
||||
Port int `yaml:"port"`
|
||||
|
||||
// Username 数据库用户名
|
||||
Username string `yaml:"username"`
|
||||
|
||||
// Password 数据库密码
|
||||
Password string `yaml:"password"`
|
||||
|
||||
// DBName 数据库名称
|
||||
DBName string `yaml:"dbname"`
|
||||
|
||||
// SSLMode SSL模式
|
||||
SSLMode string `yaml:"sslmode"`
|
||||
|
||||
// MaxOpenConns 最大开放连接数
|
||||
MaxOpenConns int `yaml:"max_open_conns"`
|
||||
|
||||
// MaxIdleConns 最大空闲连接数
|
||||
MaxIdleConns int `yaml:"max_idle_conns"`
|
||||
|
||||
// ConnMaxLifetime 连接最大生命周期(秒)
|
||||
ConnMaxLifetime int `yaml:"conn_max_lifetime"`
|
||||
}
|
||||
|
||||
// WebSocketConfig 代表WebSocket配置
|
||||
type WebSocketConfig struct {
|
||||
// Timeout WebSocket请求超时时间(秒)
|
||||
Timeout int `yaml:"timeout"`
|
||||
|
||||
// HeartbeatInterval 心跳检测间隔(秒), 如果超过这个时间没有消息往来系统会自动发送一个心跳包维持长链接
|
||||
HeartbeatInterval int `yaml:"heartbeat_interval"`
|
||||
}
|
||||
|
||||
// HeartbeatConfig 代表心跳配置
|
||||
type HeartbeatConfig struct {
|
||||
// Interval 心跳间隔(秒)
|
||||
Interval int `yaml:"interval"`
|
||||
|
||||
// Concurrency 请求并发数
|
||||
Concurrency int `yaml:"concurrency"`
|
||||
}
|
||||
|
||||
// NewConfig 创建并返回一个新的配置实例
|
||||
func NewConfig() *Config {
|
||||
// 默认值可以在这里设置,但我们优先使用配置文件中的值
|
||||
return &Config{}
|
||||
}
|
||||
|
||||
// Load 从指定路径加载配置文件
|
||||
func (c *Config) Load(path string) error {
|
||||
// 读取配置文件
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("配置文件读取失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析YAML配置
|
||||
if err := yaml.Unmarshal(data, c); err != nil {
|
||||
return fmt.Errorf("配置文件解析失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDatabaseConnectionString 获取数据库连接字符串
|
||||
func (c *Config) GetDatabaseConnectionString() string {
|
||||
// 构建PostgreSQL连接字符串
|
||||
return fmt.Sprintf(
|
||||
"user=%s password=%s dbname=%s host=%s port=%d sslmode=%s",
|
||||
c.Database.Username,
|
||||
c.Database.Password,
|
||||
c.Database.DBName,
|
||||
c.Database.Host,
|
||||
c.Database.Port,
|
||||
c.Database.SSLMode,
|
||||
)
|
||||
}
|
||||
117
internal/infra/database/postgres.go
Normal file
117
internal/infra/database/postgres.go
Normal file
@@ -0,0 +1,117 @@
|
||||
// Package database 提供基于PostgreSQL的数据存储功能
|
||||
// 使用GORM作为ORM库来操作数据库
|
||||
// 实现与PostgreSQL数据库的连接和基本操作
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PostgresStorage 代表基于PostgreSQL的存储实现
|
||||
// 使用GORM作为ORM库
|
||||
type PostgresStorage struct {
|
||||
db *gorm.DB
|
||||
connectionString string
|
||||
maxOpenConns int
|
||||
maxIdleConns int
|
||||
connMaxLifetime int
|
||||
logger *logs.Logger // 依赖注入的 logger
|
||||
}
|
||||
|
||||
// NewPostgresStorage 创建并返回一个新的PostgreSQL存储实例
|
||||
// 它接收一个 logger 实例,而不是自己创建
|
||||
func NewPostgresStorage(connectionString string, maxOpenConns, maxIdleConns, connMaxLifetime int, logger *logs.Logger) *PostgresStorage {
|
||||
return &PostgresStorage{
|
||||
connectionString: connectionString,
|
||||
maxOpenConns: maxOpenConns,
|
||||
maxIdleConns: maxIdleConns,
|
||||
connMaxLifetime: connMaxLifetime,
|
||||
logger: logger, // 注入 logger
|
||||
}
|
||||
}
|
||||
|
||||
// Connect 建立与PostgreSQL数据库的连接
|
||||
// 使用GORM建立数据库连接,并使用自定义的 logger 接管 GORM 日志
|
||||
func (ps *PostgresStorage) Connect() error {
|
||||
ps.logger.Info("正在连接PostgreSQL数据库")
|
||||
|
||||
// 创建 GORM 的 logger 适配器
|
||||
gormLogger := logs.NewGormLogger(ps.logger)
|
||||
|
||||
var err error
|
||||
// 在 gorm.Open 时传入我们自定义的 logger
|
||||
ps.db, err = gorm.Open(postgres.Open(ps.connectionString), &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
})
|
||||
if err != nil {
|
||||
ps.logger.Errorw("数据库连接失败", "error", err)
|
||||
return fmt.Errorf("数据库连接失败: %w", err) // 使用 %w 进行错误包装
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
sqlDB, err := ps.db.DB()
|
||||
if err != nil {
|
||||
ps.logger.Errorw("获取数据库实例失败", "error", err)
|
||||
return fmt.Errorf("获取数据库实例失败: %w", err)
|
||||
}
|
||||
|
||||
if err = sqlDB.Ping(); err != nil {
|
||||
ps.logger.Errorw("数据库连接测试失败", "error", err)
|
||||
return fmt.Errorf("数据库连接测试失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置连接池参数
|
||||
sqlDB.SetMaxOpenConns(ps.maxOpenConns)
|
||||
sqlDB.SetMaxIdleConns(ps.maxIdleConns)
|
||||
sqlDB.SetConnMaxLifetime(time.Duration(ps.connMaxLifetime) * time.Second)
|
||||
|
||||
ps.logger.Info("PostgreSQL数据库连接成功")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect 断开与PostgreSQL数据库的连接
|
||||
// 安全地关闭所有数据库连接
|
||||
func (ps *PostgresStorage) Disconnect() error {
|
||||
if ps.db != nil {
|
||||
ps.logger.Info("正在断开PostgreSQL数据库连接")
|
||||
|
||||
sqlDB, err := ps.db.DB()
|
||||
if err != nil {
|
||||
ps.logger.Errorw("获取数据库实例失败", "error", err)
|
||||
return fmt.Errorf("获取数据库实例失败: %w", err)
|
||||
}
|
||||
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
ps.logger.Errorw("关闭数据库连接失败", "error", err)
|
||||
return fmt.Errorf("关闭数据库连接失败: %w", err)
|
||||
}
|
||||
ps.logger.Info("PostgreSQL数据库连接已断开")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDB 获取GORM数据库实例
|
||||
// 用于执行具体的数据库操作
|
||||
func (ps *PostgresStorage) GetDB() *gorm.DB {
|
||||
return ps.db
|
||||
}
|
||||
|
||||
// Migrate 执行数据库迁移
|
||||
func (ps *PostgresStorage) Migrate(models ...interface{}) error {
|
||||
if len(models) == 0 {
|
||||
ps.logger.Info("没有需要迁移的数据库模型,跳过迁移步骤")
|
||||
return nil
|
||||
}
|
||||
ps.logger.Info("正在自动迁移数据库表结构")
|
||||
if err := ps.db.AutoMigrate(models...); err != nil {
|
||||
ps.logger.Errorw("数据库表结构迁移失败", "error", err)
|
||||
return fmt.Errorf("数据库表结构迁移失败: %w", err)
|
||||
}
|
||||
ps.logger.Info("数据库表结构迁移完成")
|
||||
return nil
|
||||
}
|
||||
53
internal/infra/database/storage.go
Normal file
53
internal/infra/database/storage.go
Normal file
@@ -0,0 +1,53 @@
|
||||
// Package database 提供统一的数据存储接口
|
||||
// 定义存储接口规范,支持多种存储后端实现
|
||||
// 当前支持PostgreSQL实现
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/config"
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Storage 代表统一的存储接口
|
||||
// 所有存储实现都需要实现此接口定义的方法
|
||||
type Storage interface {
|
||||
// Connect 建立与存储后端的连接
|
||||
Connect() error
|
||||
|
||||
// Disconnect 断开与存储后端的连接
|
||||
Disconnect() error
|
||||
|
||||
// GetDB 获取数据库实例
|
||||
GetDB() *gorm.DB
|
||||
|
||||
// Migrate 执行数据库迁移
|
||||
// 参数为需要迁移的 GORM 模型
|
||||
Migrate(models ...interface{}) error
|
||||
}
|
||||
|
||||
// NewStorage 创建并返回一个存储实例
|
||||
// 根据配置返回相应的存储实现
|
||||
func NewStorage(cfg config.DatabaseConfig, logger *logs.Logger) Storage {
|
||||
// 构建数据库连接字符串
|
||||
connectionString := fmt.Sprintf(
|
||||
"user=%s password=%s dbname=%s host=%s port=%d sslmode=%s",
|
||||
cfg.Username,
|
||||
cfg.Password,
|
||||
cfg.DBName,
|
||||
cfg.Host,
|
||||
cfg.Port,
|
||||
cfg.SSLMode,
|
||||
)
|
||||
|
||||
// 当前默认返回PostgreSQL存储实现,并将 logger 注入
|
||||
return NewPostgresStorage(
|
||||
connectionString,
|
||||
cfg.MaxOpenConns,
|
||||
cfg.MaxIdleConns,
|
||||
cfg.ConnMaxLifetime,
|
||||
logger,
|
||||
)
|
||||
}
|
||||
165
internal/infra/logs/logs.go
Normal file
165
internal/infra/logs/logs.go
Normal file
@@ -0,0 +1,165 @@
|
||||
// Package logs 提供了高度可配置的日志功能,基于 uber-go/zap 实现。
|
||||
// 它支持将日志同时输出到控制台和文件,并提供日志滚动归档功能。
|
||||
// 该包还特别为 Gin 和 GORM 框架提供了开箱即用的日志接管能力。
|
||||
package logs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/config"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
gormlogger "gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// Logger 是一个封装了 zap.SugaredLogger 的日志记录器。
|
||||
// 它提供了结构化日志记录的各种方法,并实现了 io.Writer 接口以兼容 Gin。
|
||||
type Logger struct {
|
||||
*zap.SugaredLogger
|
||||
}
|
||||
|
||||
// NewLogger 根据提供的配置创建一个新的 Logger 实例。
|
||||
// 这是实现依赖注入的关键,在应用启动时调用一次。
|
||||
func NewLogger(cfg config.LogConfig) *Logger {
|
||||
// 1. 设置日志编码器
|
||||
encoder := getEncoder(cfg.Format)
|
||||
|
||||
// 2. 设置日志写入器 (支持文件和控制台)
|
||||
writeSyncer := getWriteSyncer(cfg)
|
||||
|
||||
// 3. 设置日志级别
|
||||
level := zap.NewAtomicLevel()
|
||||
if err := level.UnmarshalText([]byte(cfg.Level)); err != nil {
|
||||
level.SetLevel(zap.InfoLevel) // 解析失败则默认为 Info 级别
|
||||
}
|
||||
|
||||
// 4. 创建 Zap 核心
|
||||
core := zapcore.NewCore(encoder, writeSyncer, level)
|
||||
|
||||
// 5. 构建 Logger
|
||||
// zap.AddCaller() 会记录调用日志的代码行
|
||||
// zap.AddCallerSkip(1) 可以向上跳一层调用栈,如果我们将 logger.Info 等方法再封装一层,这个选项会很有用
|
||||
zapLogger := zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1))
|
||||
|
||||
return &Logger{zapLogger.Sugar()}
|
||||
}
|
||||
|
||||
// getEncoder 根据指定的格式返回一个 zapcore.Encoder。
|
||||
func getEncoder(format string) zapcore.Encoder {
|
||||
encoderConfig := zap.NewProductionEncoderConfig()
|
||||
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder // 时间格式: 2006-01-02T15:04:05.000Z0700
|
||||
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder // 日志级别大写: INFO
|
||||
|
||||
if format == "json" {
|
||||
return zapcore.NewJSONEncoder(encoderConfig)
|
||||
}
|
||||
// 默认或 "console"
|
||||
return zapcore.NewConsoleEncoder(encoderConfig)
|
||||
}
|
||||
|
||||
// getWriteSyncer 根据配置创建日志写入目标。
|
||||
func getWriteSyncer(cfg config.LogConfig) zapcore.WriteSyncer {
|
||||
writers := []zapcore.WriteSyncer{os.Stdout}
|
||||
|
||||
if cfg.EnableFile {
|
||||
// 使用 lumberjack 实现日志滚动
|
||||
fileWriter := &lumberjack.Logger{
|
||||
Filename: cfg.FilePath,
|
||||
MaxSize: cfg.MaxSize,
|
||||
MaxBackups: cfg.MaxBackups,
|
||||
MaxAge: cfg.MaxAge,
|
||||
Compress: cfg.Compress,
|
||||
}
|
||||
writers = append(writers, zapcore.AddSync(fileWriter))
|
||||
}
|
||||
|
||||
return zapcore.NewMultiWriteSyncer(writers...)
|
||||
}
|
||||
|
||||
// Write 实现了 io.Writer 接口,用于接管 Gin 的默认输出。
|
||||
// Gin 的日志(如 [GIN-debug] Listening and serving HTTP on :8080)会通过这个方法写入。
|
||||
func (l *Logger) Write(p []byte) (n int, err error) {
|
||||
msg := strings.TrimSpace(string(p))
|
||||
if msg != "" {
|
||||
l.Info(msg) // 使用我们自己的 logger 来打印 Gin 的日志
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// --- GORM 日志适配器 ---
|
||||
|
||||
// GormLogger 是一个实现了 gormlogger.Interface 的适配器,
|
||||
// 它将 GORM 的日志重定向到我们的 zap Logger 中。
|
||||
type GormLogger struct {
|
||||
ZapLogger *Logger
|
||||
SlowThreshold time.Duration
|
||||
SkipErrRecordNotFound bool // 是否跳过 "record not found" 错误
|
||||
}
|
||||
|
||||
// NewGormLogger 创建一个新的 GORM 日志记录器实例。
|
||||
func NewGormLogger(zapLogger *Logger) *GormLogger {
|
||||
return &GormLogger{
|
||||
ZapLogger: zapLogger,
|
||||
SlowThreshold: 200 * time.Millisecond, // 慢查询阈值,超过200ms则警告
|
||||
SkipErrRecordNotFound: true, // 通常我们不关心 "record not found" 错误
|
||||
}
|
||||
}
|
||||
|
||||
// LogMode 设置日志模式,这里我们总是使用 zap 的级别控制,所以这个方法可以为空。
|
||||
func (g *GormLogger) LogMode(level gormlogger.LogLevel) gormlogger.Interface {
|
||||
// GORM 的 LogLevel 在这里不起作用,因为我们完全由 Zap 控制
|
||||
return g
|
||||
}
|
||||
|
||||
// Info 打印 Info 级别的日志。
|
||||
func (g *GormLogger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||
g.ZapLogger.Infof(msg, data...)
|
||||
}
|
||||
|
||||
// Warn 打印 Warn 级别的日志。
|
||||
func (g *GormLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||
g.ZapLogger.Warnf(msg, data...)
|
||||
}
|
||||
|
||||
// Error 打印 Error 级别的日志。
|
||||
func (g *GormLogger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||
g.ZapLogger.Errorf(msg, data...)
|
||||
}
|
||||
|
||||
// Trace 打印 SQL 查询日志,这是 GORM 日志的核心。
|
||||
func (g *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
elapsed := time.Since(begin)
|
||||
sql, rows := fc()
|
||||
|
||||
fields := []interface{}{
|
||||
"sql", sql,
|
||||
"rows", rows,
|
||||
"elapsed", fmt.Sprintf("%.3fms", float64(elapsed.Nanoseconds())/1e6),
|
||||
}
|
||||
|
||||
// --- 逻辑修复开始 ---
|
||||
if err != nil {
|
||||
// 如果是 "record not found" 错误且我们配置了跳过,则直接返回
|
||||
if g.SkipErrRecordNotFound && strings.Contains(err.Error(), "record not found") {
|
||||
return
|
||||
}
|
||||
// 否则,记录为错误日志
|
||||
g.ZapLogger.With(fields...).Errorf("[GORM] error: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 如果查询时间超过慢查询阈值,则记录警告
|
||||
if g.SlowThreshold != 0 && elapsed > g.SlowThreshold {
|
||||
g.ZapLogger.With(fields...).Warnf("[GORM] slow query")
|
||||
return
|
||||
}
|
||||
|
||||
// 正常情况,记录 Debug 级别的 SQL 查询
|
||||
g.ZapLogger.With(fields...).Debugf("[GORM] trace")
|
||||
// --- 逻辑修复结束 ---
|
||||
}
|
||||
128
internal/infra/logs/logs_test.go
Normal file
128
internal/infra/logs/logs_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package logs_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/config"
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// captureOutput 是一个辅助函数,用于捕获 logger 的输出到内存缓冲区
|
||||
func captureOutput(cfg config.LogConfig) (*logs.Logger, *bytes.Buffer) {
|
||||
var buf bytes.Buffer
|
||||
// 使用一个简单的 Console Encoder 进行测试,方便断言字符串
|
||||
encoderConfig := zap.NewDevelopmentEncoderConfig()
|
||||
encoderConfig.EncodeTime = nil // 忽略时间,避免测试结果不一致
|
||||
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
|
||||
encoderConfig.EncodeCaller = nil // 忽略调用者信息
|
||||
encoder := zapcore.NewConsoleEncoder(encoderConfig)
|
||||
|
||||
writer := zapcore.AddSync(&buf)
|
||||
|
||||
level := zap.NewAtomicLevel()
|
||||
_ = level.UnmarshalText([]byte(cfg.Level))
|
||||
|
||||
core := zapcore.NewCore(encoder, writer, level)
|
||||
// 在测试中我们直接操作 zap Logger,而不是通过封装的 NewLogger,以注入内存 writer
|
||||
zapLogger := zap.New(core)
|
||||
|
||||
logger := &logs.Logger{SugaredLogger: zapLogger.Sugar()}
|
||||
return logger, &buf
|
||||
}
|
||||
|
||||
func TestNewLogger(t *testing.T) {
|
||||
t.Run("Constructor does not panic", func(t *testing.T) {
|
||||
// 测试 Console 格式
|
||||
cfgConsole := config.LogConfig{Level: "info", Format: "console"}
|
||||
assert.NotPanics(t, func() { logs.NewLogger(cfgConsole) })
|
||||
|
||||
// 测试 JSON 格式
|
||||
cfgJSON := config.LogConfig{Level: "info", Format: "json"}
|
||||
assert.NotPanics(t, func() { logs.NewLogger(cfgJSON) })
|
||||
|
||||
// 测试文件日志启用
|
||||
// 不实际写入文件,只确保构造函数能正常运行
|
||||
cfgFile := config.LogConfig{
|
||||
Level: "info",
|
||||
EnableFile: true,
|
||||
FilePath: "test.log",
|
||||
}
|
||||
assert.NotPanics(t, func() { logs.NewLogger(cfgFile) })
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogger_Write_ForGin(t *testing.T) {
|
||||
logger, buf := captureOutput(config.LogConfig{Level: "info"})
|
||||
|
||||
ginLog := "[GIN-debug] Listening and serving HTTP on :8080\n"
|
||||
_, err := logger.Write([]byte(ginLog))
|
||||
|
||||
assert.NoError(t, err)
|
||||
output := buf.String()
|
||||
// logger.Write 会将 gin 的日志转为 info 级别
|
||||
assert.Contains(t, output, "INFO")
|
||||
assert.Contains(t, output, strings.TrimSpace(ginLog))
|
||||
}
|
||||
|
||||
func TestGormLogger(t *testing.T) {
|
||||
logger, buf := captureOutput(config.LogConfig{Level: "debug"}) // 设置为 debug 以捕获所有级别
|
||||
gormLogger := logs.NewGormLogger(logger)
|
||||
|
||||
// 模拟 GORM 的 Trace 调用参数
|
||||
ctx := context.Background()
|
||||
sql := "SELECT * FROM users WHERE id = 1"
|
||||
rows := int64(1)
|
||||
fc := func() (string, int64) {
|
||||
return sql, rows
|
||||
}
|
||||
|
||||
t.Run("Slow Query", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
// 模拟一个耗时超过 200ms 的查询
|
||||
begin := time.Now().Add(-300 * time.Millisecond)
|
||||
gormLogger.Trace(ctx, begin, fc, nil)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "WARN", "应包含 WARN 级别")
|
||||
assert.Contains(t, output, "[GORM] slow query", "应包含慢查询信息")
|
||||
// 修复:不再检查严格的 JSON 格式,只检查关键内容
|
||||
assert.Contains(t, output, "SELECT * FROM users WHERE id = 1", "应包含 SQL 语句")
|
||||
})
|
||||
|
||||
t.Run("Error Query", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
queryError := errors.New("syntax error")
|
||||
gormLogger.Trace(ctx, time.Now(), fc, queryError)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "ERROR")
|
||||
assert.Contains(t, output, "[GORM] error: syntax error")
|
||||
})
|
||||
|
||||
t.Run("Record Not Found Error is Skipped", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
queryError := errors.New("record not found") // 模拟 GORM 的 RecordNotFound 错误
|
||||
gormLogger.Trace(ctx, time.Now(), fc, queryError)
|
||||
|
||||
// 在修复 logs.go 中的 bug 后,这里应该为空
|
||||
assert.Empty(t, buf.String(), "开启 SkipErrRecordNotFound 后,record not found 错误不应产生任何日志")
|
||||
})
|
||||
|
||||
t.Run("Normal Query", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
// 模拟一个快速查询
|
||||
gormLogger.Trace(ctx, time.Now(), fc, nil)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "DEBUG") // 正常查询是 Debug 级别
|
||||
assert.Contains(t, output, "[GORM] trace")
|
||||
})
|
||||
}
|
||||
55
internal/infra/models/user.go
Normal file
55
internal/infra/models/user.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Package models 定义了应用的数据模型,例如用户、产品等。
|
||||
package models
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// User 代表系统中的用户模型
|
||||
type User struct {
|
||||
// gorm.Model 内嵌了 ID, CreatedAt, UpdatedAt, 和 DeletedAt
|
||||
// DeletedAt 字段的存在自动为 GORM 开启了软删除模式
|
||||
gorm.Model
|
||||
|
||||
// Username 是用户的登录名,应该是唯一的
|
||||
// 修正了 gorm 标签的拼写错误 (移除了 gorm 后面的冒号)
|
||||
Username string `gorm:"unique;not null" json:"username"`
|
||||
|
||||
// Password 存储的是加密后的密码哈希,而不是明文
|
||||
// json:"-" 标签确保此字段在序列化为 JSON 时被忽略,防止密码泄露
|
||||
Password string `gorm:"not null" json:"-"`
|
||||
}
|
||||
|
||||
// TableName 自定义 User 模型对应的数据库表名
|
||||
// GORM 默认会使用复数形式 "users",但显式定义是一种好习惯
|
||||
func (User) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
// --- GORM Hooks ---
|
||||
|
||||
// BeforeCreate 是一个 GORM 钩子,在创建用户记录前自动调用。
|
||||
// 这是哈希初始密码最可靠的地方。
|
||||
func (u *User) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
// 如果密码不为空,则执行哈希
|
||||
if u.Password != "" {
|
||||
// 使用 bcrypt 对密码进行哈希
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 将明文密码替换为哈希值
|
||||
u.Password = string(hashedPassword)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Helper Methods ---
|
||||
|
||||
// CheckPassword 用于验证输入的明文密码是否与数据库中存储的哈希匹配
|
||||
func (u *User) CheckPassword(plainPassword string) bool {
|
||||
// bcrypt.CompareHashAndPassword 会安全地比较哈希和明文,能有效防止时序攻击
|
||||
err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(plainPassword))
|
||||
return err == nil
|
||||
}
|
||||
44
internal/infra/models/user_test.go
Normal file
44
internal/infra/models/user_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
// Package models_test 包含对 models 包的单元测试
|
||||
package models_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func TestUser_CheckPassword(t *testing.T) {
|
||||
plainPassword := "my-secret-password"
|
||||
|
||||
// 1. 生成一个密码哈希用于测试
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(plainPassword), bcrypt.DefaultCost)
|
||||
assert.NoError(t, err, "生成密码哈希不应出错")
|
||||
|
||||
user := &models.User{
|
||||
Password: string(hashedPassword),
|
||||
}
|
||||
|
||||
t.Run("密码正确", func(t *testing.T) {
|
||||
// 2. 使用正确的明文密码进行校验
|
||||
match := user.CheckPassword(plainPassword)
|
||||
assert.True(t, match, "正确的密码应该校验通过")
|
||||
})
|
||||
|
||||
t.Run("密码错误", func(t *testing.T) {
|
||||
// 3. 使用错误的明文密码进行校验
|
||||
match := user.CheckPassword("wrong-password")
|
||||
assert.False(t, match, "错误的密码应该校验失败")
|
||||
})
|
||||
|
||||
t.Run("空密码", func(t *testing.T) {
|
||||
// 4. 使用空字符串作为密码进行校验
|
||||
match := user.CheckPassword("")
|
||||
assert.False(t, match, "空密码应该校验失败")
|
||||
})
|
||||
}
|
||||
|
||||
// 注意:BeforeSave 钩子是一个 GORM 框架的回调,它的正确性
|
||||
// 将在 repository 的集成测试中,通过实际创建一个用户来得到验证,
|
||||
// 而不是在这里进行孤立的、脆弱的单元测试。
|
||||
49
internal/infra/repository/user_repository.go
Normal file
49
internal/infra/repository/user_repository.go
Normal file
@@ -0,0 +1,49 @@
|
||||
// Package repository 提供了数据访问的仓库实现
|
||||
package repository
|
||||
|
||||
import (
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/models"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UserRepository 定义了与用户模型相关的数据库操作接口
|
||||
// 这是为了让业务逻辑层依赖于抽象,而不是具体的数据库实现
|
||||
type UserRepository interface {
|
||||
Create(user *models.User) error
|
||||
FindByUsername(username string) (*models.User, error)
|
||||
FindByID(id uint) (*models.User, error)
|
||||
}
|
||||
|
||||
// gormUserRepository 是 UserRepository 的 GORM 实现
|
||||
type gormUserRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGormUserRepository 创建一个新的 UserRepository GORM 实现实例
|
||||
func NewGormUserRepository(db *gorm.DB) UserRepository {
|
||||
return &gormUserRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建一个新的用户记录
|
||||
func (r *gormUserRepository) Create(user *models.User) error {
|
||||
// BeforeSave 钩子会在这里被自动触发
|
||||
return r.db.Create(user).Error
|
||||
}
|
||||
|
||||
// FindByUsername 根据用户名查找用户
|
||||
func (r *gormUserRepository) FindByUsername(username string) (*models.User, error) {
|
||||
var user models.User
|
||||
if err := r.db.Where("username = ?", username).First(&user).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// FindByID 根据 ID 查找用户
|
||||
func (r *gormUserRepository) FindByID(id uint) (*models.User, error) {
|
||||
var user models.User
|
||||
if err := r.db.First(&user, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
93
internal/infra/repository/user_repository_test.go
Normal file
93
internal/infra/repository/user_repository_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
// Package repository_test 包含对 repository 包的集成测试
|
||||
package repository_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/models"
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/repository"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// setupTestDB 是一个辅助函数,用于为每个测试创建一个
|
||||
// 干净的、内存中的 SQLite 数据库实例。
|
||||
func setupTestDB(t *testing.T) *gorm.DB {
|
||||
// "file::memory:?cache=shared" 是 GORM 连接内存 SQLite 的标准方式
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||
assert.NoError(t, err, "连接内存数据库不应出错")
|
||||
|
||||
// 自动迁移 User 表结构
|
||||
err = db.AutoMigrate(&models.User{})
|
||||
assert.NoError(t, err, "数据库迁移不应出错")
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestGormUserRepository(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := repository.NewGormUserRepository(db)
|
||||
|
||||
plainPassword := "my-secret-password"
|
||||
userToCreate := &models.User{
|
||||
Username: "testuser",
|
||||
Password: plainPassword, // 我们提供的是明文密码
|
||||
}
|
||||
|
||||
t.Run("Create - 成功创建并验证密码哈希", func(t *testing.T) {
|
||||
err := repo.Create(userToCreate)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证用户已被创建
|
||||
assert.NotZero(t, userToCreate.ID)
|
||||
|
||||
// 从数据库中直接取回记录,以验证 BeforeSave 钩子是否生效
|
||||
var savedUser models.User
|
||||
db.First(&savedUser, userToCreate.ID)
|
||||
|
||||
// 验证密码字段存储的不是明文
|
||||
assert.NotEqual(t, plainPassword, savedUser.Password, "数据库中存储的密码不应是明文")
|
||||
|
||||
// 验证存储的哈希是正确的
|
||||
assert.True(t, savedUser.CheckPassword(plainPassword), "存储的密码哈希应该能与原明文匹配")
|
||||
})
|
||||
|
||||
t.Run("Create - 用户名冲突", func(t *testing.T) {
|
||||
// 尝试创建一个同名用户
|
||||
duplicateUser := &models.User{Username: "testuser", Password: "anypassword"}
|
||||
err := repo.Create(duplicateUser)
|
||||
|
||||
// 我们期望一个错误,因为用户名是唯一的
|
||||
assert.Error(t, err, "创建同名用户应该返回错误")
|
||||
// 更精确地,可以检查是否是唯一键冲突错误
|
||||
assert.Contains(t, err.Error(), "UNIQUE constraint failed: users.username", "错误信息应包含唯一键冲突")
|
||||
})
|
||||
|
||||
t.Run("FindByUsername - 找到用户", func(t *testing.T) {
|
||||
foundUser, err := repo.FindByUsername("testuser")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, foundUser)
|
||||
assert.Equal(t, userToCreate.ID, foundUser.ID)
|
||||
assert.Equal(t, "testuser", foundUser.Username)
|
||||
})
|
||||
|
||||
t.Run("FindByUsername - 未找到用户", func(t *testing.T) {
|
||||
_, err := repo.FindByUsername("nonexistent")
|
||||
assert.Error(t, err, "查找不存在的用户应该返回错误")
|
||||
assert.ErrorIs(t, err, gorm.ErrRecordNotFound, "错误类型应为 gorm.ErrRecordNotFound")
|
||||
})
|
||||
|
||||
t.Run("FindByID - 找到用户", func(t *testing.T) {
|
||||
foundUser, err := repo.FindByID(userToCreate.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, foundUser)
|
||||
assert.Equal(t, userToCreate.ID, foundUser.ID)
|
||||
})
|
||||
|
||||
t.Run("FindByID - 未找到用户", func(t *testing.T) {
|
||||
_, err := repo.FindByID(99999)
|
||||
assert.Error(t, err, "查找不存在的ID应该返回错误")
|
||||
assert.ErrorIs(t, err, gorm.ErrRecordNotFound, "错误类型应为 gorm.ErrRecordNotFound")
|
||||
})
|
||||
}
|
||||
222
internal/infra/task/task.go
Normal file
222
internal/infra/task/task.go
Normal file
@@ -0,0 +1,222 @@
|
||||
// Package task 提供任务队列和执行框架
|
||||
// 负责管理任务队列、调度和执行各种控制任务
|
||||
package task
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs"
|
||||
)
|
||||
|
||||
// Task 代表一个任务接口
|
||||
// 所有任务都需要实现此接口
|
||||
type Task interface {
|
||||
// Execute 执行任务
|
||||
Execute() error
|
||||
|
||||
// GetID 获取任务ID
|
||||
GetID() string
|
||||
|
||||
// GetPriority 获取任务优先级
|
||||
GetPriority() int
|
||||
|
||||
// IsDone 检查任务是否已完成
|
||||
IsDone() bool
|
||||
}
|
||||
|
||||
// taskItem 任务队列中的元素
|
||||
type taskItem struct {
|
||||
task Task
|
||||
priority int
|
||||
index int
|
||||
}
|
||||
|
||||
// TaskQueue 代表任务队列
|
||||
type TaskQueue struct {
|
||||
// queue 任务队列(按优先级排序)
|
||||
queue *priorityQueue
|
||||
|
||||
// mutex 互斥锁
|
||||
mutex sync.Mutex
|
||||
|
||||
// logger 日志记录器
|
||||
logger *logs.Logger
|
||||
}
|
||||
|
||||
// NewTaskQueue 创建并返回一个新的任务队列实例。
|
||||
func NewTaskQueue(logger *logs.Logger) *TaskQueue {
|
||||
pq := make(priorityQueue, 0)
|
||||
heap.Init(&pq)
|
||||
|
||||
return &TaskQueue{
|
||||
queue: &pq,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// AddTask 向队列中添加任务
|
||||
func (tq *TaskQueue) AddTask(task Task) {
|
||||
tq.mutex.Lock()
|
||||
defer tq.mutex.Unlock()
|
||||
|
||||
item := &taskItem{
|
||||
task: task,
|
||||
priority: task.GetPriority(),
|
||||
}
|
||||
heap.Push(tq.queue, item)
|
||||
tq.logger.Infow("任务已添加到队列", "taskID", task.GetID())
|
||||
}
|
||||
|
||||
// GetNextTask 获取下一个要执行的任务(优先级最高的任务)
|
||||
func (tq *TaskQueue) GetNextTask() Task {
|
||||
tq.mutex.Lock()
|
||||
defer tq.mutex.Unlock()
|
||||
|
||||
if tq.queue.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
item := heap.Pop(tq.queue).(*taskItem)
|
||||
tq.logger.Infow("从队列中获取任务", "taskID", item.task.GetID())
|
||||
return item.task
|
||||
}
|
||||
|
||||
// GetTaskCount 获取队列中的任务数量
|
||||
func (tq *TaskQueue) GetTaskCount() int {
|
||||
tq.mutex.Lock()
|
||||
defer tq.mutex.Unlock()
|
||||
|
||||
return tq.queue.Len()
|
||||
}
|
||||
|
||||
// priorityQueue 实现优先级队列
|
||||
type priorityQueue []*taskItem
|
||||
|
||||
func (pq priorityQueue) Len() int { return len(pq) }
|
||||
|
||||
func (pq priorityQueue) Less(i, j int) bool {
|
||||
return pq[i].priority < pq[j].priority
|
||||
}
|
||||
|
||||
func (pq priorityQueue) Swap(i, j int) {
|
||||
pq[i], pq[j] = pq[j], pq[i]
|
||||
pq[i].index = i
|
||||
pq[j].index = j
|
||||
}
|
||||
|
||||
func (pq *priorityQueue) Push(x interface{}) {
|
||||
n := len(*pq)
|
||||
item := x.(*taskItem)
|
||||
item.index = n
|
||||
*pq = append(*pq, item)
|
||||
}
|
||||
|
||||
func (pq *priorityQueue) Pop() interface{} {
|
||||
old := *pq
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
old[n-1] = nil // 避免内存泄漏
|
||||
item.index = -1 // 无效索引
|
||||
*pq = old[0 : n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
// Executor 代表任务执行器
|
||||
type Executor struct {
|
||||
// taskQueue 任务队列
|
||||
taskQueue *TaskQueue
|
||||
|
||||
// workers 工作协程数量
|
||||
workers int
|
||||
|
||||
// ctx 执行上下文
|
||||
ctx context.Context
|
||||
|
||||
// cancel 取消函数
|
||||
cancel context.CancelFunc
|
||||
|
||||
// wg 等待组
|
||||
wg sync.WaitGroup
|
||||
|
||||
// logger 日志记录器
|
||||
logger *logs.Logger
|
||||
}
|
||||
|
||||
// NewExecutor 创建并返回一个新的任务执行器实例。
|
||||
func NewExecutor(workers int, logger *logs.Logger) *Executor {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &Executor{
|
||||
taskQueue: NewTaskQueue(logger), // 将 logger 传递给 TaskQueue
|
||||
workers: workers,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动任务执行器
|
||||
func (e *Executor) Start() {
|
||||
e.logger.Infow("正在启动任务执行器", "workers", e.workers)
|
||||
|
||||
// 启动工作协程
|
||||
for i := 0; i < e.workers; i++ {
|
||||
e.wg.Add(1)
|
||||
go e.worker(i)
|
||||
}
|
||||
|
||||
e.logger.Info("任务执行器启动成功")
|
||||
}
|
||||
|
||||
// Stop 停止任务执行器
|
||||
func (e *Executor) Stop() {
|
||||
e.logger.Info("正在停止任务执行器")
|
||||
|
||||
// 取消上下文
|
||||
e.cancel()
|
||||
|
||||
// 等待所有工作协程结束
|
||||
e.wg.Wait()
|
||||
|
||||
e.logger.Info("任务执行器已停止")
|
||||
}
|
||||
|
||||
// SubmitTask 提交任务到执行器
|
||||
func (e *Executor) SubmitTask(task Task) {
|
||||
e.taskQueue.AddTask(task)
|
||||
e.logger.Infow("任务已提交", "taskID", task.GetID())
|
||||
}
|
||||
|
||||
// worker 工作协程
|
||||
func (e *Executor) worker(id int) {
|
||||
defer e.wg.Done()
|
||||
|
||||
e.logger.Infow("工作协程已启动", "workerID", id)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-e.ctx.Done():
|
||||
e.logger.Infow("工作协程已停止", "workerID", id)
|
||||
return
|
||||
default:
|
||||
// 获取下一个任务
|
||||
task := e.taskQueue.GetNextTask()
|
||||
if task != nil {
|
||||
e.logger.Infow("工作协程正在执行任务", "workerID", id, "taskID", task.GetID())
|
||||
|
||||
// 执行任务
|
||||
if err := task.Execute(); err != nil {
|
||||
e.logger.Errorw("任务执行失败", "workerID", id, "taskID", task.GetID(), "error", err)
|
||||
} else {
|
||||
e.logger.Infow("任务执行成功", "workerID", id, "taskID", task.GetID())
|
||||
}
|
||||
} else {
|
||||
// 没有任务时短暂休眠
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
304
internal/infra/task/task_test.go
Normal file
304
internal/infra/task/task_test.go
Normal file
@@ -0,0 +1,304 @@
|
||||
// Package task_test 包含对 task 包的单元测试
|
||||
package task_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/config"
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs"
|
||||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/task"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// testLogger 是一个用于所有测试用例的静默 logger 实例。
|
||||
var testLogger *logs.Logger
|
||||
|
||||
func init() {
|
||||
// 使用 "fatal" 级别来创建一个在测试期间不会产生任何输出的 logger。
|
||||
// 这避免了在运行 `go test` 时被日志淹没。
|
||||
cfg := config.LogConfig{Level: "fatal"}
|
||||
testLogger = logs.NewLogger(cfg)
|
||||
}
|
||||
|
||||
// MockTask 用于测试的模拟任务
|
||||
type MockTask struct {
|
||||
id string
|
||||
priority int
|
||||
isDone bool
|
||||
execute func() error
|
||||
executed int32 // 使用原子操作来跟踪执行次数
|
||||
}
|
||||
|
||||
// Execute 实现了 Task 接口,并确保每次调用都增加执行计数
|
||||
func (m *MockTask) Execute() error {
|
||||
atomic.AddInt32(&m.executed, 1)
|
||||
if m.execute != nil {
|
||||
return m.execute()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTask) GetID() string {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func (m *MockTask) GetPriority() int {
|
||||
return m.priority
|
||||
}
|
||||
|
||||
func (m *MockTask) IsDone() bool {
|
||||
return m.isDone
|
||||
}
|
||||
|
||||
// ExecutedCount 返回任务被执行的次数
|
||||
func (m *MockTask) ExecutedCount() int32 {
|
||||
return atomic.LoadInt32(&m.executed)
|
||||
}
|
||||
|
||||
// --- Helper function for robust waiting ---
|
||||
func waitForWaitGroup(t *testing.T, wg *sync.WaitGroup, timeout time.Duration) {
|
||||
waitChan := make(chan struct{})
|
||||
go func() {
|
||||
defer close(waitChan)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-waitChan:
|
||||
// Wait succeeded
|
||||
case <-time.After(timeout):
|
||||
t.Fatal("timed out waiting for tasks to complete")
|
||||
}
|
||||
}
|
||||
|
||||
// --- TaskQueue Tests (No changes needed) ---
|
||||
|
||||
func TestNewTaskQueue(t *testing.T) {
|
||||
tq := task.NewTaskQueue(testLogger)
|
||||
assert.NotNil(t, tq, "新创建的任务队列不应为 nil")
|
||||
assert.Equal(t, 0, tq.GetTaskCount(), "新创建的任务队列应为空")
|
||||
}
|
||||
|
||||
func TestTaskQueue_AddTask(t *testing.T) {
|
||||
tq := task.NewTaskQueue(testLogger)
|
||||
mockTask := &MockTask{id: "task1", priority: 1}
|
||||
|
||||
tq.AddTask(mockTask)
|
||||
assert.Equal(t, 1, tq.GetTaskCount(), "添加任务后,队列中的任务数应为 1")
|
||||
}
|
||||
|
||||
// ... (other TaskQueue tests remain the same)
|
||||
func TestTaskQueue_GetNextTask(t *testing.T) {
|
||||
t.Run("从空队列获取任务", func(t *testing.T) {
|
||||
tq := task.NewTaskQueue(testLogger)
|
||||
nextTask := tq.GetNextTask()
|
||||
assert.Nil(t, nextTask, "从空队列中获取任务应返回 nil")
|
||||
})
|
||||
|
||||
t.Run("按优先级获取任务", func(t *testing.T) {
|
||||
tq := task.NewTaskQueue(testLogger)
|
||||
task1 := &MockTask{id: "task1", priority: 10}
|
||||
task2 := &MockTask{id: "task2", priority: 1} // 优先级更高
|
||||
task3 := &MockTask{id: "task3", priority: 5}
|
||||
|
||||
tq.AddTask(task1)
|
||||
tq.AddTask(task2)
|
||||
tq.AddTask(task3)
|
||||
|
||||
assert.Equal(t, 3, tq.GetTaskCount(), "添加三个任务后,队列中的任务数应为 3")
|
||||
|
||||
nextTask := tq.GetNextTask()
|
||||
assert.NotNil(t, nextTask)
|
||||
assert.Equal(t, "task2", nextTask.GetID(), "应首先获取优先级最高的任务 (task2)")
|
||||
|
||||
nextTask = tq.GetNextTask()
|
||||
assert.NotNil(t, nextTask)
|
||||
assert.Equal(t, "task3", nextTask.GetID(), "应获取下一个优先级最高的任务 (task3)")
|
||||
|
||||
nextTask = tq.GetNextTask()
|
||||
assert.NotNil(t, nextTask)
|
||||
assert.Equal(t, "task1", nextTask.GetID(), "应最后获取优先级最低的任务 (task1)")
|
||||
|
||||
assert.Equal(t, 0, tq.GetTaskCount(), "获取所有任务后,队列应为空")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTaskQueue_Concurrency(t *testing.T) {
|
||||
tq := task.NewTaskQueue(testLogger)
|
||||
var wg sync.WaitGroup
|
||||
taskCount := 100
|
||||
|
||||
wg.Add(taskCount)
|
||||
for i := 0; i < taskCount; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
tq.AddTask(&MockTask{id: fmt.Sprintf("task-%d", i), priority: i})
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, taskCount, tq.GetTaskCount(), "并发添加任务后,队列中的任务数应为 %d", taskCount)
|
||||
|
||||
wg.Add(taskCount)
|
||||
for i := 0; i < taskCount; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
task := tq.GetNextTask()
|
||||
assert.NotNil(t, task)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, 0, tq.GetTaskCount(), "并发获取所有任务后,队列应为空")
|
||||
}
|
||||
|
||||
// --- Executor Tests (Refactored for reliability) ---
|
||||
|
||||
func TestNewExecutor(t *testing.T) {
|
||||
executor := task.NewExecutor(5, testLogger)
|
||||
assert.NotNil(t, executor, "新创建的执行器不应为 nil")
|
||||
}
|
||||
|
||||
func TestExecutor_StartStop(t *testing.T) {
|
||||
executor := task.NewExecutor(2, testLogger)
|
||||
executor.Start()
|
||||
// 确保立即停止不会导致死锁或竞争条件。
|
||||
executor.Stop()
|
||||
}
|
||||
|
||||
// TestExecutor_SubmitAndExecuteTask 测试提交并执行单个任务 (已重构,更可靠)
|
||||
func TestExecutor_SubmitAndExecuteTask(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
executor := task.NewExecutor(1, testLogger)
|
||||
mockTask := &MockTask{
|
||||
id: "task1",
|
||||
priority: 1,
|
||||
execute: func() error {
|
||||
wg.Done() // 任务完成时通知 WaitGroup
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
executor.Start()
|
||||
executor.SubmitTask(mockTask)
|
||||
|
||||
// 等待任务完成,设置一个合理的超时时间
|
||||
waitForWaitGroup(t, &wg, 2*time.Second)
|
||||
|
||||
executor.Stop()
|
||||
|
||||
assert.Equal(t, int32(1), mockTask.ExecutedCount(), "任务应该已被执行")
|
||||
}
|
||||
|
||||
// TestExecutor_ExecuteMultipleTasks 测试执行多个任务 (已重构,更可靠)
|
||||
func TestExecutor_ExecuteMultipleTasks(t *testing.T) {
|
||||
taskCount := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(taskCount)
|
||||
|
||||
executor := task.NewExecutor(3, testLogger)
|
||||
mockTasks := make([]*MockTask, taskCount)
|
||||
for i := 0; i < taskCount; i++ {
|
||||
mockTasks[i] = &MockTask{
|
||||
id: fmt.Sprintf("task-%d", i),
|
||||
priority: i,
|
||||
execute: func() error {
|
||||
wg.Done() // 每个任务完成时都通知 WaitGroup
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
executor.Start()
|
||||
for _, task := range mockTasks {
|
||||
executor.SubmitTask(task)
|
||||
}
|
||||
|
||||
// 等待所有任务完成
|
||||
waitForWaitGroup(t, &wg, 2*time.Second)
|
||||
|
||||
executor.Stop()
|
||||
|
||||
var totalExecuted int32
|
||||
for _, task := range mockTasks {
|
||||
totalExecuted += task.ExecutedCount()
|
||||
}
|
||||
|
||||
assert.Equal(t, int32(taskCount), totalExecuted, "所有提交的任务都应该被执行")
|
||||
}
|
||||
|
||||
// TestExecutor_TaskExecutionError 测试任务执行失败的场景 (已重构,更可靠)
|
||||
func TestExecutor_TaskExecutionError(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2) // 我们期望两个任务都被执行
|
||||
|
||||
executor := task.NewExecutor(1, testLogger)
|
||||
errorTask := &MockTask{
|
||||
id: "errorTask",
|
||||
priority: 1,
|
||||
execute: func() error {
|
||||
wg.Done()
|
||||
return errors.New("执行失败")
|
||||
},
|
||||
}
|
||||
|
||||
successTask := &MockTask{
|
||||
id: "successTask",
|
||||
priority: 2, // 后执行
|
||||
execute: func() error {
|
||||
wg.Done()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
executor.Start()
|
||||
executor.SubmitTask(errorTask)
|
||||
executor.SubmitTask(successTask)
|
||||
|
||||
waitForWaitGroup(t, &wg, 2*time.Second)
|
||||
executor.Stop()
|
||||
|
||||
assert.Equal(t, int32(1), errorTask.ExecutedCount(), "失败的任务应该被执行一次")
|
||||
assert.Equal(t, int32(1), successTask.ExecutedCount(), "成功的任务也应该被执行")
|
||||
}
|
||||
|
||||
// TestExecutor_StopWithPendingTasks 测试停止执行器时仍有待处理任务 (已重构,更可靠)
|
||||
func TestExecutor_StopWithPendingTasks(t *testing.T) {
|
||||
executor := task.NewExecutor(1, testLogger)
|
||||
task1Started := make(chan struct{})
|
||||
|
||||
task1 := &MockTask{
|
||||
id: "task1",
|
||||
priority: 1,
|
||||
execute: func() error {
|
||||
close(task1Started) // 发送信号,通知测试 task1 已开始执行
|
||||
time.Sleep(200 * time.Millisecond) // 模拟耗时操作
|
||||
return nil
|
||||
},
|
||||
}
|
||||
task2 := &MockTask{id: "task2", priority: 2}
|
||||
|
||||
executor.Start()
|
||||
executor.SubmitTask(task1)
|
||||
executor.SubmitTask(task2)
|
||||
|
||||
// 等待 task1 开始执行的信号,而不是依赖不确定的 sleep
|
||||
select {
|
||||
case <-task1Started:
|
||||
// task1 已开始,可以安全地停止执行器了
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("timed out waiting for task1 to start")
|
||||
}
|
||||
|
||||
executor.Stop()
|
||||
|
||||
assert.Equal(t, int32(1), task1.ExecutedCount(), "task1 应该在停止前开始执行")
|
||||
assert.Equal(t, int32(0), task2.ExecutedCount(), "task2 不应该被执行,因为执行器已停止")
|
||||
}
|
||||
Reference in New Issue
Block a user