增加websocket支持
This commit is contained in:
		
							
								
								
									
										604
									
								
								vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										604
									
								
								vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -18,8 +18,8 @@ import ( | ||||
|  | ||||
| 	"github.com/jackc/pgx/v5/internal/iobufpool" | ||||
| 	"github.com/jackc/pgx/v5/internal/pgio" | ||||
| 	"github.com/jackc/pgx/v5/pgconn/ctxwatch" | ||||
| 	"github.com/jackc/pgx/v5/pgconn/internal/bgreader" | ||||
| 	"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" | ||||
| 	"github.com/jackc/pgx/v5/pgproto3" | ||||
| ) | ||||
|  | ||||
| @@ -52,12 +52,6 @@ type LookupFunc func(ctx context.Context, host string) (addrs []string, err erro | ||||
| // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. | ||||
| type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend | ||||
|  | ||||
| // PgErrorHandler is a function that handles errors returned from Postgres. This function must return true to keep | ||||
| // the connection open. Returning false will cause the connection to be closed immediately. You should return | ||||
| // false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is | ||||
| // aware of the origin of the error, but it must not invoke any query method. | ||||
| type PgErrorHandler func(*PgConn, *PgError) bool | ||||
|  | ||||
| // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at | ||||
| // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin | ||||
| // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY | ||||
| @@ -80,9 +74,6 @@ type PgConn struct { | ||||
| 	frontend          *pgproto3.Frontend | ||||
| 	bgReader          *bgreader.BGReader | ||||
| 	slowWriteTimer    *time.Timer | ||||
| 	bgReaderStarted   chan struct{} | ||||
|  | ||||
| 	customData map[string]any | ||||
|  | ||||
| 	config *Config | ||||
|  | ||||
| @@ -105,9 +96,8 @@ type PgConn struct { | ||||
| 	cleanupDone chan struct{} | ||||
| } | ||||
|  | ||||
| // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value | ||||
| // format) to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a | ||||
| // connect attempt. | ||||
| // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) | ||||
| // to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a connect attempt. | ||||
| func Connect(ctx context.Context, connString string) (*PgConn, error) { | ||||
| 	config, err := ParseConfig(connString) | ||||
| 	if err != nil { | ||||
| @@ -117,9 +107,9 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) { | ||||
| 	return ConnectConfig(ctx, config) | ||||
| } | ||||
|  | ||||
| // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value | ||||
| // format) and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details. | ||||
| // ctx can be used to cancel a connect attempt. | ||||
| // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) | ||||
| // and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details. ctx can be | ||||
| // used to cancel a connect attempt. | ||||
| func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { | ||||
| 	config, err := ParseConfigWithOptions(connString, parseConfigOptions) | ||||
| 	if err != nil { | ||||
| @@ -134,46 +124,15 @@ func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptio | ||||
| // | ||||
| // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An | ||||
| // authentication error will terminate the chain of attempts (like libpq: | ||||
| // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. | ||||
| func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) { | ||||
| // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, | ||||
| // if all attempts fail the last error is returned. | ||||
| func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err error) { | ||||
| 	// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from | ||||
| 	// zero values. | ||||
| 	if !config.createdByParseConfig { | ||||
| 		panic("config must be created by ParseConfig") | ||||
| 	} | ||||
|  | ||||
| 	var allErrors []error | ||||
|  | ||||
| 	connectConfigs, errs := buildConnectOneConfigs(ctx, config) | ||||
| 	if len(errs) > 0 { | ||||
| 		allErrors = append(allErrors, errs...) | ||||
| 	} | ||||
|  | ||||
| 	if len(connectConfigs) == 0 { | ||||
| 		return nil, &ConnectError{Config: config, err: fmt.Errorf("hostname resolving error: %w", errors.Join(allErrors...))} | ||||
| 	} | ||||
|  | ||||
| 	pgConn, errs := connectPreferred(ctx, config, connectConfigs) | ||||
| 	if len(errs) > 0 { | ||||
| 		allErrors = append(allErrors, errs...) | ||||
| 		return nil, &ConnectError{Config: config, err: errors.Join(allErrors...)} | ||||
| 	} | ||||
|  | ||||
| 	if config.AfterConnect != nil { | ||||
| 		err := config.AfterConnect(ctx, pgConn) | ||||
| 		if err != nil { | ||||
| 			pgConn.conn.Close() | ||||
| 			return nil, &ConnectError{Config: config, err: fmt.Errorf("AfterConnect error: %w", err)} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return pgConn, nil | ||||
| } | ||||
|  | ||||
| // buildConnectOneConfigs resolves hostnames and builds a list of connectOneConfigs to try connecting to. It returns a | ||||
| // slice of successfully resolved connectOneConfigs and a slice of errors. It is possible for both slices to contain | ||||
| // values if some hosts were successfully resolved and others were not. | ||||
| func buildConnectOneConfigs(ctx context.Context, config *Config) ([]*connectOneConfig, []error) { | ||||
| 	// Simplify usage by treating primary config and fallbacks the same. | ||||
| 	fallbackConfigs := []*FallbackConfig{ | ||||
| 		{ | ||||
| @@ -183,28 +142,95 @@ func buildConnectOneConfigs(ctx context.Context, config *Config) ([]*connectOneC | ||||
| 		}, | ||||
| 	} | ||||
| 	fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) | ||||
| 	ctx := octx | ||||
| 	fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) | ||||
| 	if err != nil { | ||||
| 		return nil, &connectError{config: config, msg: "hostname resolving error", err: err} | ||||
| 	} | ||||
|  | ||||
| 	var configs []*connectOneConfig | ||||
| 	if len(fallbackConfigs) == 0 { | ||||
| 		return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} | ||||
| 	} | ||||
|  | ||||
| 	var allErrors []error | ||||
| 	foundBestServer := false | ||||
| 	var fallbackConfig *FallbackConfig | ||||
| 	for i, fc := range fallbackConfigs { | ||||
| 		// ConnectTimeout restricts the whole connection process. | ||||
| 		if config.ConnectTimeout != 0 { | ||||
| 			// create new context first time or when previous host was different | ||||
| 			if i == 0 || (fallbackConfigs[i].Host != fallbackConfigs[i-1].Host) { | ||||
| 				var cancel context.CancelFunc | ||||
| 				ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) | ||||
| 				defer cancel() | ||||
| 			} | ||||
| 		} else { | ||||
| 			ctx = octx | ||||
| 		} | ||||
| 		pgConn, err = connect(ctx, config, fc, false) | ||||
| 		if err == nil { | ||||
| 			foundBestServer = true | ||||
| 			break | ||||
| 		} else if pgerr, ok := err.(*PgError); ok { | ||||
| 			err = &connectError{config: config, msg: "server error", err: pgerr} | ||||
| 			const ERRCODE_INVALID_PASSWORD = "28P01"                    // wrong password | ||||
| 			const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings | ||||
| 			const ERRCODE_INVALID_CATALOG_NAME = "3D000"                // db does not exist | ||||
| 			const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501"              // missing connect privilege | ||||
| 			if pgerr.Code == ERRCODE_INVALID_PASSWORD || | ||||
| 				pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && fc.TLSConfig != nil || | ||||
| 				pgerr.Code == ERRCODE_INVALID_CATALOG_NAME || | ||||
| 				pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { | ||||
| 				break | ||||
| 			} | ||||
| 		} else if cerr, ok := err.(*connectError); ok { | ||||
| 			if _, ok := cerr.err.(*NotPreferredError); ok { | ||||
| 				fallbackConfig = fc | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	for _, fb := range fallbackConfigs { | ||||
| 	if !foundBestServer && fallbackConfig != nil { | ||||
| 		pgConn, err = connect(ctx, config, fallbackConfig, true) | ||||
| 		if pgerr, ok := err.(*PgError); ok { | ||||
| 			err = &connectError{config: config, msg: "server error", err: pgerr} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError | ||||
| 	} | ||||
|  | ||||
| 	if config.AfterConnect != nil { | ||||
| 		err := config.AfterConnect(ctx, pgConn) | ||||
| 		if err != nil { | ||||
| 			pgConn.conn.Close() | ||||
| 			return nil, &connectError{config: config, msg: "AfterConnect error", err: err} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return pgConn, nil | ||||
| } | ||||
|  | ||||
| func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { | ||||
| 	var configs []*FallbackConfig | ||||
|  | ||||
| 	var lookupErrors []error | ||||
|  | ||||
| 	for _, fb := range fallbacks { | ||||
| 		// skip resolve for unix sockets | ||||
| 		if isAbsolutePath(fb.Host) { | ||||
| 			network, address := NetworkAddress(fb.Host, fb.Port) | ||||
| 			configs = append(configs, &connectOneConfig{ | ||||
| 				network:          network, | ||||
| 				address:          address, | ||||
| 				originalHostname: fb.Host, | ||||
| 				tlsConfig:        fb.TLSConfig, | ||||
| 			configs = append(configs, &FallbackConfig{ | ||||
| 				Host:      fb.Host, | ||||
| 				Port:      fb.Port, | ||||
| 				TLSConfig: fb.TLSConfig, | ||||
| 			}) | ||||
|  | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		ips, err := config.LookupFunc(ctx, fb.Host) | ||||
| 		ips, err := lookupFn(ctx, fb.Host) | ||||
| 		if err != nil { | ||||
| 			allErrors = append(allErrors, err) | ||||
| 			lookupErrors = append(lookupErrors, err) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| @@ -213,139 +239,70 @@ func buildConnectOneConfigs(ctx context.Context, config *Config) ([]*connectOneC | ||||
| 			if err == nil { | ||||
| 				port, err := strconv.ParseUint(splitPort, 10, 16) | ||||
| 				if err != nil { | ||||
| 					return nil, []error{fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)} | ||||
| 					return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err) | ||||
| 				} | ||||
| 				network, address := NetworkAddress(splitIP, uint16(port)) | ||||
| 				configs = append(configs, &connectOneConfig{ | ||||
| 					network:          network, | ||||
| 					address:          address, | ||||
| 					originalHostname: fb.Host, | ||||
| 					tlsConfig:        fb.TLSConfig, | ||||
| 				configs = append(configs, &FallbackConfig{ | ||||
| 					Host:      splitIP, | ||||
| 					Port:      uint16(port), | ||||
| 					TLSConfig: fb.TLSConfig, | ||||
| 				}) | ||||
| 			} else { | ||||
| 				network, address := NetworkAddress(ip, fb.Port) | ||||
| 				configs = append(configs, &connectOneConfig{ | ||||
| 					network:          network, | ||||
| 					address:          address, | ||||
| 					originalHostname: fb.Host, | ||||
| 					tlsConfig:        fb.TLSConfig, | ||||
| 				configs = append(configs, &FallbackConfig{ | ||||
| 					Host:      ip, | ||||
| 					Port:      fb.Port, | ||||
| 					TLSConfig: fb.TLSConfig, | ||||
| 				}) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return configs, allErrors | ||||
| } | ||||
|  | ||||
| // connectPreferred attempts to connect to the preferred host from connectOneConfigs. The connections are attempted in | ||||
| // order. If a connection is successful it is returned. If no connection is successful then all errors are returned. If | ||||
| // a connection attempt returns a [NotPreferredError], then that host will be used if no other hosts are successful. | ||||
| func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []*connectOneConfig) (*PgConn, []error) { | ||||
| 	octx := ctx | ||||
| 	var allErrors []error | ||||
|  | ||||
| 	var fallbackConnectOneConfig *connectOneConfig | ||||
| 	for i, c := range connectOneConfigs { | ||||
| 		// ConnectTimeout restricts the whole connection process. | ||||
| 		if config.ConnectTimeout != 0 { | ||||
| 			// create new context first time or when previous host was different | ||||
| 			if i == 0 || (connectOneConfigs[i].address != connectOneConfigs[i-1].address) { | ||||
| 				var cancel context.CancelFunc | ||||
| 				ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) | ||||
| 				defer cancel() | ||||
| 			} | ||||
| 		} else { | ||||
| 			ctx = octx | ||||
| 		} | ||||
|  | ||||
| 		pgConn, err := connectOne(ctx, config, c, false) | ||||
| 		if pgConn != nil { | ||||
| 			return pgConn, nil | ||||
| 		} | ||||
|  | ||||
| 		allErrors = append(allErrors, err) | ||||
|  | ||||
| 		var pgErr *PgError | ||||
| 		if errors.As(err, &pgErr) { | ||||
| 			const ERRCODE_INVALID_PASSWORD = "28P01"                    // wrong password | ||||
| 			const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings | ||||
| 			const ERRCODE_INVALID_CATALOG_NAME = "3D000"                // db does not exist | ||||
| 			const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501"              // missing connect privilege | ||||
| 			if pgErr.Code == ERRCODE_INVALID_PASSWORD || | ||||
| 				pgErr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && c.tlsConfig != nil || | ||||
| 				pgErr.Code == ERRCODE_INVALID_CATALOG_NAME || | ||||
| 				pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { | ||||
| 				return nil, allErrors | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		var npErr *NotPreferredError | ||||
| 		if errors.As(err, &npErr) { | ||||
| 			fallbackConnectOneConfig = c | ||||
| 		} | ||||
| 	// See https://github.com/jackc/pgx/issues/1464. When Go 1.20 can be used in pgx consider using errors.Join so all | ||||
| 	// errors are reported. | ||||
| 	if len(configs) == 0 && len(lookupErrors) > 0 { | ||||
| 		return nil, lookupErrors[0] | ||||
| 	} | ||||
|  | ||||
| 	if fallbackConnectOneConfig != nil { | ||||
| 		pgConn, err := connectOne(ctx, config, fallbackConnectOneConfig, true) | ||||
| 		if err == nil { | ||||
| 			return pgConn, nil | ||||
| 		} | ||||
| 		allErrors = append(allErrors, err) | ||||
| 	} | ||||
|  | ||||
| 	return nil, allErrors | ||||
| 	return configs, nil | ||||
| } | ||||
|  | ||||
| // connectOne makes one connection attempt to a single host. | ||||
| func connectOne(ctx context.Context, config *Config, connectConfig *connectOneConfig, | ||||
| func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig, | ||||
| 	ignoreNotPreferredErr bool, | ||||
| ) (*PgConn, error) { | ||||
| 	pgConn := new(PgConn) | ||||
| 	pgConn.config = config | ||||
| 	pgConn.cleanupDone = make(chan struct{}) | ||||
| 	pgConn.customData = make(map[string]any) | ||||
|  | ||||
| 	var err error | ||||
|  | ||||
| 	newPerDialConnectError := func(msg string, err error) *perDialConnectError { | ||||
| 		err = normalizeTimeoutError(ctx, err) | ||||
| 		e := &perDialConnectError{address: connectConfig.address, originalHostname: connectConfig.originalHostname, err: fmt.Errorf("%s: %w", msg, err)} | ||||
| 		return e | ||||
| 	} | ||||
|  | ||||
| 	pgConn.conn, err = config.DialFunc(ctx, connectConfig.network, connectConfig.address) | ||||
| 	network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) | ||||
| 	netConn, err := config.DialFunc(ctx, network, address) | ||||
| 	if err != nil { | ||||
| 		return nil, newPerDialConnectError("dial error", err) | ||||
| 		return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} | ||||
| 	} | ||||
|  | ||||
| 	if connectConfig.tlsConfig != nil { | ||||
| 		pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn}) | ||||
| 		pgConn.contextWatcher.Watch(ctx) | ||||
| 		tlsConn, err := startTLS(pgConn.conn, connectConfig.tlsConfig) | ||||
| 	pgConn.conn = netConn | ||||
| 	pgConn.contextWatcher = newContextWatcher(netConn) | ||||
| 	pgConn.contextWatcher.Watch(ctx) | ||||
|  | ||||
| 	if fallbackConfig.TLSConfig != nil { | ||||
| 		nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig) | ||||
| 		pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. | ||||
| 		if err != nil { | ||||
| 			pgConn.conn.Close() | ||||
| 			return nil, newPerDialConnectError("tls error", err) | ||||
| 			netConn.Close() | ||||
| 			return nil, &connectError{config: config, msg: "tls error", err: err} | ||||
| 		} | ||||
|  | ||||
| 		pgConn.conn = tlsConn | ||||
| 		pgConn.conn = nbTLSConn | ||||
| 		pgConn.contextWatcher = newContextWatcher(nbTLSConn) | ||||
| 		pgConn.contextWatcher.Watch(ctx) | ||||
| 	} | ||||
|  | ||||
| 	pgConn.contextWatcher = ctxwatch.NewContextWatcher(config.BuildContextWatcherHandler(pgConn)) | ||||
| 	pgConn.contextWatcher.Watch(ctx) | ||||
| 	defer pgConn.contextWatcher.Unwatch() | ||||
|  | ||||
| 	pgConn.parameterStatuses = make(map[string]string) | ||||
| 	pgConn.status = connStatusConnecting | ||||
| 	pgConn.bgReader = bgreader.New(pgConn.conn) | ||||
| 	pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), | ||||
| 		func() { | ||||
| 			pgConn.bgReader.Start() | ||||
| 			pgConn.bgReaderStarted <- struct{}{} | ||||
| 		}, | ||||
| 	) | ||||
| 	pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) | ||||
| 	pgConn.slowWriteTimer.Stop() | ||||
| 	pgConn.bgReaderStarted = make(chan struct{}) | ||||
| 	pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn) | ||||
|  | ||||
| 	startupMsg := pgproto3.StartupMessage{ | ||||
| @@ -366,7 +323,7 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo | ||||
| 	pgConn.frontend.Send(&startupMsg) | ||||
| 	if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { | ||||
| 		pgConn.conn.Close() | ||||
| 		return nil, newPerDialConnectError("failed to write startup message", err) | ||||
| 		return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} | ||||
| 	} | ||||
|  | ||||
| 	for { | ||||
| @@ -374,9 +331,9 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo | ||||
| 		if err != nil { | ||||
| 			pgConn.conn.Close() | ||||
| 			if err, ok := err.(*PgError); ok { | ||||
| 				return nil, newPerDialConnectError("server error", err) | ||||
| 				return nil, err | ||||
| 			} | ||||
| 			return nil, newPerDialConnectError("failed to receive message", err) | ||||
| 			return nil, &connectError{config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} | ||||
| 		} | ||||
|  | ||||
| 		switch msg := msg.(type) { | ||||
| @@ -389,26 +346,26 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo | ||||
| 			err = pgConn.txPasswordMessage(pgConn.config.Password) | ||||
| 			if err != nil { | ||||
| 				pgConn.conn.Close() | ||||
| 				return nil, newPerDialConnectError("failed to write password message", err) | ||||
| 				return nil, &connectError{config: config, msg: "failed to write password message", err: err} | ||||
| 			} | ||||
| 		case *pgproto3.AuthenticationMD5Password: | ||||
| 			digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) | ||||
| 			err = pgConn.txPasswordMessage(digestedPassword) | ||||
| 			if err != nil { | ||||
| 				pgConn.conn.Close() | ||||
| 				return nil, newPerDialConnectError("failed to write password message", err) | ||||
| 				return nil, &connectError{config: config, msg: "failed to write password message", err: err} | ||||
| 			} | ||||
| 		case *pgproto3.AuthenticationSASL: | ||||
| 			err = pgConn.scramAuth(msg.AuthMechanisms) | ||||
| 			if err != nil { | ||||
| 				pgConn.conn.Close() | ||||
| 				return nil, newPerDialConnectError("failed SASL auth", err) | ||||
| 				return nil, &connectError{config: config, msg: "failed SASL auth", err: err} | ||||
| 			} | ||||
| 		case *pgproto3.AuthenticationGSS: | ||||
| 			err = pgConn.gssAuth() | ||||
| 			if err != nil { | ||||
| 				pgConn.conn.Close() | ||||
| 				return nil, newPerDialConnectError("failed GSS auth", err) | ||||
| 				return nil, &connectError{config: config, msg: "failed GSS auth", err: err} | ||||
| 			} | ||||
| 		case *pgproto3.ReadyForQuery: | ||||
| 			pgConn.status = connStatusIdle | ||||
| @@ -426,7 +383,7 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo | ||||
| 						return pgConn, nil | ||||
| 					} | ||||
| 					pgConn.conn.Close() | ||||
| 					return nil, newPerDialConnectError("ValidateConnect failed", err) | ||||
| 					return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} | ||||
| 				} | ||||
| 			} | ||||
| 			return pgConn, nil | ||||
| @@ -434,14 +391,21 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo | ||||
| 			// handled by ReceiveMessage | ||||
| 		case *pgproto3.ErrorResponse: | ||||
| 			pgConn.conn.Close() | ||||
| 			return nil, newPerDialConnectError("server error", ErrorResponseToPgError(msg)) | ||||
| 			return nil, ErrorResponseToPgError(msg) | ||||
| 		default: | ||||
| 			pgConn.conn.Close() | ||||
| 			return nil, newPerDialConnectError("received unexpected message", err) | ||||
| 			return nil, &connectError{config: config, msg: "received unexpected message", err: err} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { | ||||
| 	return ctxwatch.NewContextWatcher( | ||||
| 		func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, | ||||
| 		func() { conn.SetDeadline(time.Time{}) }, | ||||
| 	) | ||||
| } | ||||
|  | ||||
| func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { | ||||
| 	err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) | ||||
| 	if err != nil { | ||||
| @@ -576,12 +540,11 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { | ||||
| 	case *pgproto3.ParameterStatus: | ||||
| 		pgConn.parameterStatuses[msg.Name] = msg.Value | ||||
| 	case *pgproto3.ErrorResponse: | ||||
| 		err := ErrorResponseToPgError(msg) | ||||
| 		if pgConn.config.OnPgError != nil && !pgConn.config.OnPgError(pgConn, err) { | ||||
| 		if msg.Severity == "FATAL" { | ||||
| 			pgConn.status = connStatusClosed | ||||
| 			pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. | ||||
| 			close(pgConn.cleanupDone) | ||||
| 			return nil, err | ||||
| 			return nil, ErrorResponseToPgError(msg) | ||||
| 		} | ||||
| 	case *pgproto3.NoticeResponse: | ||||
| 		if pgConn.config.OnNotice != nil { | ||||
| @@ -630,7 +593,7 @@ func (pgConn *PgConn) Frontend() *pgproto3.Frontend { | ||||
| 	return pgConn.frontend | ||||
| } | ||||
|  | ||||
| // Close closes a connection. It is safe to call Close on an already closed connection. Close attempts a clean close by | ||||
| // Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by | ||||
| // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The | ||||
| // underlying net.Conn.Close() will always be called regardless of any other errors. | ||||
| func (pgConn *PgConn) Close(ctx context.Context) error { | ||||
| @@ -843,9 +806,6 @@ type StatementDescription struct { | ||||
|  | ||||
| // Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This | ||||
| // allows Prepare to also to describe statements without creating a server-side prepared statement. | ||||
| // | ||||
| // Prepare does not send a PREPARE statement to the server. It uses the PostgreSQL Parse and Describe protocol messages | ||||
| // directly. | ||||
| func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { | ||||
| 	if err := pgConn.lock(); err != nil { | ||||
| 		return nil, err | ||||
| @@ -902,73 +862,26 @@ readloop: | ||||
| 	return psd, nil | ||||
| } | ||||
|  | ||||
| // Deallocate deallocates a prepared statement. | ||||
| // | ||||
| // Deallocate does not send a DEALLOCATE statement to the server. It uses the PostgreSQL Close protocol message | ||||
| // directly. This has slightly different behavior than executing DEALLOCATE statement. | ||||
| //   - Deallocate can succeed in an aborted transaction. | ||||
| //   - Deallocating a non-existent prepared statement is not an error. | ||||
| func (pgConn *PgConn) Deallocate(ctx context.Context, name string) error { | ||||
| 	if err := pgConn.lock(); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	defer pgConn.unlock() | ||||
|  | ||||
| 	if ctx != context.Background() { | ||||
| 		select { | ||||
| 		case <-ctx.Done(): | ||||
| 			return newContextAlreadyDoneError(ctx) | ||||
| 		default: | ||||
| 		} | ||||
| 		pgConn.contextWatcher.Watch(ctx) | ||||
| 		defer pgConn.contextWatcher.Unwatch() | ||||
| 	} | ||||
|  | ||||
| 	pgConn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name}) | ||||
| 	pgConn.frontend.SendSync(&pgproto3.Sync{}) | ||||
| 	err := pgConn.flushWithPotentialWriteReadDeadlock() | ||||
| 	if err != nil { | ||||
| 		pgConn.asyncClose() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	for { | ||||
| 		msg, err := pgConn.receiveMessage() | ||||
| 		if err != nil { | ||||
| 			pgConn.asyncClose() | ||||
| 			return normalizeTimeoutError(ctx, err) | ||||
| 		} | ||||
|  | ||||
| 		switch msg := msg.(type) { | ||||
| 		case *pgproto3.ErrorResponse: | ||||
| 			return ErrorResponseToPgError(msg) | ||||
| 		case *pgproto3.ReadyForQuery: | ||||
| 			return nil | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ErrorResponseToPgError converts a wire protocol error message to a *PgError. | ||||
| func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { | ||||
| 	return &PgError{ | ||||
| 		Severity:            msg.Severity, | ||||
| 		SeverityUnlocalized: msg.SeverityUnlocalized, | ||||
| 		Code:                string(msg.Code), | ||||
| 		Message:             string(msg.Message), | ||||
| 		Detail:              string(msg.Detail), | ||||
| 		Hint:                msg.Hint, | ||||
| 		Position:            msg.Position, | ||||
| 		InternalPosition:    msg.InternalPosition, | ||||
| 		InternalQuery:       string(msg.InternalQuery), | ||||
| 		Where:               string(msg.Where), | ||||
| 		SchemaName:          string(msg.SchemaName), | ||||
| 		TableName:           string(msg.TableName), | ||||
| 		ColumnName:          string(msg.ColumnName), | ||||
| 		DataTypeName:        string(msg.DataTypeName), | ||||
| 		ConstraintName:      msg.ConstraintName, | ||||
| 		File:                string(msg.File), | ||||
| 		Line:                msg.Line, | ||||
| 		Routine:             string(msg.Routine), | ||||
| 		Severity:         msg.Severity, | ||||
| 		Code:             string(msg.Code), | ||||
| 		Message:          string(msg.Message), | ||||
| 		Detail:           string(msg.Detail), | ||||
| 		Hint:             msg.Hint, | ||||
| 		Position:         msg.Position, | ||||
| 		InternalPosition: msg.InternalPosition, | ||||
| 		InternalQuery:    string(msg.InternalQuery), | ||||
| 		Where:            string(msg.Where), | ||||
| 		SchemaName:       string(msg.SchemaName), | ||||
| 		TableName:        string(msg.TableName), | ||||
| 		ColumnName:       string(msg.ColumnName), | ||||
| 		DataTypeName:     string(msg.DataTypeName), | ||||
| 		ConstraintName:   msg.ConstraintName, | ||||
| 		File:             string(msg.File), | ||||
| 		Line:             msg.Line, | ||||
| 		Routine:          string(msg.Routine), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -1011,7 +924,10 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { | ||||
| 	defer cancelConn.Close() | ||||
|  | ||||
| 	if ctx != context.Background() { | ||||
| 		contextWatcher := ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: cancelConn}) | ||||
| 		contextWatcher := ctxwatch.NewContextWatcher( | ||||
| 			func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, | ||||
| 			func() { cancelConn.SetDeadline(time.Time{}) }, | ||||
| 		) | ||||
| 		contextWatcher.Watch(ctx) | ||||
| 		defer contextWatcher.Unwatch() | ||||
| 	} | ||||
| @@ -1019,21 +935,16 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { | ||||
| 	buf := make([]byte, 16) | ||||
| 	binary.BigEndian.PutUint32(buf[0:4], 16) | ||||
| 	binary.BigEndian.PutUint32(buf[4:8], 80877102) | ||||
| 	binary.BigEndian.PutUint32(buf[8:12], pgConn.pid) | ||||
| 	binary.BigEndian.PutUint32(buf[12:16], pgConn.secretKey) | ||||
|  | ||||
| 	if _, err := cancelConn.Write(buf); err != nil { | ||||
| 		return fmt.Errorf("write to connection for cancellation: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	// Wait for the cancel request to be acknowledged by the server. | ||||
| 	// It copies the behavior of the libpq: https://github.com/postgres/postgres/blob/REL_16_0/src/interfaces/libpq/fe-connect.c#L4946-L4960 | ||||
| 	_, _ = cancelConn.Read(buf) | ||||
|  | ||||
| 	return nil | ||||
| 	binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) | ||||
| 	binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) | ||||
| 	// Postgres will process the request and close the connection | ||||
| 	// so when don't need to read the reply | ||||
| 	// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.6.7.10 | ||||
| 	_, err = cancelConn.Write(buf) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // WaitForNotification waits for a LISTEN/NOTIFY message to be received. It returns an error if a notification was not | ||||
| // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not | ||||
| // received. | ||||
| func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { | ||||
| 	if err := pgConn.lock(); err != nil { | ||||
| @@ -1544,10 +1455,8 @@ func (rr *ResultReader) Read() *Result { | ||||
| 		values := rr.Values() | ||||
| 		row := make([][]byte, len(values)) | ||||
| 		for i := range row { | ||||
| 			if values[i] != nil { | ||||
| 				row[i] = make([]byte, len(values[i])) | ||||
| 				copy(row[i], values[i]) | ||||
| 			} | ||||
| 			row[i] = make([]byte, len(values[i])) | ||||
| 			copy(row[i], values[i]) | ||||
| 		} | ||||
| 		br.Rows = append(br.Rows, row) | ||||
| 	} | ||||
| @@ -1697,55 +1606,25 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { | ||||
| // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. | ||||
| type Batch struct { | ||||
| 	buf []byte | ||||
| 	err error | ||||
| } | ||||
|  | ||||
| // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. | ||||
| func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { | ||||
| 	if batch.err != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) | ||||
| 	if batch.err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) | ||||
| 	batch.ExecPrepared("", paramValues, paramFormats, resultFormats) | ||||
| } | ||||
|  | ||||
| // ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. | ||||
| func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { | ||||
| 	if batch.err != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) | ||||
| 	if batch.err != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) | ||||
| 	if batch.err != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf) | ||||
| 	if batch.err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) | ||||
| 	batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) | ||||
| 	batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) | ||||
| } | ||||
|  | ||||
| // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a | ||||
| // transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing | ||||
| // multiple queries in a single round trip than using pipeline mode. | ||||
| func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { | ||||
| 	if batch.err != nil { | ||||
| 		return &MultiResultReader{ | ||||
| 			closed: true, | ||||
| 			err:    batch.err, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := pgConn.lock(); err != nil { | ||||
| 		return &MultiResultReader{ | ||||
| 			closed: true, | ||||
| @@ -1771,13 +1650,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR | ||||
| 		pgConn.contextWatcher.Watch(ctx) | ||||
| 	} | ||||
|  | ||||
| 	batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf) | ||||
| 	if batch.err != nil { | ||||
| 		multiResult.closed = true | ||||
| 		multiResult.err = batch.err | ||||
| 		pgConn.unlock() | ||||
| 		return multiResult | ||||
| 	} | ||||
| 	batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) | ||||
|  | ||||
| 	pgConn.enterPotentialWriteReadDeadlock() | ||||
| 	defer pgConn.exitPotentialWriteReadDeadlock() | ||||
| @@ -1859,16 +1732,10 @@ func (pgConn *PgConn) enterPotentialWriteReadDeadlock() { | ||||
|  | ||||
| // exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock. | ||||
| func (pgConn *PgConn) exitPotentialWriteReadDeadlock() { | ||||
| 	if !pgConn.slowWriteTimer.Stop() { | ||||
| 		// The timer starts its function in a separate goroutine. It is necessary to ensure the background reader has | ||||
| 		// started before calling Stop. Otherwise, the background reader may not be stopped. That on its own is not a | ||||
| 		// serious problem. But what is a serious problem is that the background reader may start at an inopportune time in | ||||
| 		// a subsequent query. For example, if a subsequent query was canceled then a deadline may be set on the net.Conn to | ||||
| 		// interrupt an in-progress read. After the read is interrupted, but before the deadline is cleared, the background | ||||
| 		// reader could start and read a deadline error. Then the next query would receive the an unexpected deadline error. | ||||
| 		<-pgConn.bgReaderStarted | ||||
| 		pgConn.bgReader.Stop() | ||||
| 	} | ||||
| 	// The state of the timer is not relevant upon exiting the potential slow write. It may both | ||||
| 	// fire (due to a slow write), or not fire (due to a fast write). | ||||
| 	_ = pgConn.slowWriteTimer.Stop() | ||||
| 	pgConn.bgReader.Stop() | ||||
| } | ||||
|  | ||||
| func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error { | ||||
| @@ -1897,16 +1764,11 @@ func (pgConn *PgConn) SyncConn(ctx context.Context) error { | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// This should never happen. Only way I can imagine this occurring is if the server is constantly sending data such as | ||||
| 	// This should never happen. Only way I can imagine this occuring is if the server is constantly sending data such as | ||||
| 	// LISTEN/NOTIFY or log notifications such that we never can get an empty buffer. | ||||
| 	return errors.New("SyncConn: conn never synchronized") | ||||
| } | ||||
|  | ||||
| // CustomData returns a map that can be used to associate custom data with the connection. | ||||
| func (pgConn *PgConn) CustomData() map[string]any { | ||||
| 	return pgConn.customData | ||||
| } | ||||
|  | ||||
| // HijackedConn is the result of hijacking a connection. | ||||
| // | ||||
| // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning | ||||
| @@ -1919,7 +1781,6 @@ type HijackedConn struct { | ||||
| 	TxStatus          byte | ||||
| 	Frontend          *pgproto3.Frontend | ||||
| 	Config            *Config | ||||
| 	CustomData        map[string]any | ||||
| } | ||||
|  | ||||
| // Hijack extracts the internal connection data. pgConn must be in an idle state. SyncConn should be called immediately | ||||
| @@ -1942,7 +1803,6 @@ func (pgConn *PgConn) Hijack() (*HijackedConn, error) { | ||||
| 		TxStatus:          pgConn.txStatus, | ||||
| 		Frontend:          pgConn.frontend, | ||||
| 		Config:            pgConn.config, | ||||
| 		CustomData:        pgConn.customData, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| @@ -1962,23 +1822,16 @@ func Construct(hc *HijackedConn) (*PgConn, error) { | ||||
| 		txStatus:          hc.TxStatus, | ||||
| 		frontend:          hc.Frontend, | ||||
| 		config:            hc.Config, | ||||
| 		customData:        hc.CustomData, | ||||
|  | ||||
| 		status: connStatusIdle, | ||||
|  | ||||
| 		cleanupDone: make(chan struct{}), | ||||
| 	} | ||||
|  | ||||
| 	pgConn.contextWatcher = ctxwatch.NewContextWatcher(hc.Config.BuildContextWatcherHandler(pgConn)) | ||||
| 	pgConn.contextWatcher = newContextWatcher(pgConn.conn) | ||||
| 	pgConn.bgReader = bgreader.New(pgConn.conn) | ||||
| 	pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), | ||||
| 		func() { | ||||
| 			pgConn.bgReader.Start() | ||||
| 			pgConn.bgReaderStarted <- struct{}{} | ||||
| 		}, | ||||
| 	) | ||||
| 	pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) | ||||
| 	pgConn.slowWriteTimer.Stop() | ||||
| 	pgConn.bgReaderStarted = make(chan struct{}) | ||||
| 	pgConn.frontend = hc.Config.BuildFrontend(pgConn.bgReader, pgConn.conn) | ||||
|  | ||||
| 	return pgConn, nil | ||||
| @@ -2120,13 +1973,6 @@ func (p *Pipeline) Flush() error { | ||||
|  | ||||
| // Sync establishes a synchronization point and flushes the queued requests. | ||||
| func (p *Pipeline) Sync() error { | ||||
| 	if p.closed { | ||||
| 		if p.err != nil { | ||||
| 			return p.err | ||||
| 		} | ||||
| 		return errors.New("pipeline closed") | ||||
| 	} | ||||
|  | ||||
| 	p.conn.frontend.SendSync(&pgproto3.Sync{}) | ||||
| 	err := p.Flush() | ||||
| 	if err != nil { | ||||
| @@ -2143,28 +1989,14 @@ func (p *Pipeline) Sync() error { | ||||
| // *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. If no | ||||
| // results are available, results and err will both be nil. | ||||
| func (p *Pipeline) GetResults() (results any, err error) { | ||||
| 	if p.closed { | ||||
| 		if p.err != nil { | ||||
| 			return nil, p.err | ||||
| 		} | ||||
| 		return nil, errors.New("pipeline closed") | ||||
| 	} | ||||
|  | ||||
| 	if p.expectedReadyForQueryCount == 0 { | ||||
| 		return nil, nil | ||||
| 	} | ||||
|  | ||||
| 	return p.getResults() | ||||
| } | ||||
|  | ||||
| func (p *Pipeline) getResults() (results any, err error) { | ||||
| 	for { | ||||
| 		msg, err := p.conn.receiveMessage() | ||||
| 		if err != nil { | ||||
| 			p.closed = true | ||||
| 			p.err = err | ||||
| 			p.conn.asyncClose() | ||||
| 			return nil, normalizeTimeoutError(p.ctx, err) | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		switch msg := msg.(type) { | ||||
| @@ -2186,8 +2018,7 @@ func (p *Pipeline) getResults() (results any, err error) { | ||||
| 		case *pgproto3.ParseComplete: | ||||
| 			peekedMsg, err := p.conn.peekMessage() | ||||
| 			if err != nil { | ||||
| 				p.conn.asyncClose() | ||||
| 				return nil, normalizeTimeoutError(p.ctx, err) | ||||
| 				return nil, err | ||||
| 			} | ||||
| 			if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok { | ||||
| 				return p.getResultsPrepare() | ||||
| @@ -2247,7 +2078,6 @@ func (p *Pipeline) Close() error { | ||||
| 	if p.closed { | ||||
| 		return p.err | ||||
| 	} | ||||
|  | ||||
| 	p.closed = true | ||||
|  | ||||
| 	if p.pendingSync { | ||||
| @@ -2260,7 +2090,7 @@ func (p *Pipeline) Close() error { | ||||
| 	} | ||||
|  | ||||
| 	for p.expectedReadyForQueryCount > 0 { | ||||
| 		_, err := p.getResults() | ||||
| 		_, err := p.GetResults() | ||||
| 		if err != nil { | ||||
| 			p.err = err | ||||
| 			var pgErr *PgError | ||||
| @@ -2276,71 +2106,3 @@ func (p *Pipeline) Close() error { | ||||
|  | ||||
| 	return p.err | ||||
| } | ||||
|  | ||||
| // DeadlineContextWatcherHandler handles canceled contexts by setting a deadline on a net.Conn. | ||||
| type DeadlineContextWatcherHandler struct { | ||||
| 	Conn net.Conn | ||||
|  | ||||
| 	// DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled. | ||||
| 	DeadlineDelay time.Duration | ||||
| } | ||||
|  | ||||
| func (h *DeadlineContextWatcherHandler) HandleCancel(ctx context.Context) { | ||||
| 	h.Conn.SetDeadline(time.Now().Add(h.DeadlineDelay)) | ||||
| } | ||||
|  | ||||
| func (h *DeadlineContextWatcherHandler) HandleUnwatchAfterCancel() { | ||||
| 	h.Conn.SetDeadline(time.Time{}) | ||||
| } | ||||
|  | ||||
| // CancelRequestContextWatcherHandler handles canceled contexts by sending a cancel request to the server. It also sets | ||||
| // a deadline on a net.Conn as a fallback. | ||||
| type CancelRequestContextWatcherHandler struct { | ||||
| 	Conn *PgConn | ||||
|  | ||||
| 	// CancelRequestDelay is the delay before sending the cancel request to the server. | ||||
| 	CancelRequestDelay time.Duration | ||||
|  | ||||
| 	// DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled. | ||||
| 	DeadlineDelay time.Duration | ||||
|  | ||||
| 	cancelFinishedChan             chan struct{} | ||||
| 	handleUnwatchAfterCancelCalled func() | ||||
| } | ||||
|  | ||||
| func (h *CancelRequestContextWatcherHandler) HandleCancel(context.Context) { | ||||
| 	h.cancelFinishedChan = make(chan struct{}) | ||||
| 	var handleUnwatchedAfterCancelCalledCtx context.Context | ||||
| 	handleUnwatchedAfterCancelCalledCtx, h.handleUnwatchAfterCancelCalled = context.WithCancel(context.Background()) | ||||
|  | ||||
| 	deadline := time.Now().Add(h.DeadlineDelay) | ||||
| 	h.Conn.conn.SetDeadline(deadline) | ||||
|  | ||||
| 	go func() { | ||||
| 		defer close(h.cancelFinishedChan) | ||||
|  | ||||
| 		select { | ||||
| 		case <-handleUnwatchedAfterCancelCalledCtx.Done(): | ||||
| 			return | ||||
| 		case <-time.After(h.CancelRequestDelay): | ||||
| 		} | ||||
|  | ||||
| 		cancelRequestCtx, cancel := context.WithDeadline(handleUnwatchedAfterCancelCalledCtx, deadline) | ||||
| 		defer cancel() | ||||
| 		h.Conn.CancelRequest(cancelRequestCtx) | ||||
|  | ||||
| 		// CancelRequest is inherently racy. Even though the cancel request has been received by the server at this point, | ||||
| 		// it hasn't necessarily been delivered to the other connection. If we immediately return and the connection is | ||||
| 		// immediately used then it is possible the CancelRequest will actually cancel our next query. The | ||||
| 		// TestCancelRequestContextWatcherHandler Stress test can produce this error without the sleep below. The sleep time | ||||
| 		// is arbitrary, but should be sufficient to prevent this error case. | ||||
| 		time.Sleep(100 * time.Millisecond) | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| func (h *CancelRequestContextWatcherHandler) HandleUnwatchAfterCancel() { | ||||
| 	h.handleUnwatchAfterCancelCalled() | ||||
| 	<-h.cancelFinishedChan | ||||
|  | ||||
| 	h.Conn.conn.SetDeadline(time.Time{}) | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user