diff --git a/internal/app/api/api.go b/internal/app/api/api.go index 06f2af8..c2f7165 100644 --- a/internal/app/api/api.go +++ b/internal/app/api/api.go @@ -91,17 +91,17 @@ func NewAPI(cfg config.ServerConfig, config: cfg, listenHandler: listenHandler, // 在 NewAPI 中初始化用户控制器,并将其作为 API 结构体的成员 - userController: user.NewController(baseCtx, userService), + userController: user.NewController(logs.AddCompName(baseCtx, "UserController"), userService), // 在 NewAPI 中初始化设备控制器,并将其作为 API 结构体的成员 - deviceController: device.NewController(baseCtx, deviceService), + deviceController: device.NewController(logs.AddCompName(baseCtx, "DeviceController"), deviceService), // 在 NewAPI 中初始化计划控制器,并将其作为 API 结构体的成员 - planController: plan.NewController(baseCtx, planService), + planController: plan.NewController(logs.AddCompName(baseCtx, "PlanController"), planService), // 在 NewAPI 中初始化猪场管理控制器 - pigFarmController: management.NewPigFarmController(baseCtx, pigFarmService), + pigFarmController: management.NewPigFarmController(logs.AddCompName(baseCtx, "PigFarmController"), pigFarmService), // 在 NewAPI 中初始化猪群控制器 - pigBatchController: management.NewPigBatchController(baseCtx, pigBatchService), + pigBatchController: management.NewPigBatchController(logs.AddCompName(baseCtx, "PigBatchController"), pigBatchService), // 在 NewAPI 中初始化数据监控控制器 - monitorController: monitor.NewController(baseCtx, monitorService), + monitorController: monitor.NewController(logs.AddCompName(baseCtx, "MonitorController"), monitorService), } api.setupRoutes() // 设置所有路由 diff --git a/internal/app/middleware/audit.go b/internal/app/middleware/audit.go index 3ebf0a0..9af8dea 100644 --- a/internal/app/middleware/audit.go +++ b/internal/app/middleware/audit.go @@ -15,7 +15,7 @@ import ( func AuditLogMiddleware(ctx context.Context, auditService service.AuditService) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - newCtx := logs.AddFuncName(ctx, c.Request().Context(), "AuditLogMiddleware") + newCtx := logs.AddFuncName(c.Request().Context(), ctx, "AuditLogMiddleware") // 首先执行请求链中的后续处理程序(即业务控制器) err := next(c) @@ -48,9 +48,12 @@ func AuditLogMiddleware(ctx context.Context, auditService service.AuditService) status, _ := c.Get(models.ContextAuditStatus.String()).(models.AuditStatus) resultDetails, _ := c.Get(models.ContextAuditResultDetails.String()).(string) + // 为异步任务创建一个分离的 Context,以防止原始请求的 Context 被取消 + detachedCtx := logs.DetachContext(newCtx) + // 调用审计服务记录日志(异步) auditService.LogAction( - newCtx, + detachedCtx, user, reqCtx, actionType, diff --git a/internal/app/middleware/auth.go b/internal/app/middleware/auth.go index eaa3628..1607e14 100644 --- a/internal/app/middleware/auth.go +++ b/internal/app/middleware/auth.go @@ -22,7 +22,7 @@ import ( func AuthMiddleware(ctx context.Context, tokenGenerator token.Generator, userRepo repository.UserRepository) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - reqCtx := logs.AddFuncName(ctx, c.Request().Context(), "AuthMiddleware") + reqCtx := logs.AddFuncName(c.Request().Context(), ctx, "AuthMiddleware") // 从 Authorization header 获取 token authHeader := c.Request().Header.Get("Authorization") diff --git a/internal/infra/database/postgres.go b/internal/infra/database/postgres.go index f11cd63..7b29ebd 100644 --- a/internal/infra/database/postgres.go +++ b/internal/infra/database/postgres.go @@ -47,7 +47,7 @@ func (ps *PostgresStorage) Connect(ctx context.Context) error { logger.Info("正在连接PostgreSQL数据库") // 创建 GORM 的 logger 适配器 - gormLogger := logs.NewGormLogger(logger) + gormLogger := logs.NewGormLogger(logs.GetLogger(logs.AddCompName(context.Background(), "GORM"))) var err error // 在 gorm.Open 时传入我们自定义的 logger diff --git a/internal/infra/logs/context.go b/internal/infra/logs/context.go index 32b875c..7a8c0c3 100644 --- a/internal/infra/logs/context.go +++ b/internal/infra/logs/context.go @@ -3,6 +3,7 @@ package logs import ( "context" "fmt" + "strings" ) // contextKey 是用于在 context.Context 中存储值的私有类型,避免键冲突。 @@ -61,3 +62,29 @@ func AddFuncName(upstreamCtx context.Context, selfCtx context.Context, funcName return newCtx } + +// DetachContext 创建一个“分离”的 Context。 +// 新的 Context 会继承原始 Context 中的所有值(特别是用于日志追踪的 chainKey 和 compNameKey), +// 但它会使用 context.Background() 作为其父级,从而“丢弃”原始 Context 的取消信号。 +// 这对于需要在请求结束后继续执行的异步任务(如记录审计日志)至关重要,可以防止出现 "context canceled" 错误。 +func DetachContext(ctx context.Context) context.Context { + detachedCtx := context.Background() + + // 复制我们关心的、用于日志追踪的所有值 + if val := ctx.Value(chainKey); val != nil { + detachedCtx = context.WithValue(detachedCtx, chainKey, val) + } + if val := ctx.Value(compNameKey); val != nil { + detachedCtx = context.WithValue(detachedCtx, compNameKey, val) + } + + return detachedCtx +} + +// 获取context中的调用链字符串 +func GetTraceStr(ctx context.Context) string { + if trace, ok := ctx.Value(chainKey).([]string); ok { + return strings.Join(trace, "->") + } + return "" +} diff --git a/internal/infra/logs/logs.go b/internal/infra/logs/logs.go index d60b05d..13e8c4d 100644 --- a/internal/infra/logs/logs.go +++ b/internal/infra/logs/logs.go @@ -91,13 +91,13 @@ func GetLogger(ctx context.Context) *Logger { return defaultLogger } - chain, ok := val.([]string) - if !ok || len(chain) == 0 { + chain := GetTraceStr(ctx) + if chain == "" { return defaultLogger } // 使用 With 方法创建带有 traceKey 字段的 Logger 副本 - newSugaredLogger := defaultLogger.With(traceKey, strings.Join(chain, "->")) + newSugaredLogger := defaultLogger.With(traceKey, chain) return &Logger{newSugaredLogger} } @@ -211,8 +211,11 @@ func (g *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql "elapsed", fmt.Sprintf("%.3fms", float64(elapsed.Nanoseconds())/1e6), } - // 获取带有调用链的 logger - logger := GetLogger(ctx) + // 附加调用链信息 + chain := GetTraceStr(ctx) + if chain != "" { + fields = append(fields, traceKey, chain) + } if err != nil { // 如果是 "record not found" 错误且我们配置了跳过,则直接返回 @@ -220,18 +223,18 @@ func (g *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql return } // 否则,记录为错误日志 - logger.With(fields...).Errorf("[GORM] error: %s", err) + defaultLogger.With(fields...).Errorf("[GORM] error: %s", err) return } // 如果查询时间超过慢查询阈值,则记录警告 if g.SlowThreshold != 0 && elapsed > g.SlowThreshold { - logger.With(fields...).Warnf("[GORM] slow query") + defaultLogger.With(fields...).Warnf("[GORM] slow query") return } // 正常情况,记录 Debug 级别的 SQL 查询 - logger.With(fields...).Debugf("[GORM] trace") + defaultLogger.With(fields...).Debugf("[GORM] trace") // --- 逻辑修复结束 --- }