263 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			263 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package repository_test
 | |
| 
 | |
| import (
 | |
| 	"encoding/json"
 | |
| 	"strconv"
 | |
| 	"testing"
 | |
| 
 | |
| 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/models"
 | |
| 	"git.huangwc.com/pig/pig-farm-controller/internal/infra/repository"
 | |
| 	"github.com/stretchr/testify/assert"
 | |
| 	"gorm.io/gorm"
 | |
| )
 | |
| 
 | |
| // createTestDevice 辅助函数,用于创建测试设备
 | |
| func createTestDevice(t *testing.T, db *gorm.DB, name string, deviceType models.DeviceType, parentID *uint) *models.Device {
 | |
| 	device := &models.Device{
 | |
| 		Name:     name,
 | |
| 		Type:     deviceType,
 | |
| 		ParentID: parentID,
 | |
| 		// 其他字段可以根据需要添加
 | |
| 	}
 | |
| 	err := db.Create(device).Error
 | |
| 	assert.NoError(t, err)
 | |
| 	return device
 | |
| }
 | |
| 
 | |
| func TestRepoCreate(t *testing.T) {
 | |
| 	db := setupTestDB(t)
 | |
| 	repo := repository.NewGormDeviceRepository(db)
 | |
| 
 | |
| 	loraProps, _ := json.Marshal(models.LoraProperties{LoraAddress: "0xABCD"})
 | |
| 
 | |
| 	t.Run("成功创建区域主控", func(t *testing.T) {
 | |
| 		device := &models.Device{
 | |
| 			Name:       "主控A",
 | |
| 			Type:       models.DeviceTypeAreaController,
 | |
| 			Location:   "猪舍1",
 | |
| 			Properties: loraProps,
 | |
| 		}
 | |
| 		err := repo.Create(device)
 | |
| 		assert.NoError(t, err)
 | |
| 		assert.NotZero(t, device.ID, "创建后应获得一个非零ID")
 | |
| 		assert.Nil(t, device.ParentID, "区域主控的 ParentID 应为 nil")
 | |
| 	})
 | |
| 
 | |
| 	t.Run("成功创建子设备", func(t *testing.T) {
 | |
| 		parent := createTestDevice(t, db, "父设备", models.DeviceTypeAreaController, nil)
 | |
| 		child := &models.Device{
 | |
| 			Name:     "子设备A",
 | |
| 			Type:     models.DeviceTypeDevice,
 | |
| 			ParentID: &parent.ID,
 | |
| 		}
 | |
| 		err := repo.Create(child)
 | |
| 		assert.NoError(t, err)
 | |
| 		assert.NotZero(t, child.ID)
 | |
| 		assert.NotNil(t, child.ParentID)
 | |
| 		assert.Equal(t, parent.ID, *child.ParentID)
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func TestRepoFindByID(t *testing.T) {
 | |
| 	db := setupTestDB(t)
 | |
| 	repo := repository.NewGormDeviceRepository(db)
 | |
| 
 | |
| 	device := createTestDevice(t, db, "测试设备", models.DeviceTypeAreaController, nil)
 | |
| 
 | |
| 	t.Run("成功通过ID查找", func(t *testing.T) {
 | |
| 		foundDevice, err := repo.FindByID(device.ID)
 | |
| 		assert.NoError(t, err)
 | |
| 		assert.NotNil(t, foundDevice)
 | |
| 		assert.Equal(t, device.ID, foundDevice.ID)
 | |
| 		assert.Equal(t, device.Name, foundDevice.Name)
 | |
| 	})
 | |
| 
 | |
| 	t.Run("查找不存在的ID", func(t *testing.T) {
 | |
| 		_, err := repo.FindByID(9999) // 不存在的ID
 | |
| 		assert.Error(t, err)
 | |
| 		assert.ErrorIs(t, err, gorm.ErrRecordNotFound)
 | |
| 	})
 | |
| 
 | |
| 	t.Run("数据库查询失败", func(t *testing.T) {
 | |
| 		// 模拟数据库连接关闭,强制查询失败
 | |
| 		sqlDB, _ := db.DB()
 | |
| 		sqlDB.Close()
 | |
| 
 | |
| 		_, err := repo.FindByID(device.ID)
 | |
| 		assert.Error(t, err)
 | |
| 		assert.Contains(t, err.Error(), "database is closed")
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func TestRepoFindByIDString(t *testing.T) {
 | |
| 	db := setupTestDB(t)
 | |
| 	repo := repository.NewGormDeviceRepository(db)
 | |
| 
 | |
| 	device := createTestDevice(t, db, "测试设备", models.DeviceTypeAreaController, nil)
 | |
| 
 | |
| 	t.Run("成功通过字符串ID查找", func(t *testing.T) {
 | |
| 		idStr := strconv.FormatUint(uint64(device.ID), 10)
 | |
| 		foundDevice, err := repo.FindByIDString(idStr)
 | |
| 		assert.NoError(t, err)
 | |
| 		assert.NotNil(t, foundDevice)
 | |
| 		assert.Equal(t, device.ID, foundDevice.ID)
 | |
| 	})
 | |
| 
 | |
| 	t.Run("无效的字符串ID格式", func(t *testing.T) {
 | |
| 		_, err := repo.FindByIDString("invalid-id")
 | |
| 		assert.Error(t, err)
 | |
| 		assert.Contains(t, err.Error(), "无效的设备ID格式")
 | |
| 	})
 | |
| 
 | |
| 	t.Run("查找不存在的字符串ID", func(t *testing.T) {
 | |
| 		idStr := strconv.FormatUint(uint64(9999), 10) // 不存在的ID
 | |
| 		_, err := repo.FindByIDString(idStr)
 | |
| 		assert.Error(t, err)
 | |
| 		assert.ErrorIs(t, err, gorm.ErrRecordNotFound)
 | |
| 	})
 | |
| 
 | |
| 	t.Run("数据库查询失败", func(t *testing.T) {
 | |
| 		sqlDB, _ := db.DB()
 | |
| 		sqlDB.Close()
 | |
| 
 | |
| 		idStr := strconv.FormatUint(uint64(device.ID), 10)
 | |
| 		_, err := repo.FindByIDString(idStr)
 | |
| 		assert.Error(t, err)
 | |
| 		assert.Contains(t, err.Error(), "database is closed")
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func TestRepoListAll(t *testing.T) {
 | |
| 	db := setupTestDB(t)
 | |
| 	repo := repository.NewGormDeviceRepository(db)
 | |
| 
 | |
| 	t.Run("成功获取空列表", func(t *testing.T) {
 | |
| 		devices, err := repo.ListAll()
 | |
| 		assert.NoError(t, err)
 | |
| 		assert.Empty(t, devices)
 | |
| 	})
 | |
| 
 | |
| 	t.Run("成功获取包含设备的列表", func(t *testing.T) {
 | |
| 		createTestDevice(t, db, "设备1", models.DeviceTypeAreaController, nil)
 | |
| 		createTestDevice(t, db, "设备2", models.DeviceTypeDevice, nil)
 | |
| 
 | |
| 		devices, err := repo.ListAll()
 | |
| 		assert.NoError(t, err)
 | |
| 		assert.Len(t, devices, 2)
 | |
| 		assert.Equal(t, "设备1", devices[0].Name)
 | |
| 		assert.Equal(t, "设备2", devices[1].Name)
 | |
| 	})
 | |
| 
 | |
| 	t.Run("数据库查询失败", func(t *testing.T) {
 | |
| 		sqlDB, _ := db.DB()
 | |
| 		sqlDB.Close()
 | |
| 
 | |
| 		_, err := repo.ListAll()
 | |
| 		assert.Error(t, err)
 | |
| 		assert.Contains(t, err.Error(), "database is closed")
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func TestRepoListByParentID(t *testing.T) {
 | |
| 	db := setupTestDB(t)
 | |
| 	repo := repository.NewGormDeviceRepository(db)
 | |
| 
 | |
| 	parent1 := createTestDevice(t, db, "父设备1", models.DeviceTypeAreaController, nil)
 | |
| 	parent2 := createTestDevice(t, db, "父设备2", models.DeviceTypeAreaController, nil)
 | |
| 	child1_1 := createTestDevice(t, db, "子设备1-1", models.DeviceTypeDevice, &parent1.ID)
 | |
| 	child1_2 := createTestDevice(t, db, "子设备1-2", models.DeviceTypeDevice, &parent1.ID)
 | |
| 	_ = createTestDevice(t, db, "子设备2-1", models.DeviceTypeDevice, &parent2.ID)
 | |
| 
 | |
| 	t.Run("成功通过父ID查找子设备", func(t *testing.T) {
 | |
| 		children, err := repo.ListByParentID(&parent1.ID)
 | |
| 		assert.NoError(t, err)
 | |
| 		assert.Len(t, children, 2)
 | |
| 		assert.Contains(t, []uint{child1_1.ID, child1_2.ID}, children[0].ID)
 | |
| 		assert.Contains(t, []uint{child1_1.ID, child1_2.ID}, children[1].ID)
 | |
| 	})
 | |
| 
 | |
| 	t.Run("成功通过nil父ID查找顶层设备", func(t *testing.T) {
 | |
| 		parents, err := repo.ListByParentID(nil)
 | |
| 		assert.NoError(t, err)
 | |
| 		assert.Len(t, parents, 2)
 | |
| 		assert.Contains(t, []uint{parent1.ID, parent2.ID}, parents[0].ID)
 | |
| 		assert.Contains(t, []uint{parent1.ID, parent2.ID}, parents[1].ID)
 | |
| 	})
 | |
| 
 | |
| 	t.Run("查找不存在的父ID", func(t *testing.T) {
 | |
| 		nonExistentParentID := uint(9999)
 | |
| 		children, err := repo.ListByParentID(&nonExistentParentID)
 | |
| 		assert.NoError(t, err) // GORM 在未找到时返回空列表而不是错误
 | |
| 		assert.Empty(t, children)
 | |
| 	})
 | |
| 
 | |
| 	t.Run("数据库查询失败", func(t *testing.T) {
 | |
| 		sqlDB, _ := db.DB()
 | |
| 		sqlDB.Close()
 | |
| 
 | |
| 		_, err := repo.ListByParentID(&parent1.ID)
 | |
| 		assert.Error(t, err)
 | |
| 		assert.Contains(t, err.Error(), "database is closed")
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func TestRepoUpdate(t *testing.T) {
 | |
| 	db := setupTestDB(t)
 | |
| 	repo := repository.NewGormDeviceRepository(db)
 | |
| 
 | |
| 	device := createTestDevice(t, db, "原始设备", models.DeviceTypeAreaController, nil)
 | |
| 
 | |
| 	t.Run("成功更新设备信息", func(t *testing.T) {
 | |
| 		device.Name = "更新后的设备"
 | |
| 		device.Location = "新地点"
 | |
| 		err := repo.Update(device)
 | |
| 		assert.NoError(t, err)
 | |
| 
 | |
| 		updatedDevice, err := repo.FindByID(device.ID)
 | |
| 		assert.NoError(t, err)
 | |
| 		assert.Equal(t, "更新后的设备", updatedDevice.Name)
 | |
| 		assert.Equal(t, "新地点", updatedDevice.Location)
 | |
| 	})
 | |
| 
 | |
| 	t.Run("数据库更新失败", func(t *testing.T) {
 | |
| 		sqlDB, _ := db.DB()
 | |
| 		sqlDB.Close()
 | |
| 
 | |
| 		device.Name = "更新失败的设备"
 | |
| 		err := repo.Update(device)
 | |
| 		assert.Error(t, err)
 | |
| 		assert.Contains(t, err.Error(), "database is closed")
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func TestRepoDelete(t *testing.T) {
 | |
| 	db := setupTestDB(t)
 | |
| 	repo := repository.NewGormDeviceRepository(db)
 | |
| 
 | |
| 	device := createTestDevice(t, db, "待删除设备", models.DeviceTypeAreaController, nil)
 | |
| 
 | |
| 	t.Run("成功删除设备", func(t *testing.T) {
 | |
| 		err := repo.Delete(device.ID)
 | |
| 		assert.NoError(t, err)
 | |
| 
 | |
| 		// 验证设备已被软删除
 | |
| 		_, err = repo.FindByID(device.ID)
 | |
| 		assert.Error(t, err, "删除后应无法找到设备")
 | |
| 		assert.ErrorIs(t, err, gorm.ErrRecordNotFound, "错误类型应为 RecordNotFound")
 | |
| 	})
 | |
| 
 | |
| 	t.Run("删除不存在的设备", func(t *testing.T) {
 | |
| 		err := repo.Delete(9999) // 不存在的ID
 | |
| 		assert.NoError(t, err)   // GORM 的 Delete 方法在删除不存在的记录时不会报错
 | |
| 	})
 | |
| 
 | |
| 	t.Run("数据库删除失败", func(t *testing.T) {
 | |
| 		sqlDB, _ := db.DB()
 | |
| 		sqlDB.Close()
 | |
| 
 | |
| 		err := repo.Delete(device.ID)
 | |
| 		assert.Error(t, err)
 | |
| 		assert.Contains(t, err.Error(), "database is closed")
 | |
| 	})
 | |
| }
 |