增加websocket支持

This commit is contained in:
2025-09-08 13:47:13 +08:00
parent 9caedd697d
commit 7e0fd53dd3
336 changed files with 9020 additions and 20356 deletions

View File

@@ -19,7 +19,6 @@ import (
"github.com/jackc/pgpassfile"
"github.com/jackc/pgservicefile"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3"
)
@@ -40,12 +39,7 @@ type Config struct {
DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc
// BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called
// when a context passed to a PgConn method is canceled.
BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
KerberosSrvName string
KerberosSpn string
@@ -66,17 +60,12 @@ type Config struct {
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler
// OnPgError is a callback function called when a Postgres error is received by the server. The default handler will close
// the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure
// that you close on FATAL errors by returning false.
OnPgError PgErrorHandler
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
// ParseConfigOptions contains options that control how a config is built such as GetSSLPassword.
type ParseConfigOptions struct {
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the libpq function
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function
// PQsetSSLKeyPassHook_OpenSSL.
GetSSLPassword GetSSLPasswordFunc
}
@@ -118,14 +107,6 @@ type FallbackConfig struct {
TLSConfig *tls.Config // nil disables TLS
}
// connectOneConfig is the configuration for a single attempt to connect to a single host.
type connectOneConfig struct {
network string
address string
originalHostname string // original hostname before resolving
tlsConfig *tls.Config // nil disables TLS
}
// isAbsolutePath checks if the provided value is an absolute path either
// beginning with a forward slash (as on Linux-based systems) or with a capital
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
@@ -160,11 +141,11 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format. See
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be empty
// to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style).
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
//
// # Example Keyword/Value
// # Example DSN
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
//
// # Example URL
@@ -183,7 +164,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
//
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
// via database URL or keyword/value:
// via database URL or DSN:
//
// PGHOST
// PGPORT
@@ -247,16 +228,16 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
connStringSettings := make(map[string]string)
if connString != "" {
var err error
// connString may be a database URL or in PostgreSQL keyword/value format
// connString may be a database URL or a DSN
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
}
} else {
connStringSettings, err = parseKeywordValueSettings(connString)
connStringSettings, err = parseDSNSettings(connString)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as keyword/value", err: err}
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
}
}
}
@@ -265,7 +246,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
if service, present := settings["service"]; present {
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err}
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
}
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
@@ -280,22 +261,12 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
return pgproto3.NewFrontend(r, w)
},
BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler {
return &DeadlineContextWatcherHandler{Conn: pgConn.conn}
},
OnPgError: func(_ *PgConn, pgErr *PgError) bool {
// we want to automatically close any fatal errors
if strings.EqualFold(pgErr.Severity, "FATAL") {
return false
}
return true
},
}
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err}
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
}
config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
@@ -357,7 +328,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
port, err := parsePort(portStr)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err}
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
}
var tlsConfigs []*tls.Config
@@ -369,7 +340,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
var err error
tlsConfigs, err = configTLS(settings, host, options)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err}
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
}
}
@@ -413,7 +384,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
case "any":
// do nothing
default:
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
}
return config, nil
@@ -467,17 +438,14 @@ func parseEnvSettings() map[string]string {
func parseURLSettings(connString string) (map[string]string, error) {
settings := make(map[string]string)
parsedURL, err := url.Parse(connString)
url, err := url.Parse(connString)
if err != nil {
if urlErr := new(url.Error); errors.As(err, &urlErr) {
return nil, urlErr.Err
}
return nil, err
}
if parsedURL.User != nil {
settings["user"] = parsedURL.User.Username()
if password, present := parsedURL.User.Password(); present {
if url.User != nil {
settings["user"] = url.User.Username()
if password, present := url.User.Password(); present {
settings["password"] = password
}
}
@@ -485,7 +453,7 @@ func parseURLSettings(connString string) (map[string]string, error) {
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
var hosts []string
var ports []string
for _, host := range strings.Split(parsedURL.Host, ",") {
for _, host := range strings.Split(url.Host, ",") {
if host == "" {
continue
}
@@ -511,7 +479,7 @@ func parseURLSettings(connString string) (map[string]string, error) {
settings["port"] = strings.Join(ports, ",")
}
database := strings.TrimLeft(parsedURL.Path, "/")
database := strings.TrimLeft(url.Path, "/")
if database != "" {
settings["database"] = database
}
@@ -520,7 +488,7 @@ func parseURLSettings(connString string) (map[string]string, error) {
"dbname": "database",
}
for k, v := range parsedURL.Query() {
for k, v := range url.Query() {
if k2, present := nameMap[k]; present {
k = k2
}
@@ -537,7 +505,7 @@ func isIPOnly(host string) bool {
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
func parseKeywordValueSettings(s string) (map[string]string, error) {
func parseDSNSettings(s string) (map[string]string, error) {
settings := make(map[string]string)
nameMap := map[string]string{
@@ -548,7 +516,7 @@ func parseKeywordValueSettings(s string) (map[string]string, error) {
var key, val string
eqIdx := strings.IndexRune(s, '=')
if eqIdx < 0 {
return nil, errors.New("invalid keyword/value")
return nil, errors.New("invalid dsn")
}
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
@@ -600,7 +568,7 @@ func parseKeywordValueSettings(s string) (map[string]string, error) {
}
if key == "" {
return nil, errors.New("invalid keyword/value")
return nil, errors.New("invalid dsn")
}
settings[key] = val
@@ -657,36 +625,6 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
tlsConfig := &tls.Config{}
if sslrootcert != "" {
var caCertPool *x509.CertPool
if sslrootcert == "system" {
var err error
caCertPool, err = x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("unable to load system certificate pool: %w", err)
}
sslmode = "verify-full"
} else {
caCertPool = x509.NewCertPool()
caPath := sslrootcert
caCert, err := os.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("unable to read CA file: %w", err)
}
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, errors.New("unable to add CA to cert pool")
}
}
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
}
switch sslmode {
case "disable":
return []*tls.Config{nil}, nil
@@ -744,6 +682,23 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
return nil, errors.New("sslmode is invalid")
}
if sslrootcert != "" {
caCertPool := x509.NewCertPool()
caPath := sslrootcert
caCert, err := os.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("unable to read CA file: %w", err)
}
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, errors.New("unable to add CA to cert pool")
}
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
}
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
}
@@ -754,9 +709,6 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
return nil, fmt.Errorf("unable to read sslkey: %w", err)
}
block, _ := pem.Decode(buf)
if block == nil {
return nil, errors.New("failed to decode sslkey")
}
var pemKey []byte
var decryptedKey []byte
var decryptedError error
@@ -833,8 +785,7 @@ func parsePort(s string) (uint16, error) {
}
func makeDefaultDialer() *net.Dialer {
// rely on GOLANG KeepAlive settings
return &net.Dialer{}
return &net.Dialer{KeepAlive: 5 * time.Minute}
}
func makeDefaultResolver() *net.Resolver {
@@ -858,7 +809,7 @@ func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
return d.DialContext
}
// ValidateConnectTargetSessionAttrsReadWrite is a ValidateConnectFunc that implements libpq compatible
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-write.
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
@@ -873,7 +824,7 @@ func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgC
return nil
}
// ValidateConnectTargetSessionAttrsReadOnly is a ValidateConnectFunc that implements libpq compatible
// ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-only.
func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
@@ -888,7 +839,7 @@ func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgCo
return nil
}
// ValidateConnectTargetSessionAttrsStandby is a ValidateConnectFunc that implements libpq compatible
// ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=standby.
func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
@@ -903,7 +854,7 @@ func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgCon
return nil
}
// ValidateConnectTargetSessionAttrsPrimary is a ValidateConnectFunc that implements libpq compatible
// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=primary.
func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
@@ -918,7 +869,7 @@ func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgCon
return nil
}
// ValidateConnectTargetSessionAttrsPreferStandby is a ValidateConnectFunc that implements libpq compatible
// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=prefer-standby.
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()