Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/pgmock/pgmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ func WaitForClose() Step {

func AcceptUnauthenticatedConnRequestSteps() []Step {
return []Step{
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersion30, Parameters: map[string]string{}}),
SendMessage(&pgproto3.AuthenticationOk{}),
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: []byte{0, 0, 0, 0}}),
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
}
}
47 changes: 47 additions & 0 deletions pgconn/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ type Config struct {
// that you close on FATAL errors by returning false.
OnPgError PgErrorHandler

// MinProtocolVersion is the minimum acceptable PostgreSQL protocol version.
// If the server does not support at least this version, the connection will fail.
// Valid values: "3.0", "3.2", "latest". Defaults to "3.0".
MinProtocolVersion string

// MaxProtocolVersion is the maximum PostgreSQL protocol version to request from the server.
// Valid values: "3.0", "3.2", "latest". Defaults to "3.0" for compatibility.
MaxProtocolVersion string

createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}

Expand Down Expand Up @@ -213,6 +222,8 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// PGCONNECT_TIMEOUT
// PGTARGETSESSIONATTRS
// PGTZ
// PGMINPROTOCOLVERSION
// PGMAXPROTOCOLVERSION
//
// See http://www.postgresql.org/docs/current/static/libpq-envars.html for details on the meaning of environment variables.
//
Expand Down Expand Up @@ -338,6 +349,8 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
"target_session_attrs": {},
"service": {},
"servicefile": {},
"min_protocol_version": {},
"max_protocol_version": {},
}

// Adding kerberos configuration
Expand Down Expand Up @@ -430,6 +443,27 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
}

minProto, err := parseProtocolVersion(settings["min_protocol_version"])
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "invalid min_protocol_version", err: err}
}
maxProto, err := parseProtocolVersion(settings["max_protocol_version"])
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "invalid max_protocol_version", err: err}
}
if minProto > maxProto {
return nil, &ParseConfigError{ConnString: connString, msg: "min_protocol_version cannot be greater than max_protocol_version"}
}

config.MinProtocolVersion = settings["min_protocol_version"]
config.MaxProtocolVersion = settings["max_protocol_version"]
if config.MinProtocolVersion == "" {
config.MinProtocolVersion = "3.0"
}
if config.MaxProtocolVersion == "" {
config.MaxProtocolVersion = "3.0"
}

return config, nil
}

Expand Down Expand Up @@ -467,6 +501,8 @@ func parseEnvSettings() map[string]string {
"PGSERVICEFILE": "servicefile",
"PGTZ": "timezone",
"PGOPTIONS": "options",
"PGMINPROTOCOLVERSION": "min_protocol_version",
"PGMAXPROTOCOLVERSION": "max_protocol_version",
}

for envname, realname := range nameMap {
Expand Down Expand Up @@ -960,3 +996,14 @@ func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn

return nil
}

func parseProtocolVersion(s string) (uint32, error) {
switch s {
case "", "3.0":
return pgproto3.ProtocolVersion30, nil
case "3.2", "latest":
return pgproto3.ProtocolVersion32, nil
default:
return 0, fmt.Errorf("invalid protocol version: %q", s)
}
}
116 changes: 116 additions & 0 deletions pgconn/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1198,3 +1198,119 @@ func TestParseConfigExplicitEmptyUserDefaultsToOSUser(t *testing.T) {
)
}
}

