From 11a0d2ab81feb3d6a678cc066f3d2754ddf35f0c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Nov 2025 18:39:44 -0500 Subject: [PATCH 01/15] POC for skipping Describe Portal when executing prepared statements --- bench_test.go | 80 +++++++++++++++++++++++++ pgconn/pgconn.go | 68 +++++++++++++++++++-- pgconn/pgconn_test.go | 136 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 278 insertions(+), 6 deletions(-) diff --git a/bench_test.go b/bench_test.go index f26251182..2af3424d2 100644 --- a/bench_test.go +++ b/bench_test.go @@ -158,6 +158,41 @@ func BenchmarkMinimalPgConnPreparedSelect(b *testing.B) { } } +func BenchmarkMinimalPgConnPreparedStatementDescriptionSelect(b *testing.B) { + conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) + defer closeConn(b, conn) + + pgConn := conn.PgConn() + + psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::int8", nil) + if err != nil { + b.Fatal(err) + } + + encodedBytes := make([]byte, 8) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + + rr := pgConn.ExecPreparedStatementDescription(context.Background(), psd, [][]byte{encodedBytes}, []int16{1}, []int16{1}) + if err != nil { + b.Fatal(err) + } + + for rr.NextRow() { + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[0], encodedBytes) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], encodedBytes) + } + } + } + _, err = rr.Close() + if err != nil { + b.Fatal(err) + } + } +} + func BenchmarkPointerPointerWithNullValues(b *testing.B) { conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) @@ -1263,6 +1298,51 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) { } } +func BenchmarkSelectRowsPgConnExecPreparedStatementDescription(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + psd, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil) + if err != nil { + b.Fatal(err) + } + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + formats := []struct { + name string + code int16 + }{ + {"text", pgx.TextFormatCode}, + {"binary - mostly", pgx.BinaryFormatCode}, + } + for _, format := range formats { + b.Run(format.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + rr := conn.PgConn().ExecPreparedStatementDescription( + context.Background(), + psd, + [][]byte{[]byte(strconv.FormatInt(rowCount, 10))}, + nil, + []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code}, + ) + for rr.NextRow() { + rr.Values() + } + + _, err := rr.Close() + if err != nil { + b.Fatal(err) + } + } + }) + } + }) + } +} + type queryRecorder struct { conn net.Conn writeBuf []byte diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index a18f18741..123b94ce1 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -23,6 +23,7 @@ import ( "github.com/jackc/pgx/v5/pgconn/ctxwatch" "github.com/jackc/pgx/v5/pgconn/internal/bgreader" "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" ) const ( @@ -1159,7 +1160,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] pgConn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) pgConn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) - pgConn.execExtendedSuffix(result) + pgConn.execExtendedSuffix(result, nil, nil) return result } @@ -1184,7 +1185,37 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) - pgConn.execExtendedSuffix(result) + pgConn.execExtendedSuffix(result, nil, nil) + + return result +} + +// ExecPreparedStatementDescription enqueues the execution of a prepared statement via the PostgreSQL extended query +// protocol. +// +// This differs from ExecPrepared in that it takes a *StatementDescription instead of just the prepared statement name. +// Because it has the *StatementDescription it can avoid the Describe Portal message that ExecPrepared must send to get +// the result column descriptions. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all params are text format. ExecPrepared will panic if len(paramFormats) is not +// 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or binary +// format. If resultFormats is nil all results will be in text format. +// +// ResultReader must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecPreparedStatementDescription(ctx context.Context, statementDescription *StatementDescription, paramValues [][]byte, paramFormats, resultFormats []int16) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { + return result + } + + pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: statementDescription.Name, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + + pgConn.execExtendedSuffix(result, statementDescription, resultFormats) return result } @@ -1224,8 +1255,10 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } -func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { - pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) +func (pgConn *PgConn) execExtendedSuffix(result *ResultReader, statementDescription *StatementDescription, resultFormats []int16) { + if statementDescription == nil { + pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + } pgConn.frontend.SendExecute(&pgproto3.Execute{}) pgConn.frontend.SendSync(&pgproto3.Sync{}) @@ -1239,7 +1272,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { return } - result.readUntilRowDescription() + result.readUntilRowDescription(statementDescription, resultFormats) } // CopyTo executes the copy command sql and copies the results to w. @@ -1656,13 +1689,36 @@ func (rr *ResultReader) Close() (CommandTag, error) { // readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any // error will be stored in the ResultReader. -func (rr *ResultReader) readUntilRowDescription() { +func (rr *ResultReader) readUntilRowDescription(statementDescription *StatementDescription, resultFormats []int16) { for !rr.commandConcluded { // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. // This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are // manually used to construct a query that does not issue a describe statement. msg, _ := rr.pgConn.peekMessage() if _, ok := msg.(*pgproto3.DataRow); ok { + if statementDescription != nil { + rr.fieldDescriptions = statementDescription.Fields + // Adjust field descriptions for resultFormats + if len(resultFormats) == 0 { + // No format codes provided, default to text format + for i := range rr.fieldDescriptions { + rr.fieldDescriptions[i].Format = pgtype.TextFormatCode + } + } else if len(resultFormats) == 1 { + // Single format code applies to all columns + for i := range rr.fieldDescriptions { + rr.fieldDescriptions[i].Format = resultFormats[0] + } + } else if len(resultFormats) == len(rr.fieldDescriptions) { + // One format code per column + for i := range rr.fieldDescriptions { + rr.fieldDescriptions[i].Format = resultFormats[i] + } + } else { + // This should be impossible to reach as the mismatch would have been caught earlier. + rr.concludeCommand(CommandTag{}, fmt.Errorf("mismatched result format codes length")) + } + } return } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 001b6345e..3463b659d 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -1435,6 +1435,142 @@ func TestConnExecPreparedEmptySQL(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecPreparedStatementDescription(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + psd, err := pgConn.Prepare(ctx, "ps1", "select $1::text as msg", nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPreparedStatementDescription(ctx, psd, [][]byte{[]byte("Hello, world")}, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, "msg", result.FieldDescriptions()[0].Name) + + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", commandTag.String()) + assert.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +type byteCounterConn struct { + conn net.Conn + bytesRead int + bytesWritten int +} + +func (cbn *byteCounterConn) Read(b []byte) (n int, err error) { + n, err = cbn.conn.Read(b) + cbn.bytesRead += n + return n, err +} + +func (cbn *byteCounterConn) Write(b []byte) (n int, err error) { + n, err = cbn.conn.Write(b) + cbn.bytesWritten += n + return n, err +} + +func (cbn *byteCounterConn) Close() error { + return cbn.conn.Close() +} + +func (cbn *byteCounterConn) LocalAddr() net.Addr { + return cbn.conn.LocalAddr() +} + +func (cbn *byteCounterConn) RemoteAddr() net.Addr { + return cbn.conn.RemoteAddr() +} + +func (cbn *byteCounterConn) SetDeadline(t time.Time) error { + return cbn.conn.SetDeadline(t) +} + +func (cbn *byteCounterConn) SetReadDeadline(t time.Time) error { + return cbn.conn.SetReadDeadline(t) +} + +func (cbn *byteCounterConn) SetWriteDeadline(t time.Time) error { + return cbn.conn.SetWriteDeadline(t) +} + +func TestConnExecPreparedStatementDescriptionNetworkUsage(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + var counterConn *byteCounterConn + config.AfterNetConnect = func(ctx context.Context, config *pgconn.Config, conn net.Conn) (net.Conn, error) { + counterConn = &byteCounterConn{conn: conn} + return counterConn, nil + } + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + require.NotNil(t, counterConn) + + psd, err := pgConn.Prepare(ctx, "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 9) + + counterConn.bytesWritten = 0 + counterConn.bytesRead = 0 + + result := pgConn.ExecPrepared(ctx, + psd.Name, + [][]byte{[]byte("1")}, + nil, + []int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode}, + ).Read() + require.NoError(t, result.Err) + withDescribeBytesWritten := counterConn.bytesWritten + withDescribeBytesRead := counterConn.bytesRead + + counterConn.bytesWritten = 0 + counterConn.bytesRead = 0 + + result = pgConn.ExecPreparedStatementDescription( + ctx, + psd, + [][]byte{[]byte("1")}, + nil, + []int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode}, + ).Read() + require.NoError(t, result.Err) + noDescribeBytesWritten := counterConn.bytesWritten + noDescribeBytesRead := counterConn.bytesRead + + assert.Equal(t, 61, withDescribeBytesWritten) + assert.Equal(t, 54, noDescribeBytesWritten) + assert.Equal(t, 391, withDescribeBytesRead) + assert.Equal(t, 153, noDescribeBytesRead) + + ensureConnValid(t, pgConn) +} + func TestConnExecBatch(t *testing.T) { t.Parallel() From d68eeb8a061519fc724cb24c9027065dc788ef9d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Nov 2025 07:14:43 -0600 Subject: [PATCH 02/15] Skip test on CockroachDB --- pgconn/pgconn_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 3463b659d..ffd3824e5 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -1530,6 +1530,10 @@ func TestConnExecPreparedStatementDescriptionNetworkUsage(t *testing.T) { defer closeConn(t, pgConn) require.NotNil(t, counterConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server uses different number of bytes for same operations") + } + psd, err := pgConn.Prepare(ctx, "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil) require.NoError(t, err) require.NotNil(t, psd) From 308961aaba214333b516a58374b5b68c2c4789b4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 26 Dec 2025 15:07:21 -0600 Subject: [PATCH 03/15] Add ExecPreparedStatementDescription to pgconn.Batch Refactor row / field description handling. --- pgconn/pgconn.go | 161 +++++++++++++++++++++++++++++++++--------- pgconn/pgconn_test.go | 25 ++++++- 2 files changed, 152 insertions(+), 34 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 123b94ce1..f01d08d6d 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -829,13 +829,15 @@ type FieldDescription struct { Format int16 } -func (pgConn *PgConn) convertRowDescription(dst []FieldDescription, rd *pgproto3.RowDescription) []FieldDescription { - if cap(dst) >= len(rd.Fields) { - dst = dst[:len(rd.Fields):len(rd.Fields)] +func (pgConn *PgConn) getFieldDescriptionSlice(n int) []FieldDescription { + if cap(pgConn.fieldDescriptions) >= n { + return pgConn.fieldDescriptions[:n:n] } else { - dst = make([]FieldDescription, len(rd.Fields)) + return make([]FieldDescription, n) } +} +func convertRowDescription(dst []FieldDescription, rd *pgproto3.RowDescription) { for i := range rd.Fields { dst[i].Name = string(rd.Fields[i].Name) dst[i].TableOID = rd.Fields[i].TableOID @@ -845,8 +847,6 @@ func (pgConn *PgConn) convertRowDescription(dst []FieldDescription, rd *pgproto3 dst[i].TypeModifier = rd.Fields[i].TypeModifier dst[i].Format = rd.Fields[i].Format } - - return dst } type StatementDescription struct { @@ -910,7 +910,8 @@ readloop: psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) copy(psd.ParamOIDs, msg.ParameterOIDs) case *pgproto3.RowDescription: - psd.Fields = pgConn.convertRowDescription(nil, msg) + psd.Fields = make([]FieldDescription, len(msg.Fields)) + convertRowDescription(psd.Fields, msg) case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: @@ -1475,6 +1476,10 @@ type MultiResultReader struct { rr *ResultReader + // Data from when the batch was queued. + statementDescriptions []*StatementDescription + resultFormats [][]int16 + closed bool err error } @@ -1516,6 +1521,59 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) // NextResult returns advances the MultiResultReader to the next result and returns true if a result is available. func (mrr *MultiResultReader) NextResult() bool { for !mrr.closed && mrr.err == nil { + msg, _ := mrr.pgConn.peekMessage() + if _, ok := msg.(*pgproto3.DataRow); ok { + if len(mrr.statementDescriptions) > 0 { + rr := ResultReader{ + pgConn: mrr.pgConn, + multiResultReader: mrr, + ctx: mrr.ctx, + } + + // This result corresponds to a prepared statement description that was provided when queuing the batch. + sd := mrr.statementDescriptions[0] + mrr.statementDescriptions = mrr.statementDescriptions[1:] + + resultFormats := mrr.resultFormats[0] + mrr.resultFormats = mrr.resultFormats[1:] + + sdFields := sd.Fields + rr.fieldDescriptions = rr.pgConn.getFieldDescriptionSlice(len(sdFields)) + + switch { + case len(resultFormats) == 0: + // No format codes provided means text format for all columns. + for i := range sdFields { + rr.fieldDescriptions[i] = sdFields[i] + rr.fieldDescriptions[i].Format = pgtype.TextFormatCode + } + case len(resultFormats) == 1: + // Single format code applies to all columns. + format := resultFormats[0] + for i := range sdFields { + rr.fieldDescriptions[i] = sdFields[i] + rr.fieldDescriptions[i].Format = format + } + case len(resultFormats) == len(sdFields): + // One format code per column. + for i := range sdFields { + rr.fieldDescriptions[i] = sdFields[i] + rr.fieldDescriptions[i].Format = resultFormats[i] + } + default: + // This should not occur if Bind validation is correct, but handle gracefully + rr.concludeCommand(CommandTag{}, fmt.Errorf("result format codes length %d does not match field count %d", len(resultFormats), len(sdFields))) + } + + mrr.pgConn.resultReader = rr + mrr.rr = &mrr.pgConn.resultReader + return true + } + + mrr.err = fmt.Errorf("unexpected DataRow message without preceding RowDescription") + return false + } + msg, err := mrr.receiveMessage() if err != nil { return false @@ -1527,8 +1585,9 @@ func (mrr *MultiResultReader) NextResult() bool { pgConn: mrr.pgConn, multiResultReader: mrr, ctx: mrr.ctx, - fieldDescriptions: mrr.pgConn.convertRowDescription(mrr.pgConn.fieldDescriptions[:], msg), + fieldDescriptions: mrr.pgConn.getFieldDescriptionSlice(len(msg.Fields)), } + convertRowDescription(mrr.pgConn.resultReader.fieldDescriptions, msg) mrr.rr = &mrr.pgConn.resultReader return true @@ -1691,32 +1750,38 @@ func (rr *ResultReader) Close() (CommandTag, error) { // error will be stored in the ResultReader. func (rr *ResultReader) readUntilRowDescription(statementDescription *StatementDescription, resultFormats []int16) { for !rr.commandConcluded { - // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. - // This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are + // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. This + // is expected if statementDescription is not nil, but it is also possible if SendBytes and ReceiveResults are // manually used to construct a query that does not issue a describe statement. msg, _ := rr.pgConn.peekMessage() if _, ok := msg.(*pgproto3.DataRow); ok { if statementDescription != nil { - rr.fieldDescriptions = statementDescription.Fields - // Adjust field descriptions for resultFormats - if len(resultFormats) == 0 { - // No format codes provided, default to text format - for i := range rr.fieldDescriptions { + sdFields := statementDescription.Fields + rr.fieldDescriptions = rr.pgConn.getFieldDescriptionSlice(len(sdFields)) + + switch { + case len(resultFormats) == 0: + // No format codes provided means text format for all columns. + for i := range sdFields { + rr.fieldDescriptions[i] = sdFields[i] rr.fieldDescriptions[i].Format = pgtype.TextFormatCode } - } else if len(resultFormats) == 1 { - // Single format code applies to all columns - for i := range rr.fieldDescriptions { - rr.fieldDescriptions[i].Format = resultFormats[0] + case len(resultFormats) == 1: + // Single format code applies to all columns. + format := resultFormats[0] + for i := range sdFields { + rr.fieldDescriptions[i] = sdFields[i] + rr.fieldDescriptions[i].Format = format } - } else if len(resultFormats) == len(rr.fieldDescriptions) { - // One format code per column - for i := range rr.fieldDescriptions { + case len(resultFormats) == len(sdFields): + // One format code per column. + for i := range sdFields { + rr.fieldDescriptions[i] = sdFields[i] rr.fieldDescriptions[i].Format = resultFormats[i] } - } else { - // This should be impossible to reach as the mismatch would have been caught earlier. - rr.concludeCommand(CommandTag{}, fmt.Errorf("mismatched result format codes length")) + default: + // This should not occur if Bind validation is correct, but handle gracefully + rr.concludeCommand(CommandTag{}, fmt.Errorf("result format codes length %d does not match field count %d", len(resultFormats), len(sdFields))) } } return @@ -1751,7 +1816,8 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error switch msg := msg.(type) { case *pgproto3.RowDescription: - rr.fieldDescriptions = rr.pgConn.convertRowDescription(rr.pgConn.fieldDescriptions[:], msg) + rr.fieldDescriptions = rr.pgConn.getFieldDescriptionSlice(len(msg.Fields)) + convertRowDescription(rr.fieldDescriptions, msg) case *pgproto3.CommandComplete: rr.concludeCommand(rr.pgConn.makeCommandTag(msg.CommandTag), nil) case *pgproto3.EmptyQueryResponse: @@ -1785,8 +1851,10 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { - buf []byte - err error + buf []byte + statementDescriptions []*StatementDescription + resultFormats [][]int16 + err error } // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. @@ -1824,6 +1892,31 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor } } +// ExecPreparedStatementDescription appends an ExecPreparedStatementDescription command to the batch. See +// PgConn.ExecPrepared for parameter descriptions. +// +// This differs from ExecPrepared in that it takes a *StatementDescription instead of just the prepared statement name. +// Because it has the *StatementDescription it can avoid the Describe Portal message that ExecPrepared must send to get +// the result column descriptions. +func (batch *Batch) ExecPreparedStatementDescription(statementDescription *StatementDescription, paramValues [][]byte, paramFormats, resultFormats []int16) { + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: statementDescription.Name, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.statementDescriptions = append(batch.statementDescriptions, statementDescription) + batch.resultFormats = append(batch.resultFormats, resultFormats) + + batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf) + if batch.err != nil { + return + } +} + // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing // multiple queries in a single round trip than using pipeline mode. @@ -1843,8 +1936,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR } pgConn.multiResultReader = MultiResultReader{ - pgConn: pgConn, - ctx: ctx, + pgConn: pgConn, + ctx: ctx, + statementDescriptions: batch.statementDescriptions, + resultFormats: batch.resultFormats, } multiResult := &pgConn.multiResultReader @@ -2402,8 +2497,9 @@ func (p *Pipeline) getResults() (results any, err error) { pgConn: p.conn, pipeline: p, ctx: p.ctx, - fieldDescriptions: p.conn.convertRowDescription(p.conn.fieldDescriptions[:], msg), + fieldDescriptions: p.conn.getFieldDescriptionSlice(len(msg.Fields)), } + convertRowDescription(p.conn.resultReader.fieldDescriptions, msg) return &p.conn.resultReader, nil case *pgproto3.CommandComplete: p.conn.resultReader = ResultReader{ @@ -2449,7 +2545,8 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) copy(psd.ParamOIDs, msg.ParameterOIDs) case *pgproto3.RowDescription: - psd.Fields = p.conn.convertRowDescription(nil, msg) + psd.Fields = make([]FieldDescription, len(msg.Fields)) + convertRowDescription(psd.Fields, msg) return psd, nil // NoData is returned instead of RowDescription when there is no expected result. e.g. An INSERT without a RETURNING diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index ffd3824e5..988a642a9 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -1588,14 +1588,20 @@ func TestConnExecBatch(t *testing.T) { _, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil) require.NoError(t, err) + sd, err := pgConn.Prepare(ctx, "ps2", "select $1::text as name, $2::int as age", nil) + require.NoError(t, err) + batch := &pgconn.Batch{} batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecPreparedStatementDescription(sd, [][]byte{[]byte("ExecPreparedStatementDescription 1"), []byte("42")}, nil, nil) + batch.ExecPreparedStatementDescription(sd, [][]byte{[]byte("ExecPreparedStatementDescription 2"), []byte("43")}, nil, []int16{pgx.BinaryFormatCode}) + batch.ExecPreparedStatementDescription(sd, [][]byte{[]byte("ExecPreparedStatementDescription 3"), []byte("44")}, nil, []int16{pgx.TextFormatCode, pgx.BinaryFormatCode}) batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) results, err := pgConn.ExecBatch(ctx, batch).ReadAll() require.NoError(t, err) - require.Len(t, results, 3) + require.Len(t, results, 6) require.Len(t, results[0].Rows, 1) require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) @@ -1606,7 +1612,22 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", results[1].CommandTag.String()) require.Len(t, results[2].Rows, 1) - require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) + require.Equal(t, "ExecPreparedStatementDescription 1", string(results[2].Rows[0][0])) + require.Equal(t, "42", string(results[2].Rows[0][1])) + assert.Equal(t, "SELECT 1", results[2].CommandTag.String()) + + require.Len(t, results[3].Rows, 1) + require.Equal(t, "ExecPreparedStatementDescription 2", string(results[3].Rows[0][0])) + require.Equal(t, []byte{0, 0, 0, 43}, results[3].Rows[0][1]) + assert.Equal(t, "SELECT 1", results[3].CommandTag.String()) + + require.Len(t, results[4].Rows, 1) + require.Equal(t, "ExecPreparedStatementDescription 3", string(results[4].Rows[0][0])) + require.Equal(t, []byte{0, 0, 0, 44}, results[4].Rows[0][1]) + assert.Equal(t, "SELECT 1", results[4].CommandTag.String()) + + require.Len(t, results[5].Rows, 1) + require.Equal(t, "ExecParams 2", string(results[5].Rows[0][0])) assert.Equal(t, "SELECT 1", results[2].CommandTag.String()) } From 4f1d5351397f42f479fe930785ba3a3465ae903a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 26 Dec 2025 15:15:11 -0600 Subject: [PATCH 04/15] Make test compatible with CockroachDB --- pgconn/pgconn_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 988a642a9..c775329dd 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -1588,7 +1588,7 @@ func TestConnExecBatch(t *testing.T) { _, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil) require.NoError(t, err) - sd, err := pgConn.Prepare(ctx, "ps2", "select $1::text as name, $2::int as age", nil) + sd, err := pgConn.Prepare(ctx, "ps2", "select $1::text as name, $2::bigint as age", nil) require.NoError(t, err) batch := &pgconn.Batch{} @@ -1618,12 +1618,12 @@ func TestConnExecBatch(t *testing.T) { require.Len(t, results[3].Rows, 1) require.Equal(t, "ExecPreparedStatementDescription 2", string(results[3].Rows[0][0])) - require.Equal(t, []byte{0, 0, 0, 43}, results[3].Rows[0][1]) + require.Equal(t, []byte{0, 0, 0, 0, 0, 0, 0, 43}, results[3].Rows[0][1]) assert.Equal(t, "SELECT 1", results[3].CommandTag.String()) require.Len(t, results[4].Rows, 1) require.Equal(t, "ExecPreparedStatementDescription 3", string(results[4].Rows[0][0])) - require.Equal(t, []byte{0, 0, 0, 44}, results[4].Rows[0][1]) + require.Equal(t, []byte{0, 0, 0, 0, 0, 0, 0, 44}, results[4].Rows[0][1]) assert.Equal(t, "SELECT 1", results[4].CommandTag.String()) require.Len(t, results[5].Rows, 1) From d2a8c84da9308e6aaba8f5593a48da6bab3d3639 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Dec 2025 08:43:21 -0600 Subject: [PATCH 05/15] Rename ExecPreparedStatementDescription to ExecStatement --- bench_test.go | 6 +++--- pgconn/pgconn.go | 12 +++++------- pgconn/pgconn_test.go | 20 ++++++++++---------- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/bench_test.go b/bench_test.go index 2af3424d2..9b3fa5c61 100644 --- a/bench_test.go +++ b/bench_test.go @@ -174,7 +174,7 @@ func BenchmarkMinimalPgConnPreparedStatementDescriptionSelect(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - rr := pgConn.ExecPreparedStatementDescription(context.Background(), psd, [][]byte{encodedBytes}, []int16{1}, []int16{1}) + rr := pgConn.ExecStatement(context.Background(), psd, [][]byte{encodedBytes}, []int16{1}, []int16{1}) if err != nil { b.Fatal(err) } @@ -1298,7 +1298,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) { } } -func BenchmarkSelectRowsPgConnExecPreparedStatementDescription(b *testing.B) { +func BenchmarkSelectRowsPgConnExecStatement(b *testing.B) { conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(b, conn) @@ -1321,7 +1321,7 @@ func BenchmarkSelectRowsPgConnExecPreparedStatementDescription(b *testing.B) { for _, format := range formats { b.Run(format.name, func(b *testing.B) { for i := 0; i < b.N; i++ { - rr := conn.PgConn().ExecPreparedStatementDescription( + rr := conn.PgConn().ExecStatement( context.Background(), psd, [][]byte{[]byte(strconv.FormatInt(rowCount, 10))}, diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index f01d08d6d..4db02c30e 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1191,10 +1191,9 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } -// ExecPreparedStatementDescription enqueues the execution of a prepared statement via the PostgreSQL extended query -// protocol. +// ExecStatement enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. // -// This differs from ExecPrepared in that it takes a *StatementDescription instead of just the prepared statement name. +// This differs from ExecPrepared in that it takes a *StatementDescription instead of the prepared statement name. // Because it has the *StatementDescription it can avoid the Describe Portal message that ExecPrepared must send to get // the result column descriptions. // @@ -1208,7 +1207,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa // format. If resultFormats is nil all results will be in text format. // // ResultReader must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecPreparedStatementDescription(ctx context.Context, statementDescription *StatementDescription, paramValues [][]byte, paramFormats, resultFormats []int16) *ResultReader { +func (pgConn *PgConn) ExecStatement(ctx context.Context, statementDescription *StatementDescription, paramValues [][]byte, paramFormats, resultFormats []int16) *ResultReader { result := pgConn.execExtendedPrefix(ctx, paramValues) if result.closed { return result @@ -1892,13 +1891,12 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor } } -// ExecPreparedStatementDescription appends an ExecPreparedStatementDescription command to the batch. See -// PgConn.ExecPrepared for parameter descriptions. +// ExecStatement appends an ExecStatement command to the batch. See PgConn.ExecPrepared for parameter descriptions. // // This differs from ExecPrepared in that it takes a *StatementDescription instead of just the prepared statement name. // Because it has the *StatementDescription it can avoid the Describe Portal message that ExecPrepared must send to get // the result column descriptions. -func (batch *Batch) ExecPreparedStatementDescription(statementDescription *StatementDescription, paramValues [][]byte, paramFormats, resultFormats []int16) { +func (batch *Batch) ExecStatement(statementDescription *StatementDescription, paramValues [][]byte, paramFormats, resultFormats []int16) { if batch.err != nil { return } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index c775329dd..04b69b13a 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -1435,7 +1435,7 @@ func TestConnExecPreparedEmptySQL(t *testing.T) { ensureConnValid(t, pgConn) } -func TestConnExecPreparedStatementDescription(t *testing.T) { +func TestConnExecStatement(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) @@ -1451,7 +1451,7 @@ func TestConnExecPreparedStatementDescription(t *testing.T) { assert.Len(t, psd.ParamOIDs, 1) assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPreparedStatementDescription(ctx, psd, [][]byte{[]byte("Hello, world")}, nil, nil) + result := pgConn.ExecStatement(ctx, psd, [][]byte{[]byte("Hello, world")}, nil, nil) require.Len(t, result.FieldDescriptions(), 1) assert.Equal(t, "msg", result.FieldDescriptions()[0].Name) @@ -1510,7 +1510,7 @@ func (cbn *byteCounterConn) SetWriteDeadline(t time.Time) error { return cbn.conn.SetWriteDeadline(t) } -func TestConnExecPreparedStatementDescriptionNetworkUsage(t *testing.T) { +func TestConnExecStatementNetworkUsage(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) @@ -1556,7 +1556,7 @@ func TestConnExecPreparedStatementDescriptionNetworkUsage(t *testing.T) { counterConn.bytesWritten = 0 counterConn.bytesRead = 0 - result = pgConn.ExecPreparedStatementDescription( + result = pgConn.ExecStatement( ctx, psd, [][]byte{[]byte("1")}, @@ -1595,9 +1595,9 @@ func TestConnExecBatch(t *testing.T) { batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) - batch.ExecPreparedStatementDescription(sd, [][]byte{[]byte("ExecPreparedStatementDescription 1"), []byte("42")}, nil, nil) - batch.ExecPreparedStatementDescription(sd, [][]byte{[]byte("ExecPreparedStatementDescription 2"), []byte("43")}, nil, []int16{pgx.BinaryFormatCode}) - batch.ExecPreparedStatementDescription(sd, [][]byte{[]byte("ExecPreparedStatementDescription 3"), []byte("44")}, nil, []int16{pgx.TextFormatCode, pgx.BinaryFormatCode}) + batch.ExecStatement(sd, [][]byte{[]byte("ExecStatement 1"), []byte("42")}, nil, nil) + batch.ExecStatement(sd, [][]byte{[]byte("ExecStatement 2"), []byte("43")}, nil, []int16{pgx.BinaryFormatCode}) + batch.ExecStatement(sd, [][]byte{[]byte("ExecStatement 3"), []byte("44")}, nil, []int16{pgx.TextFormatCode, pgx.BinaryFormatCode}) batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) results, err := pgConn.ExecBatch(ctx, batch).ReadAll() require.NoError(t, err) @@ -1612,17 +1612,17 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", results[1].CommandTag.String()) require.Len(t, results[2].Rows, 1) - require.Equal(t, "ExecPreparedStatementDescription 1", string(results[2].Rows[0][0])) + require.Equal(t, "ExecStatement 1", string(results[2].Rows[0][0])) require.Equal(t, "42", string(results[2].Rows[0][1])) assert.Equal(t, "SELECT 1", results[2].CommandTag.String()) require.Len(t, results[3].Rows, 1) - require.Equal(t, "ExecPreparedStatementDescription 2", string(results[3].Rows[0][0])) + require.Equal(t, "ExecStatement 2", string(results[3].Rows[0][0])) require.Equal(t, []byte{0, 0, 0, 0, 0, 0, 0, 43}, results[3].Rows[0][1]) assert.Equal(t, "SELECT 1", results[3].CommandTag.String()) require.Len(t, results[4].Rows, 1) - require.Equal(t, "ExecPreparedStatementDescription 3", string(results[4].Rows[0][0])) + require.Equal(t, "ExecStatement 3", string(results[4].Rows[0][0])) require.Equal(t, []byte{0, 0, 0, 0, 0, 0, 0, 44}, results[4].Rows[0][1]) assert.Equal(t, "SELECT 1", results[4].CommandTag.String()) From fabf278a46da3a8877eff87c89a678b48ef6534c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 31 Dec 2025 11:45:20 -0600 Subject: [PATCH 06/15] Add Pipeline.SendQueryStatement Queue execution of prepared statement without describing portal in pipeline mode. Ugly and hacky but works for now. --- pgconn/pgconn.go | 135 ++++++++++++++++++++++++++++++++++++++---- pgconn/pgconn_test.go | 51 ++++++++++++++++ 2 files changed, 173 insertions(+), 13 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 4db02c30e..9c52b7a25 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -2172,9 +2172,10 @@ func Construct(hc *HijackedConn) (*PgConn, error) { // Pipeline represents a connection in pipeline mode. // -// SendPrepare, SendQueryParams, and SendQueryPrepared queue requests to the server. These requests are not written until -// pipeline is flushed by Flush or Sync. Sync must be called after the last request is queued. Requests between -// synchronization points are implicitly transactional unless explicit transaction control statements have been issued. +// SendPrepare, SendQueryParams, SendQueryPrepared, and SendQueryStatement queue requests to the server. These requests +// are not written until pipeline is flushed by Flush or Sync. Sync must be called after the last request is queued. +// Requests between synchronization points are implicitly transactional unless explicit transaction control statements +// have been issued. // // The context the pipeline was started with is in effect for the entire life of the Pipeline. // @@ -2203,6 +2204,7 @@ const ( pipelinePrepare pipelineQueryParams pipelineQueryPrepared + pipelineQueryStatement pipelineDeallocate pipelineSyncRequest pipelineFlushRequest @@ -2216,6 +2218,8 @@ type pipelineRequestEvent struct { type pipelineState struct { requestEventQueue list.List + statementDescriptionsQueue list.List + resultFormatsQueue list.List lastRequestType pipelineRequestType pgErr *PgError expectedReadyForQueryCount int @@ -2223,6 +2227,8 @@ type pipelineState struct { func (s *pipelineState) Init() { s.requestEventQueue.Init() + s.statementDescriptionsQueue.Init() + s.resultFormatsQueue.Init() s.lastRequestType = pipelineNil } @@ -2287,6 +2293,29 @@ func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType { } } +func (s *pipelineState) PushBackStatementData(sd *StatementDescription, resultFormats []int16) { + s.statementDescriptionsQueue.PushBack(sd) + s.resultFormatsQueue.PushBack(resultFormats) +} + +func (s *pipelineState) ExtractFrontStatementData() (*StatementDescription, []int16) { + sdElem := s.statementDescriptionsQueue.Front() + var sd *StatementDescription + if sdElem != nil { + s.statementDescriptionsQueue.Remove(sdElem) + sd = sdElem.Value.(*StatementDescription) + } + + rfElem := s.resultFormatsQueue.Front() + var resultFormats []int16 + if rfElem != nil { + s.resultFormatsQueue.Remove(rfElem) + resultFormats = rfElem.Value.([]int16) + } + + return sd, resultFormats +} + func (s *pipelineState) HandleError(err *PgError) { s.pgErr = err } @@ -2329,6 +2358,8 @@ func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline { return pipeline } + pgConn.resultReader = ResultReader{closed: true} + pgConn.pipeline = Pipeline{ conn: pgConn, ctx: ctx, @@ -2398,6 +2429,18 @@ func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, para p.state.PushBackRequestType(pipelineQueryPrepared) } +// SendQueryStatement is the pipeline version of *PgConn.ExecStatement. +func (p *Pipeline) SendQueryStatement(statementDescription *StatementDescription, paramValues [][]byte, paramFormats, resultFormats []int16) { + if p.closed { + return + } + + p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: statementDescription.Name, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + p.conn.frontend.SendExecute(&pgproto3.Execute{}) + p.state.PushBackRequestType(pipelineQueryStatement) + p.state.PushBackStatementData(statementDescription, resultFormats) +} + // SendFlushRequest sends a request for the server to flush its output buffer. // // The server flushes its output buffer automatically as a result of Sync being called, @@ -2472,15 +2515,75 @@ func (p *Pipeline) GetResults() (results any, err error) { return nil, errors.New("pipeline closed") } - if p.state.ExtractFrontRequestType() == pipelineNil { - return nil, nil - } - return p.getResults() } func (p *Pipeline) getResults() (results any, err error) { + if !p.conn.resultReader.closed { + _, err := p.conn.resultReader.Close() + if err != nil { + return nil, err + } + } + + // Get the current request type. Skip over flush requests. + var currentRequestType pipelineRequestType for { + currentRequestType = p.state.ExtractFrontRequestType() + if currentRequestType == pipelineNil { + return nil, nil + } + + if currentRequestType != pipelineFlushRequest { + break + } + } + + for { + if currentRequestType == pipelineQueryStatement { + msg, _ := p.conn.peekMessage() + if _, ok := msg.(*pgproto3.DataRow); ok { + sd, resultFormats := p.state.ExtractFrontStatementData() + if sd != nil || resultFormats != nil { + sdFields := sd.Fields + rr := ResultReader{ + pgConn: p.conn, + pipeline: p, + ctx: p.ctx, + fieldDescriptions: p.conn.getFieldDescriptionSlice(len(sdFields)), + } + + switch { + case len(resultFormats) == 0: + // No format codes provided means text format for all columns. + for i := range sdFields { + rr.fieldDescriptions[i] = sdFields[i] + rr.fieldDescriptions[i].Format = pgtype.TextFormatCode + } + case len(resultFormats) == 1: + // Single format code applies to all columns. + format := resultFormats[0] + for i := range sdFields { + rr.fieldDescriptions[i] = sdFields[i] + rr.fieldDescriptions[i].Format = format + } + case len(resultFormats) == len(sdFields): + // One format code per column. + for i := range sdFields { + rr.fieldDescriptions[i] = sdFields[i] + rr.fieldDescriptions[i].Format = resultFormats[i] + } + default: + // This should not occur if Bind validation is correct, but handle gracefully + return nil, fmt.Errorf("result format codes length %d does not match field count %d", len(resultFormats), len(sdFields)) + } + + p.conn.resultReader = rr + return &p.conn.resultReader, nil + } + } + } + msg, err := p.conn.receiveMessage() if err != nil { p.closed = true @@ -2491,6 +2594,11 @@ func (p *Pipeline) getResults() (results any, err error) { switch msg := msg.(type) { case *pgproto3.RowDescription: + if currentRequestType != pipelineQueryParams && currentRequestType != pipelineQueryPrepared { + p.conn.asyncClose() + return nil, fmt.Errorf("unexpected RowDescription for request type %d", currentRequestType) + } + p.conn.resultReader = ResultReader{ pgConn: p.conn, pipeline: p, @@ -2507,12 +2615,7 @@ func (p *Pipeline) getResults() (results any, err error) { } return &p.conn.resultReader, nil case *pgproto3.ParseComplete: - peekedMsg, err := p.conn.peekMessage() - if err != nil { - p.conn.asyncClose() - return nil, normalizeTimeoutError(p.ctx, err) - } - if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok { + if currentRequestType == pipelinePrepare { return p.getResultsPrepare() } case *pgproto3.CloseComplete: @@ -2521,8 +2624,14 @@ func (p *Pipeline) getResults() (results any, err error) { p.state.HandleReadyForQuery() return &PipelineSync{}, nil case *pgproto3.ErrorResponse: + // Error message that is received while expecting a Sync message still consumes the expected Sync. Put it back. + if currentRequestType == pipelineSyncRequest { + p.state.requestEventQueue.PushFront(pipelineRequestEvent{RequestType: pipelineSyncRequest, WasSentToServer: true, BeforeFlushOrSync: true}) + } + pgErr := ErrorResponseToPgError(msg) p.state.HandleError(pgErr) + p.conn.resultReader.closed = true return nil, pgErr } } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 04b69b13a..3fd43a031 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -3760,6 +3760,47 @@ func TestPipelineFlushWithError(t *testing.T) { ensureConnValid(t, pgConn) } +func TestPipelineGetResultsHandlesPartiallyReadResults(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + sd, err := pgConn.Prepare(ctx, "ps", "select n from generate_series($1::int, $2::int) n", nil) + require.NoError(t, err) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendQueryStatement(sd, [][]byte{[]byte("1"), []byte("3")}, nil, nil) + pipeline.SendQueryStatement(sd, [][]byte{[]byte("5"), []byte("7")}, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + require.True(t, rr.NextRow()) + require.Equal(t, "1", string(rr.Values()[0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + require.True(t, rr.NextRow()) + require.Equal(t, "5", string(rr.Values()[0])) + require.True(t, rr.NextRow()) + require.Equal(t, "6", string(rr.Values()[0])) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + func TestPipelineCloseReadsUnreadResults(t *testing.T) { t.Parallel() @@ -3770,6 +3811,9 @@ func TestPipelineCloseReadsUnreadResults(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + sd, err := pgConn.Prepare(ctx, "ps", "select $1::text as msg", nil) + require.NoError(t, err) + pipeline := pgConn.StartPipeline(ctx) pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) @@ -3782,6 +3826,11 @@ func TestPipelineCloseReadsUnreadResults(t *testing.T) { err = pipeline.Sync() require.NoError(t, err) + pipeline.SendQueryStatement(sd, [][]byte{[]byte("6")}, nil, nil) + pipeline.SendQueryStatement(sd, [][]byte{[]byte("7")}, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + results, err := pipeline.GetResults() require.NoError(t, err) rr, ok := results.(*pgconn.ResultReader) @@ -4269,6 +4318,8 @@ func TestFatalErrorReceivedInPipelineMode(t *testing.T) { steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.SendMessage(&pgproto3.ParseComplete{})) + steps = append(steps, pgmock.SendMessage(&pgproto3.ParameterDescription{})) steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ {Name: []byte("mock")}, }})) From ceba377b1c02b668c1db9a9b84fb0711ba0c9932 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 31 Dec 2025 15:34:32 -0600 Subject: [PATCH 07/15] Allow preloading ResultReader row values in pipeline mode --- pgconn/pgconn.go | 78 +++++++++++++++++++++++++----------------------- 1 file changed, 40 insertions(+), 38 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 9c52b7a25..e1cca7aaa 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1638,6 +1638,7 @@ type ResultReader struct { fieldDescriptions []FieldDescription rowValues [][]byte commandTag CommandTag + preloaded bool commandConcluded bool closed bool err error @@ -1679,6 +1680,11 @@ func (rr *ResultReader) Read() *Result { // NextRow advances the ResultReader to the next row and returns true if a row is available. func (rr *ResultReader) NextRow() bool { + if rr.preloaded { + rr.preloaded = false + return true + } + for !rr.commandConcluded { msg, err := rr.receiveMessage() if err != nil { @@ -1695,6 +1701,11 @@ func (rr *ResultReader) NextRow() bool { return false } +func (rr *ResultReader) preloadRowValues(values [][]byte) { + rr.rowValues = values + rr.preloaded = true +} + // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until // the ResultReader is closed. It may return nil (for example, if the query did not return a result set or an error was // encountered.) @@ -2526,23 +2537,37 @@ func (p *Pipeline) getResults() (results any, err error) { } } - // Get the current request type. Skip over flush requests. - var currentRequestType pipelineRequestType + currentRequestType := p.state.ExtractFrontRequestType() + if currentRequestType == pipelineNil { + return nil, nil + } + for { - currentRequestType = p.state.ExtractFrontRequestType() - if currentRequestType == pipelineNil { - return nil, nil + msg, err := p.conn.receiveMessage() + if err != nil { + p.closed = true + p.err = err + p.conn.asyncClose() + return nil, normalizeTimeoutError(p.ctx, err) } - if currentRequestType != pipelineFlushRequest { - break - } - } + switch msg := msg.(type) { + case *pgproto3.RowDescription: + if currentRequestType != pipelineQueryParams && currentRequestType != pipelineQueryPrepared { + p.conn.asyncClose() + return nil, fmt.Errorf("unexpected RowDescription for request type %d", currentRequestType) + } - for { - if currentRequestType == pipelineQueryStatement { - msg, _ := p.conn.peekMessage() - if _, ok := msg.(*pgproto3.DataRow); ok { + p.conn.resultReader = ResultReader{ + pgConn: p.conn, + pipeline: p, + ctx: p.ctx, + fieldDescriptions: p.conn.getFieldDescriptionSlice(len(msg.Fields)), + } + convertRowDescription(p.conn.resultReader.fieldDescriptions, msg) + return &p.conn.resultReader, nil + case *pgproto3.DataRow: + if currentRequestType == pipelineQueryStatement { sd, resultFormats := p.state.ExtractFrontStatementData() if sd != nil || resultFormats != nil { sdFields := sd.Fields @@ -2578,35 +2603,12 @@ func (p *Pipeline) getResults() (results any, err error) { return nil, fmt.Errorf("result format codes length %d does not match field count %d", len(resultFormats), len(sdFields)) } + rr.preloadRowValues(msg.Values) + p.conn.resultReader = rr return &p.conn.resultReader, nil } } - } - - msg, err := p.conn.receiveMessage() - if err != nil { - p.closed = true - p.err = err - p.conn.asyncClose() - return nil, normalizeTimeoutError(p.ctx, err) - } - - switch msg := msg.(type) { - case *pgproto3.RowDescription: - if currentRequestType != pipelineQueryParams && currentRequestType != pipelineQueryPrepared { - p.conn.asyncClose() - return nil, fmt.Errorf("unexpected RowDescription for request type %d", currentRequestType) - } - - p.conn.resultReader = ResultReader{ - pgConn: p.conn, - pipeline: p, - ctx: p.ctx, - fieldDescriptions: p.conn.getFieldDescriptionSlice(len(msg.Fields)), - } - convertRowDescription(p.conn.resultReader.fieldDescriptions, msg) - return &p.conn.resultReader, nil case *pgproto3.CommandComplete: p.conn.resultReader = ResultReader{ commandTag: p.conn.makeCommandTag(msg.CommandTag), From d21f10154e48620338a1ba6ae74bd26d587e3fc2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 31 Dec 2025 16:42:37 -0600 Subject: [PATCH 08/15] Refactor Pipeline.getResults() Now delegates to specific methods for each request type. --- pgconn/pgconn.go | 265 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 189 insertions(+), 76 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index e1cca7aaa..02176db04 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -2538,9 +2538,30 @@ func (p *Pipeline) getResults() (results any, err error) { } currentRequestType := p.state.ExtractFrontRequestType() - if currentRequestType == pipelineNil { + switch currentRequestType { + case pipelineNil: return nil, nil + case pipelinePrepare: + return p.getResultsPrepare() + case pipelineQueryParams: + return p.getResultsQueryParams() + case pipelineQueryPrepared: + return p.getResultsQueryPrepared() + case pipelineQueryStatement: + return p.getResultsQueryStatement() + case pipelineDeallocate: + return p.getResultsDeallocate() + case pipelineSyncRequest: + return p.getResultsSync() + case pipelineFlushRequest: + return nil, errors.New("BUG: pipelineFlushRequest should not be in request queue") + default: + return nil, errors.New("BUG: unknown pipeline request type") } +} + +func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { + psd := &StatementDescription{} for { msg, err := p.conn.receiveMessage() @@ -2552,12 +2573,84 @@ func (p *Pipeline) getResults() (results any, err error) { } switch msg := msg.(type) { + case *pgproto3.ParseComplete: + case *pgproto3.ParameterDescription: + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) case *pgproto3.RowDescription: - if currentRequestType != pipelineQueryParams && currentRequestType != pipelineQueryPrepared { - p.conn.asyncClose() - return nil, fmt.Errorf("unexpected RowDescription for request type %d", currentRequestType) + psd.Fields = make([]FieldDescription, len(msg.Fields)) + convertRowDescription(psd.Fields, msg) + return psd, nil + + // NoData is returned instead of RowDescription when there is no expected result. e.g. An INSERT without a RETURNING + // clause. + case *pgproto3.NoData: + return psd, nil + + // These should never happen here. But don't take chances that could lead to a deadlock. + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + return nil, pgErr + case *pgproto3.CommandComplete: + p.conn.asyncClose() + return nil, errors.New("BUG: received CommandComplete while handling Describe") + case *pgproto3.ReadyForQuery: + p.conn.asyncClose() + return nil, errors.New("BUG: received ReadyForQuery while handling Describe") + } + } +} + +func (p *Pipeline) getResultsQueryParams() (*ResultReader, error) { + for { + msg, err := p.conn.receiveMessage() + if err != nil { + p.closed = true + p.err = err + p.conn.asyncClose() + return nil, normalizeTimeoutError(p.ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ParseComplete: + case *pgproto3.RowDescription: + p.conn.resultReader = ResultReader{ + pgConn: p.conn, + pipeline: p, + ctx: p.ctx, + fieldDescriptions: p.conn.getFieldDescriptionSlice(len(msg.Fields)), + } + convertRowDescription(p.conn.resultReader.fieldDescriptions, msg) + return &p.conn.resultReader, nil + case *pgproto3.CommandComplete: + p.conn.resultReader = ResultReader{ + commandTag: p.conn.makeCommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, } + return &p.conn.resultReader, nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + } + } +} +func (p *Pipeline) getResultsQueryPrepared() (*ResultReader, error) { + for { + msg, err := p.conn.receiveMessage() + if err != nil { + p.closed = true + p.err = err + p.conn.asyncClose() + return nil, normalizeTimeoutError(p.ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: p.conn.resultReader = ResultReader{ pgConn: p.conn, pipeline: p, @@ -2566,48 +2659,73 @@ func (p *Pipeline) getResults() (results any, err error) { } convertRowDescription(p.conn.resultReader.fieldDescriptions, msg) return &p.conn.resultReader, nil + case *pgproto3.CommandComplete: + p.conn.resultReader = ResultReader{ + commandTag: p.conn.makeCommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + } + return &p.conn.resultReader, nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + } + } +} + +func (p *Pipeline) getResultsQueryStatement() (*ResultReader, error) { + for { + msg, err := p.conn.receiveMessage() + if err != nil { + p.closed = true + p.err = err + p.conn.asyncClose() + return nil, normalizeTimeoutError(p.ctx, err) + } + + switch msg := msg.(type) { case *pgproto3.DataRow: - if currentRequestType == pipelineQueryStatement { - sd, resultFormats := p.state.ExtractFrontStatementData() - if sd != nil || resultFormats != nil { - sdFields := sd.Fields - rr := ResultReader{ - pgConn: p.conn, - pipeline: p, - ctx: p.ctx, - fieldDescriptions: p.conn.getFieldDescriptionSlice(len(sdFields)), - } + sd, resultFormats := p.state.ExtractFrontStatementData() + if sd != nil || resultFormats != nil { + sdFields := sd.Fields + rr := ResultReader{ + pgConn: p.conn, + pipeline: p, + ctx: p.ctx, + fieldDescriptions: p.conn.getFieldDescriptionSlice(len(sdFields)), + } - switch { - case len(resultFormats) == 0: - // No format codes provided means text format for all columns. - for i := range sdFields { - rr.fieldDescriptions[i] = sdFields[i] - rr.fieldDescriptions[i].Format = pgtype.TextFormatCode - } - case len(resultFormats) == 1: - // Single format code applies to all columns. - format := resultFormats[0] - for i := range sdFields { - rr.fieldDescriptions[i] = sdFields[i] - rr.fieldDescriptions[i].Format = format - } - case len(resultFormats) == len(sdFields): - // One format code per column. - for i := range sdFields { - rr.fieldDescriptions[i] = sdFields[i] - rr.fieldDescriptions[i].Format = resultFormats[i] - } - default: - // This should not occur if Bind validation is correct, but handle gracefully - return nil, fmt.Errorf("result format codes length %d does not match field count %d", len(resultFormats), len(sdFields)) + switch { + case len(resultFormats) == 0: + // No format codes provided means text format for all columns. + for i := range sdFields { + rr.fieldDescriptions[i] = sdFields[i] + rr.fieldDescriptions[i].Format = pgtype.TextFormatCode + } + case len(resultFormats) == 1: + // Single format code applies to all columns. + format := resultFormats[0] + for i := range sdFields { + rr.fieldDescriptions[i] = sdFields[i] + rr.fieldDescriptions[i].Format = format + } + case len(resultFormats) == len(sdFields): + // One format code per column. + for i := range sdFields { + rr.fieldDescriptions[i] = sdFields[i] + rr.fieldDescriptions[i].Format = resultFormats[i] } + default: + // This should not occur if Bind validation is correct, but handle gracefully + return nil, fmt.Errorf("result format codes length %d does not match field count %d", len(resultFormats), len(sdFields)) + } - rr.preloadRowValues(msg.Values) + rr.preloadRowValues(msg.Values) - p.conn.resultReader = rr - return &p.conn.resultReader, nil - } + p.conn.resultReader = rr + return &p.conn.resultReader, nil } case *pgproto3.CommandComplete: p.conn.resultReader = ResultReader{ @@ -2616,21 +2734,7 @@ func (p *Pipeline) getResults() (results any, err error) { closed: true, } return &p.conn.resultReader, nil - case *pgproto3.ParseComplete: - if currentRequestType == pipelinePrepare { - return p.getResultsPrepare() - } - case *pgproto3.CloseComplete: - return &CloseComplete{}, nil - case *pgproto3.ReadyForQuery: - p.state.HandleReadyForQuery() - return &PipelineSync{}, nil case *pgproto3.ErrorResponse: - // Error message that is received while expecting a Sync message still consumes the expected Sync. Put it back. - if currentRequestType == pipelineSyncRequest { - p.state.requestEventQueue.PushFront(pipelineRequestEvent{RequestType: pipelineSyncRequest, WasSentToServer: true, BeforeFlushOrSync: true}) - } - pgErr := ErrorResponseToPgError(msg) p.state.HandleError(pgErr) p.conn.resultReader.closed = true @@ -2639,41 +2743,50 @@ func (p *Pipeline) getResults() (results any, err error) { } } -func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { - psd := &StatementDescription{} - +func (p *Pipeline) getResultsDeallocate() (*CloseComplete, error) { for { msg, err := p.conn.receiveMessage() if err != nil { + p.closed = true + p.err = err p.conn.asyncClose() return nil, normalizeTimeoutError(p.ctx, err) } switch msg := msg.(type) { - case *pgproto3.ParameterDescription: - psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) - copy(psd.ParamOIDs, msg.ParameterOIDs) - case *pgproto3.RowDescription: - psd.Fields = make([]FieldDescription, len(msg.Fields)) - convertRowDescription(psd.Fields, msg) - return psd, nil - - // NoData is returned instead of RowDescription when there is no expected result. e.g. An INSERT without a RETURNING - // clause. - case *pgproto3.NoData: - return psd, nil - - // These should never happen here. But don't take chances that could lead to a deadlock. + case *pgproto3.CloseComplete: + return &CloseComplete{}, nil case *pgproto3.ErrorResponse: pgErr := ErrorResponseToPgError(msg) p.state.HandleError(pgErr) + p.conn.resultReader.closed = true return nil, pgErr - case *pgproto3.CommandComplete: + } + } +} + +func (p *Pipeline) getResultsSync() (*PipelineSync, error) { + for { + msg, err := p.conn.receiveMessage() + if err != nil { + p.closed = true + p.err = err p.conn.asyncClose() - return nil, errors.New("BUG: received CommandComplete while handling Describe") + return nil, normalizeTimeoutError(p.ctx, err) + } + + switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - p.conn.asyncClose() - return nil, errors.New("BUG: received ReadyForQuery while handling Describe") + p.state.HandleReadyForQuery() + return &PipelineSync{}, nil + case *pgproto3.ErrorResponse: + // Error message that is received while expecting a Sync message still consumes the expected Sync. Put it back. + p.state.requestEventQueue.PushFront(pipelineRequestEvent{RequestType: pipelineSyncRequest, WasSentToServer: true, BeforeFlushOrSync: true}) + + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr } } } From 3172802f466b8082c200a5605b5f84da46915cc2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 31 Dec 2025 16:47:36 -0600 Subject: [PATCH 09/15] Extract Pipeline.receiveMessage --- pgconn/pgconn.go | 53 +++++++++++++++++++++--------------------------- 1 file changed, 23 insertions(+), 30 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 02176db04..51ad251e8 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -2564,12 +2564,9 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { psd := &StatementDescription{} for { - msg, err := p.conn.receiveMessage() + msg, err := p.receiveMessage() if err != nil { - p.closed = true - p.err = err - p.conn.asyncClose() - return nil, normalizeTimeoutError(p.ctx, err) + return nil, err } switch msg := msg.(type) { @@ -2604,12 +2601,9 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { func (p *Pipeline) getResultsQueryParams() (*ResultReader, error) { for { - msg, err := p.conn.receiveMessage() + msg, err := p.receiveMessage() if err != nil { - p.closed = true - p.err = err - p.conn.asyncClose() - return nil, normalizeTimeoutError(p.ctx, err) + return nil, err } switch msg := msg.(type) { @@ -2641,12 +2635,9 @@ func (p *Pipeline) getResultsQueryParams() (*ResultReader, error) { func (p *Pipeline) getResultsQueryPrepared() (*ResultReader, error) { for { - msg, err := p.conn.receiveMessage() + msg, err := p.receiveMessage() if err != nil { - p.closed = true - p.err = err - p.conn.asyncClose() - return nil, normalizeTimeoutError(p.ctx, err) + return nil, err } switch msg := msg.(type) { @@ -2677,12 +2668,9 @@ func (p *Pipeline) getResultsQueryPrepared() (*ResultReader, error) { func (p *Pipeline) getResultsQueryStatement() (*ResultReader, error) { for { - msg, err := p.conn.receiveMessage() + msg, err := p.receiveMessage() if err != nil { - p.closed = true - p.err = err - p.conn.asyncClose() - return nil, normalizeTimeoutError(p.ctx, err) + return nil, err } switch msg := msg.(type) { @@ -2745,12 +2733,9 @@ func (p *Pipeline) getResultsQueryStatement() (*ResultReader, error) { func (p *Pipeline) getResultsDeallocate() (*CloseComplete, error) { for { - msg, err := p.conn.receiveMessage() + msg, err := p.receiveMessage() if err != nil { - p.closed = true - p.err = err - p.conn.asyncClose() - return nil, normalizeTimeoutError(p.ctx, err) + return nil, err } switch msg := msg.(type) { @@ -2767,12 +2752,9 @@ func (p *Pipeline) getResultsDeallocate() (*CloseComplete, error) { func (p *Pipeline) getResultsSync() (*PipelineSync, error) { for { - msg, err := p.conn.receiveMessage() + msg, err := p.receiveMessage() if err != nil { - p.closed = true - p.err = err - p.conn.asyncClose() - return nil, normalizeTimeoutError(p.ctx, err) + return nil, err } switch msg := msg.(type) { @@ -2791,6 +2773,17 @@ func (p *Pipeline) getResultsSync() (*PipelineSync, error) { } } +func (p *Pipeline) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := p.conn.receiveMessage() + if err != nil { + p.closed = true + p.err = err + p.conn.asyncClose() + return nil, normalizeTimeoutError(p.ctx, err) + } + return msg, nil +} + // Close closes the pipeline and returns the connection to normal mode. func (p *Pipeline) Close() error { if p.closed { From e66f5eff021aefba0573c2a341aaef41723f7feb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 31 Dec 2025 20:51:00 -0600 Subject: [PATCH 10/15] Further refactoring of pipeline result handling --- pgconn/pgconn.go | 431 +++++++++++++++++++++++++++-------------------- 1 file changed, 252 insertions(+), 179 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 51ad251e8..a2f09dcf2 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -2561,227 +2561,300 @@ func (p *Pipeline) getResults() (results any, err error) { } func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { + err := p.receiveParseComplete("Prepare") + if err != nil { + return nil, err + } + psd := &StatementDescription{} - for { - msg, err := p.receiveMessage() - if err != nil { - return nil, err - } + msg, err := p.receiveMessage() + if err != nil { + return nil, err + } - switch msg := msg.(type) { - case *pgproto3.ParseComplete: - case *pgproto3.ParameterDescription: - psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) - copy(psd.ParamOIDs, msg.ParameterOIDs) - case *pgproto3.RowDescription: - psd.Fields = make([]FieldDescription, len(msg.Fields)) - convertRowDescription(psd.Fields, msg) - return psd, nil + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage("Prepare ParameterDescription", msg) + } - // NoData is returned instead of RowDescription when there is no expected result. e.g. An INSERT without a RETURNING - // clause. - case *pgproto3.NoData: - return psd, nil + msg, err = p.receiveMessage() + if err != nil { + return nil, err + } - // These should never happen here. But don't take chances that could lead to a deadlock. - case *pgproto3.ErrorResponse: - pgErr := ErrorResponseToPgError(msg) - p.state.HandleError(pgErr) - return nil, pgErr - case *pgproto3.CommandComplete: - p.conn.asyncClose() - return nil, errors.New("BUG: received CommandComplete while handling Describe") - case *pgproto3.ReadyForQuery: - p.conn.asyncClose() - return nil, errors.New("BUG: received ReadyForQuery while handling Describe") - } + switch msg := msg.(type) { + case *pgproto3.RowDescription: + psd.Fields = make([]FieldDescription, len(msg.Fields)) + convertRowDescription(psd.Fields, msg) + return psd, nil + + // NoData is returned instead of RowDescription when there is no expected result. e.g. An INSERT without a RETURNING + // clause. + case *pgproto3.NoData: + return psd, nil + + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage("Prepare RowDescription", msg) } } func (p *Pipeline) getResultsQueryParams() (*ResultReader, error) { - for { - msg, err := p.receiveMessage() - if err != nil { - return nil, err - } + err := p.receiveParseComplete("QueryParams") + if err != nil { + return nil, err + } - switch msg := msg.(type) { - case *pgproto3.ParseComplete: - case *pgproto3.RowDescription: - p.conn.resultReader = ResultReader{ - pgConn: p.conn, - pipeline: p, - ctx: p.ctx, - fieldDescriptions: p.conn.getFieldDescriptionSlice(len(msg.Fields)), - } - convertRowDescription(p.conn.resultReader.fieldDescriptions, msg) - return &p.conn.resultReader, nil - case *pgproto3.CommandComplete: - p.conn.resultReader = ResultReader{ - commandTag: p.conn.makeCommandTag(msg.CommandTag), - commandConcluded: true, - closed: true, - } - return &p.conn.resultReader, nil - case *pgproto3.ErrorResponse: - pgErr := ErrorResponseToPgError(msg) - p.state.HandleError(pgErr) - p.conn.resultReader.closed = true - return nil, pgErr - } + err = p.receiveBindComplete("QueryParams") + if err != nil { + return nil, err } + + return p.receiveDescribedResultReader("QueryParams") } func (p *Pipeline) getResultsQueryPrepared() (*ResultReader, error) { - for { - msg, err := p.receiveMessage() - if err != nil { - return nil, err - } + err := p.receiveBindComplete("QueryPrepared") + if err != nil { + return nil, err + } - switch msg := msg.(type) { - case *pgproto3.RowDescription: - p.conn.resultReader = ResultReader{ - pgConn: p.conn, - pipeline: p, - ctx: p.ctx, - fieldDescriptions: p.conn.getFieldDescriptionSlice(len(msg.Fields)), + return p.receiveDescribedResultReader("QueryPrepared") +} + +func (p *Pipeline) getResultsQueryStatement() (*ResultReader, error) { + err := p.receiveBindComplete("QueryStatement") + if err != nil { + return nil, err + } + + msg, err := p.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.DataRow: + sd, resultFormats := p.state.ExtractFrontStatementData() + if sd == nil { + return nil, errors.New("BUG: missing statement description or result formats for QueryStatement") + } + sdFields := sd.Fields + rr := ResultReader{ + pgConn: p.conn, + pipeline: p, + ctx: p.ctx, + fieldDescriptions: p.conn.getFieldDescriptionSlice(len(sdFields)), + } + + switch { + case len(resultFormats) == 0: + // No format codes provided means text format for all columns. + for i := range sdFields { + rr.fieldDescriptions[i] = sdFields[i] + rr.fieldDescriptions[i].Format = pgtype.TextFormatCode } - convertRowDescription(p.conn.resultReader.fieldDescriptions, msg) - return &p.conn.resultReader, nil - case *pgproto3.CommandComplete: - p.conn.resultReader = ResultReader{ - commandTag: p.conn.makeCommandTag(msg.CommandTag), - commandConcluded: true, - closed: true, + case len(resultFormats) == 1: + // Single format code applies to all columns. + format := resultFormats[0] + for i := range sdFields { + rr.fieldDescriptions[i] = sdFields[i] + rr.fieldDescriptions[i].Format = format } - return &p.conn.resultReader, nil - case *pgproto3.ErrorResponse: - pgErr := ErrorResponseToPgError(msg) - p.state.HandleError(pgErr) - p.conn.resultReader.closed = true - return nil, pgErr + case len(resultFormats) == len(sdFields): + // One format code per column. + for i := range sdFields { + rr.fieldDescriptions[i] = sdFields[i] + rr.fieldDescriptions[i].Format = resultFormats[i] + } + default: + // This should not occur if Bind validation is correct, but handle gracefully + return nil, fmt.Errorf("result format codes length %d does not match field count %d", len(resultFormats), len(sdFields)) + } + + rr.preloadRowValues(msg.Values) + + p.conn.resultReader = rr + return &p.conn.resultReader, nil + case *pgproto3.CommandComplete: + p.conn.resultReader = ResultReader{ + commandTag: p.conn.makeCommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, } + return &p.conn.resultReader, nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage("QueryStatement", msg) } } -func (p *Pipeline) getResultsQueryStatement() (*ResultReader, error) { - for { - msg, err := p.receiveMessage() - if err != nil { - return nil, err - } +func (p *Pipeline) getResultsDeallocate() (*CloseComplete, error) { + msg, err := p.receiveMessage() + if err != nil { + return nil, err + } - switch msg := msg.(type) { - case *pgproto3.DataRow: - sd, resultFormats := p.state.ExtractFrontStatementData() - if sd != nil || resultFormats != nil { - sdFields := sd.Fields - rr := ResultReader{ - pgConn: p.conn, - pipeline: p, - ctx: p.ctx, - fieldDescriptions: p.conn.getFieldDescriptionSlice(len(sdFields)), - } + switch msg := msg.(type) { + case *pgproto3.CloseComplete: + return &CloseComplete{}, nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage("Deallocate", msg) + } +} - switch { - case len(resultFormats) == 0: - // No format codes provided means text format for all columns. - for i := range sdFields { - rr.fieldDescriptions[i] = sdFields[i] - rr.fieldDescriptions[i].Format = pgtype.TextFormatCode - } - case len(resultFormats) == 1: - // Single format code applies to all columns. - format := resultFormats[0] - for i := range sdFields { - rr.fieldDescriptions[i] = sdFields[i] - rr.fieldDescriptions[i].Format = format - } - case len(resultFormats) == len(sdFields): - // One format code per column. - for i := range sdFields { - rr.fieldDescriptions[i] = sdFields[i] - rr.fieldDescriptions[i].Format = resultFormats[i] - } - default: - // This should not occur if Bind validation is correct, but handle gracefully - return nil, fmt.Errorf("result format codes length %d does not match field count %d", len(resultFormats), len(sdFields)) - } +func (p *Pipeline) getResultsSync() (*PipelineSync, error) { + msg, err := p.receiveMessage() + if err != nil { + return nil, err + } - rr.preloadRowValues(msg.Values) + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + p.state.HandleReadyForQuery() + return &PipelineSync{}, nil + case *pgproto3.ErrorResponse: + // Error message that is received while expecting a Sync message still consumes the expected Sync. Put it back. + p.state.requestEventQueue.PushFront(pipelineRequestEvent{RequestType: pipelineSyncRequest, WasSentToServer: true, BeforeFlushOrSync: true}) - p.conn.resultReader = rr - return &p.conn.resultReader, nil - } - case *pgproto3.CommandComplete: - p.conn.resultReader = ResultReader{ - commandTag: p.conn.makeCommandTag(msg.CommandTag), - commandConcluded: true, - closed: true, - } - return &p.conn.resultReader, nil - case *pgproto3.ErrorResponse: - pgErr := ErrorResponseToPgError(msg) - p.state.HandleError(pgErr) - p.conn.resultReader.closed = true - return nil, pgErr - } + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage("Sync", msg) } } -func (p *Pipeline) getResultsDeallocate() (*CloseComplete, error) { - for { - msg, err := p.receiveMessage() - if err != nil { - return nil, err - } +func (p *Pipeline) receiveParseComplete(errStr string) error { + msg, err := p.receiveMessage() + if err != nil { + return err + } - switch msg := msg.(type) { - case *pgproto3.CloseComplete: - return &CloseComplete{}, nil - case *pgproto3.ErrorResponse: - pgErr := ErrorResponseToPgError(msg) - p.state.HandleError(pgErr) - p.conn.resultReader.closed = true - return nil, pgErr + switch msg := msg.(type) { + case *pgproto3.ParseComplete: + return nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + return pgErr + default: + return p.handleUnexpectedMessage(fmt.Sprintf("%s Parse", errStr), msg) + } +} + +func (p *Pipeline) receiveBindComplete(errStr string) error { + msg, err := p.receiveMessage() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.BindComplete: + return nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + return pgErr + default: + return p.handleUnexpectedMessage(fmt.Sprintf("%s Bind", errStr), msg) + } +} + +func (p *Pipeline) receiveDescribedResultReader(errStr string) (*ResultReader, error) { + msg, err := p.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + p.conn.resultReader = ResultReader{ + pgConn: p.conn, + pipeline: p, + ctx: p.ctx, + fieldDescriptions: p.conn.getFieldDescriptionSlice(len(msg.Fields)), + } + convertRowDescription(p.conn.resultReader.fieldDescriptions, msg) + return &p.conn.resultReader, nil + case *pgproto3.NoData: + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage(fmt.Sprintf("%s RowDescription or NoData", errStr), msg) + } + + msg, err = p.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.CommandComplete: + p.conn.resultReader = ResultReader{ + commandTag: p.conn.makeCommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, } + return &p.conn.resultReader, nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage(fmt.Sprintf("%s CommandComplete", errStr), msg) } } -func (p *Pipeline) getResultsSync() (*PipelineSync, error) { +func (p *Pipeline) receiveMessage() (pgproto3.BackendMessage, error) { for { - msg, err := p.receiveMessage() + msg, err := p.conn.receiveMessage() if err != nil { - return nil, err + p.closed = true + p.err = err + p.conn.asyncClose() + return nil, normalizeTimeoutError(p.ctx, err) } switch msg := msg.(type) { - case *pgproto3.ReadyForQuery: - p.state.HandleReadyForQuery() - return &PipelineSync{}, nil - case *pgproto3.ErrorResponse: - // Error message that is received while expecting a Sync message still consumes the expected Sync. Put it back. - p.state.requestEventQueue.PushFront(pipelineRequestEvent{RequestType: pipelineSyncRequest, WasSentToServer: true, BeforeFlushOrSync: true}) - - pgErr := ErrorResponseToPgError(msg) - p.state.HandleError(pgErr) - p.conn.resultReader.closed = true - return nil, pgErr + case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse, *pgproto3.NotificationResponse: + // Filter these message types out in pipeline mode. The normal processing is handled by PgConn.receiveMessage. + default: + return msg, nil } } } -func (p *Pipeline) receiveMessage() (pgproto3.BackendMessage, error) { - msg, err := p.conn.receiveMessage() - if err != nil { - p.closed = true - p.err = err - p.conn.asyncClose() - return nil, normalizeTimeoutError(p.ctx, err) - } - return msg, nil +func (p *Pipeline) handleUnexpectedMessage(errStr string, msg pgproto3.BackendMessage) error { + p.closed = true + p.err = fmt.Errorf("pipeline: %s: received unexpected message type %T", errStr, msg) + p.conn.asyncClose() + return p.err } // Close closes the pipeline and returns the connection to normal mode. From 4b034335a875080c698b7452ad399b06423ca9e1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 31 Dec 2025 21:05:27 -0600 Subject: [PATCH 11/15] Extract combineFieldDescriptionsAndResultFormats --- pgconn/pgconn.go | 107 +++++++++++++++++------------------------------ 1 file changed, 38 insertions(+), 69 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index a2f09dcf2..5f573335d 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1539,29 +1539,9 @@ func (mrr *MultiResultReader) NextResult() bool { sdFields := sd.Fields rr.fieldDescriptions = rr.pgConn.getFieldDescriptionSlice(len(sdFields)) - switch { - case len(resultFormats) == 0: - // No format codes provided means text format for all columns. - for i := range sdFields { - rr.fieldDescriptions[i] = sdFields[i] - rr.fieldDescriptions[i].Format = pgtype.TextFormatCode - } - case len(resultFormats) == 1: - // Single format code applies to all columns. - format := resultFormats[0] - for i := range sdFields { - rr.fieldDescriptions[i] = sdFields[i] - rr.fieldDescriptions[i].Format = format - } - case len(resultFormats) == len(sdFields): - // One format code per column. - for i := range sdFields { - rr.fieldDescriptions[i] = sdFields[i] - rr.fieldDescriptions[i].Format = resultFormats[i] - } - default: - // This should not occur if Bind validation is correct, but handle gracefully - rr.concludeCommand(CommandTag{}, fmt.Errorf("result format codes length %d does not match field count %d", len(resultFormats), len(sdFields))) + err := combineFieldDescriptionsAndResultFormats(rr.fieldDescriptions, sdFields, resultFormats) + if err != nil { + rr.concludeCommand(CommandTag{}, err) } mrr.pgConn.resultReader = rr @@ -1769,29 +1749,9 @@ func (rr *ResultReader) readUntilRowDescription(statementDescription *StatementD sdFields := statementDescription.Fields rr.fieldDescriptions = rr.pgConn.getFieldDescriptionSlice(len(sdFields)) - switch { - case len(resultFormats) == 0: - // No format codes provided means text format for all columns. - for i := range sdFields { - rr.fieldDescriptions[i] = sdFields[i] - rr.fieldDescriptions[i].Format = pgtype.TextFormatCode - } - case len(resultFormats) == 1: - // Single format code applies to all columns. - format := resultFormats[0] - for i := range sdFields { - rr.fieldDescriptions[i] = sdFields[i] - rr.fieldDescriptions[i].Format = format - } - case len(resultFormats) == len(sdFields): - // One format code per column. - for i := range sdFields { - rr.fieldDescriptions[i] = sdFields[i] - rr.fieldDescriptions[i].Format = resultFormats[i] - } - default: - // This should not occur if Bind validation is correct, but handle gracefully - rr.concludeCommand(CommandTag{}, fmt.Errorf("result format codes length %d does not match field count %d", len(resultFormats), len(sdFields))) + err := combineFieldDescriptionsAndResultFormats(rr.fieldDescriptions, sdFields, resultFormats) + if err != nil { + rr.concludeCommand(CommandTag{}, err) } } return @@ -2658,29 +2618,9 @@ func (p *Pipeline) getResultsQueryStatement() (*ResultReader, error) { fieldDescriptions: p.conn.getFieldDescriptionSlice(len(sdFields)), } - switch { - case len(resultFormats) == 0: - // No format codes provided means text format for all columns. - for i := range sdFields { - rr.fieldDescriptions[i] = sdFields[i] - rr.fieldDescriptions[i].Format = pgtype.TextFormatCode - } - case len(resultFormats) == 1: - // Single format code applies to all columns. - format := resultFormats[0] - for i := range sdFields { - rr.fieldDescriptions[i] = sdFields[i] - rr.fieldDescriptions[i].Format = format - } - case len(resultFormats) == len(sdFields): - // One format code per column. - for i := range sdFields { - rr.fieldDescriptions[i] = sdFields[i] - rr.fieldDescriptions[i].Format = resultFormats[i] - } - default: - // This should not occur if Bind validation is correct, but handle gracefully - return nil, fmt.Errorf("result format codes length %d does not match field count %d", len(resultFormats), len(sdFields)) + err := combineFieldDescriptionsAndResultFormats(rr.fieldDescriptions, sdFields, resultFormats) + if err != nil { + return nil, err } rr.preloadRowValues(msg.Values) @@ -2959,3 +2899,32 @@ func (h *CancelRequestContextWatcherHandler) HandleUnwatchAfterCancel() { h.Conn.conn.SetDeadline(time.Time{}) } + +func combineFieldDescriptionsAndResultFormats(outputFields, inputFields []FieldDescription, resultFormats []int16) error { + switch { + case len(resultFormats) == 0: + // No format codes provided means text format for all columns. + for i := range inputFields { + outputFields[i] = inputFields[i] + outputFields[i].Format = pgtype.TextFormatCode + } + case len(resultFormats) == 1: + // Single format code applies to all columns. + format := resultFormats[0] + for i := range inputFields { + outputFields[i] = inputFields[i] + outputFields[i].Format = format + } + case len(resultFormats) == len(inputFields): + // One format code per column. + for i := range inputFields { + outputFields[i] = inputFields[i] + outputFields[i].Format = resultFormats[i] + } + default: + // This should not occur if Bind validation is correct, but handle gracefully + return fmt.Errorf("result format codes length %d does not match field count %d", len(resultFormats), len(inputFields)) + } + + return nil +} From e1638b6b215519510473298289eb860c7544b015 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 31 Dec 2025 21:22:59 -0600 Subject: [PATCH 12/15] Add test of pgx network usage before eliding Describe --- batch_test.go | 37 +++++++++++++++++++++++++ query_test.go | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/batch_test.go b/batch_test.go index d49b4e862..88fdccc78 100644 --- a/batch_test.go +++ b/batch_test.go @@ -1130,6 +1130,43 @@ func TestSendBatchHandlesTimeoutBetweenParseAndDescribe(t *testing.T) { }) } +func TestBatchNetworkUsage(t *testing.T) { + t.Parallel() + + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + var counterConn *byteCounterConn + config.AfterNetConnect = func(ctx context.Context, config *pgconn.Config, conn net.Conn) (net.Conn, error) { + counterConn = &byteCounterConn{conn: conn} + return counterConn, nil + } + + conn := mustConnect(t, config) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "Server uses different number of bytes for same operations") + + counterConn.bytesWritten = 0 + counterConn.bytesRead = 0 + + batch := &pgx.Batch{} + + for range 10 { + batch.Queue( + "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", + 1, + ) + } + + err := conn.SendBatch(context.Background(), batch).Close() + require.NoError(t, err) + + assert.Equal(t, 4116, counterConn.bytesRead) + assert.Equal(t, 1478, counterConn.bytesWritten) + + ensureConnValid(t, conn) +} + func ExampleConn_SendBatch() { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() diff --git a/query_test.go b/query_test.go index b9e01b49c..e1d5ae995 100644 --- a/query_test.go +++ b/query_test.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "net" "os" "strconv" "strings" @@ -2249,6 +2250,80 @@ func TestQueryWithProcedureParametersInAndOut(t *testing.T) { }) } +type byteCounterConn struct { + conn net.Conn + bytesRead int + bytesWritten int +} + +func (cbn *byteCounterConn) Read(b []byte) (n int, err error) { + n, err = cbn.conn.Read(b) + cbn.bytesRead += n + return n, err +} + +func (cbn *byteCounterConn) Write(b []byte) (n int, err error) { + n, err = cbn.conn.Write(b) + cbn.bytesWritten += n + return n, err +} + +func (cbn *byteCounterConn) Close() error { + return cbn.conn.Close() +} + +func (cbn *byteCounterConn) LocalAddr() net.Addr { + return cbn.conn.LocalAddr() +} + +func (cbn *byteCounterConn) RemoteAddr() net.Addr { + return cbn.conn.RemoteAddr() +} + +func (cbn *byteCounterConn) SetDeadline(t time.Time) error { + return cbn.conn.SetDeadline(t) +} + +func (cbn *byteCounterConn) SetReadDeadline(t time.Time) error { + return cbn.conn.SetReadDeadline(t) +} + +func (cbn *byteCounterConn) SetWriteDeadline(t time.Time) error { + return cbn.conn.SetWriteDeadline(t) +} + +func TestQueryNetworkUsage(t *testing.T) { + t.Parallel() + + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + var counterConn *byteCounterConn + config.AfterNetConnect = func(ctx context.Context, config *pgconn.Config, conn net.Conn) (net.Conn, error) { + counterConn = &byteCounterConn{conn: conn} + return counterConn, nil + } + + conn := mustConnect(t, config) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "Server uses different number of bytes for same operations") + + counterConn.bytesWritten = 0 + counterConn.bytesRead = 0 + + rows, _ := conn.Query( + context.Background(), + "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", + 1, + ) + rows.Close() + require.NoError(t, rows.Err()) + + assert.Equal(t, 651, counterConn.bytesRead) + assert.Equal(t, 434, counterConn.bytesWritten) + ensureConnValid(t, conn) +} + // This example uses Query without using any helpers to read the results. Normally CollectRows, ForEachRow, or another // helper function should be used. func ExampleConn_Query() { From c8400ae879f6ceb43434cc0676e255b64761f034 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 31 Dec 2025 21:43:29 -0600 Subject: [PATCH 13/15] Get field descriptions even when no rows --- pgconn/pgconn.go | 56 +++++++++++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 5f573335d..ec27ec34d 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1740,11 +1740,12 @@ func (rr *ResultReader) Close() (CommandTag, error) { // error will be stored in the ResultReader. func (rr *ResultReader) readUntilRowDescription(statementDescription *StatementDescription, resultFormats []int16) { for !rr.commandConcluded { - // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. This - // is expected if statementDescription is not nil, but it is also possible if SendBytes and ReceiveResults are - // manually used to construct a query that does not issue a describe statement. - msg, _ := rr.pgConn.peekMessage() - if _, ok := msg.(*pgproto3.DataRow); ok { + msg, _ := rr.receiveMessage() + switch msg := msg.(type) { + case *pgproto3.RowDescription: + return + case *pgproto3.DataRow: + rr.preloadRowValues(msg.Values) if statementDescription != nil { sdFields := statementDescription.Fields rr.fieldDescriptions = rr.pgConn.getFieldDescriptionSlice(len(sdFields)) @@ -1755,11 +1756,16 @@ func (rr *ResultReader) readUntilRowDescription(statementDescription *StatementD } } return - } + case *pgproto3.CommandComplete: + if statementDescription != nil { + sdFields := statementDescription.Fields + rr.fieldDescriptions = rr.pgConn.getFieldDescriptionSlice(len(sdFields)) - // Consume the message - msg, _ = rr.receiveMessage() - if _, ok := msg.(*pgproto3.RowDescription); ok { + err := combineFieldDescriptionsAndResultFormats(rr.fieldDescriptions, sdFields, resultFormats) + if err != nil { + rr.concludeCommand(CommandTag{}, err) + } + } return } } @@ -2604,34 +2610,34 @@ func (p *Pipeline) getResultsQueryStatement() (*ResultReader, error) { return nil, err } + sd, resultFormats := p.state.ExtractFrontStatementData() + if sd == nil { + return nil, errors.New("BUG: missing statement description or result formats for QueryStatement") + } + sdFields := sd.Fields + fieldDescriptions := p.conn.getFieldDescriptionSlice(len(sdFields)) + err = combineFieldDescriptionsAndResultFormats(fieldDescriptions, sdFields, resultFormats) + if err != nil { + return nil, err + } + switch msg := msg.(type) { case *pgproto3.DataRow: - sd, resultFormats := p.state.ExtractFrontStatementData() - if sd == nil { - return nil, errors.New("BUG: missing statement description or result formats for QueryStatement") - } - sdFields := sd.Fields rr := ResultReader{ pgConn: p.conn, pipeline: p, ctx: p.ctx, - fieldDescriptions: p.conn.getFieldDescriptionSlice(len(sdFields)), + fieldDescriptions: fieldDescriptions, } - - err := combineFieldDescriptionsAndResultFormats(rr.fieldDescriptions, sdFields, resultFormats) - if err != nil { - return nil, err - } - rr.preloadRowValues(msg.Values) - p.conn.resultReader = rr return &p.conn.resultReader, nil case *pgproto3.CommandComplete: p.conn.resultReader = ResultReader{ - commandTag: p.conn.makeCommandTag(msg.CommandTag), - commandConcluded: true, - closed: true, + commandTag: p.conn.makeCommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + fieldDescriptions: fieldDescriptions, } return &p.conn.resultReader, nil case *pgproto3.ErrorResponse: From a4213bc0f81b25b9ee06421b7373d4bae5c57284 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 31 Dec 2025 21:43:47 -0600 Subject: [PATCH 14/15] pgx uses ExecStatement instead of ExecPrepared This reduces redundant protocol messages when the statement description is already known. --- conn.go | 4 ++-- query_test.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index 1fc0b5680..3f6c441cf 100644 --- a/conn.go +++ b/conn.go @@ -610,7 +610,7 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription return pgconn.CommandTag{}, err } - result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + result := c.pgConn.ExecStatement(ctx, sd, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return result.CommandTag, result.Err } @@ -844,7 +844,7 @@ optionLoop: if !explicitPreparedStatement && mode == QueryExecModeCacheDescribe { rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, resultFormats) } else { - rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats) + rows.resultReader = c.pgConn.ExecStatement(ctx, sd, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats) } } else if mode == QueryExecModeExec { err := c.eqb.Build(c.typeMap, nil, args) diff --git a/query_test.go b/query_test.go index e1d5ae995..34d06444f 100644 --- a/query_test.go +++ b/query_test.go @@ -2319,8 +2319,8 @@ func TestQueryNetworkUsage(t *testing.T) { rows.Close() require.NoError(t, rows.Err()) - assert.Equal(t, 651, counterConn.bytesRead) - assert.Equal(t, 434, counterConn.bytesWritten) + assert.Equal(t, 413, counterConn.bytesRead) + assert.Equal(t, 427, counterConn.bytesWritten) ensureConnValid(t, conn) } From c3a17505ff2d016e41002629f57da8bf6859e762 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 31 Dec 2025 21:52:09 -0600 Subject: [PATCH 15/15] pgx batch uses SendQueryStatement --- batch_test.go | 11 ++++------- conn.go | 2 +- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/batch_test.go b/batch_test.go index 88fdccc78..df205ae3c 100644 --- a/batch_test.go +++ b/batch_test.go @@ -1043,15 +1043,12 @@ func TestSendBatchStatementTimeout(t *testing.T) { assert.NoError(t, err) // get pg_sleep results - rows, err := br.Query() - assert.NoError(t, err) + rows, _ := br.Query() // Consume rows and check error - for rows.Next() { - } + rows.Close() err = rows.Err() assert.ErrorContains(t, err, "(SQLSTATE 57014)") - rows.Close() // The last error should be repeated when closing the batch err = br.Close() @@ -1161,8 +1158,8 @@ func TestBatchNetworkUsage(t *testing.T) { err := conn.SendBatch(context.Background(), batch).Close() require.NoError(t, err) - assert.Equal(t, 4116, counterConn.bytesRead) - assert.Equal(t, 1478, counterConn.bytesWritten) + assert.Equal(t, 1736, counterConn.bytesRead) + assert.Equal(t, 1408, counterConn.bytesWritten) ensureConnValid(t, conn) } diff --git a/conn.go b/conn.go index 3f6c441cf..0823c79a7 100644 --- a/conn.go +++ b/conn.go @@ -1236,7 +1236,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d if bi.sd.Name == "" { pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats) } else { - pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + pipeline.SendQueryStatement(bi.sd, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) } }