From cfa26170e22ff1e96ac9d4d5cd6bc34095541a22 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Wed, 28 Jan 2026 23:24:05 +0100 Subject: [PATCH 1/2] Implement support for protocol 3.2 This adds support for the new protocol 3.2 that is available in Postgres 18+. The specific change is that secret keys are now variable length encoded and no longer a fixed uint32. This is to both improve security and to provide room for additional metadata for middleware. Signed-off-by: Dirkjan Bussink --- internal/pgmock/pgmock.go | 4 +- pgconn/config.go | 47 +++++++ pgconn/config_test.go | 116 ++++++++++++++++ pgconn/pgconn.go | 29 +++- pgconn/pgconn_test.go | 40 +++++- pgproto3/backend.go | 2 +- pgproto3/backend_key_data.go | 33 ++++- pgproto3/backend_key_data_test.go | 87 ++++++++++++ pgproto3/backend_test.go | 28 +++- pgproto3/cancel_request.go | 45 +++++-- pgproto3/cancel_request_test.go | 142 ++++++++++++++++++++ pgproto3/frontend.go | 3 + pgproto3/json_test.go | 46 ++++++- pgproto3/negotiate_protocol_version.go | 86 ++++++++++++ pgproto3/negotiate_protocol_version_test.go | 89 ++++++++++++ pgproto3/startup_message.go | 10 +- 16 files changed, 765 insertions(+), 42 deletions(-) create mode 100644 pgproto3/backend_key_data_test.go create mode 100644 pgproto3/cancel_request_test.go create mode 100644 pgproto3/negotiate_protocol_version.go create mode 100644 pgproto3/negotiate_protocol_version_test.go diff --git a/internal/pgmock/pgmock.go b/internal/pgmock/pgmock.go index c82d7ffc8..f5de5f885 100644 --- a/internal/pgmock/pgmock.go +++ b/internal/pgmock/pgmock.go @@ -128,9 +128,9 @@ func WaitForClose() Step { func AcceptUnauthenticatedConnRequestSteps() []Step { return []Step{ - ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersion30, Parameters: map[string]string{}}), SendMessage(&pgproto3.AuthenticationOk{}), - SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: []byte{0, 0, 0, 0}}), SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), } } diff --git a/pgconn/config.go b/pgconn/config.go index 42bc2e92b..d61187bd6 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -83,6 +83,15 @@ type Config struct { // that you close on FATAL errors by returning false. OnPgError PgErrorHandler + // MinProtocolVersion is the minimum acceptable PostgreSQL protocol version. + // If the server does not support at least this version, the connection will fail. + // Valid values: "3.0", "3.2", "latest". Defaults to "3.0". + MinProtocolVersion string + + // MaxProtocolVersion is the maximum PostgreSQL protocol version to request from the server. + // Valid values: "3.0", "3.2", "latest". Defaults to "3.0" for compatibility. + MaxProtocolVersion string + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -213,6 +222,8 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGCONNECT_TIMEOUT // PGTARGETSESSIONATTRS // PGTZ +// PGMINPROTOCOLVERSION +// PGMAXPROTOCOLVERSION // // See http://www.postgresql.org/docs/current/static/libpq-envars.html for details on the meaning of environment variables. // @@ -338,6 +349,8 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con "target_session_attrs": {}, "service": {}, "servicefile": {}, + "min_protocol_version": {}, + "max_protocol_version": {}, } // Adding kerberos configuration @@ -430,6 +443,27 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} } + minProto, err := parseProtocolVersion(settings["min_protocol_version"]) + if err != nil { + return nil, &ParseConfigError{ConnString: connString, msg: "invalid min_protocol_version", err: err} + } + maxProto, err := parseProtocolVersion(settings["max_protocol_version"]) + if err != nil { + return nil, &ParseConfigError{ConnString: connString, msg: "invalid max_protocol_version", err: err} + } + if minProto > maxProto { + return nil, &ParseConfigError{ConnString: connString, msg: "min_protocol_version cannot be greater than max_protocol_version"} + } + + config.MinProtocolVersion = settings["min_protocol_version"] + config.MaxProtocolVersion = settings["max_protocol_version"] + if config.MinProtocolVersion == "" { + config.MinProtocolVersion = "3.0" + } + if config.MaxProtocolVersion == "" { + config.MaxProtocolVersion = "3.0" + } + return config, nil } @@ -467,6 +501,8 @@ func parseEnvSettings() map[string]string { "PGSERVICEFILE": "servicefile", "PGTZ": "timezone", "PGOPTIONS": "options", + "PGMINPROTOCOLVERSION": "min_protocol_version", + "PGMAXPROTOCOLVERSION": "max_protocol_version", } for envname, realname := range nameMap { @@ -960,3 +996,14 @@ func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn return nil } + +func parseProtocolVersion(s string) (uint32, error) { + switch s { + case "", "3.0": + return pgproto3.ProtocolVersion30, nil + case "3.2", "latest": + return pgproto3.ProtocolVersion32, nil + default: + return 0, fmt.Errorf("invalid protocol version: %q", s) + } +} diff --git a/pgconn/config_test.go b/pgconn/config_test.go index 9665a6e15..2a2b1e284 100644 --- a/pgconn/config_test.go +++ b/pgconn/config_test.go @@ -1198,3 +1198,119 @@ func TestParseConfigExplicitEmptyUserDefaultsToOSUser(t *testing.T) { ) } } + +func TestParseConfigProtocolVersion(t *testing.T) { + tests := []struct { + name string + connString string + envMin string + envMax string + expectedMin string + expectedMax string + expectError bool + expectedErrContain string + }{ + { + name: "defaults to 3.0", + connString: "postgres://localhost/test", + expectedMin: "3.0", + expectedMax: "3.0", + }, + { + name: "max_protocol_version=3.2", + connString: "postgres://localhost/test?max_protocol_version=3.2", + expectedMin: "3.0", + expectedMax: "3.2", + }, + { + name: "min_protocol_version=3.2 and max_protocol_version=3.2", + connString: "postgres://localhost/test?min_protocol_version=3.2&max_protocol_version=3.2", + expectedMin: "3.2", + expectedMax: "3.2", + }, + { + name: "max_protocol_version=latest", + connString: "postgres://localhost/test?max_protocol_version=latest", + expectedMin: "3.0", + expectedMax: "latest", + }, + { + name: "min and max = latest", + connString: "postgres://localhost/test?min_protocol_version=latest&max_protocol_version=latest", + expectedMin: "latest", + expectedMax: "latest", + }, + { + name: "invalid min_protocol_version", + connString: "postgres://localhost/test?min_protocol_version=2.0", + expectError: true, + expectedErrContain: "invalid min_protocol_version", + }, + { + name: "invalid max_protocol_version", + connString: "postgres://localhost/test?max_protocol_version=4.0", + expectError: true, + expectedErrContain: "invalid max_protocol_version", + }, + { + name: "min > max", + connString: "postgres://localhost/test?min_protocol_version=3.2&max_protocol_version=3.0", + expectError: true, + expectedErrContain: "min_protocol_version cannot be greater than max_protocol_version", + }, + { + name: "environment variable PGMINPROTOCOLVERSION without matching max fails", + connString: "postgres://localhost/test", + envMin: "3.2", + expectError: true, + expectedErrContain: "min_protocol_version cannot be greater than max_protocol_version", + }, + { + name: "environment variables PGMINPROTOCOLVERSION and PGMAXPROTOCOLVERSION together", + connString: "postgres://localhost/test", + envMin: "3.2", + envMax: "3.2", + expectedMin: "3.2", + expectedMax: "3.2", + }, + { + name: "environment variable PGMAXPROTOCOLVERSION", + connString: "postgres://localhost/test", + envMax: "3.2", + expectedMin: "3.0", + expectedMax: "3.2", + }, + { + name: "conn string overrides environment variable", + connString: "postgres://localhost/test?max_protocol_version=3.0", + envMax: "3.2", + expectedMin: "3.0", + expectedMax: "3.0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Clear protocol version env vars + t.Setenv("PGMINPROTOCOLVERSION", "") + t.Setenv("PGMAXPROTOCOLVERSION", "") + + if tt.envMin != "" { + t.Setenv("PGMINPROTOCOLVERSION", tt.envMin) + } + if tt.envMax != "" { + t.Setenv("PGMAXPROTOCOLVERSION", tt.envMax) + } + + config, err := pgconn.ParseConfig(tt.connString) + if tt.expectError { + require.ErrorContains(t, err, tt.expectedErrContain) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expectedMin, config.MinProtocolVersion, "MinProtocolVersion") + assert.Equal(t, tt.expectedMax, config.MaxProtocolVersion, "MaxProtocolVersion") + }) + } +} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index ec27ec34d..ab87db6f2 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -77,7 +77,7 @@ type NotificationHandler func(*PgConn, *Notification) type PgConn struct { conn net.Conn pid uint32 // backend pid - secretKey uint32 // key to use to send a cancel query message to the server + secretKey []byte // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server txStatus byte frontend *pgproto3.Frontend @@ -319,6 +319,15 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo return e } + maxProtocolVersion, err := parseProtocolVersion(config.MaxProtocolVersion) + if err != nil { + return nil, newPerDialConnectError("invalid max_protocol_version", err) + } + minProtocolVersion, err := parseProtocolVersion(config.MinProtocolVersion) + if err != nil { + return nil, newPerDialConnectError("invalid min_protocol_version", err) + } + pgConn.conn, err = config.DialFunc(ctx, connectConfig.network, connectConfig.address) if err != nil { return nil, newPerDialConnectError("dial error", err) @@ -371,7 +380,7 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn) startupMsg := pgproto3.StartupMessage{ - ProtocolVersion: pgproto3.ProtocolVersionNumber, + ProtocolVersion: maxProtocolVersion, Parameters: make(map[string]string), } @@ -452,6 +461,12 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo return pgConn, nil case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse: // handled by ReceiveMessage + case *pgproto3.NegotiateProtocolVersion: + serverVersion := pgproto3.ProtocolVersion30&0xFFFF0000 | uint32(msg.NewestMinorProtocol) + if serverVersion < minProtocolVersion { + pgConn.conn.Close() + return nil, newPerDialConnectError("server protocol version too low", nil) + } case *pgproto3.ErrorResponse: pgConn.conn.Close() return nil, newPerDialConnectError("server error", ErrorResponseToPgError(msg)) @@ -641,7 +656,7 @@ func (pgConn *PgConn) TxStatus() byte { } // SecretKey returns the backend secret key used to send a cancel query message to the server. -func (pgConn *PgConn) SecretKey() uint32 { +func (pgConn *PgConn) SecretKey() []byte { return pgConn.secretKey } @@ -1040,11 +1055,11 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { defer contextWatcher.Unwatch() } - buf := make([]byte, 16) - binary.BigEndian.PutUint32(buf[0:4], 16) + buf := make([]byte, 12+len(pgConn.secretKey)) + binary.BigEndian.PutUint32(buf[0:4], uint32(len(buf))) binary.BigEndian.PutUint32(buf[4:8], 80877102) binary.BigEndian.PutUint32(buf[8:12], pgConn.pid) - binary.BigEndian.PutUint32(buf[12:16], pgConn.secretKey) + copy(buf[12:], pgConn.secretKey) if _, err := cancelConn.Write(buf); err != nil { return fmt.Errorf("write to connection for cancellation: %w", err) @@ -2077,7 +2092,7 @@ func (pgConn *PgConn) CustomData() map[string]any { type HijackedConn struct { Conn net.Conn PID uint32 // backend pid - SecretKey uint32 // key to use to send a cancel query message to the server + SecretKey []byte // key to use to send a cancel query message to the server ParameterStatuses map[string]string // parameters that have been reported by the server TxStatus byte Frontend *pgproto3.Frontend diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 3fd43a031..448de49e6 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -215,10 +215,10 @@ func TestConnectTimeout(t *testing.T) { t.Parallel() script := &pgmock.Script{ Steps: []pgmock.Step{ - pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersion30, Parameters: map[string]string{}}), pgmock.SendMessage(&pgproto3.AuthenticationOk{}), pgmockWaitStep(time.Millisecond * 500), - pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: []byte{0, 0, 0, 0}}), pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), }, } @@ -4112,7 +4112,7 @@ func TestSNISupport(t *testing.T) { } srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil))) - srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))) + srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: []byte{0, 0, 0, 0}}).Encode(nil))) srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))) serverSNINameChan <- sniHost @@ -4553,3 +4553,37 @@ func TestCancelRequestContextWatcherHandler(t *testing.T) { }) } } + +func TestConnectProtocolVersion32(t *testing.T) { + t.Parallel() + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE") + " max_protocol_version=3.2") + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("CockroachDB does not support protocol version 3.2 yet") + } + + result, err := pgConn.Exec(context.Background(), "show server_version_num").ReadAll() + require.NoError(t, err) + require.Len(t, result, 1) + require.Len(t, result[0].Rows, 1) + require.Len(t, result[0].Rows[0], 1) + pgVersion, err := strconv.Atoi(string(result[0].Rows[0][0])) + require.NoError(t, err) + + // Check secret key length - PG18+ returns 32 bytes, older versions return 4 + secretKey := pgConn.SecretKey() + + if pgVersion < 180000 { + assert.Len(t, secretKey, 4) + } else { + assert.Len(t, secretKey, 32) + } +} diff --git a/pgproto3/backend.go b/pgproto3/backend.go index d9d0f370c..e211c99b5 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -137,7 +137,7 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { code := binary.BigEndian.Uint32(buf) switch code { - case ProtocolVersionNumber: + case ProtocolVersion30, ProtocolVersion32: err = b.startupMessage.Decode(buf) if err != nil { return nil, err diff --git a/pgproto3/backend_key_data.go b/pgproto3/backend_key_data.go index 23f5da677..c73b2da0c 100644 --- a/pgproto3/backend_key_data.go +++ b/pgproto3/backend_key_data.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/hex" "encoding/json" "github.com/jackc/pgx/v5/internal/pgio" @@ -9,7 +10,7 @@ import ( type BackendKeyData struct { ProcessID uint32 - SecretKey uint32 + SecretKey []byte } // Backend identifies this message as sendable by the PostgreSQL backend. @@ -18,12 +19,13 @@ func (*BackendKeyData) Backend() {} // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *BackendKeyData) Decode(src []byte) error { - if len(src) != 8 { + if len(src) < 8 { return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} } dst.ProcessID = binary.BigEndian.Uint32(src[:4]) - dst.SecretKey = binary.BigEndian.Uint32(src[4:]) + dst.SecretKey = make([]byte, len(src)-4) + copy(dst.SecretKey, src[4:]) return nil } @@ -32,7 +34,7 @@ func (dst *BackendKeyData) Decode(src []byte) error { func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) { dst, sp := beginMessage(dst, 'K') dst = pgio.AppendUint32(dst, src.ProcessID) - dst = pgio.AppendUint32(dst, src.SecretKey) + dst = append(dst, src.SecretKey...) return finishMessage(dst, sp) } @@ -41,10 +43,29 @@ func (src BackendKeyData) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ProcessID uint32 - SecretKey uint32 + SecretKey string }{ Type: "BackendKeyData", ProcessID: src.ProcessID, - SecretKey: src.SecretKey, + SecretKey: hex.EncodeToString(src.SecretKey), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *BackendKeyData) UnmarshalJSON(data []byte) error { + var msg struct { + ProcessID uint32 + SecretKey string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.ProcessID = msg.ProcessID + secretKey, err := hex.DecodeString(msg.SecretKey) + if err != nil { + return err + } + dst.SecretKey = secretKey + return nil +} diff --git a/pgproto3/backend_key_data_test.go b/pgproto3/backend_key_data_test.go new file mode 100644 index 000000000..ab52a2f7f --- /dev/null +++ b/pgproto3/backend_key_data_test.go @@ -0,0 +1,87 @@ +package pgproto3 + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBackendKeyDataDecodeProtocol30(t *testing.T) { + // Protocol 3.0: 8 bytes (4 for ProcessID, 4 for SecretKey) + src := []byte{ + 0x00, 0x00, 0x22, 0xA0, // ProcessID: 8864 + 0xD9, 0x0C, 0xAE, 0xDB, // SecretKey + } + + var msg BackendKeyData + err := msg.Decode(src) + require.NoError(t, err) + assert.Equal(t, uint32(8864), msg.ProcessID) + expectedKey := []byte{0xD9, 0x0C, 0xAE, 0xDB} + assert.Equal(t, expectedKey, msg.SecretKey) +} + +func TestBackendKeyDataDecodeProtocol32(t *testing.T) { + // Protocol 3.2: variable-length key (using 32 bytes here) + secretKey := []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, + } + + src := append([]byte{0x00, 0x00, 0x22, 0xA0}, secretKey...) // ProcessID: 8864 + + var msg BackendKeyData + err := msg.Decode(src) + require.NoError(t, err) + + assert.Equal(t, uint32(8864), msg.ProcessID) + assert.Equal(t, secretKey, msg.SecretKey) +} + +func TestBackendKeyDataEncodeProtocol30(t *testing.T) { + msg := BackendKeyData{ + ProcessID: 8864, + SecretKey: []byte{0xD9, 0x0C, 0xAE, 0xDB}, + } + + buf, err := msg.Encode(nil) + require.NoError(t, err) + + expected := []byte{ + 'K', // message type + 0x00, 0x00, 0x00, 0x0C, // length: 12 (4 + 8) + 0x00, 0x00, 0x22, 0xA0, // ProcessID: 8864 + 0xD9, 0x0C, 0xAE, 0xDB, // SecretKey + } + + assert.Equal(t, expected, buf) +} + +func TestBackendKeyDataEncodeProtocol32(t *testing.T) { + secretKey := []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, + } + + msg := BackendKeyData{ + ProcessID: 8864, + SecretKey: secretKey, + } + + buf, err := msg.Encode(nil) + require.NoError(t, err) + + // 'K' + 4 byte length + 4 byte ProcessID + 32 byte SecretKey = 41 bytes total, length field is 40 + expected := append([]byte{ + 'K', // message type + 0x00, 0x00, 0x00, 0x28, // length: 40 (4 + 4 + 32) + 0x00, 0x00, 0x22, 0xA0, // ProcessID: 8864 + }, secretKey...) + + assert.Equal(t, expected, buf) +} diff --git a/pgproto3/backend_test.go b/pgproto3/backend_test.go index 5107ef76a..17e347835 100644 --- a/pgproto3/backend_test.go +++ b/pgproto3/backend_test.go @@ -64,12 +64,28 @@ func TestBackendReceiveUnexpectedEOF(t *testing.T) { func TestStartupMessage(t *testing.T) { t.Parallel() - t.Run("valid StartupMessage", func(t *testing.T) { + t.Run("valid StartupMessage 3.0", func(t *testing.T) { want := &pgproto3.StartupMessage{ - ProtocolVersion: pgproto3.ProtocolVersionNumber, - Parameters: map[string]string{ - "username": "tester", - }, + ProtocolVersion: pgproto3.ProtocolVersion30, + Parameters: map[string]string{"username": "tester"}, + } + dst, err := want.Encode([]byte{}) + require.NoError(t, err) + + server := &interruptReader{} + server.push(dst) + + backend := pgproto3.NewBackend(server, nil) + + msg, err := backend.ReceiveStartupMessage() + require.NoError(t, err) + require.Equal(t, want, msg) + }) + + t.Run("valid StartupMessage 3.2", func(t *testing.T) { + want := &pgproto3.StartupMessage{ + ProtocolVersion: pgproto3.ProtocolVersion32, + Parameters: map[string]string{"username": "tester"}, } dst, err := want.Encode([]byte{}) require.NoError(t, err) @@ -107,7 +123,7 @@ func TestStartupMessage(t *testing.T) { server := &interruptReader{} dst := []byte{} dst = pgio.AppendUint32(dst, tt.packetLen) - dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber) + dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersion30) server.push(dst) backend := pgproto3.NewBackend(server, nil) diff --git a/pgproto3/cancel_request.go b/pgproto3/cancel_request.go index 6b52dd977..63ebe5c47 100644 --- a/pgproto3/cancel_request.go +++ b/pgproto3/cancel_request.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/hex" "encoding/json" "errors" @@ -12,35 +13,42 @@ const cancelRequestCode = 80877102 type CancelRequest struct { ProcessID uint32 - SecretKey uint32 + SecretKey []byte } // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*CancelRequest) Frontend() {} func (dst *CancelRequest) Decode(src []byte) error { - if len(src) != 12 { - return errors.New("bad cancel request size") + if len(src) < 12 { + return errors.New("cancel request too short") + } + if len(src) > 264 { + return errors.New("cancel request too long") } requestCode := binary.BigEndian.Uint32(src) - if requestCode != cancelRequestCode { return errors.New("bad cancel request code") } dst.ProcessID = binary.BigEndian.Uint32(src[4:]) - dst.SecretKey = binary.BigEndian.Uint32(src[8:]) + dst.SecretKey = make([]byte, len(src)-8) + copy(dst.SecretKey, src[8:]) return nil } // Encode encodes src into dst. dst will include the 4 byte message length. func (src *CancelRequest) Encode(dst []byte) ([]byte, error) { - dst = pgio.AppendInt32(dst, 16) + if len(src.SecretKey) > 256 { + return nil, errors.New("secret key too long") + } + msgLen := int32(12 + len(src.SecretKey)) + dst = pgio.AppendInt32(dst, msgLen) dst = pgio.AppendInt32(dst, cancelRequestCode) dst = pgio.AppendUint32(dst, src.ProcessID) - dst = pgio.AppendUint32(dst, src.SecretKey) + dst = append(dst, src.SecretKey...) return dst, nil } @@ -49,10 +57,29 @@ func (src CancelRequest) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ProcessID uint32 - SecretKey uint32 + SecretKey string }{ Type: "CancelRequest", ProcessID: src.ProcessID, - SecretKey: src.SecretKey, + SecretKey: hex.EncodeToString(src.SecretKey), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CancelRequest) UnmarshalJSON(data []byte) error { + var msg struct { + ProcessID uint32 + SecretKey string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.ProcessID = msg.ProcessID + secretKey, err := hex.DecodeString(msg.SecretKey) + if err != nil { + return err + } + dst.SecretKey = secretKey + return nil +} diff --git a/pgproto3/cancel_request_test.go b/pgproto3/cancel_request_test.go new file mode 100644 index 000000000..488e0b7b0 --- /dev/null +++ b/pgproto3/cancel_request_test.go @@ -0,0 +1,142 @@ +package pgproto3 + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCancelRequestDecode(t *testing.T) { + secretKey32 := []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, + } + + tests := []struct { + name string + src []byte + expectedProcessID uint32 + expectedSecretKey []byte + expectError bool + }{ + { + name: "Protocol 3.0 (16 bytes total)", + src: []byte{ + 0x04, 0xD2, 0x16, 0x2E, // cancelRequestCode: 80877102 + 0x00, 0x00, 0x22, 0xA0, // ProcessID: 8864 + 0xD9, 0x0C, 0xAE, 0xDB, // SecretKey + }, + expectedProcessID: 8864, + expectedSecretKey: []byte{0xD9, 0x0C, 0xAE, 0xDB}, + }, + { + name: "Protocol 3.2 (variable-length 32-byte key)", + src: append([]byte{ + 0x04, 0xD2, 0x16, 0x2E, // cancelRequestCode: 80877102 + 0x00, 0x00, 0x22, 0xA0, // ProcessID: 8864 + }, secretKey32...), + expectedProcessID: 8864, + expectedSecretKey: secretKey32, + }, + { + name: "invalid length (too short)", + src: []byte{ + 0x00, 0x00, 0x00, 0x00, // invalid length + 0x00, 0x00, 0x22, 0xA0, // ProcessID: 8864 + 0xD9, 0x0C, 0xAE, 0xDB, // SecretKey + }, + expectError: true, + }, + { + name: "invalid length (too long)", + src: append([]byte{ + 0x00, 0x00, 0x01, 0x09, // invalid length: 265 + 0x04, 0xD2, 0x16, 0x2E, // cancelRequestCode: 80877102 + 0x00, 0x00, 0x22, 0xA0, // ProcessID: 8864 + }, make([]byte, 257)...), // 257 bytes secret key (1 byte too many) + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var msg CancelRequest + err := msg.Decode(tt.src) + + if tt.expectError { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expectedProcessID, msg.ProcessID) + assert.Equal(t, tt.expectedSecretKey, msg.SecretKey) + }) + } +} + +func TestCancelRequestEncode(t *testing.T) { + secretKey32 := []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, + } + + tests := []struct { + name string + msg CancelRequest + expected []byte + expectError bool + }{ + { + name: "Protocol 3.0 (4-byte key)", + msg: CancelRequest{ + ProcessID: 8864, + SecretKey: []byte{0xD9, 0x0C, 0xAE, 0xDB}, + }, + expected: []byte{ + 0x00, 0x00, 0x00, 0x10, // length: 16 + 0x04, 0xD2, 0x16, 0x2E, // cancelRequestCode: 80877102 + 0x00, 0x00, 0x22, 0xA0, // ProcessID: 8864 + 0xD9, 0x0C, 0xAE, 0xDB, // SecretKey + }, + }, + { + name: "Protocol 3.2 (32-byte key)", + msg: CancelRequest{ + ProcessID: 8864, + SecretKey: secretKey32, + }, + // 4 byte length + 4 byte code + 4 byte ProcessID + 32 byte SecretKey = 44 bytes total + expected: append([]byte{ + 0x00, 0x00, 0x00, 0x2C, // length: 44 (12 + 32) + 0x04, 0xD2, 0x16, 0x2E, // cancelRequestCode: 80877102 + 0x00, 0x00, 0x22, 0xA0, // ProcessID: 8864 + }, secretKey32...), + }, + { + name: "Too long secret key", + msg: CancelRequest{ + ProcessID: 8864, + SecretKey: make([]byte, 257), // 1 byte too many + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf, err := tt.msg.Encode(nil) + if tt.expectError { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.expected, buf) + }) + } +} diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 056e547cd..23e10aadc 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -52,6 +52,7 @@ type Frontend struct { readyForQuery ReadyForQuery rowDescription RowDescription portalSuspended PortalSuspended + negotiateProtocolVersion NegotiateProtocolVersion bodyLen int maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error. @@ -383,6 +384,8 @@ func (f *Frontend) Receive() (BackendMessage, error) { msg = &f.copyBothResponse case 'Z': msg = &f.readyForQuery + case 'v': + msg = &f.negotiateProtocolVersion default: return nil, fmt.Errorf("unknown message type: %c", f.msgType) } diff --git a/pgproto3/json_test.go b/pgproto3/json_test.go index 677221249..bdf8574ce 100644 --- a/pgproto3/json_test.go +++ b/pgproto3/json_test.go @@ -93,11 +93,29 @@ func TestJSONUnmarshalAuthenticationSASLFinal(t *testing.T) { } } -func TestJSONUnmarshalBackendKeyData(t *testing.T) { - data := []byte(`{"Type":"BackendKeyData","ProcessID":8864,"SecretKey":3641487067}`) +func TestJSONUnmarshalBackendKeyData30(t *testing.T) { + // SecretKey is now hex-encoded: d90caedb = 3641487067 in big-endian + data := []byte(`{"Type":"BackendKeyData","ProcessID":8864,"SecretKey":"d90caedb"}`) want := BackendKeyData{ ProcessID: 8864, - SecretKey: 3641487067, + SecretKey: []byte{0xd9, 0x0c, 0xae, 0xdb}, + } + + var got BackendKeyData + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled BackendKeyData struct doesn't match expected value") + } +} + +func TestJSONUnmarshalBackendKeyData32(t *testing.T) { + // 32-byte key as hex + data := []byte(`{"Type":"BackendKeyData","ProcessID":8864,"SecretKey":"0102030405060708091011121314151617181920212223242526272829303132"}`) + want := BackendKeyData{ + ProcessID: 8864, + SecretKey: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x30, 0x31, 0x32}, } var got BackendKeyData @@ -367,10 +385,28 @@ func TestJSONUnmarshalBind(t *testing.T) { } func TestJSONUnmarshalCancelRequest(t *testing.T) { - data := []byte(`{"Type":"CancelRequest","ProcessID":8864,"SecretKey":3641487067}`) + // SecretKey is now hex-encoded: d90caedb = 3641487067 in big-endian + data := []byte(`{"Type":"CancelRequest","ProcessID":8864,"SecretKey":"d90caedb"}`) + want := CancelRequest{ + ProcessID: 8864, + SecretKey: []byte{0xd9, 0x0c, 0xae, 0xdb}, + } + + var got CancelRequest + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CancelRequest struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCancelRequestLongKey(t *testing.T) { + // 32-byte key as hex + data := []byte(`{"Type":"CancelRequest","ProcessID":8864,"SecretKey":"0102030405060708091011121314151617181920212223242526272829303132"}`) want := CancelRequest{ ProcessID: 8864, - SecretKey: 3641487067, + SecretKey: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x30, 0x31, 0x32}, } var got CancelRequest diff --git a/pgproto3/negotiate_protocol_version.go b/pgproto3/negotiate_protocol_version.go new file mode 100644 index 000000000..6e96554f2 --- /dev/null +++ b/pgproto3/negotiate_protocol_version.go @@ -0,0 +1,86 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type NegotiateProtocolVersion struct { + NewestMinorProtocol uint32 + UnrecognizedOptions []string +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*NegotiateProtocolVersion) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *NegotiateProtocolVersion) Decode(src []byte) error { + if len(src) < 8 { + return &invalidMessageLenErr{messageType: "NegotiateProtocolVersion", expectedLen: 8, actualLen: len(src)} + } + + dst.NewestMinorProtocol = binary.BigEndian.Uint32(src[:4]) + optionCount := int(binary.BigEndian.Uint32(src[4:8])) + + rp := 8 + dst.UnrecognizedOptions = make([]string, 0, optionCount) + for i := 0; i < optionCount; i++ { + if rp >= len(src) { + return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"} + } + end := rp + for end < len(src) && src[end] != 0 { + end++ + } + if end >= len(src) { + return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"} + } + dst.UnrecognizedOptions = append(dst.UnrecognizedOptions, string(src[rp:end])) + rp = end + 1 + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *NegotiateProtocolVersion) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'v') + dst = pgio.AppendUint32(dst, src.NewestMinorProtocol) + dst = pgio.AppendUint32(dst, uint32(len(src.UnrecognizedOptions))) + for _, option := range src.UnrecognizedOptions { + dst = append(dst, option...) + dst = append(dst, 0) + } + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src NegotiateProtocolVersion) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + NewestMinorProtocol uint32 + UnrecognizedOptions []string + }{ + Type: "NegotiateProtocolVersion", + NewestMinorProtocol: src.NewestMinorProtocol, + UnrecognizedOptions: src.UnrecognizedOptions, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *NegotiateProtocolVersion) UnmarshalJSON(data []byte) error { + var msg struct { + NewestMinorProtocol uint32 + UnrecognizedOptions []string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.NewestMinorProtocol = msg.NewestMinorProtocol + dst.UnrecognizedOptions = msg.UnrecognizedOptions + return nil +} diff --git a/pgproto3/negotiate_protocol_version_test.go b/pgproto3/negotiate_protocol_version_test.go new file mode 100644 index 000000000..41469e3f7 --- /dev/null +++ b/pgproto3/negotiate_protocol_version_test.go @@ -0,0 +1,89 @@ +package pgproto3 + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNegotiateProtocolVersionDecode(t *testing.T) { + src := []byte{ + 0x00, 0x00, 0x00, 0x00, // NewestMinorProtocol: 0 + 0x00, 0x00, 0x00, 0x02, // Option count: 2 + 'o', 'p', 't', '1', 0x00, // "opt1" + 'o', 'p', 't', '2', 0x00, // "opt2" + } + + var msg NegotiateProtocolVersion + err := msg.Decode(src) + require.NoError(t, err) + + assert.Equal(t, uint32(0), msg.NewestMinorProtocol) + assert.Equal(t, []string{"opt1", "opt2"}, msg.UnrecognizedOptions) +} + +func TestNegotiateProtocolVersionDecodeNoOptions(t *testing.T) { + // Message: minor version 2, no unrecognized options + src := []byte{ + 0x00, 0x00, 0x00, 0x02, // NewestMinorProtocol: 2 + 0x00, 0x00, 0x00, 0x00, // Option count: 0 + } + + var msg NegotiateProtocolVersion + err := msg.Decode(src) + require.NoError(t, err) + assert.Equal(t, uint32(2), msg.NewestMinorProtocol) + assert.Equal(t, 0, len(msg.UnrecognizedOptions)) +} + +func TestNegotiateProtocolVersionEncode(t *testing.T) { + msg := NegotiateProtocolVersion{ + NewestMinorProtocol: 0, + UnrecognizedOptions: []string{"opt1", "opt2"}, + } + + buf, err := msg.Encode(nil) + require.NoError(t, err) + + expected := []byte{ + 'v', // message type + 0x00, 0x00, 0x00, 0x16, // length: 22 (4 for length + 4 + 4 + 5 + 5) + 0x00, 0x00, 0x00, 0x00, // NewestMinorProtocol: 0 + 0x00, 0x00, 0x00, 0x02, // Option count: 2 + 'o', 'p', 't', '1', 0x00, + 'o', 'p', 't', '2', 0x00, + } + + require.Equal(t, expected, buf) +} + +func TestNegotiateProtocolVersionJSON(t *testing.T) { + msg := NegotiateProtocolVersion{ + NewestMinorProtocol: 0, + UnrecognizedOptions: []string{"opt1"}, + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + var decoded NegotiateProtocolVersion + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, msg, decoded) +} + +func TestJSONUnmarshalNegotiateProtocolVersion(t *testing.T) { + data := []byte(`{"Type":"NegotiateProtocolVersion","NewestMinorProtocol":0,"UnrecognizedOptions":["opt1"]}`) + want := NegotiateProtocolVersion{ + NewestMinorProtocol: 0, + UnrecognizedOptions: []string{"opt1"}, + } + + var got NegotiateProtocolVersion + err := json.Unmarshal(data, &got) + require.NoError(t, err) + assert.Equal(t, want, got) +} diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go index 3af4587d8..6caab3ee4 100644 --- a/pgproto3/startup_message.go +++ b/pgproto3/startup_message.go @@ -10,7 +10,11 @@ import ( "github.com/jackc/pgx/v5/internal/pgio" ) -const ProtocolVersionNumber = 196608 // 3.0 +const ( + ProtocolVersion30 = 196608 // 3.0 + ProtocolVersion32 = 196610 // 3.2 + ProtocolVersionNumber = ProtocolVersion30 // Default is still 3.0 +) type StartupMessage struct { ProtocolVersion uint32 @@ -30,8 +34,8 @@ func (dst *StartupMessage) Decode(src []byte) error { dst.ProtocolVersion = binary.BigEndian.Uint32(src) rp := 4 - if dst.ProtocolVersion != ProtocolVersionNumber { - return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) + if dst.ProtocolVersion != ProtocolVersion30 && dst.ProtocolVersion != ProtocolVersion32 { + return fmt.Errorf("Bad startup message version number. Expected %d or %d, got %d", ProtocolVersion30, ProtocolVersion32, dst.ProtocolVersion) } dst.Parameters = make(map[string]string) From 7a143fdafa8028267f7d14cdfff7bbeda8cf4440 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Wed, 11 Feb 2026 20:04:31 +0100 Subject: [PATCH 2/2] Set protocol version separate Signed-off-by: Dirkjan Bussink --- pgconn/pgconn_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 448de49e6..0c8562def 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -4556,8 +4556,9 @@ func TestCancelRequestContextWatcherHandler(t *testing.T) { func TestConnectProtocolVersion32(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE") + " max_protocol_version=3.2") + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) + config.MaxProtocolVersion = "3.2" ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel()