修正依赖问题

This commit is contained in:
2025-09-08 13:53:49 +08:00
parent 7e0fd53dd3
commit b63abe1d2d
164 changed files with 2155 additions and 1080 deletions

View File

@@ -1,3 +1,73 @@
# 5.5.5 (March 9, 2024)
Use spaces instead of parentheses for SQL sanitization.
This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as
`set foo to $1` where the substitution is taking place in a location where an arbitrary expression is not allowed.
# 5.5.4 (March 4, 2024)
Fix CVE-2024-27304
SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer
overflow in the calculated message size can cause the one large message to be sent as multiple messages under the
attacker's control.
Thanks to Paul Gerste for reporting this issue.
* Fix behavior of CollectRows to return empty slice if Rows are empty (Felix)
* Fix simple protocol encoding of json.RawMessage
* Fix *Pipeline.getResults should close pipeline on error
* Fix panic in TryFindUnderlyingTypeScanPlan (David Kurman)
* Fix deallocation of invalidated cached statements in a transaction
* Handle invalid sslkey file
* Fix scan float4 into sql.Scanner
* Fix pgtype.Bits not making copy of data from read buffer. This would cause the data to be corrupted by future reads.
# 5.5.3 (February 3, 2024)
* Fix: prepared statement already exists
* Improve CopyFrom auto-conversion of text-ish values
* Add ltree type support (Florent Viel)
* Make some properties of Batch and QueuedQuery public (Pavlo Golub)
* Add AppendRows function (Edoardo Spadolini)
* Optimize convert UUID [16]byte to string (Kirill Malikov)
* Fix: LargeObject Read and Write of more than ~1GB at a time (Mitar)
# 5.5.2 (January 13, 2024)
* Allow NamedArgs to start with underscore
* pgproto3: Maximum message body length support (jeremy.spriet)
* Upgrade golang.org/x/crypto to v0.17.0
* Add snake_case support to RowToStructByName (Tikhon Fedulov)
* Fix: update description cache after exec prepare (James Hartig)
* Fix: pipeline checks if it is closed (James Hartig and Ryan Fowler)
* Fix: normalize timeout / context errors during TLS startup (Samuel Stauffer)
* Add OnPgError for easier centralized error handling (James Hartig)
# 5.5.1 (December 9, 2023)
* Add CopyFromFunc helper function. (robford)
* Add PgConn.Deallocate method that uses PostgreSQL protocol Close message.
* pgx uses new PgConn.Deallocate method. This allows deallocating statements to work in a failed transaction. This fixes a case where the prepared statement map could become invalid.
* Fix: Prefer driver.Valuer over json.Marshaler for json fields. (Jacopo)
* Fix: simple protocol SQL sanitizer previously panicked if an invalid $0 placeholder was used. This now returns an error instead. (maksymnevajdev)
* Add pgtype.Numeric.ScanScientific (Eshton Robateau)
# 5.5.0 (November 4, 2023)
* Add CollectExactlyOneRow. (Julien GOTTELAND)
* Add OpenDBFromPool to create *database/sql.DB from *pgxpool.Pool. (Lev Zakharov)
* Prepare can automatically choose statement name based on sql. This makes it easier to explicitly manage prepared statements.
* Statement cache now uses deterministic, stable statement names.
* database/sql prepared statement names are deterministically generated.
* Fix: SendBatch wasn't respecting context cancellation.
* Fix: Timeout error from pipeline is now normalized.
* Fix: database/sql encoding json.RawMessage to []byte.
* CancelRequest: Wait for the cancel request to be acknowledged by the server. This should improve PgBouncer compatibility. (Anton Levakin)
* stdlib: Use Ping instead of CheckConn in ResetSession
* Add json.Marshaler and json.Unmarshaler for Float4, Float8 (Kirill Mironov)
# 5.4.3 (August 5, 2023)
* Fix: QCharArrayOID was defined with the wrong OID (Christoph Engelbert)

View File

@@ -79,20 +79,11 @@ echo "listen_addresses = '127.0.0.1'" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql
echo "port = $PGPORT" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
cat testsetup/postgresql_ssl.conf >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
cp testsetup/pg_hba.conf .testdb/$POSTGRESQL_DATA_DIR/pg_hba.conf
cp testsetup/ca.cnf .testdb
cp testsetup/localhost.cnf .testdb
cp testsetup/pgx_sslcert.cnf .testdb
cd .testdb
# Generate a CA public / private key pair.
openssl genrsa -out ca.key 4096
openssl req -x509 -config ca.cnf -new -nodes -key ca.key -sha256 -days 365 -subj '/O=pgx-test-root' -out ca.pem
# Generate the certificate for localhost (the server).
openssl genrsa -out localhost.key 2048
openssl req -new -config localhost.cnf -key localhost.key -out localhost.csr
openssl x509 -req -in localhost.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out localhost.crt -days 364 -sha256 -extfile localhost.cnf -extensions v3_req
# Generate CA, server, and encrypted client certificates.
go run ../testsetup/generate_certs.go
# Copy certificates to server directory and set permissions.
cp ca.pem $POSTGRESQL_DATA_DIR/root.crt
@@ -100,11 +91,6 @@ cp localhost.key $POSTGRESQL_DATA_DIR/server.key
chmod 600 $POSTGRESQL_DATA_DIR/server.key
cp localhost.crt $POSTGRESQL_DATA_DIR/server.crt
# Generate the certificate for client authentication.
openssl genrsa -des3 -out pgx_sslcert.key -passout pass:certpw 2048
openssl req -new -config pgx_sslcert.cnf -key pgx_sslcert.key -passin pass:certpw -out pgx_sslcert.csr
openssl x509 -req -in pgx_sslcert.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out pgx_sslcert.crt -days 363 -sha256 -extfile pgx_sslcert.cnf -extensions v3_req
cd ..
```

View File

@@ -86,9 +86,13 @@ It is also possible to use the `database/sql` interface and convert a connection
See CONTRIBUTING.md for setup instructions.
## Architecture
See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.com/watch?v=sXMSWhcHCf8) for a description of pgx architecture.
## Supported Go and PostgreSQL Versions
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.19 and higher and PostgreSQL 11 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.20 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
## Version Policy
@@ -116,6 +120,7 @@ pgerrcode contains constants for the PostgreSQL error codes.
* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid)
* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal)
* [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos))
* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)

View File

@@ -10,8 +10,8 @@ import (
// QueuedQuery is a query that has been queued for execution via a Batch.
type QueuedQuery struct {
query string
arguments []any
SQL string
Arguments []any
fn batchItemFunc
sd *pgconn.StatementDescription
}
@@ -57,22 +57,24 @@ func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
// Batch queries are a way of bundling multiple queries together to avoid
// unnecessary network round trips. A Batch must only be sent once.
type Batch struct {
queuedQueries []*QueuedQuery
QueuedQueries []*QueuedQuery
}
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement.
// The only pgx option argument that is supported is QueryRewriter. Queries are executed using the
// connection's DefaultQueryExecMode.
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
qq := &QueuedQuery{
query: query,
arguments: arguments,
SQL: query,
Arguments: arguments,
}
b.queuedQueries = append(b.queuedQueries, qq)
b.QueuedQueries = append(b.QueuedQueries, qq)
return qq
}
// Len returns number of queries that have been queued so far.
func (b *Batch) Len() int {
return len(b.queuedQueries)
return len(b.QueuedQueries)
}
type BatchResults interface {
@@ -225,9 +227,9 @@ func (br *batchResults) Close() error {
}
// Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
if br.b.queuedQueries[br.qqIdx].fn != nil {
err := br.b.queuedQueries[br.qqIdx].fn(br)
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.QueuedQueries[br.qqIdx].fn != nil {
err := br.b.QueuedQueries[br.qqIdx].fn(br)
if err != nil {
br.err = err
}
@@ -251,10 +253,10 @@ func (br *batchResults) earlyError() error {
}
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
bi := br.b.queuedQueries[br.qqIdx]
query = bi.query
args = bi.arguments
if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
bi := br.b.QueuedQueries[br.qqIdx]
query = bi.SQL
args = bi.Arguments
ok = true
br.qqIdx++
}
@@ -394,9 +396,9 @@ func (br *pipelineBatchResults) Close() error {
}
// Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
if br.b.queuedQueries[br.qqIdx].fn != nil {
err := br.b.queuedQueries[br.qqIdx].fn(br)
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.QueuedQueries[br.qqIdx].fn != nil {
err := br.b.QueuedQueries[br.qqIdx].fn(br)
if err != nil {
br.err = err
}
@@ -420,10 +422,10 @@ func (br *pipelineBatchResults) earlyError() error {
}
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
bi := br.b.queuedQueries[br.qqIdx]
query = bi.query
args = bi.arguments
if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
bi := br.b.QueuedQueries[br.qqIdx]
query = bi.SQL
args = bi.Arguments
ok = true
br.qqIdx++
}

View File

@@ -2,6 +2,8 @@ package pgx
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"strconv"
@@ -35,7 +37,7 @@ type ConnConfig struct {
// DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol
// and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as
// PGBouncer. In this case it may be preferrable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same
// PGBouncer. In this case it may be preferable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same
// functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument.
DefaultQueryExecMode QueryExecMode
@@ -99,8 +101,12 @@ func (ident Identifier) Sanitize() string {
return strings.Join(parts, ".")
}
// ErrNoRows occurs when rows are expected but none are returned.
var ErrNoRows = errors.New("no rows in result set")
var (
// ErrNoRows occurs when rows are expected but none are returned.
ErrNoRows = errors.New("no rows in result set")
// ErrTooManyRows occurs when more rows than expected are returned.
ErrTooManyRows = errors.New("too many rows in result set")
)
var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
@@ -269,7 +275,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
return c, nil
}
// Close closes a connection. It is safe to call Close on a already closed
// Close closes a connection. It is safe to call Close on an already closed
// connection.
func (c *Conn) Close(ctx context.Context) error {
if c.IsClosed() {
@@ -280,12 +286,15 @@ func (c *Conn) Close(ctx context.Context) error {
return err
}
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
// Prepare creates a prepared statement with name and sql. sql can contain placeholders for bound parameters. These
// placeholders are referenced positionally as $1, $2, etc. name can be used instead of sql with Query, QueryRow, and
// Exec to execute the statement. It can also be used with Batch.Queue.
//
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same
// name and sql arguments. This allows a code path to Prepare and Query/Exec without
// concern for if the statement has already been prepared.
// The underlying PostgreSQL identifier for the prepared statement will be name if name != sql or a digest of sql if
// name == sql.
//
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same name and sql arguments. This
// allows a code path to Prepare and Query/Exec without concern for if the statement has already been prepared.
func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) {
if c.prepareTracer != nil {
ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql})
@@ -307,23 +316,48 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.Statem
}()
}
sd, err = c.pgConn.Prepare(ctx, name, sql, nil)
var psName, psKey string
if name == sql {
digest := sha256.Sum256([]byte(sql))
psName = "stmt_" + hex.EncodeToString(digest[0:24])
psKey = sql
} else {
psName = name
psKey = name
}
sd, err = c.pgConn.Prepare(ctx, psName, sql, nil)
if err != nil {
return nil, err
}
if name != "" {
c.preparedStatements[name] = sd
if psKey != "" {
c.preparedStatements[psKey] = sd
}
return sd, nil
}
// Deallocate released a prepared statement
// Deallocate releases a prepared statement. Calling Deallocate on a non-existent prepared statement will succeed.
func (c *Conn) Deallocate(ctx context.Context, name string) error {
delete(c.preparedStatements, name)
_, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll()
return err
var psName string
sd := c.preparedStatements[name]
if sd != nil {
psName = sd.Name
} else {
psName = name
}
err := c.pgConn.Deallocate(ctx, psName)
if err != nil {
return err
}
if sd != nil {
delete(c.preparedStatements, name)
}
return nil
}
// DeallocateAll releases all previously prepared statements from the server and client, where it also resets the statement and description cache.
@@ -441,7 +475,7 @@ optionLoop:
if queryRewriter != nil {
sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
if err != nil {
return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %v", err)
return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %w", err)
}
}
@@ -461,7 +495,7 @@ optionLoop:
}
sd := c.statementCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql)
if err != nil {
return pgconn.CommandTag{}, err
}
@@ -479,6 +513,7 @@ optionLoop:
if err != nil {
return pgconn.CommandTag{}, err
}
c.descriptionCache.Put(sd)
}
return c.execParams(ctx, sd, arguments)
@@ -573,13 +608,16 @@ type QueryExecMode int32
const (
_ QueryExecMode = iota
// Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single
// round trip after the statement is cached. This is the default.
// Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single round
// trip after the statement is cached. This is the default. If the database schema is modified or the search_path is
// changed after a statement is cached then the first execution of a previously cached query may fail. e.g. If the
// number of columns returned by a "SELECT *" changes or the type of a column is changed.
QueryExecModeCacheStatement
// Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the
// extended protocol. Queries are executed in a single round trip after the description is cached. If the database
// schema is modified or the search_path is changed this may result in undetected result decoding errors.
// Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the extended
// protocol. Queries are executed in a single round trip after the description is cached. If the database schema is
// modified or the search_path is changed after a statement is cached then the first execution of a previously cached
// query may fail. e.g. If the number of columns returned by a "SELECT *" changes or the type of a column is changed.
QueryExecModeCacheDescribe
// Get the statement description on every execution. This uses the extended protocol. Queries require two round trips
@@ -592,13 +630,13 @@ const (
// Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol
// with text formatted parameters and results. Queries are executed in a single round trip. Type mappings can be
// registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are
// unregistered or ambigious. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know
// unregistered or ambiguous. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know
// the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot.
QueryExecModeExec
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments.
// Queries are executed in a single round trip. Type mappings can be registered with
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious.
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambiguous.
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use
// a map[string]string directly as an argument. This mode cannot.
//
@@ -705,7 +743,7 @@ optionLoop:
sql, args, err = queryRewriter.RewriteQuery(ctx, c, sql, args)
if err != nil {
rows := c.getRows(ctx, originalSQL, originalArgs)
err = fmt.Errorf("rewrite query failed: %v", err)
err = fmt.Errorf("rewrite query failed: %w", err)
rows.fatal(err)
return rows, err
}
@@ -815,7 +853,7 @@ func (c *Conn) getStatementDescription(
}
sd = c.statementCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql)
if err != nil {
return nil, err
}
@@ -865,15 +903,14 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
return &batchResults{ctx: ctx, conn: c, err: err}
}
mode := c.config.DefaultQueryExecMode
for _, bi := range b.queuedQueries {
for _, bi := range b.QueuedQueries {
var queryRewriter QueryRewriter
sql := bi.query
arguments := bi.arguments
sql := bi.SQL
arguments := bi.Arguments
optionLoop:
for len(arguments) > 0 {
// Update Batch.Queue function comment when additional options are implemented
switch arg := arguments[0].(type) {
case QueryRewriter:
queryRewriter = arg
@@ -887,21 +924,23 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
var err error
sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("rewrite query failed: %v", err)}
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("rewrite query failed: %w", err)}
}
}
bi.query = sql
bi.arguments = arguments
bi.SQL = sql
bi.Arguments = arguments
}
// TODO: changing mode per batch? Update Batch.Queue function comment when implemented
mode := c.config.DefaultQueryExecMode
if mode == QueryExecModeSimpleProtocol {
return c.sendBatchQueryExecModeSimpleProtocol(ctx, b)
}
// All other modes use extended protocol and thus can use prepared statements.
for _, bi := range b.queuedQueries {
if sd, ok := c.preparedStatements[bi.query]; ok {
for _, bi := range b.QueuedQueries {
if sd, ok := c.preparedStatements[bi.SQL]; ok {
bi.sd = sd
}
}
@@ -922,11 +961,11 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
var sb strings.Builder
for i, bi := range b.queuedQueries {
for i, bi := range b.QueuedQueries {
if i > 0 {
sb.WriteByte(';')
}
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
sql, err := c.sanitizeForSimpleQuery(bi.SQL, bi.Arguments...)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
@@ -945,21 +984,21 @@ func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batc
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
batch := &pgconn.Batch{}
for _, bi := range b.queuedQueries {
for _, bi := range b.QueuedQueries {
sd := bi.sd
if sd != nil {
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
err := c.eqb.Build(c.typeMap, sd, bi.Arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else {
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
err := c.eqb.Build(c.typeMap, nil, bi.Arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
batch.ExecParams(bi.SQL, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
}
}
@@ -984,18 +1023,18 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.queuedQueries {
for _, bi := range b.QueuedQueries {
if bi.sd == nil {
sd := c.statementCache.Get(bi.query)
sd := c.statementCache.Get(bi.SQL)
if sd != nil {
bi.sd = sd
} else {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd = &pgconn.StatementDescription{
Name: stmtcache.NextStatementName(),
SQL: bi.query,
Name: stmtcache.StatementName(bi.SQL),
SQL: bi.SQL,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
@@ -1016,17 +1055,17 @@ func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.queuedQueries {
for _, bi := range b.QueuedQueries {
if bi.sd == nil {
sd := c.descriptionCache.Get(bi.query)
sd := c.descriptionCache.Get(bi.SQL)
if sd != nil {
bi.sd = sd
} else {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd = &pgconn.StatementDescription{
SQL: bi.query,
SQL: bi.SQL,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
@@ -1043,13 +1082,13 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch)
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.queuedQueries {
for _, bi := range b.QueuedQueries {
if bi.sd == nil {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd := &pgconn.StatementDescription{
SQL: bi.query,
SQL: bi.SQL,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
@@ -1062,7 +1101,7 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch)
}
func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) {
pipeline := c.pgConn.StartPipeline(context.Background())
pipeline := c.pgConn.StartPipeline(ctx)
defer func() {
if pbr != nil && pbr.err != nil {
pipeline.Close()
@@ -1115,11 +1154,11 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
}
// Queue the queries.
for _, bi := range b.queuedQueries {
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
for _, bi := range b.QueuedQueries {
err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments)
if err != nil {
// we wrap the error so we the user can understand which query failed inside the batch
err = fmt.Errorf("error building query %s: %w", bi.query, err)
err = fmt.Errorf("error building query %s: %w", bi.SQL, err)
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
}
@@ -1164,7 +1203,15 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
return sanitize.SanitizeSQL(sql, valueArgs...)
}
// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration.
// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. typeName must be
// the name of a type where the underlying type(s) is already understood by pgx. It is for derived types. In particular,
// typeName must be one of the following:
// - An array type name of a type that is already registered. e.g. "_foo" when "foo" is registered.
// - A composite type name where all field types are already registered.
// - A domain type name where the base type is already registered.
// - An enum type name.
// - A range type name where the element type is already registered.
// - A multirange type name where the element type is already registered.
func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) {
var oid uint32
@@ -1307,17 +1354,17 @@ order by attnum`,
}
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
if c.pgConn.TxStatus() != 'I' {
if txStatus := c.pgConn.TxStatus(); txStatus != 'I' && txStatus != 'T' {
return nil
}
if c.descriptionCache != nil {
c.descriptionCache.HandleInvalidated()
c.descriptionCache.RemoveInvalidated()
}
var invalidatedStatements []*pgconn.StatementDescription
if c.statementCache != nil {
invalidatedStatements = c.statementCache.HandleInvalidated()
invalidatedStatements = c.statementCache.GetInvalidated()
}
if len(invalidatedStatements) == 0 {
@@ -1329,7 +1376,6 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
for _, sd := range invalidatedStatements {
pipeline.SendDeallocate(sd.Name)
delete(c.preparedStatements, sd.Name)
}
err := pipeline.Sync()
@@ -1342,5 +1388,10 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
}
c.statementCache.RemoveInvalidated()
for _, sd := range invalidatedStatements {
delete(c.preparedStatements, sd.Name)
}
return nil
}

View File

@@ -64,6 +64,33 @@ func (cts *copyFromSlice) Err() error {
return cts.err
}
// CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values.
// nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil,
// or it returns an error. If nxtf returns an error, the copy is aborted.
func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource {
return &copyFromFunc{next: nxtf}
}
type copyFromFunc struct {
next func() ([]any, error)
valueRow []any
err error
}
func (g *copyFromFunc) Next() bool {
g.valueRow, g.err = g.next()
// only return true if valueRow exists and no error
return g.valueRow != nil && g.err == nil
}
func (g *copyFromFunc) Values() ([]any, error) {
return g.valueRow, g.err
}
func (g *copyFromFunc) Err() error {
return g.err
}
// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
type CopyFromSource interface {
// Next returns true if there is another row and makes the next row data

View File

@@ -187,7 +187,7 @@ implemented on top of pgconn. The Conn.PgConn() method can be used to access thi
PgBouncer
By default pgx automatically uses prepared statements. Prepared statements are incompaptible with PgBouncer. This can be
By default pgx automatically uses prepared statements. Prepared statements are incompatible with PgBouncer. This can be
disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode.
*/
package pgx

View File

@@ -36,7 +36,7 @@ func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescri
for i := range args {
err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i])
if err != nil {
err = fmt.Errorf("failed to encode args[%d]: %v", i, err)
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
return err
}
}

View File

@@ -35,6 +35,11 @@ func (q *Query) Sanitize(args ...any) (string, error) {
str = part
case int:
argIdx := part - 1
if argIdx < 0 {
return "", fmt.Errorf("first sql argument must be > 0")
}
if argIdx >= len(args) {
return "", fmt.Errorf("insufficient arguments")
}
@@ -58,6 +63,10 @@ func (q *Query) Sanitize(args ...any) (string, error) {
return "", fmt.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
str = " " + str + " "
default:
return "", fmt.Errorf("invalid Part type: %T", part)
}

View File

@@ -34,7 +34,8 @@ func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
}
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or
// sd.SQL has been invalidated and HandleInvalidated has not been called yet.
func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
if sd.SQL == "" {
panic("cannot store statement description with empty SQL")
@@ -44,6 +45,13 @@ func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
return
}
// The statement may have been invalidated but not yet handled. Do not readd it to the cache.
for _, invalidSD := range c.invalidStmts {
if invalidSD.SQL == sd.SQL {
return
}
}
if c.l.Len() == c.cap {
c.invalidateOldest()
}
@@ -73,10 +81,16 @@ func (c *LRUCache) InvalidateAll() {
c.l = list.New()
}
func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription {
invalidStmts := c.invalidStmts
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *LRUCache) RemoveInvalidated() {
c.invalidStmts = nil
return invalidStmts
}
// Len returns the number of cached prepared statement descriptions.

View File

@@ -2,18 +2,17 @@
package stmtcache
import (
"strconv"
"sync/atomic"
"crypto/sha256"
"encoding/hex"
"github.com/jackc/pgx/v5/pgconn"
)
var stmtCounter int64
// NextStatementName returns a statement name that will be unique for the lifetime of the program.
func NextStatementName() string {
n := atomic.AddInt64(&stmtCounter, 1)
return "stmtcache_" + strconv.FormatInt(n, 10)
// StatementName returns a statement name that will be stable for sql across multiple connections and program
// executions.
func StatementName(sql string) string {
digest := sha256.Sum256([]byte(sql))
return "stmtcache_" + hex.EncodeToString(digest[0:24])
}
// Cache caches statement descriptions.
@@ -30,8 +29,13 @@ type Cache interface {
// InvalidateAll invalidates all statement descriptions.
InvalidateAll()
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
HandleInvalidated() []*pgconn.StatementDescription
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
GetInvalidated() []*pgconn.StatementDescription
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
RemoveInvalidated()
// Len returns the number of cached prepared statement descriptions.
Len() int
@@ -39,19 +43,3 @@ type Cache interface {
// Cap returns the maximum number of cached prepared statement descriptions.
Cap() int
}
func IsStatementInvalid(err error) bool {
pgErr, ok := err.(*pgconn.PgError)
if !ok {
return false
}
// https://github.com/jackc/pgx/issues/1162
//
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
// have so it should be safe.
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
return possibleInvalidCachedPlanError
}

View File

@@ -54,10 +54,16 @@ func (c *UnlimitedCache) InvalidateAll() {
c.m = make(map[string]*pgconn.StatementDescription)
}
func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription {
invalidStmts := c.invalidStmts
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *UnlimitedCache) RemoveInvalidated() {
c.invalidStmts = nil
return invalidStmts
}
// Len returns the number of cached prepared statement descriptions.

View File

@@ -6,6 +6,11 @@ import (
"io"
)
// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of
// PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data
// in the message,maxLargeObjectMessageLength should be no larger than 1 GB - 1 KB.
var maxLargeObjectMessageLength = 1024*1024*1024 - 1024
// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it
// was created.
//
@@ -68,32 +73,64 @@ type LargeObject struct {
// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written.
func (o *LargeObject) Write(p []byte) (int, error) {
var n int
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n)
if err != nil {
return n, err
nTotal := 0
for {
expected := len(p) - nTotal
if expected == 0 {
break
} else if expected > maxLargeObjectMessageLength {
expected = maxLargeObjectMessageLength
}
var n int
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p[nTotal:nTotal+expected]).Scan(&n)
if err != nil {
return nTotal, err
}
if n < 0 {
return nTotal, errors.New("failed to write to large object")
}
nTotal += n
if n < expected {
return nTotal, errors.New("short write to large object")
} else if n > expected {
return nTotal, errors.New("invalid write to large object")
}
}
if n < 0 {
return 0, errors.New("failed to write to large object")
}
return n, nil
return nTotal, nil
}
// Read reads up to len(p) bytes into p returning the number of bytes read.
func (o *LargeObject) Read(p []byte) (int, error) {
var res []byte
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res)
copy(p, res)
if err != nil {
return len(res), err
nTotal := 0
for {
expected := len(p) - nTotal
if expected == 0 {
break
} else if expected > maxLargeObjectMessageLength {
expected = maxLargeObjectMessageLength
}
var res []byte
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res)
copy(p[nTotal:], res)
nTotal += len(res)
if err != nil {
return nTotal, err
}
if len(res) < expected {
return nTotal, io.EOF
} else if len(res) > expected {
return nTotal, errors.New("invalid read of large object")
}
}
if len(res) < len(p) {
err = io.EOF
}
return len(res), err
return nTotal, nil
}
// Seek moves the current location pointer to the new location specified by offset.

View File

@@ -14,6 +14,9 @@ import (
//
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2)
//
// Named placeholders are case sensitive and must start with a letter or underscore. Subsequent characters can be
// letters, numbers, or underscores.
type NamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface.
@@ -80,7 +83,7 @@ func rawState(l *sqlLexer) stateFn {
return doubleQuoteState
case '@':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if isLetter(nextRune) {
if isLetter(nextRune) || nextRune == '_' {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}

View File

@@ -47,7 +47,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
return err
}
// Receive server-first-message payload in a AuthenticationSASLContinue.
// Receive server-first-message payload in an AuthenticationSASLContinue.
saslContinue, err := c.rxSASLContinue()
if err != nil {
return err
@@ -67,7 +67,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
return err
}
// Receive server-final-message payload in a AuthenticationSASLFinal.
// Receive server-final-message payload in an AuthenticationSASLFinal.
saslFinal, err := c.rxSASLFinal()
if err != nil {
return err

View File

@@ -60,6 +60,11 @@ type Config struct {
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler
// OnPgError is a callback function called when a Postgres error is received by the server. The default handler will close
// the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure
// that you close on FATAL errors by returning false.
OnPgError PgErrorHandler
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
@@ -232,12 +237,12 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
}
} else {
connStringSettings, err = parseDSNSettings(connString)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as DSN", err: err}
}
}
}
@@ -246,7 +251,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
if service, present := settings["service"]; present {
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err}
}
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
@@ -261,12 +266,19 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
return pgproto3.NewFrontend(r, w)
},
OnPgError: func(_ *PgConn, pgErr *PgError) bool {
// we want to automatically close any fatal errors
if strings.EqualFold(pgErr.Severity, "FATAL") {
return false
}
return true
},
}
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err}
}
config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
@@ -328,7 +340,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
port, err := parsePort(portStr)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err}
}
var tlsConfigs []*tls.Config
@@ -340,7 +352,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
var err error
tlsConfigs, err = configTLS(settings, host, options)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err}
}
}
@@ -384,7 +396,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
case "any":
// do nothing
default:
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
}
return config, nil
@@ -709,6 +721,9 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
return nil, fmt.Errorf("unable to read sslkey: %w", err)
}
block, _ := pem.Decode(buf)
if block == nil {
return nil, errors.New("failed to decode sslkey")
}
var pemKey []byte
var decryptedKey []byte
var decryptedError error
@@ -809,7 +824,7 @@ func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
return d.DialContext
}
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
// ValidateConnectTargetSessionAttrsReadWrite is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-write.
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
@@ -824,7 +839,7 @@ func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgC
return nil
}
// ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible
// ValidateConnectTargetSessionAttrsReadOnly is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-only.
func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
@@ -839,7 +854,7 @@ func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgCo
return nil
}
// ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible
// ValidateConnectTargetSessionAttrsStandby is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=standby.
func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
@@ -854,7 +869,7 @@ func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgCon
return nil
}
// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible
// ValidateConnectTargetSessionAttrsPrimary is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=primary.
func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
@@ -869,7 +884,7 @@ func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgCon
return nil
}
// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible
// ValidateConnectTargetSessionAttrsPreferStandby is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=prefer-standby.
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()