func TestParseConfigProtocolVersion(t *testing.T) {
tests := []struct {
name string
connString string
envMin string
envMax string
expectedMin string
expectedMax string
expectError bool
expectedErrContain string
}{
{
name: "defaults to 3.0",
connString: "postgres://localhost/test",
expectedMin: "3.0",
expectedMax: "3.0",
},
{
name: "max_protocol_version=3.2",
connString: "postgres://localhost/test?max_protocol_version=3.2",
expectedMin: "3.0",
expectedMax: "3.2",
},
{
name: "min_protocol_version=3.2 and max_protocol_version=3.2",
connString: "postgres://localhost/test?min_protocol_version=3.2&max_protocol_version=3.2",
expectedMin: "3.2",
expectedMax: "3.2",
},
{
name: "max_protocol_version=latest",
connString: "postgres://localhost/test?max_protocol_version=latest",
expectedMin: "3.0",
expectedMax: "latest",
},
{
name: "min and max = latest",
connString: "postgres://localhost/test?min_protocol_version=latest&max_protocol_version=latest",
expectedMin: "latest",
expectedMax: "latest",
},
{
name: "invalid min_protocol_version",
connString: "postgres://localhost/test?min_protocol_version=2.0",
expectError: true,
expectedErrContain: "invalid min_protocol_version",
},
{
name: "invalid max_protocol_version",
connString: "postgres://localhost/test?max_protocol_version=4.0",
expectError: true,
expectedErrContain: "invalid max_protocol_version",
},
{
name: "min > max",
connString: "postgres://localhost/test?min_protocol_version=3.2&max_protocol_version=3.0",
expectError: true,
expectedErrContain: "min_protocol_version cannot be greater than max_protocol_version",
},
{
name: "environment variable PGMINPROTOCOLVERSION without matching max fails",
connString: "postgres://localhost/test",
envMin: "3.2",
expectError: true,
expectedErrContain: "min_protocol_version cannot be greater than max_protocol_version",
},
{
name: "environment variables PGMINPROTOCOLVERSION and PGMAXPROTOCOLVERSION together",
connString: "postgres://localhost/test",
envMin: "3.2",
envMax: "3.2",
expectedMin: "3.2",
expectedMax: "3.2",
},
{
name: "environment variable PGMAXPROTOCOLVERSION",
connString: "postgres://localhost/test",
envMax: "3.2",
expectedMin: "3.0",
expectedMax: "3.2",
},
{
name: "conn string overrides environment variable",
connString: "postgres://localhost/test?max_protocol_version=3.0",
envMax: "3.2",
expectedMin: "3.0",
expectedMax: "3.0",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Clear protocol version env vars
t.Setenv("PGMINPROTOCOLVERSION", "")
t.Setenv("PGMAXPROTOCOLVERSION", "")

if tt.envMin != "" {
t.Setenv("PGMINPROTOCOLVERSION", tt.envMin)
}
if tt.envMax != "" {
t.Setenv("PGMAXPROTOCOLVERSION", tt.envMax)
}

config, err := pgconn.ParseConfig(tt.connString)
if tt.expectError {
require.ErrorContains(t, err, tt.expectedErrContain)
return
}

require.NoError(t, err)
assert.Equal(t, tt.expectedMin, config.MinProtocolVersion, "MinProtocolVersion")
assert.Equal(t, tt.expectedMax, config.MaxProtocolVersion, "MaxProtocolVersion")
})
}
}
29 changes: 22 additions & 7 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ type NotificationHandler func(*PgConn, *Notification)
type PgConn struct {
conn net.Conn
pid uint32 // backend pid
secretKey uint32 // key to use to send a cancel query message to the server
secretKey []byte // key to use to send a cancel query message to the server
parameterStatuses map[string]string // parameters that have been reported by the server
txStatus byte
frontend *pgproto3.Frontend
Expand Down Expand Up @@ -319,6 +319,15 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo
return e
}

maxProtocolVersion, err := parseProtocolVersion(config.MaxProtocolVersion)
if err != nil {
return nil, newPerDialConnectError("invalid max_protocol_version", err)
}
minProtocolVersion, err := parseProtocolVersion(config.MinProtocolVersion)
if err != nil {
return nil, newPerDialConnectError("invalid min_protocol_version", err)
}

pgConn.conn, err = config.DialFunc(ctx, connectConfig.network, connectConfig.address)
if err != nil {
return nil, newPerDialConnectError("dial error", err)
Expand Down Expand Up @@ -371,7 +380,7 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo
pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn)

startupMsg := pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber,
ProtocolVersion: maxProtocolVersion,
Parameters: make(map[string]string),
}

Expand Down Expand Up @@ -452,6 +461,12 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo
return pgConn, nil
case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse:
// handled by ReceiveMessage
case *pgproto3.NegotiateProtocolVersion:
serverVersion := pgproto3.ProtocolVersion30&0xFFFF0000 | uint32(msg.NewestMinorProtocol)
if serverVersion < minProtocolVersion {
pgConn.conn.Close()
return nil, newPerDialConnectError("server protocol version too low", nil)
}
case *pgproto3.ErrorResponse:
pgConn.conn.Close()
return nil, newPerDialConnectError("server error", ErrorResponseToPgError(msg))
Expand Down Expand Up @@ -641,7 +656,7 @@ func (pgConn *PgConn) TxStatus() byte {
}

// SecretKey returns the backend secret key used to send a cancel query message to the server.
func (pgConn *PgConn) SecretKey() uint32 {
func (pgConn *PgConn) SecretKey() []byte {
return pgConn.secretKey
}

Expand Down Expand Up @@ -1040,11 +1055,11 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
defer contextWatcher.Unwatch()
}

buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16)
buf := make([]byte, 12+len(pgConn.secretKey))
binary.BigEndian.PutUint32(buf[0:4], uint32(len(buf)))
binary.BigEndian.PutUint32(buf[4:8], 80877102)
binary.BigEndian.PutUint32(buf[8:12], pgConn.pid)
binary.BigEndian.PutUint32(buf[12:16], pgConn.secretKey)
copy(buf[12:], pgConn.secretKey)

if _, err := cancelConn.Write(buf); err != nil {
return fmt.Errorf("write to connection for cancellation: %w", err)
Expand Down Expand Up @@ -2077,7 +2092,7 @@ func (pgConn *PgConn) CustomData() map[string]any {
type HijackedConn struct {
Conn net.Conn
PID uint32 // backend pid
SecretKey uint32 // key to use to send a cancel query message to the server
SecretKey []byte // key to use to send a cancel query message to the server
ParameterStatuses map[string]string // parameters that have been reported by the server
TxStatus byte
Frontend *pgproto3.Frontend
Expand Down
41 changes: 38 additions & 3 deletions pgconn/pgconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,10 @@ func TestConnectTimeout(t *testing.T) {
t.Parallel()
script := &pgmock.Script{
Steps: []pgmock.Step{
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersion30, Parameters: map[string]string{}}),
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
pgmockWaitStep(time.Millisecond * 500),
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: []byte{0, 0, 0, 0}}),
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
},
}
Expand Down Expand Up @@ -4112,7 +4112,7 @@ func TestSNISupport(t *testing.T) {
}

srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)))
srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)))
srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: []byte{0, 0, 0, 0}}).Encode(nil)))
srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)))

serverSNINameChan <- sniHost
Expand Down Expand Up @@ -4553,3 +4553,38 @@ func TestCancelRequestContextWatcherHandler(t *testing.T) {
})
}
}

func TestConnectProtocolVersion32(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.MaxProtocolVersion = "3.2"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to use a constant instead of a literal?
Then usages could be found easier in IDEs and Agents.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't done that since I don't see it used for example for ssl mode flags. But I think it's really a project style decision then, more a question for @jackc then I think.


ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

pgConn, err := pgconn.ConnectConfig(ctx, config)
require.NoError(t, err)
defer closeConn(t, pgConn)

if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("CockroachDB does not support protocol version 3.2 yet")
}

result, err := pgConn.Exec(context.Background(), "show server_version_num").ReadAll()
require.NoError(t, err)
require.Len(t, result, 1)
require.Len(t, result[0].Rows, 1)
require.Len(t, result[0].Rows[0], 1)
pgVersion, err := strconv.Atoi(string(result[0].Rows[0][0]))
require.NoError(t, err)

// Check secret key length - PG18+ returns 32 bytes, older versions return 4
secretKey := pgConn.SecretKey()

if pgVersion < 180000 {
assert.Len(t, secretKey, 4)
} else {
assert.Len(t, secretKey, 32)
}
}
2 changes: 1 addition & 1 deletion pgproto3/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
code := binary.BigEndian.Uint32(buf)

switch code {
case ProtocolVersionNumber:
case ProtocolVersion30, ProtocolVersion32:
err = b.startupMessage.Decode(buf)
if err != nil {
return nil, err
Expand Down
Loading
Loading