View File

@@ -57,22 +57,23 @@ func (pe *PgError) SQLState() string {
return pe.Code
}
type connectError struct {
config *Config
// ConnectError is the error returned when a connection attempt fails.
type ConnectError struct {
Config *Config // The configuration that was used in the connection attempt.
msg string
err error
}
func (e *connectError) Error() string {
func (e *ConnectError) Error() string {
sb := &strings.Builder{}
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.Config.Host, e.Config.User, e.Config.Database, e.msg)
if e.err != nil {
fmt.Fprintf(sb, " (%s)", e.err.Error())
}
return sb.String()
}
func (e *connectError) Unwrap() error {
func (e *ConnectError) Unwrap() error {
return e.err
}
@@ -88,33 +89,38 @@ func (e *connLockError) Error() string {
return e.status
}
type parseConfigError struct {
connString string
// ParseConfigError is the error returned when a connection string cannot be parsed.
type ParseConfigError struct {
ConnString string // The connection string that could not be parsed.
msg string
err error
}
func (e *parseConfigError) Error() string {
connString := redactPW(e.connString)
func (e *ParseConfigError) Error() string {
// Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only
// return a static string. That would ensure that the error message cannot leak a password. The ConnString field would
// allow access to the original string if desired and Unwrap would allow access to the underlying error.
connString := redactPW(e.ConnString)
if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
}
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
}
func (e *parseConfigError) Unwrap() error {
func (e *ParseConfigError) Unwrap() error {
return e.err
}
func normalizeTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
if ctx.Err() == context.Canceled {
// Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error.
return context.Canceled
} else if ctx.Err() == context.DeadlineExceeded {
return &errTimeout{err: ctx.Err()}
} else {
return &errTimeout{err: err}
return &errTimeout{err: netErr}
}
}
return err

View File

@@ -52,6 +52,12 @@ type LookupFunc func(ctx context.Context, host string) (addrs []string, err erro
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend
// PgErrorHandler is a function that handles errors returned from Postgres. This function must return true to keep
// the connection open. Returning false will cause the connection to be closed immediately. You should return
// false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is
// aware of the origin of the error, but it must not invoke any query method.
type PgErrorHandler func(*PgConn, *PgError) bool
// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY
@@ -74,6 +80,7 @@ type PgConn struct {
frontend *pgproto3.Frontend
bgReader *bgreader.BGReader
slowWriteTimer *time.Timer
bgReaderStarted chan struct{}
config *Config
@@ -145,11 +152,11 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
ctx := octx
fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs)
if err != nil {
return nil, &connectError{config: config, msg: "hostname resolving error", err: err}
return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err}
}
if len(fallbackConfigs) == 0 {
return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
}
foundBestServer := false
@@ -171,7 +178,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
foundBestServer = true
break
} else if pgerr, ok := err.(*PgError); ok {
err = &connectError{config: config, msg: "server error", err: pgerr}
err = &ConnectError{Config: config, msg: "server error", err: pgerr}
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
@@ -182,7 +189,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
break
}
} else if cerr, ok := err.(*connectError); ok {
} else if cerr, ok := err.(*ConnectError); ok {
if _, ok := cerr.err.(*NotPreferredError); ok {
fallbackConfig = fc
}
@@ -192,7 +199,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
if !foundBestServer && fallbackConfig != nil {
pgConn, err = connect(ctx, config, fallbackConfig, true)
if pgerr, ok := err.(*PgError); ok {
err = &connectError{config: config, msg: "server error", err: pgerr}
err = &ConnectError{Config: config, msg: "server error", err: pgerr}
}
}
@@ -204,7 +211,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
err := config.AfterConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "AfterConnect error", err: err}
return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err}
}
}
@@ -276,7 +283,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
netConn, err := config.DialFunc(ctx, network, address)
if err != nil {
return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
}
pgConn.conn = netConn
@@ -288,7 +295,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil {
netConn.Close()
return nil, &connectError{config: config, msg: "tls error", err: err}
return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
}
pgConn.conn = nbTLSConn
@@ -301,8 +308,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.parameterStatuses = make(map[string]string)
pgConn.status = connStatusConnecting
pgConn.bgReader = bgreader.New(pgConn.conn)
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start)
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64),
func() {
pgConn.bgReader.Start()
pgConn.bgReaderStarted <- struct{}{}
},
)
pgConn.slowWriteTimer.Stop()
pgConn.bgReaderStarted = make(chan struct{})
pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn)
startupMsg := pgproto3.StartupMessage{
@@ -323,7 +336,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.frontend.Send(&startupMsg)
if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
}
for {
@@ -333,7 +346,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if err, ok := err.(*PgError); ok {
return nil, err
}
return nil, &connectError{config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
}
switch msg := msg.(type) {
@@ -346,26 +359,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
err = pgConn.txPasswordMessage(pgConn.config.Password)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write password message", err: err}
return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
}
case *pgproto3.AuthenticationMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write password message", err: err}
return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
}
case *pgproto3.AuthenticationSASL:
err = pgConn.scramAuth(msg.AuthMechanisms)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed SASL auth", err: err}
return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err}
}
case *pgproto3.AuthenticationGSS:
err = pgConn.gssAuth()
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed GSS auth", err: err}
return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err}
}
case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle
@@ -383,7 +396,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return pgConn, nil
}
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err}
}
}
return pgConn, nil
@@ -394,7 +407,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return nil, ErrorResponseToPgError(msg)
default:
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "received unexpected message", err: err}
return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err}
}
}
}
@@ -540,11 +553,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
case *pgproto3.ParameterStatus:
pgConn.parameterStatuses[msg.Name] = msg.Value
case *pgproto3.ErrorResponse:
if msg.Severity == "FATAL" {
err := ErrorResponseToPgError(msg)
if pgConn.config.OnPgError != nil && !pgConn.config.OnPgError(pgConn, err) {
pgConn.status = connStatusClosed
pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return.
close(pgConn.cleanupDone)
return nil, ErrorResponseToPgError(msg)
return nil, err
}
case *pgproto3.NoticeResponse:
if pgConn.config.OnNotice != nil {
@@ -593,7 +607,7 @@ func (pgConn *PgConn) Frontend() *pgproto3.Frontend {
return pgConn.frontend
}
// Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by
// Close closes a connection. It is safe to call Close on an already closed connection. Close attempts a clean close by
// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The
// underlying net.Conn.Close() will always be called regardless of any other errors.
func (pgConn *PgConn) Close(ctx context.Context) error {
@@ -806,6 +820,9 @@ type StatementDescription struct {
// Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This
// allows Prepare to also to describe statements without creating a server-side prepared statement.
//
// Prepare does not send a PREPARE statement to the server. It uses the PostgreSQL Parse and Describe protocol messages
// directly.
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) {
if err := pgConn.lock(); err != nil {
return nil, err
@@ -862,6 +879,52 @@ readloop:
return psd, nil
}
// Deallocate deallocates a prepared statement.
//
// Deallocate does not send a DEALLOCATE statement to the server. It uses the PostgreSQL Close protocol message
// directly. This has slightly different behavior than executing DEALLOCATE statement.
// - Deallocate can succeed in an aborted transaction.
// - Deallocating a non-existent prepared statement is not an error.
func (pgConn *PgConn) Deallocate(ctx context.Context, name string) error {
if err := pgConn.lock(); err != nil {
return err
}
defer pgConn.unlock()
if ctx != context.Background() {
select {
case <-ctx.Done():
return newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
pgConn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name})
pgConn.frontend.SendSync(&pgproto3.Sync{})
err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
return err
}
for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
return normalizeTimeoutError(ctx, err)
}
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
return ErrorResponseToPgError(msg)
case *pgproto3.ReadyForQuery:
return nil
}
}
}
// ErrorResponseToPgError converts a wire protocol error message to a *PgError.
func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
return &PgError{
@@ -935,16 +998,21 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16)
binary.BigEndian.PutUint32(buf[4:8], 80877102)
binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid))
binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
// Postgres will process the request and close the connection
// so when don't need to read the reply
// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.6.7.10
_, err = cancelConn.Write(buf)
return err
binary.BigEndian.PutUint32(buf[8:12], pgConn.pid)
binary.BigEndian.PutUint32(buf[12:16], pgConn.secretKey)
if _, err := cancelConn.Write(buf); err != nil {
return fmt.Errorf("write to connection for cancellation: %w", err)
}
// Wait for the cancel request to be acknowledged by the server.
// It copies the behavior of the libpq: https://github.com/postgres/postgres/blob/REL_16_0/src/interfaces/libpq/fe-connect.c#L4946-L4960
_, _ = cancelConn.Read(buf)
return nil
}
// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not
// WaitForNotification waits for a LISTEN/NOTIFY message to be received. It returns an error if a notification was not
// received.
func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
if err := pgConn.lock(); err != nil {
@@ -1606,25 +1674,55 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
type Batch struct {
buf []byte
err error
}
// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.ExecPrepared("", paramValues, paramFormats, resultFormats)
}
// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
batch.buf = (&pgproto3.Execute{}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf)
if batch.err != nil {
return
}
}
// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
// transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing
// multiple queries in a single round trip than using pipeline mode.
func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
if batch.err != nil {
return &MultiResultReader{
closed: true,
err: batch.err,
}
}
if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
@@ -1650,7 +1748,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
pgConn.contextWatcher.Watch(ctx)
}
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
if batch.err != nil {
multiResult.closed = true
multiResult.err = batch.err
pgConn.unlock()
return multiResult
}
pgConn.enterPotentialWriteReadDeadlock()
defer pgConn.exitPotentialWriteReadDeadlock()
@@ -1732,10 +1836,16 @@ func (pgConn *PgConn) enterPotentialWriteReadDeadlock() {
// exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock.
func (pgConn *PgConn) exitPotentialWriteReadDeadlock() {
// The state of the timer is not relevant upon exiting the potential slow write. It may both
// fire (due to a slow write), or not fire (due to a fast write).
_ = pgConn.slowWriteTimer.Stop()
pgConn.bgReader.Stop()
if !pgConn.slowWriteTimer.Stop() {
// The timer starts its function in a separate goroutine. It is necessary to ensure the background reader has
// started before calling Stop. Otherwise, the background reader may not be stopped. That on its own is not a
// serious problem. But what is a serious problem is that the background reader may start at an inopportune time in
// a subsequent query. For example, if a subsequent query was canceled then a deadline may be set on the net.Conn to
// interrupt an in-progress read. After the read is interrupted, but before the deadline is cleared, the background
// reader could start and read a deadline error. Then the next query would receive the an unexpected deadline error.
<-pgConn.bgReaderStarted
pgConn.bgReader.Stop()
}
}
func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error {
@@ -1764,7 +1874,7 @@ func (pgConn *PgConn) SyncConn(ctx context.Context) error {
}
}
// This should never happen. Only way I can imagine this occuring is if the server is constantly sending data such as
// This should never happen. Only way I can imagine this occurring is if the server is constantly sending data such as
// LISTEN/NOTIFY or log notifications such that we never can get an empty buffer.
return errors.New("SyncConn: conn never synchronized")
}
@@ -1830,8 +1940,14 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
pgConn.contextWatcher = newContextWatcher(pgConn.conn)
pgConn.bgReader = bgreader.New(pgConn.conn)
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start)
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64),
func() {
pgConn.bgReader.Start()
pgConn.bgReaderStarted <- struct{}{}
},
)
pgConn.slowWriteTimer.Stop()
pgConn.bgReaderStarted = make(chan struct{})
pgConn.frontend = hc.Config.BuildFrontend(pgConn.bgReader, pgConn.conn)
return pgConn, nil
@@ -1973,6 +2089,13 @@ func (p *Pipeline) Flush() error {
// Sync establishes a synchronization point and flushes the queued requests.
func (p *Pipeline) Sync() error {
if p.closed {
if p.err != nil {
return p.err
}
return errors.New("pipeline closed")
}
p.conn.frontend.SendSync(&pgproto3.Sync{})
err := p.Flush()
if err != nil {
@@ -1989,14 +2112,28 @@ func (p *Pipeline) Sync() error {
// *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. If no
// results are available, results and err will both be nil.
func (p *Pipeline) GetResults() (results any, err error) {
if p.closed {
if p.err != nil {
return nil, p.err
}
return nil, errors.New("pipeline closed")
}
if p.expectedReadyForQueryCount == 0 {
return nil, nil
}
return p.getResults()
}
func (p *Pipeline) getResults() (results any, err error) {
for {
msg, err := p.conn.receiveMessage()
if err != nil {
return nil, err
p.closed = true
p.err = err
p.conn.asyncClose()
return nil, normalizeTimeoutError(p.ctx, err)
}
switch msg := msg.(type) {
@@ -2018,7 +2155,8 @@ func (p *Pipeline) GetResults() (results any, err error) {
case *pgproto3.ParseComplete:
peekedMsg, err := p.conn.peekMessage()
if err != nil {
return nil, err
p.conn.asyncClose()
return nil, normalizeTimeoutError(p.ctx, err)
}
if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok {
return p.getResultsPrepare()
@@ -2078,6 +2216,7 @@ func (p *Pipeline) Close() error {
if p.closed {
return p.err
}
p.closed = true
if p.pendingSync {
@@ -2090,7 +2229,7 @@ func (p *Pipeline) Close() error {
}
for p.expectedReadyForQueryCount > 0 {
_, err := p.GetResults()
_, err := p.getResults()
if err != nil {
p.err = err
var pgErr *PgError

View File

@@ -1,6 +1,6 @@
# pgproto3
Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
Package pgproto3 is an encoder and decoder of the PostgreSQL wire protocol version 3.
pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more.

View File

@@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -27,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error {
return nil
}
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 4)
func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeGSS)
return dst
return finishMessage(dst, sp)
}
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {

View File

@@ -31,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error {
return nil
}
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
dst = append(dst, a.Data...)
return dst
return finishMessage(dst, sp)
}
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {

View File

@@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 12)
func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
dst = append(dst, src.Salt[:]...)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationOk) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeOk)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -47,10 +47,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeSASL)
for _, s := range src.AuthMechanisms {
@@ -59,9 +57,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte {
}
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Unmarshaler.

View File

@@ -16,7 +16,8 @@ type Backend struct {
// before it is actually transmitted (i.e. before Flush).
tracer *tracer
wbuf []byte
wbuf []byte
encodeError error
// Frontend message flyweights
bind Bind
@@ -38,6 +39,7 @@ type Backend struct {
terminate Terminate
bodyLen int
maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error.
msgType byte
partialMsg bool
authType uint32
@@ -54,11 +56,21 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
return &Backend{cr: cr, w: w}
}
// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
// called.
// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error
// encountered will be returned from Flush.
func (b *Backend) Send(msg BackendMessage) {
if b.encodeError != nil {
return
}
prevLen := len(b.wbuf)
b.wbuf = msg.Encode(b.wbuf)
newBuf, err := msg.Encode(b.wbuf)
if err != nil {
b.encodeError = err
return
}
b.wbuf = newBuf
if b.tracer != nil {
b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
}
@@ -66,6 +78,12 @@ func (b *Backend) Send(msg BackendMessage) {
// Flush writes any pending messages to the frontend (i.e. the client).
func (b *Backend) Flush() error {
if err := b.encodeError; err != nil {
b.encodeError = nil
b.wbuf = b.wbuf[:0]
return &writeError{err: err, safeToRetry: true}
}
n, err := b.w.Write(b.wbuf)
const maxLen = 1024
@@ -158,6 +176,9 @@ func (b *Backend) Receive() (FrontendMessage, error) {
b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen {
return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen}
}
b.partialMsg = true
}
@@ -260,3 +281,12 @@ func (b *Backend) SetAuthType(authType uint32) error {
return nil
}
// SetMaxBodyLen sets the maximum length of a message body in octets. If a message body exceeds this length, Receive will return
// an error. This is useful for protecting against malicious clients that send large messages with the intent of
// causing memory exhaustion.
// The default value is 0.
// If maxBodyLen is 0, then no maximum is enforced.
func (b *Backend) SetMaxBodyLen(maxBodyLen int) {
b.maxBodyLen = maxBodyLen
}

View File

@@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BackendKeyData) Encode(dst []byte) []byte {
dst = append(dst, 'K')
dst = pgio.AppendUint32(dst, 12)
func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'K')
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -5,7 +5,9 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@@ -108,21 +110,25 @@ func (dst *Bind) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Bind) Encode(dst []byte) []byte {
dst = append(dst, 'B')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *Bind) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'B')
dst = append(dst, src.DestinationPortal...)
dst = append(dst, 0)
dst = append(dst, src.PreparedStatement...)
dst = append(dst, 0)
if len(src.ParameterFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many parameter format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
for _, fc := range src.ParameterFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
if len(src.Parameters) > math.MaxUint16 {
return nil, errors.New("too many parameters")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
for _, p := range src.Parameters {
if p == nil {
@@ -134,14 +140,15 @@ func (src *Bind) Encode(dst []byte) []byte {
dst = append(dst, p...)
}
if len(src.ResultFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many result format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
for _, fc := range src.ResultFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BindComplete) Encode(dst []byte) []byte {
return append(dst, '2', 0, 0, 0, 4)
func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '2', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *CancelRequest) Encode(dst []byte) []byte {
func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
dst = pgio.AppendInt32(dst, 16)
dst = pgio.AppendInt32(dst, cancelRequestCode)
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst
return dst, nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -4,8 +4,6 @@ import (
"bytes"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
type Close struct {
@@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Close) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *Close) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'C')
dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CloseComplete) Encode(dst []byte) []byte {
return append(dst, '3', 0, 0, 0, 4)
func (src *CloseComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '3', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -3,8 +3,6 @@ package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type CommandComplete struct {
@@ -31,17 +29,11 @@ func (dst *CommandComplete) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CommandComplete) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *CommandComplete) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'C')
dst = append(dst, src.CommandTag...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@@ -44,19 +45,18 @@ func (dst *CopyBothResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyBothResponse) Encode(dst []byte) []byte {
dst = append(dst, 'W')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'W')
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -3,8 +3,6 @@ package pgproto3
import (
"encoding/hex"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type CopyData struct {
@@ -25,11 +23,10 @@ func (dst *CopyData) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyData) Encode(dst []byte) []byte {
dst = append(dst, 'd')
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
func (src *CopyData) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'd')
dst = append(dst, src.Data...)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -24,8 +24,8 @@ func (dst *CopyDone) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyDone) Encode(dst []byte) []byte {
return append(dst, 'c', 0, 0, 0, 4)
func (src *CopyDone) Encode(dst []byte) ([]byte, error) {
return append(dst, 'c', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -3,8 +3,6 @@ package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type CopyFail struct {
@@ -28,17 +26,11 @@ func (dst *CopyFail) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyFail) Encode(dst []byte) []byte {
dst = append(dst, 'f')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *CopyFail) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'f')
dst = append(dst, src.Message...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@@ -44,20 +45,19 @@ func (dst *CopyInResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyInResponse) Encode(dst []byte) []byte {
dst = append(dst, 'G')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'G')
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@@ -43,21 +44,20 @@ func (dst *CopyOutResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyOutResponse) Encode(dst []byte) []byte {
dst = append(dst, 'H')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'H')
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -4,6 +4,8 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@@ -63,11 +65,12 @@ func (dst *DataRow) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *DataRow) Encode(dst []byte) []byte {
dst = append(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *DataRow) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'D')
if len(src.Values) > math.MaxUint16 {
return nil, errors.New("too many values")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
for _, v := range src.Values {
if v == nil {
@@ -79,9 +82,7 @@ func (src *DataRow) Encode(dst []byte) []byte {
dst = append(dst, v...)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -4,8 +4,6 @@ import (
"bytes"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
type Describe struct {
@@ -37,18 +35,12 @@ func (dst *Describe) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Describe) Encode(dst []byte) []byte {
dst = append(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *Describe) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'D')
dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -1,7 +1,7 @@
// Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
// Package pgproto3 is an encoder and decoder of the PostgreSQL wire protocol version 3.
//
// The primary interfaces are Frontend and Backend. They correspond to a client and server respectively. Messages are
// sent with Send (or a specialized Send variant). Messages are automatically bufferred to minimize small writes. Call
// sent with Send (or a specialized Send variant). Messages are automatically buffered to minimize small writes. Call
// Flush to ensure a message has actually been sent.
//
// The Trace method of Frontend and Backend can be used to examine the wire-level message traffic. It outputs in a

View File

@@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
return append(dst, 'I', 0, 0, 0, 4)
func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) {
return append(dst, 'I', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -2,7 +2,6 @@ package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"strconv"
)
@@ -111,119 +110,113 @@ func (dst *ErrorResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ErrorResponse) Encode(dst []byte) []byte {
return append(dst, src.marshalBinary('E')...)
func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'E')
dst = src.appendFields(dst)
return finishMessage(dst, sp)
}
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte(typeByte)
buf.Write(bigEndian.Uint32(0))
func (src *ErrorResponse) appendFields(dst []byte) []byte {
if src.Severity != "" {
buf.WriteByte('S')
buf.WriteString(src.Severity)
buf.WriteByte(0)
dst = append(dst, 'S')
dst = append(dst, src.Severity...)
dst = append(dst, 0)
}
if src.SeverityUnlocalized != "" {
buf.WriteByte('V')
buf.WriteString(src.SeverityUnlocalized)
buf.WriteByte(0)
dst = append(dst, 'V')
dst = append(dst, src.SeverityUnlocalized...)
dst = append(dst, 0)
}
if src.Code != "" {
buf.WriteByte('C')
buf.WriteString(src.Code)
buf.WriteByte(0)
dst = append(dst, 'C')
dst = append(dst, src.Code...)
dst = append(dst, 0)
}
if src.Message != "" {
buf.WriteByte('M')
buf.WriteString(src.Message)
buf.WriteByte(0)
dst = append(dst, 'M')
dst = append(dst, src.Message...)
dst = append(dst, 0)
}
if src.Detail != "" {
buf.WriteByte('D')
buf.WriteString(src.Detail)
buf.WriteByte(0)
dst = append(dst, 'D')
dst = append(dst, src.Detail...)
dst = append(dst, 0)
}
if src.Hint != "" {
buf.WriteByte('H')
buf.WriteString(src.Hint)
buf.WriteByte(0)
dst = append(dst, 'H')
dst = append(dst, src.Hint...)
dst = append(dst, 0)
}
if src.Position != 0 {
buf.WriteByte('P')
buf.WriteString(strconv.Itoa(int(src.Position)))
buf.WriteByte(0)
dst = append(dst, 'P')
dst = append(dst, strconv.Itoa(int(src.Position))...)
dst = append(dst, 0)
}
if src.InternalPosition != 0 {
buf.WriteByte('p')
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
buf.WriteByte(0)
dst = append(dst, 'p')
dst = append(dst, strconv.Itoa(int(src.InternalPosition))...)
dst = append(dst, 0)
}
if src.InternalQuery != "" {
buf.WriteByte('q')
buf.WriteString(src.InternalQuery)
buf.WriteByte(0)
dst = append(dst, 'q')
dst = append(dst, src.InternalQuery...)
dst = append(dst, 0)
}
if src.Where != "" {
buf.WriteByte('W')
buf.WriteString(src.Where)
buf.WriteByte(0)
dst = append(dst, 'W')
dst = append(dst, src.Where...)
dst = append(dst, 0)
}
if src.SchemaName != "" {
buf.WriteByte('s')
buf.WriteString(src.SchemaName)
buf.WriteByte(0)
dst = append(dst, 's')
dst = append(dst, src.SchemaName...)
dst = append(dst, 0)
}
if src.TableName != "" {
buf.WriteByte('t')
buf.WriteString(src.TableName)
buf.WriteByte(0)
dst = append(dst, 't')
dst = append(dst, src.TableName...)
dst = append(dst, 0)
}
if src.ColumnName != "" {
buf.WriteByte('c')
buf.WriteString(src.ColumnName)
buf.WriteByte(0)
dst = append(dst, 'c')
dst = append(dst, src.ColumnName...)
dst = append(dst, 0)
}
if src.DataTypeName != "" {
buf.WriteByte('d')
buf.WriteString(src.DataTypeName)
buf.WriteByte(0)
dst = append(dst, 'd')
dst = append(dst, src.DataTypeName...)
dst = append(dst, 0)
}
if src.ConstraintName != "" {
buf.WriteByte('n')
buf.WriteString(src.ConstraintName)
buf.WriteByte(0)
dst = append(dst, 'n')
dst = append(dst, src.ConstraintName...)
dst = append(dst, 0)
}
if src.File != "" {
buf.WriteByte('F')
buf.WriteString(src.File)
buf.WriteByte(0)
dst = append(dst, 'F')
dst = append(dst, src.File...)
dst = append(dst, 0)
}
if src.Line != 0 {
buf.WriteByte('L')
buf.WriteString(strconv.Itoa(int(src.Line)))
buf.WriteByte(0)
dst = append(dst, 'L')
dst = append(dst, strconv.Itoa(int(src.Line))...)
dst = append(dst, 0)
}
if src.Routine != "" {
buf.WriteByte('R')
buf.WriteString(src.Routine)
buf.WriteByte(0)
dst = append(dst, 'R')
dst = append(dst, src.Routine...)
dst = append(dst, 0)
}
for k, v := range src.UnknownFields {
buf.WriteByte(k)
buf.WriteString(v)
buf.WriteByte(0)
dst = append(dst, k)
dst = append(dst, v...)
dst = append(dst, 0)
}
buf.WriteByte(0)
dst = append(dst, 0)
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes()
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -36,19 +36,12 @@ func (dst *Execute) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Execute) Encode(dst []byte) []byte {
dst = append(dst, 'E')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *Execute) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'E')
dst = append(dst, src.Portal...)
dst = append(dst, 0)
dst = pgio.AppendUint32(dst, src.MaxRows)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -20,8 +20,8 @@ func (dst *Flush) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Flush) Encode(dst []byte) []byte {
return append(dst, 'H', 0, 0, 0, 4)
func (src *Flush) Encode(dst []byte) ([]byte, error) {
return append(dst, 'H', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -18,7 +18,8 @@ type Frontend struct {
// idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq.
tracer *tracer
wbuf []byte
wbuf []byte
encodeError error
// Backend message flyweights
authenticationOk AuthenticationOk
@@ -64,16 +65,26 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend {
return &Frontend{cr: cr, w: w}
}
// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is
// called.
// Send sends a message to the backend (i.e. the server). The message is buffered until Flush is called. Any error
// encountered will be returned from Flush.
//
// Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods
// such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an
// extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden
// behind an interface.
func (f *Frontend) Send(msg FrontendMessage) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg)
}
@@ -81,6 +92,12 @@ func (f *Frontend) Send(msg FrontendMessage) {
// Flush writes any pending messages to the backend (i.e. the server).
func (f *Frontend) Flush() error {
if err := f.encodeError; err != nil {
f.encodeError = nil
f.wbuf = f.wbuf[:0]
return &writeError{err: err, safeToRetry: true}
}
if len(f.wbuf) == 0 {
return nil
}
@@ -116,71 +133,141 @@ func (f *Frontend) Untrace() {
f.tracer = nil
}
// SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until
// Flush is called.
// SendBind sends a Bind message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// error encountered will be returned from Flush.
func (f *Frontend) SendBind(msg *Bind) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until
// Flush is called.
// SendParse sends a Parse message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// error encountered will be returned from Flush.
func (f *Frontend) SendParse(msg *Parse) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until
// Flush is called.
// SendClose sends a Close message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// error encountered will be returned from Flush.
func (f *Frontend) SendClose(msg *Close) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until
// Flush is called.
// SendDescribe sends a Describe message to the backend (i.e. the server). The message is buffered until Flush is
// called. Any error encountered will be returned from Flush.
func (f *Frontend) SendDescribe(msg *Describe) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendExecute sends a Execute message to the backend (i.e. the server). The message is not guaranteed to be written until
// Flush is called.
// SendExecute sends an Execute message to the backend (i.e. the server). The message is buffered until Flush is called.
// Any error encountered will be returned from Flush.
func (f *Frontend) SendExecute(msg *Execute) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until
// Flush is called.
// SendSync sends a Sync message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// error encountered will be returned from Flush.
func (f *Frontend) SendSync(msg *Sync) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until
// Flush is called.
// SendQuery sends a Query message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// error encountered will be returned from Flush.
func (f *Frontend) SendQuery(msg *Query) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg)
}

View File

@@ -2,6 +2,8 @@ package pgproto3
import (
"encoding/binary"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@@ -71,15 +73,21 @@ func (dst *FunctionCall) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *FunctionCall) Encode(dst []byte) []byte {
dst = append(dst, 'F')
sp := len(dst)
dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end
func (src *FunctionCall) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'F')
dst = pgio.AppendUint32(dst, src.Function)
if len(src.ArgFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many arg format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
for _, argFormatCode := range src.ArgFormatCodes {
dst = pgio.AppendUint16(dst, argFormatCode)
}
if len(src.Arguments) > math.MaxUint16 {
return nil, errors.New("too many arguments")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
for _, argument := range src.Arguments {
if argument == nil {
@@ -90,6 +98,5 @@ func (src *FunctionCall) Encode(dst []byte) []byte {
}
}
dst = pgio.AppendUint16(dst, src.ResultFormatCode)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}

View File

@@ -39,10 +39,8 @@ func (dst *FunctionCallResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *FunctionCallResponse) Encode(dst []byte) []byte {
dst = append(dst, 'V')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *FunctionCallResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'V')
if src.Result == nil {
dst = pgio.AppendInt32(dst, -1)
@@ -51,9 +49,7 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte {
dst = append(dst, src.Result...)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -31,10 +31,10 @@ func (dst *GSSEncRequest) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *GSSEncRequest) Encode(dst []byte) []byte {
func (src *GSSEncRequest) Encode(dst []byte) ([]byte, error) {
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendInt32(dst, gssEncReqNumber)
return dst
return dst, nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -2,8 +2,6 @@ package pgproto3
import (
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type GSSResponse struct {
@@ -18,11 +16,10 @@ func (g *GSSResponse) Decode(data []byte) error {
return nil
}
func (g *GSSResponse) Encode(dst []byte) []byte {
dst = append(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
func (g *GSSResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'p')
dst = append(dst, g.Data...)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -20,8 +20,8 @@ func (dst *NoData) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NoData) Encode(dst []byte) []byte {
return append(dst, 'n', 0, 0, 0, 4)
func (src *NoData) Encode(dst []byte) ([]byte, error) {
return append(dst, 'n', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -12,6 +12,8 @@ func (dst *NoticeResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NoticeResponse) Encode(dst []byte) []byte {
return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
func (src *NoticeResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'N')
dst = (*ErrorResponse)(src).appendFields(dst)
return finishMessage(dst, sp)
}

View File

@@ -45,20 +45,14 @@ func (dst *NotificationResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NotificationResponse) Encode(dst []byte) []byte {
dst = append(dst, 'A')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *NotificationResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'A')
dst = pgio.AppendUint32(dst, src.PID)
dst = append(dst, src.Channel...)
dst = append(dst, 0)
dst = append(dst, src.Payload...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -4,6 +4,8 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@@ -39,19 +41,18 @@ func (dst *ParameterDescription) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParameterDescription) Encode(dst []byte) []byte {
dst = append(dst, 't')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 't')
if len(src.ParameterOIDs) > math.MaxUint16 {
return nil, errors.New("too many parameter oids")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -3,8 +3,6 @@ package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type ParameterStatus struct {
@@ -37,19 +35,13 @@ func (dst *ParameterStatus) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParameterStatus) Encode(dst []byte) []byte {
dst = append(dst, 'S')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *ParameterStatus) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'S')
dst = append(dst, src.Name...)
dst = append(dst, 0)
dst = append(dst, src.Value...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -4,6 +4,8 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@@ -52,24 +54,23 @@ func (dst *Parse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Parse) Encode(dst []byte) []byte {
dst = append(dst, 'P')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *Parse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'P')
dst = append(dst, src.Name...)
dst = append(dst, 0)
dst = append(dst, src.Query...)
dst = append(dst, 0)
if len(src.ParameterOIDs) > math.MaxUint16 {
return nil, errors.New("too many parameter oids")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -20,8 +20,8 @@ func (dst *ParseComplete) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParseComplete) Encode(dst []byte) []byte {
return append(dst, '1', 0, 0, 0, 4)
func (src *ParseComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '1', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -3,8 +3,6 @@ package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type PasswordMessage struct {
@@ -32,14 +30,11 @@ func (dst *PasswordMessage) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *PasswordMessage) Encode(dst []byte) []byte {
dst = append(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
func (src *PasswordMessage) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'p')
dst = append(dst, src.Password...)
dst = append(dst, 0)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -4,8 +4,14 @@ import (
"encoding/hex"
"errors"
"fmt"
"github.com/jackc/pgx/v5/internal/pgio"
)
// maxMessageBodyLen is the maximum length of a message body in bytes. See PG_LARGE_MESSAGE_LIMIT in the PostgreSQL
// source. It is defined as (MaxAllocSize - 1). MaxAllocSize is defined as 0x3fffffff.
const maxMessageBodyLen = (0x3fffffff - 1)
// Message is the interface implemented by an object that can decode and encode
// a particular PostgreSQL message.
type Message interface {
@@ -14,7 +20,7 @@ type Message interface {
Decode(data []byte) error
// Encode appends itself to dst and returns the new buffer.
Encode(dst []byte) []byte
Encode(dst []byte) ([]byte, error)
}
// FrontendMessage is a message sent by the frontend (i.e. the client).
@@ -70,6 +76,15 @@ func (e *writeError) Unwrap() error {
return e.err
}
type ExceededMaxBodyLenErr struct {
MaxExpectedBodyLen int
ActualBodyLen int
}
func (e *ExceededMaxBodyLenErr) Error() string {
return fmt.Sprintf("invalid body length: expected at most %d, but got %d", e.MaxExpectedBodyLen, e.ActualBodyLen)
}
// getValueFromJSON gets the value from a protocol message representation in JSON.
func getValueFromJSON(v map[string]string) ([]byte, error) {
if v == nil {
@@ -83,3 +98,23 @@ func getValueFromJSON(v map[string]string) ([]byte, error) {
}
return nil, errors.New("unknown protocol representation")
}
// beginMessage begines a new message of type t. It appends the message type and a placeholder for the message length to
// dst. It returns the new buffer and the position of the message length placeholder.
func beginMessage(dst []byte, t byte) ([]byte, int) {
dst = append(dst, t)
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
return dst, sp
}
// finishMessage finishes a message that was started with beginMessage. It computes the message length and writes it to
// dst[sp]. If the message length is too large it returns an error. Otherwise it returns the final message buffer.
func finishMessage(dst []byte, sp int) ([]byte, error) {
messageBodyLen := len(dst[sp:])
if messageBodyLen > maxMessageBodyLen {
return nil, errors.New("message body too large")
}
pgio.SetInt32(dst[sp:], int32(messageBodyLen))
return dst, nil
}

View File

@@ -20,8 +20,8 @@ func (dst *PortalSuspended) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *PortalSuspended) Encode(dst []byte) []byte {
return append(dst, 's', 0, 0, 0, 4)
func (src *PortalSuspended) Encode(dst []byte) ([]byte, error) {
return append(dst, 's', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -3,8 +3,6 @@ package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type Query struct {
@@ -28,14 +26,11 @@ func (dst *Query) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Query) Encode(dst []byte) []byte {
dst = append(dst, 'Q')
dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1))
func (src *Query) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'Q')
dst = append(dst, src.String...)
dst = append(dst, 0)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -25,8 +25,8 @@ func (dst *ReadyForQuery) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ReadyForQuery) Encode(dst []byte) []byte {
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus)
func (src *ReadyForQuery) Encode(dst []byte) ([]byte, error) {
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -4,6 +4,8 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@@ -99,11 +101,12 @@ func (dst *RowDescription) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *RowDescription) Encode(dst []byte) []byte {
dst = append(dst, 'T')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *RowDescription) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'T')
if len(src.Fields) > math.MaxUint16 {
return nil, errors.New("too many fields")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
for _, fd := range src.Fields {
dst = append(dst, fd.Name...)
@@ -117,9 +120,7 @@ func (src *RowDescription) Encode(dst []byte) []byte {
dst = pgio.AppendInt16(dst, fd.Format)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -39,10 +39,8 @@ func (dst *SASLInitialResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *SASLInitialResponse) Encode(dst []byte) []byte {
dst = append(dst, 'p')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *SASLInitialResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'p')
dst = append(dst, []byte(src.AuthMechanism)...)
dst = append(dst, 0)
@@ -50,9 +48,7 @@ func (src *SASLInitialResponse) Encode(dst []byte) []byte {
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -3,8 +3,6 @@ package pgproto3
import (
"encoding/hex"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type SASLResponse struct {
@@ -22,13 +20,10 @@ func (dst *SASLResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *SASLResponse) Encode(dst []byte) []byte {
dst = append(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
func (src *SASLResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'p')
dst = append(dst, src.Data...)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -31,10 +31,10 @@ func (dst *SSLRequest) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *SSLRequest) Encode(dst []byte) []byte {
func (src *SSLRequest) Encode(dst []byte) ([]byte, error) {
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendInt32(dst, sslRequestNumber)
return dst
return dst, nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -38,14 +38,14 @@ func (dst *StartupMessage) Decode(src []byte) error {
for {
idx := bytes.IndexByte(src[rp:], 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "StartupMesage"}
return &invalidMessageFormatErr{messageType: "StartupMessage"}
}
key := string(src[rp : rp+idx])
rp += idx + 1
idx = bytes.IndexByte(src[rp:], 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "StartupMesage"}
return &invalidMessageFormatErr{messageType: "StartupMessage"}
}
value := string(src[rp : rp+idx])
rp += idx + 1
@@ -64,7 +64,7 @@ func (dst *StartupMessage) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *StartupMessage) Encode(dst []byte) []byte {
func (src *StartupMessage) Encode(dst []byte) ([]byte, error) {
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
@@ -77,9 +77,7 @@ func (src *StartupMessage) Encode(dst []byte) []byte {
}
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -20,8 +20,8 @@ func (dst *Sync) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Sync) Encode(dst []byte) []byte {
return append(dst, 'S', 0, 0, 0, 4)
func (src *Sync) Encode(dst []byte) ([]byte, error) {
return append(dst, 'S', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -20,8 +20,8 @@ func (dst *Terminate) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Terminate) Encode(dst []byte) []byte {
return append(dst, 'X', 0, 0, 0, 4)
func (src *Terminate) Encode(dst []byte) ([]byte, error) {
return append(dst, 'X', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@@ -110,7 +110,7 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) {
r, _, err := buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid array: %v", err)
return nil, fmt.Errorf("invalid array: %w", err)
}
var explicitDimensions []ArrayDimension
@@ -122,7 +122,7 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) {
for {
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid array: %v", err)
return nil, fmt.Errorf("invalid array: %w", err)
}
if r == '=' {
@@ -133,12 +133,12 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) {
lower, err := arrayParseInteger(buf)
if err != nil {
return nil, fmt.Errorf("invalid array: %v", err)
return nil, fmt.Errorf("invalid array: %w", err)
}
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid array: %v", err)
return nil, fmt.Errorf("invalid array: %w", err)
}
if r != ':' {
@@ -147,12 +147,12 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) {
upper, err := arrayParseInteger(buf)
if err != nil {
return nil, fmt.Errorf("invalid array: %v", err)
return nil, fmt.Errorf("invalid array: %w", err)
}
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid array: %v", err)
return nil, fmt.Errorf("invalid array: %w", err)
}
if r != ']' {
@@ -164,12 +164,12 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) {
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid array: %v", err)
return nil, fmt.Errorf("invalid array: %w", err)
}
}
if r != '{' {
return nil, fmt.Errorf("invalid array, expected '{': %v", err)
return nil, fmt.Errorf("invalid array, expected '{' got %v", r)
}
implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}}
@@ -178,7 +178,7 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) {
for {
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid array: %v", err)
return nil, fmt.Errorf("invalid array: %w", err)
}
if r == '{' {
@@ -195,7 +195,7 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) {
for {
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid array: %v", err)
return nil, fmt.Errorf("invalid array: %w", err)
}
switch r {
@@ -214,7 +214,7 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) {
buf.UnreadRune()
value, quoted, err := arrayParseValue(buf)
if err != nil {
return nil, fmt.Errorf("invalid array value: %v", err)
return nil, fmt.Errorf("invalid array value: %w", err)
}
if currentDim == counterDim {
implicitDimensions[currentDim].Length++

View File

@@ -176,8 +176,10 @@ func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst any) error {
bitLen := int32(binary.BigEndian.Uint32(src))
rp := 4
buf := make([]byte, len(src[rp:]))
copy(buf, src[rp:])
return scanner.ScanBits(Bits{Bytes: src[rp:], Len: bitLen, Valid: true})
return scanner.ScanBits(Bits{Bytes: buf, Len: bitLen, Valid: true})
}
type scanPlanTextAnyToBitsScanner struct{}

View File

@@ -231,7 +231,7 @@ func (w *uint64Wrapper) ScanNumeric(v Numeric) error {
bi, err := v.toBigInt()
if err != nil {
return fmt.Errorf("cannot scan into *uint64: %v", err)
return fmt.Errorf("cannot scan into *uint64: %w", err)
}
if !bi.IsUint64() {
@@ -284,7 +284,7 @@ func (w *uintWrapper) ScanNumeric(v Numeric) error {
bi, err := v.toBigInt()
if err != nil {
return fmt.Errorf("cannot scan into *uint: %v", err)
return fmt.Errorf("cannot scan into *uint: %w", err)
}
if !bi.IsUint64() {

View File

@@ -282,17 +282,17 @@ func (scanPlanTextAnyToDateScanner) Scan(src []byte, dst any) error {
if match != nil {
year, err := strconv.ParseInt(match[1], 10, 32)
if err != nil {
return fmt.Errorf("BUG: cannot parse date that regexp matched (year): %v", err)
return fmt.Errorf("BUG: cannot parse date that regexp matched (year): %w", err)
}
month, err := strconv.ParseInt(match[2], 10, 32)
if err != nil {
return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %v", err)
return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %w", err)
}
day, err := strconv.ParseInt(match[3], 10, 32)
if err != nil {
return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %v", err)
return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %w", err)
}
// BC matched

View File

@@ -67,7 +67,7 @@ See example_custom_type_test.go for an example of a custom type for the PostgreS
Sometimes pgx supports a PostgreSQL type such as numeric but the Go type is in an external package that does not have
pgx support such as github.com/shopspring/decimal. These types can be registered with pgtype with custom conversion
logic. See https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid for a example
logic. See https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid for example
integrations.
New PostgreSQL Type Support
@@ -149,7 +149,7 @@ Overview of Scanning Implementation
The first step is to use the OID to lookup the correct Codec. If the OID is unavailable, Map will try to find the OID
from previous calls of Map.RegisterDefaultPgType. The Map will call the Codec's PlanScan method to get a plan for
scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types are
interfaces rather than explicit types. For example, PointCodec can use any Go type that implments the PointScanner and
interfaces rather than explicit types. For example, PointCodec can use any Go type that implements the PointScanner and
PointValuer interfaces.
If a Go value is not supported directly by a Codec then Map will try wrapping it with additional logic and try again.

View File

@@ -3,6 +3,7 @@ package pgtype
import (
"database/sql/driver"
"encoding/binary"
"encoding/json"
"fmt"
"math"
"strconv"
@@ -65,6 +66,29 @@ func (f Float4) Value() (driver.Value, error) {
return float64(f.Float32), nil
}
func (f Float4) MarshalJSON() ([]byte, error) {
if !f.Valid {
return []byte("null"), nil
}
return json.Marshal(f.Float32)
}
func (f *Float4) UnmarshalJSON(b []byte) error {
var n *float32
err := json.Unmarshal(b, &n)
if err != nil {
return err
}
if n == nil {
*f = Float4{}
} else {
*f = Float4{Float32: *n, Valid: true}
}
return nil
}
type Float4Codec struct{}
func (Float4Codec) FormatSupported(format int16) bool {
@@ -273,12 +297,12 @@ func (c Float4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr
return nil, nil
}
var n float64
var n float32
err := codecScan(c, m, oid, format, src, &n)
if err != nil {
return nil, err
}
return n, nil
return float64(n), nil
}
func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {

View File

@@ -74,6 +74,29 @@ func (f Float8) Value() (driver.Value, error) {
return f.Float64, nil
}
func (f Float8) MarshalJSON() ([]byte, error) {
if !f.Valid {
return []byte("null"), nil
}
return json.Marshal(f.Float64)
}
func (f *Float8) UnmarshalJSON(b []byte) error {
var n *float64
err := json.Unmarshal(b, &n)
if err != nil {
return err
}
if n == nil {
*f = Float8{}
} else {
*f = Float8{Float64: *n, Valid: true}
}
return nil
}
type Float8Codec struct{}
func (Float8Codec) FormatSupported(format int16) bool {
@@ -109,13 +132,6 @@ func (Float8Codec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
return nil
}
func (f *Float8) MarshalJSON() ([]byte, error) {
if !f.Valid {
return []byte("null"), nil
}
return json.Marshal(f.Float64)
}
type encodePlanFloat8CodecBinaryFloat64 struct{}
func (encodePlanFloat8CodecBinaryFloat64) Encode(value any, buf []byte) (newBuf []byte, err error) {

View File

@@ -156,7 +156,7 @@ func (scanPlanBinaryInetToNetipPrefixScanner) Scan(src []byte, dst any) error {
}
if len(src) != 8 && len(src) != 20 {
return fmt.Errorf("Received an invalid size for a inet: %d", len(src))
return fmt.Errorf("Received an invalid size for an inet: %d", len(src))
}
// ignore family

View File

@@ -179,7 +179,7 @@ func (scanPlanBinaryIntervalToIntervalScanner) Scan(src []byte, dst any) error {
}
if len(src) != 16 {
return fmt.Errorf("Received an invalid size for a interval: %d", len(src))
return fmt.Errorf("Received an invalid size for an interval: %d", len(src))
}
microseconds := int64(binary.BigEndian.Uint64(src))
@@ -242,21 +242,21 @@ func (scanPlanTextAnyToIntervalScanner) Scan(src []byte, dst any) error {
return fmt.Errorf("bad interval minute format: %s", timeParts[1])
}
secondParts := strings.SplitN(timeParts[2], ".", 2)
sec, secFrac, secFracFound := strings.Cut(timeParts[2], ".")
seconds, err := strconv.ParseInt(secondParts[0], 10, 64)
seconds, err := strconv.ParseInt(sec, 10, 64)
if err != nil {
return fmt.Errorf("bad interval second format: %s", secondParts[0])
return fmt.Errorf("bad interval second format: %s", sec)
}
var uSeconds int64
if len(secondParts) == 2 {
uSeconds, err = strconv.ParseInt(secondParts[1], 10, 64)
if secFracFound {
uSeconds, err = strconv.ParseInt(secFrac, 10, 64)
if err != nil {
return fmt.Errorf("bad interval decimal format: %s", secondParts[1])
return fmt.Errorf("bad interval decimal format: %s", secFrac)
}
for i := 0; i < 6-len(secondParts[1]); i++ {
for i := 0; i < 6-len(secFrac); i++ {
uSeconds *= 10
}
}

View File

@@ -25,18 +25,26 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
case []byte:
return encodePlanJSONCodecEitherFormatByteSlice{}
// Handle json.RawMessage specifically because if it is run through json.Marshal it may be mutated.
// e.g. `{"foo": "bar"}` -> `{"foo":"bar"}`.
case json.RawMessage:
return encodePlanJSONCodecEitherFormatJSONRawMessage{}
// Cannot rely on driver.Valuer being handled later because anything can be marshalled.
//
// https://github.com/jackc/pgx/issues/1430
//
// Check for driver.Valuer must come before json.Marshaler so that it is guaranteed to beused
// when both are implemented https://github.com/jackc/pgx/issues/1805
case driver.Valuer:
return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format}
// Must come before trying wrap encode plans because a pointer to a struct may be unwrapped to a struct that can be
// marshalled.
//
// https://github.com/jackc/pgx/issues/1681
case json.Marshaler:
return encodePlanJSONCodecEitherFormatMarshal{}
// Cannot rely on driver.Valuer being handled later because anything can be marshalled.
//
// https://github.com/jackc/pgx/issues/1430
case driver.Valuer:
return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format}
}
// Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the
@@ -76,6 +84,18 @@ func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value any, buf []byte) (n
return buf, nil
}
type encodePlanJSONCodecEitherFormatJSONRawMessage struct{}
func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byte) (newBuf []byte, err error) {
jsonBytes := value.(json.RawMessage)
if jsonBytes == nil {
return nil, nil
}
buf = append(buf, jsonBytes...)
return buf, nil
}
type encodePlanJSONCodecEitherFormatMarshal struct{}
func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {

View File

@@ -339,18 +339,18 @@ func parseUntypedTextMultirange(src []byte) ([]string, error) {
r, _, err := buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid array: %v", err)
return nil, fmt.Errorf("invalid array: %w", err)
}
if r != '{' {
return nil, fmt.Errorf("invalid multirange, expected '{': %v", err)
return nil, fmt.Errorf("invalid multirange, expected '{' got %v", r)
}
parseValueLoop:
for {
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid multirange: %v", err)
return nil, fmt.Errorf("invalid multirange: %w", err)
}
switch r {
@@ -361,7 +361,7 @@ parseValueLoop:
buf.UnreadRune()
value, err := parseRange(buf)
if err != nil {
return nil, fmt.Errorf("invalid multirange value: %v", err)
return nil, fmt.Errorf("invalid multirange value: %w", err)
}
elements = append(elements, value)
}

View File

@@ -119,6 +119,26 @@ func (n Numeric) Int64Value() (Int8, error) {
return Int8{Int64: bi.Int64(), Valid: true}, nil
}
func (n *Numeric) ScanScientific(src string) error {
if !strings.ContainsAny("eE", src) {
return scanPlanTextAnyToNumericScanner{}.Scan([]byte(src), n)
}
if bigF, ok := new(big.Float).SetString(string(src)); ok {
smallF, _ := bigF.Float64()
src = strconv.FormatFloat(smallF, 'f', -1, 64)
}
num, exp, err := parseNumericString(src)
if err != nil {
return err
}
*n = Numeric{Int: num, Exp: exp, Valid: true}
return nil
}
func (n *Numeric) toBigInt() (*big.Int, error) {
if n.Exp == 0 {
return n.Int, nil

View File

@@ -81,6 +81,8 @@ const (
IntervalOID = 1186
IntervalArrayOID = 1187
NumericArrayOID = 1231
TimetzOID = 1266
TimetzArrayOID = 1270
BitOID = 1560
BitArrayOID = 1561
VarbitOID = 1562
@@ -559,7 +561,7 @@ func TryFindUnderlyingTypeScanPlan(dst any) (plan WrappedScanPlanNextSetter, nex
}
}
if nextDstType != nil && dstValue.Type() != nextDstType {
if nextDstType != nil && dstValue.Type() != nextDstType && dstValue.CanConvert(nextDstType) {
return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true
}
@@ -1358,6 +1360,8 @@ var kindToTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{
reflect.Bool: reflect.TypeOf(false),
}
var byteSliceType = reflect.TypeOf([]byte{})
type underlyingTypeEncodePlan struct {
nextValueType reflect.Type
next EncodePlan
@@ -1372,6 +1376,10 @@ func (plan *underlyingTypeEncodePlan) Encode(value any, buf []byte) (newBuf []by
// TryWrapFindUnderlyingTypeEncodePlan tries to convert to a Go builtin type. e.g. If value was of type MyString and
// MyString was defined as a string then a wrapper plan would be returned that converts MyString to string.
func TryWrapFindUnderlyingTypeEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) {
if value == nil {
return nil, nil, false
}
if _, ok := value.(driver.Valuer); ok {
return nil, nil, false
}
@@ -1387,6 +1395,15 @@ func TryWrapFindUnderlyingTypeEncodePlan(value any) (plan WrappedEncodePlanNextS
return &underlyingTypeEncodePlan{nextValueType: nextValueType}, refValue.Convert(nextValueType).Interface(), true
}
// []byte is a special case. It is a slice but we treat it as a scalar type. In the case of a named type like
// json.RawMessage which is defined as []byte the underlying type should be considered as []byte. But any other slice
// does not have a special underlying type.
//
// https://github.com/jackc/pgx/issues/1763
if refValue.Type() != byteSliceType && refValue.Type().AssignableTo(byteSliceType) {
return &underlyingTypeEncodePlan{nextValueType: byteSliceType}, refValue.Convert(byteSliceType).Interface(), true
}
return nil, nil, false
}

View File

@@ -1,6 +1,7 @@
package pgtype
import (
"encoding/json"
"net"
"net/netip"
"reflect"
@@ -173,6 +174,7 @@ func initDefaultMap() {
registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz")
registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval")
registerDefaultPgTypeVariants[string](defaultMap, "text")
registerDefaultPgTypeVariants[json.RawMessage](defaultMap, "json")
registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea")
registerDefaultPgTypeVariants[net.IP](defaultMap, "inet")

View File

@@ -50,17 +50,17 @@ func parsePoint(src []byte) (*Point, error) {
if src[0] == '"' && src[len(src)-1] == '"' {
src = src[1 : len(src)-1]
}
parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2)
if len(parts) < 2 {
sx, sy, found := strings.Cut(string(src[1:len(src)-1]), ",")
if !found {
return nil, fmt.Errorf("invalid format for point")
}
x, err := strconv.ParseFloat(parts[0], 64)
x, err := strconv.ParseFloat(sx, 64)
if err != nil {
return nil, err
}
y, err := strconv.ParseFloat(parts[1], 64)
y, err := strconv.ParseFloat(sy, 64)
if err != nil {
return nil, err
}
@@ -247,17 +247,17 @@ func (scanPlanTextAnyToPointScanner) Scan(src []byte, dst any) error {
return fmt.Errorf("invalid length for point: %v", len(src))
}
parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2)
if len(parts) < 2 {
sx, sy, found := strings.Cut(string(src[1:len(src)-1]), ",")
if !found {
return fmt.Errorf("invalid format for point")
}
x, err := strconv.ParseFloat(parts[0], 64)
x, err := strconv.ParseFloat(sx, 64)
if err != nil {
return err
}
y, err := strconv.ParseFloat(parts[1], 64)
y, err := strconv.ParseFloat(sy, 64)
if err != nil {
return err
}

View File

@@ -40,7 +40,7 @@ func parseUntypedTextRange(src string) (*untypedTextRange, error) {
r, _, err := buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid lower bound: %v", err)
return nil, fmt.Errorf("invalid lower bound: %w", err)
}
switch r {
case '(':
@@ -53,7 +53,7 @@ func parseUntypedTextRange(src string) (*untypedTextRange, error) {
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid lower value: %v", err)
return nil, fmt.Errorf("invalid lower value: %w", err)
}
buf.UnreadRune()
@@ -62,13 +62,13 @@ func parseUntypedTextRange(src string) (*untypedTextRange, error) {
} else {
utr.Lower, err = rangeParseValue(buf)
if err != nil {
return nil, fmt.Errorf("invalid lower value: %v", err)
return nil, fmt.Errorf("invalid lower value: %w", err)
}
}
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("missing range separator: %v", err)
return nil, fmt.Errorf("missing range separator: %w", err)
}
if r != ',' {
return nil, fmt.Errorf("missing range separator: %v", r)
@@ -76,7 +76,7 @@ func parseUntypedTextRange(src string) (*untypedTextRange, error) {
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid upper value: %v", err)
return nil, fmt.Errorf("invalid upper value: %w", err)
}
if r == ')' || r == ']' {
@@ -85,12 +85,12 @@ func parseUntypedTextRange(src string) (*untypedTextRange, error) {
buf.UnreadRune()
utr.Upper, err = rangeParseValue(buf)
if err != nil {
return nil, fmt.Errorf("invalid upper value: %v", err)
return nil, fmt.Errorf("invalid upper value: %w", err)
}
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("missing upper bound: %v", err)
return nil, fmt.Errorf("missing upper bound: %w", err)
}
switch r {
case ')':

View File

@@ -120,7 +120,7 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt
buf, err = lowerPlan.Encode(lower, buf)
if err != nil {
return nil, fmt.Errorf("failed to encode %v as element of range: %v", lower, err)
return nil, fmt.Errorf("failed to encode %v as element of range: %w", lower, err)
}
if buf == nil {
return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded")
@@ -144,7 +144,7 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt
buf, err = upperPlan.Encode(upper, buf)
if err != nil {
return nil, fmt.Errorf("failed to encode %v as element of range: %v", upper, err)
return nil, fmt.Errorf("failed to encode %v as element of range: %w", upper, err)
}
if buf == nil {
return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded")
@@ -194,7 +194,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte)
buf, err = lowerPlan.Encode(lower, buf)
if err != nil {
return nil, fmt.Errorf("failed to encode %v as element of range: %v", lower, err)
return nil, fmt.Errorf("failed to encode %v as element of range: %w", lower, err)
}
if buf == nil {
return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded")
@@ -215,7 +215,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte)
buf, err = upperPlan.Encode(upper, buf)
if err != nil {
return nil, fmt.Errorf("failed to encode %v as element of range: %v", upper, err)
return nil, fmt.Errorf("failed to encode %v as element of range: %w", upper, err)
}
if buf == nil {
return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded")
@@ -282,7 +282,7 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) erro
err = lowerPlan.Scan(ubr.Lower, lowerTarget)
if err != nil {
return fmt.Errorf("cannot scan into %v from range element: %v", lowerTarget, err)
return fmt.Errorf("cannot scan into %v from range element: %w", lowerTarget, err)
}
}
@@ -294,7 +294,7 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) erro
err = upperPlan.Scan(ubr.Upper, upperTarget)
if err != nil {
return fmt.Errorf("cannot scan into %v from range element: %v", upperTarget, err)
return fmt.Errorf("cannot scan into %v from range element: %w", upperTarget, err)
}
}
@@ -332,7 +332,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error
err = lowerPlan.Scan([]byte(utr.Lower), lowerTarget)
if err != nil {
return fmt.Errorf("cannot scan into %v from range element: %v", lowerTarget, err)
return fmt.Errorf("cannot scan into %v from range element: %w", lowerTarget, err)
}
}
@@ -344,7 +344,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error
err = upperPlan.Scan([]byte(utr.Upper), upperTarget)
if err != nil {
return fmt.Errorf("cannot scan into %v from range element: %v", upperTarget, err)
return fmt.Errorf("cannot scan into %v from range element: %w", upperTarget, err)
}
}

View File

@@ -205,17 +205,17 @@ func (scanPlanTextAnyToTIDScanner) Scan(src []byte, dst any) error {
return fmt.Errorf("invalid length for tid: %v", len(src))
}
parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2)
if len(parts) < 2 {
block, offset, found := strings.Cut(string(src[1:len(src)-1]), ",")
if !found {
return fmt.Errorf("invalid format for tid")
}
blockNumber, err := strconv.ParseUint(parts[0], 10, 32)
blockNumber, err := strconv.ParseUint(block, 10, 32)
if err != nil {
return err
}
offsetNumber, err := strconv.ParseUint(parts[1], 10, 16)
offsetNumber, err := strconv.ParseUint(offset, 10, 16)
if err != nil {
return err
}

View File

@@ -52,7 +52,19 @@ func parseUUID(src string) (dst [16]byte, err error) {
// encodeUUID converts a uuid byte array to UUID standard string form.
func encodeUUID(src [16]byte) string {
return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16])
var buf [36]byte
hex.Encode(buf[0:8], src[:4])
buf[8] = '-'
hex.Encode(buf[9:13], src[4:6])
buf[13] = '-'
hex.Encode(buf[14:18], src[6:8])
buf[18] = '-'
hex.Encode(buf[19:23], src[8:10])
buf[23] = '-'
hex.Encode(buf[24:], src[10:])
return string(buf[:])
}
// Scan implements the database/sql Scanner interface.

View File

@@ -8,7 +8,6 @@ import (
"strings"
"time"
"github.com/jackc/pgx/v5/internal/stmtcache"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
)
@@ -17,7 +16,8 @@ import (
// the *Conn can be used again. Rows are closed by explicitly calling Close(),
// calling Next() until it returns false, or when a fatal error occurs.
//
// Once a Rows is closed the only methods that may be called are Close(), Err(), and CommandTag().
// Once a Rows is closed the only methods that may be called are Close(), Err(),
// and CommandTag().
//
// Rows is an interface instead of a struct to allow tests to mock Query. However,
// adding a method to an interface is technically a breaking change. Because of this
@@ -41,8 +41,15 @@ type Rows interface {
FieldDescriptions() []pgconn.FieldDescription
// Next prepares the next row for reading. It returns true if there is another
// row and false if no more rows are available. It automatically closes rows
// when all rows are read.
// row and false if no more rows are available or a fatal error has occurred.
// It automatically closes rows when all rows are read.
//
// Callers should check rows.Err() after rows.Next() returns false to detect
// whether result-set reading ended prematurely due to an error. See
// Conn.Query for details.
//
// For simpler error handling, consider using the higher-level pgx v5
// CollectRows() and ForEachRow() helpers instead.
Next() bool
// Scan reads the values from the current row into dest values positionally.
@@ -166,14 +173,12 @@ func (rows *baseRows) Close() {
}
if rows.err != nil && rows.conn != nil && rows.sql != "" {
if stmtcache.IsStatementInvalid(rows.err) {
if sc := rows.conn.statementCache; sc != nil {
sc.Invalidate(rows.sql)
}
if sc := rows.conn.statementCache; sc != nil {
sc.Invalidate(rows.sql)
}
if sc := rows.conn.descriptionCache; sc != nil {
sc.Invalidate(rows.sql)
}
if sc := rows.conn.descriptionCache; sc != nil {
sc.Invalidate(rows.sql)
}
}
@@ -412,12 +417,10 @@ type CollectableRow interface {
// RowToFunc is a function that scans or otherwise converts row to a T.
type RowToFunc[T any] func(row CollectableRow) (T, error)
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
// AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T.
func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) {
defer rows.Close()
slice := []T{}
for rows.Next() {
value, err := fn(rows)
if err != nil {
@@ -433,6 +436,11 @@ func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
return slice, nil
}
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
return AppendRows([]T{}, rows, fn)
}
// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true.
// CollectOneRow is to CollectRows as QueryRow is to Query.
func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
@@ -457,6 +465,39 @@ func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
return value, rows.Err()
}
// CollectExactlyOneRow calls fn for the first row in rows and returns the result.
// - If no rows are found returns an error where errors.Is(ErrNoRows) is true.
// - If more than 1 row is found returns an error where errors.Is(ErrTooManyRows) is true.
func CollectExactlyOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
defer rows.Close()
var (
err error
value T
)
if !rows.Next() {
if err = rows.Err(); err != nil {
return value, err
}
return value, ErrNoRows
}
value, err = fn(rows)
if err != nil {
return value, err
}
if rows.Next() {
var zero T
return zero, ErrTooManyRows
}
return value, rows.Err()
}
// RowTo returns a T scanned from row.
func RowTo[T any](row CollectableRow) (T, error) {
var value T
@@ -496,7 +537,7 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error {
}
// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row
// has fields. The row and T fields will by matched by position. If the "db" struct tag is "-" then the field will be
// has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then the field will be
// ignored.
func RowToStructByPos[T any](row CollectableRow) (T, error) {
var value T
@@ -505,7 +546,7 @@ func RowToStructByPos[T any](row CollectableRow) (T, error) {
}
// RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a
// public fields as row has fields. The row and T fields will by matched by position. If the "db" struct tag is "-" then
// public fields as row has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then
// the field will be ignored.
func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) {
var value T
@@ -560,7 +601,7 @@ func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Val
}
// RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public
// fields as row has fields. The row and T fields will by matched by name. The match is case-insensitive. The database
// fields as row has fields. The row and T fields will be matched by name. The match is case-insensitive. The database
// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
func RowToStructByName[T any](row CollectableRow) (T, error) {
var value T
@@ -569,7 +610,7 @@ func RowToStructByName[T any](row CollectableRow) (T, error) {
}
// RowToAddrOfStructByName returns the address of a T scanned from row. T must be a struct. T must have the same number
// of named public fields as row has fields. The row and T fields will by matched by name. The match is
// of named public fields as row has fields. The row and T fields will be matched by name. The match is
// case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-"
// then the field will be ignored.
func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
@@ -579,7 +620,7 @@ func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
}
// RowToStructByNameLax returns a T scanned from row. T must be a struct. T must have greater than or equal number of named public
// fields as row has fields. The row and T fields will by matched by name. The match is case-insensitive. The database
// fields as row has fields. The row and T fields will be matched by name. The match is case-insensitive. The database
// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
var value T
@@ -588,7 +629,7 @@ func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
}
// RowToAddrOfStructByNameLax returns the address of a T scanned from row. T must be a struct. T must have greater than or
// equal number of named public fields as row has fields. The row and T fields will by matched by name. The match is
// equal number of named public fields as row has fields. The row and T fields will be matched by name. The match is
// case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-"
// then the field will be ignored.
func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) {
@@ -629,7 +670,12 @@ const structTagKey = "db"
func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
i = -1
for i, desc := range fldDescs {
if strings.EqualFold(desc.Name, field) {
// Snake case support.
field = strings.ReplaceAll(field, "_", "")
descName := strings.ReplaceAll(desc.Name, "_", "")
if strings.EqualFold(descName, field) {
return i
}
}
@@ -650,7 +696,7 @@ func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, s
// Field is unexported, skip it.
continue
}
// Handle anoymous struct embedding, but do not try to handle embedded pointers.
// Handle anonymous struct embedding, but do not try to handle embedded pointers.
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs)
if err != nil {
@@ -659,7 +705,7 @@ func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, s
} else {
dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
if dbTagPresent {
dbTag = strings.Split(dbTag, ",")[0]
dbTag, _, _ = strings.Cut(dbTag, ",")
}
if dbTag == "-" {
// Field is ignored, skip it.

View File

@@ -14,12 +14,21 @@
// return err
// }
//
// Or from a *pgxpool.Pool.
//
// pool, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL"))
// if err != nil {
// return err
// }
//
// db := stdlib.OpenDBFromPool(pool)
//
// Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the
// pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used
// with sql.Open.
//
// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL"))
// connConfig.Logger = myLogger
// connConfig.Tracer = &tracelog.TraceLog{Logger: myLogger, LogLevel: tracelog.LogLevelInfo}
// connStr := stdlib.RegisterConnConfig(connConfig)
// db, _ := sql.Open("pgx", connStr)
//
@@ -74,6 +83,7 @@ import (
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
)
// Only intrinsic types should be binary format with database/sql.
@@ -125,14 +135,14 @@ func contains(list []string, y string) bool {
type OptionOpenDB func(*connector)
// OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will
// be used to connect, so only its immediate members should be modified.
// be used to connect, so only its immediate members should be modified. Used only if db is opened with *pgx.ConnConfig.
func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB {
return func(dc *connector) {
dc.BeforeConnect = bc
}
}
// OptionAfterConnect provides a callback for after connect.
// OptionAfterConnect provides a callback for after connect. Used only if db is opened with *pgx.ConnConfig.
func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB {
return func(dc *connector) {
dc.AfterConnect = ac
@@ -191,13 +201,42 @@ func GetConnector(config pgx.ConnConfig, opts ...OptionOpenDB) driver.Connector
return c
}
// GetPoolConnector creates a new driver.Connector from the given *pgxpool.Pool. By using this be sure to set the
// maximum idle connections of the *sql.DB created with this connector to zero since they must be managed from the
// *pgxpool.Pool. This is required to avoid acquiring all the connections from the pgxpool and starving any direct
// users of the pgxpool.
func GetPoolConnector(pool *pgxpool.Pool, opts ...OptionOpenDB) driver.Connector {
c := connector{
pool: pool,
ResetSession: func(context.Context, *pgx.Conn) error { return nil }, // noop reset session by default
driver: pgxDriver,
}
for _, opt := range opts {
opt(&c)
}
return c
}
func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
c := GetConnector(config, opts...)
return sql.OpenDB(c)
}
// OpenDBFromPool creates a new *sql.DB from the given *pgxpool.Pool. Note that this method automatically sets the
// maximum number of idle connections in *sql.DB to zero, since they must be managed from the *pgxpool.Pool. This is
// required to avoid acquiring all the connections from the pgxpool and starving any direct users of the pgxpool.
func OpenDBFromPool(pool *pgxpool.Pool, opts ...OptionOpenDB) *sql.DB {
c := GetPoolConnector(pool, opts...)
db := sql.OpenDB(c)
db.SetMaxIdleConns(0)
return db
}
type connector struct {
pgx.ConnConfig
pool *pgxpool.Pool
BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection
AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection
ResetSession func(context.Context, *pgx.Conn) error // function is called before a connection is reused
@@ -207,25 +246,53 @@ type connector struct {
// Connect implement driver.Connector interface
func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
var (
err error
conn *pgx.Conn
connConfig pgx.ConnConfig
conn *pgx.Conn
close func(context.Context) error
err error
)
// Create a shallow copy of the config, so that BeforeConnect can safely modify it
connConfig := c.ConnConfig
if err = c.BeforeConnect(ctx, &connConfig); err != nil {
return nil, err
if c.pool == nil {
// Create a shallow copy of the config, so that BeforeConnect can safely modify it
connConfig = c.ConnConfig
if err = c.BeforeConnect(ctx, &connConfig); err != nil {
return nil, err
}
if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil {
return nil, err
}
if err = c.AfterConnect(ctx, conn); err != nil {
return nil, err
}
close = conn.Close
} else {
var pconn *pgxpool.Conn
pconn, err = c.pool.Acquire(ctx)
if err != nil {
return nil, err
}
conn = pconn.Conn()
close = func(_ context.Context) error {
pconn.Release()
return nil
}
}
if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil {
return nil, err
}
if err = c.AfterConnect(ctx, conn); err != nil {
return nil, err
}
return &Conn{conn: conn, driver: c.driver, connConfig: connConfig, resetSessionFunc: c.ResetSession}, nil
return &Conn{
conn: conn,
close: close,
driver: c.driver,
connConfig: connConfig,
resetSessionFunc: c.ResetSession,
psRefCounts: make(map[*pgconn.StatementDescription]int),
}, nil
}
// Driver implement driver.Connector interface
@@ -302,9 +369,11 @@ func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) {
c := &Conn{
conn: conn,
close: conn.Close,
driver: dc.driver,
connConfig: *connConfig,
resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil },
psRefCounts: make(map[*pgconn.StatementDescription]int),
}
return c, nil
@@ -326,11 +395,19 @@ func UnregisterConnConfig(connStr string) {
type Conn struct {
conn *pgx.Conn
psCount int64 // Counter used for creating unique prepared statement names
close func(context.Context) error
driver *Driver
connConfig pgx.ConnConfig
resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused
lastResetSessionTime time.Time
// psRefCounts contains reference counts for prepared statements. Prepare uses the underlying pgx logic to generate
// deterministic statement names from the statement text. If this query has already been prepared then the existing
// *pgconn.StatementDescription will be returned. However, this means that if Close is called on the returned Stmt
// then the underlying prepared statement will be closed even when the underlying prepared statement is still in use
// by another database/sql Stmt. To prevent this psRefCounts keeps track of how many database/sql statements are using
// the same underlying statement and only closes the underlying statement when the reference count reaches 0.
psRefCounts map[*pgconn.StatementDescription]int
}
// Conn returns the underlying *pgx.Conn
@@ -347,13 +424,11 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
return nil, driver.ErrBadConn
}
name := fmt.Sprintf("pgx_%d", c.psCount)
c.psCount++
sd, err := c.conn.Prepare(ctx, name, query)
sd, err := c.conn.Prepare(ctx, query, query)
if err != nil {
return nil, err
}
c.psRefCounts[sd]++
return &Stmt{sd: sd, conn: c}, nil
}
@@ -361,7 +436,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
func (c *Conn) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
return c.conn.Close(ctx)
return c.close(ctx)
}
func (c *Conn) Begin() (driver.Tx, error) {
@@ -470,7 +545,7 @@ func (c *Conn) ResetSession(ctx context.Context) error {
now := time.Now()
if now.Sub(c.lastResetSessionTime) > time.Second {
if err := c.conn.PgConn().CheckConn(); err != nil {
if err := c.conn.PgConn().Ping(ctx); err != nil {
return driver.ErrBadConn
}
}
@@ -487,7 +562,16 @@ type Stmt struct {
func (s *Stmt) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
return s.conn.conn.Deallocate(ctx, s.sd.Name)
refCount := s.conn.psRefCounts[s.sd]
if refCount == 1 {
delete(s.conn.psRefCounts, s.sd)
} else {
s.conn.psRefCounts[s.sd]--
return nil
}
return s.conn.conn.Deallocate(ctx, s.sd.SQL)
}
func (s *Stmt) NumInput() int {
@@ -499,7 +583,7 @@ func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
}
func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) {
return s.conn.ExecContext(ctx, s.sd.Name, argsV)
return s.conn.ExecContext(ctx, s.sd.SQL, argsV)
}
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
@@ -507,7 +591,7 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
}
func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) {
return s.conn.QueryContext(ctx, s.sd.Name, argsV)
return s.conn.QueryContext(ctx, s.sd.SQL, argsV)
}
type rowValueFunc func(src []byte) (driver.Value, error)
@@ -753,7 +837,7 @@ func (r *Rows) Next(dest []driver.Value) error {
var err error
dest[i], err = r.valueFuncs[i](rv)
if err != nil {
return fmt.Errorf("convert field %d failed: %v", i, err)
return fmt.Errorf("convert field %d failed: %w", i, err)
}
} else {
dest[i] = nil

View File

@@ -55,7 +55,11 @@ func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, er
func tryScanStringCopyValueThenEncode(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) {
s, ok := arg.(string)
if !ok {
return nil, errors.New("not a string")
textBuf, err := m.Encode(oid, TextFormatCode, arg, nil)
if err != nil {
return nil, errors.New("not a string and cannot be encoded as text")
}
s = string(textBuf)
}
var v any