From 7e90184c7bb898953b8c234ccd8965e76fb2a450 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Mon, 22 Dec 2025 16:52:53 -0800 Subject: [PATCH 1/6] init --- pgutils/connector.go | 62 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/pgutils/connector.go b/pgutils/connector.go index 21dce91..f757874 100644 --- a/pgutils/connector.go +++ b/pgutils/connector.go @@ -6,6 +6,8 @@ import ( "fmt" "log" "net/url" + "slices" + "strings" "time" "database/sql" @@ -20,6 +22,8 @@ import ( "github.com/lib/pq" ) +const defaultPostgresPort = "5432" + type baseConnectionStringProvider interface { getBaseConnectionString(ctx context.Context) (string, error) } @@ -212,3 +216,61 @@ func MustConnectDB(conn *PostgresqlConnector) *sqlx.DB { return db } +func NewPostgresqlConnectorFromDSN(ctx context.Context, dsn string) (*PostgresqlConnector, error) { + u, err := url.Parse(dsn) + if err != nil || u.Scheme != "postgres+rds-iam" { + // Not our custom scheme: hand off to existing DSN handling. + return NewPostgresqlConnectorFromConnectionString(dsn), nil + } + + user := "" + if u.User != nil { + user = u.User.Username() + if _, hasPw := u.User.Password(); hasPw { + return nil, fmt.Errorf("postgres+rds-iam DSN must not include a password") + } + } + if user == "" { + return nil, fmt.Errorf("postgres+rds-iam DSN missing username") + } + + host := u.Hostname() + if host == "" { + return nil, fmt.Errorf("postgres+rds-iam DSN missing host") + } + + port := u.Port() + if port == "" { + port = defaultPostgresPort + } + + // Match libpq/psql defaulting: if dbname isn't specified, dbname defaults to username. + dbName := strings.TrimPrefix(u.Path, "/") + if dbName == "" { + dbName = user + } + + q := u.Query() + supportedParams := []string{"assume_role_arn", "assume_role_session_name"} + for k := range q { + if !slices.Contains(supportedParams, k) { + return nil, fmt.Errorf("postgres+rds-iam DSN has unsupported query parameter: %s", k) + } + } + + assumeRoleARN := q.Get("assume_role_arn") + assumeRoleSessionName := q.Get("assume_role_session_name") + + cfg := &IAMAuthConfig{ + RDSEndpoint: host + ":" + port, + User: user, + Database: dbName, + } + + if assumeRoleARN != "" { + cfg.AssumeRoleARN = assumeRoleARN + cfg.AssumeRoleSessionName = assumeRoleSessionName + } + + return NewPostgresqlConnectorWithIAMAuth(ctx, cfg) +} From fe2a24e573261cad057f0bc044679cb364e33fef Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Tue, 23 Dec 2025 02:47:08 -0800 Subject: [PATCH 2/6] Error handling --- pgutils/connector.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pgutils/connector.go b/pgutils/connector.go index f757874..da1c603 100644 --- a/pgutils/connector.go +++ b/pgutils/connector.go @@ -218,7 +218,11 @@ func MustConnectDB(conn *PostgresqlConnector) *sqlx.DB { func NewPostgresqlConnectorFromDSN(ctx context.Context, dsn string) (*PostgresqlConnector, error) { u, err := url.Parse(dsn) - if err != nil || u.Scheme != "postgres+rds-iam" { + if err != nil { + return nil, fmt.Errorf("filed to parse DSN: %w", err) + } + + if u.Scheme != "postgres+rds-iam" { // Not our custom scheme: hand off to existing DSN handling. return NewPostgresqlConnectorFromConnectionString(dsn), nil } From ddb9337edb94125e82631b5b883e936bc5fb122d Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Tue, 23 Dec 2025 03:23:10 -0800 Subject: [PATCH 3/6] Some comments --- pgutils/connector.go | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/pgutils/connector.go b/pgutils/connector.go index da1c603..71a10c0 100644 --- a/pgutils/connector.go +++ b/pgutils/connector.go @@ -216,14 +216,27 @@ func MustConnectDB(conn *PostgresqlConnector) *sqlx.DB { return db } +// NewPostgresqlConnectorFromDSN constructs a PostgresqlConnector from either a normal +// Postgres DSN/connection string or the custom postgres+rds-iam DSN used for RDS IAM auth. +// +// IAM example 1: postgres+rds-iam://@[:]/ +// +// Optional query params (for cross-account IAM): +// - assume_role_arn: role ARN to assume. +// - assume_role_session_name: only used when assume_role_arn is set; defaults to "pgutils-rds-iam" if omitted. +// +// IAM example 2: postgres+rds-iam://@[:]/?assume_role_arn=...&assume_role_session_name=... func NewPostgresqlConnectorFromDSN(ctx context.Context, dsn string) (*PostgresqlConnector, error) { + if dsn == "" { + return nil, errors.New("DSN cannot be empty") + } + u, err := url.Parse(dsn) if err != nil { - return nil, fmt.Errorf("filed to parse DSN: %w", err) + return nil, fmt.Errorf("failed to parse DSN: %w", err) } - if u.Scheme != "postgres+rds-iam" { - // Not our custom scheme: hand off to existing DSN handling. + if u.Scheme != "postgres+rds-iam" { // Not our custom scheme: hand off to existing DSN handling. return NewPostgresqlConnectorFromConnectionString(dsn), nil } @@ -255,25 +268,23 @@ func NewPostgresqlConnectorFromDSN(ctx context.Context, dsn string) (*Postgresql } q := u.Query() - supportedParams := []string{"assume_role_arn", "assume_role_session_name"} + supportedParams := []string{"assume_role_arn", "assume_role_session_name", "assume_role_external_id", "assume_role_duration"} for k := range q { if !slices.Contains(supportedParams, k) { return nil, fmt.Errorf("postgres+rds-iam DSN has unsupported query parameter: %s", k) } } - assumeRoleARN := q.Get("assume_role_arn") - assumeRoleSessionName := q.Get("assume_role_session_name") - cfg := &IAMAuthConfig{ RDSEndpoint: host + ":" + port, User: user, Database: dbName, } + assumeRoleARN := q.Get("assume_role_arn") if assumeRoleARN != "" { cfg.AssumeRoleARN = assumeRoleARN - cfg.AssumeRoleSessionName = assumeRoleSessionName + cfg.AssumeRoleSessionName = q.Get("assume_role_session_name") } return NewPostgresqlConnectorWithIAMAuth(ctx, cfg) From 3cd2b7f8a728bf7120fecbd19677e00c1f616871 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Tue, 23 Dec 2025 03:30:39 -0800 Subject: [PATCH 4/6] Remove unsupported (for now) parameters --- pgutils/connector.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgutils/connector.go b/pgutils/connector.go index 71a10c0..9bdf605 100644 --- a/pgutils/connector.go +++ b/pgutils/connector.go @@ -268,7 +268,7 @@ func NewPostgresqlConnectorFromDSN(ctx context.Context, dsn string) (*Postgresql } q := u.Query() - supportedParams := []string{"assume_role_arn", "assume_role_session_name", "assume_role_external_id", "assume_role_duration"} + supportedParams := []string{"assume_role_arn", "assume_role_session_name"} for k := range q { if !slices.Contains(supportedParams, k) { return nil, fmt.Errorf("postgres+rds-iam DSN has unsupported query parameter: %s", k) From b8b7cfe5a3bdf6220798d8e62415a760e3123f53 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Thu, 29 Jan 2026 18:08:08 -0800 Subject: [PATCH 5/6] Refactor pgutils to be only URL driven. --- pgutils/connector.go | 326 +++++++++++++++++++------------------------ pgutils/listener.go | 7 +- 2 files changed, 147 insertions(+), 186 deletions(-) diff --git a/pgutils/connector.go b/pgutils/connector.go index 9bdf605..89e8c15 100644 --- a/pgutils/connector.go +++ b/pgutils/connector.go @@ -2,13 +2,11 @@ package pgutils import ( "context" - "errors" "fmt" "log" + "net" "net/url" - "slices" "strings" - "time" "database/sql" "database/sql/driver" @@ -24,182 +22,115 @@ import ( const defaultPostgresPort = "5432" -type baseConnectionStringProvider interface { - getBaseConnectionString(ctx context.Context) (string, error) +// ConnectionStringProvider returns a Postgres connection string for use by clients +// that need a DSN (e.g., pq.Listener) or to build a connector. +type ConnectionStringProvider interface { + ConnectionString(ctx context.Context) (string, error) } -type PostgresqlConnector struct { - baseConnectionStringProvider - searchPath string -} - -func (conn *PostgresqlConnector) WithSearchPath(searchPath string) *PostgresqlConnector { - return &PostgresqlConnector{ - baseConnectionStringProvider: conn.baseConnectionStringProvider, - searchPath: searchPath, - } -} - -func (conn *PostgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) { - dsn, err := conn.GetConnectionString(ctx) - if err != nil { - return nil, fmt.Errorf("get connection string: %w", err) +// NewConnectionStringProviderFromURL constructs a ConnectionStringProvider from a URL-form DSN. +// +// Standard Postgres example: +// +// postgres://user:pass@host:5432/dbname?sslmode=require +// +// IAM example 1: +// +// postgres+rds-iam://user@host:5432/dbname +// +// IAM example 2 (cross-account): +// +// postgres+rds-iam://user@host:5432/dbname?assume_role_arn=...&assume_role_session_name=... +// +// For postgres+rds-iam, the provider generates a fresh IAM auth token on each ConnectionString(ctx) call. +func NewConnectionStringProviderFromURL(ctx context.Context, rawURL string) (ConnectionStringProvider, error) { + if strings.TrimSpace(rawURL) == "" { + return nil, fmt.Errorf("rawURL cannot be empty") } - pqConnector, err := pq.NewConnector(dsn) + u, err := url.Parse(rawURL) if err != nil { - return nil, fmt.Errorf("create pq connector: %w", err) + return nil, fmt.Errorf("error parsing URL: %w", err) } - return pqConnector.Connect(ctx) + switch u.Scheme { + case "postgres", "postgresql": + return &staticConnectionStringProvider{connectionString: u.String()}, nil + case "postgres+rds-iam": + return newIAMConnectionStringProviderFromURL(ctx, u) + default: + return nil, fmt.Errorf("unsupported URL scheme: %s", u.Scheme) + } } -func (conn *PostgresqlConnector) GetConnectionString(ctx context.Context) (string, error) { - dsn, err := conn.getBaseConnectionString(ctx) +// NewConnectorFromURL constructs a driver.Connector from a URL-form DSN. +// +// Standard Postgres example: +// +// postgres://user:pass@host:5432/dbname +// +// IAM example 1: +// +// postgres+rds-iam://user@host:5432/dbname +// +// IAM example 2 (cross-account): +// +// postgres+rds-iam://user@host:5432/dbname?assume_role_arn=...&assume_role_session_name=... +// +// For postgres+rds-iam, each Connect(ctx) call uses a fresh IAM auth token. +func NewConnectorFromURL(ctx context.Context, rawURL string) (driver.Connector, error) { + provider, err := NewConnectionStringProviderFromURL(ctx, rawURL) if err != nil { - return "", fmt.Errorf("get base connection string: %w", err) - } - if conn.searchPath == "" { - return dsn, nil + return nil, err } + return &postgresqlConnector{connectionStringProvider: provider}, nil +} - // Add search path - u, err := url.Parse(dsn) - if err != nil { - return "", fmt.Errorf("parse DSN URL: %w", err) +// AddSearchPathToURL returns a copy of u with search_path set in the query string. +// It returns an error if search_path is already present. +func AddSearchPathToURL(u *url.URL, searchPath string) (*url.URL, error) { + if u == nil { + return nil, fmt.Errorf("URL cannot be nil") } - q := u.Query() - if v := q.Get("search_path"); v != "" { - return "", fmt.Errorf("search_path already set to %q", v) + if searchPath == "" { + uCopy := *u + return &uCopy, nil } - q.Set("search_path", conn.searchPath) // url.Values will percent-encode commas as needed - u.RawQuery = q.Encode() - return u.String(), nil -} -func (c *PostgresqlConnector) Driver() driver.Driver { - return &pq.Driver{} -} - -type staticConnectionStringProvider struct { - connectionString string -} - -func (p *staticConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) { - return p.connectionString, nil -} - -func NewPostgresqlConnectorFromConnectionString(connectionString string) *PostgresqlConnector { - return &PostgresqlConnector{ - baseConnectionStringProvider: &staticConnectionStringProvider{connectionString}, + uCopy := *u + q := uCopy.Query() + if v := q.Get("search_path"); v != "" { + return nil, fmt.Errorf("search_path already set to %q", v) } + q.Set("search_path", searchPath) + uCopy.RawQuery = q.Encode() + return &uCopy, nil } -type IAMAuthConfig struct { - RDSEndpoint string - User string - Database string - - // Optional: cross-account role assumption. - // Set this to a role ARN in the RDS account (Account A) that has rds-db:connect. - AssumeRoleARN string - - // Optional: if your trust policy requires an external ID. - AssumeRoleExternalID string - - // Optional: override the default session name. - AssumeRoleSessionName string - - // Optional: override STS assume role duration. - // If zero, SDK default is used. - AssumeRoleDuration time.Duration -} - -type iamAuthConnectionStringProvider struct { - IAMAuthConfig - - region string - creds aws.CredentialsProvider +type postgresqlConnector struct { + connectionStringProvider ConnectionStringProvider } -func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) { - authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.region, p.User, p.creds) +func (c *postgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) { + dsn, err := c.connectionStringProvider.ConnectionString(ctx) if err != nil { - return "", fmt.Errorf("building auth token: %w", err) + return nil, fmt.Errorf("error getting connection string from provider: %w", err) } - log.Printf("Signing RDS IAM token for \n Endpoint: %s \n User: %s \n Database: %s", p.RDSEndpoint, p.User, p.Database) - - dsnURL := &url.URL{ - Scheme: "postgresql", - User: url.UserPassword(p.User, authToken), - Host: p.RDSEndpoint, - Path: "/" + p.Database, - } - - return dsnURL.String(), nil -} - -func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) (*PostgresqlConnector, error) { - if cfg.RDSEndpoint == "" || cfg.User == "" || cfg.Database == "" { - return nil, errors.New("RDS endpoint, user, and database are required") - } - - awsCfg, err := awsconfig.LoadDefaultConfig(ctx) + pqConnector, err := pq.NewConnector(dsn) if err != nil { - return nil, fmt.Errorf("load AWS config: %w", err) - } - - if awsCfg.Region == "" { - return nil, errors.New("AWS region is not configured") + return nil, fmt.Errorf("error creating pq connector: %w", err) } - creds := awsCfg.Credentials - - // Cross-account support: - // If AssumeRoleARN is set, assume a role in the RDS account (Account A) - // using the ECS task role creds from Account B as the source credentials. - if cfg.AssumeRoleARN != "" { - log.Printf("RDS IAM Assuming Role: %s for \n Endpoint: %s \n User: %s \n Database: %s", cfg.AssumeRoleARN, cfg.RDSEndpoint, cfg.User, cfg.Database) - stsClient := sts.NewFromConfig(awsCfg) - - sessionName := cfg.AssumeRoleSessionName - if sessionName == "" { - sessionName = "pgutils-rds-iam" - } - - assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, cfg.AssumeRoleARN, func(assumeRoleOpts *stscreds.AssumeRoleOptions) { - assumeRoleOpts.RoleSessionName = sessionName - - if cfg.AssumeRoleExternalID != "" { - assumeRoleOpts.ExternalID = aws.String(cfg.AssumeRoleExternalID) - } - - if cfg.AssumeRoleDuration != 0 { - assumeRoleOpts.Duration = cfg.AssumeRoleDuration - } - }) - - // Cache to avoid calling STS too frequently. - creds = aws.NewCredentialsCache(assumeProvider) - } - - return &PostgresqlConnector{ - baseConnectionStringProvider: &iamAuthConnectionStringProvider{ - IAMAuthConfig: *cfg, - region: awsCfg.Region, - creds: creds, - }, - }, nil + return pqConnector.Connect(ctx) } -// Provides missing sqlx.OpenDB -func OpenDB(conn *PostgresqlConnector) *sqlx.DB { - sqlDB := sql.OpenDB(conn) - return sqlx.NewDb(sqlDB, "postgres") +func (c *postgresqlConnector) Driver() driver.Driver { + return &pq.Driver{} } // ConnectDB opens a connection using the connector and verifies it with a ping -func ConnectDB(conn *PostgresqlConnector) (*sqlx.DB, error) { - db := OpenDB(conn) +func ConnectDB(conn driver.Connector) (*sqlx.DB, error) { + sqlDB := sql.OpenDB(conn) + db := sqlx.NewDb(sqlDB, "postgres") if err := db.Ping(); err != nil { db.Close() return nil, err @@ -208,7 +139,7 @@ func ConnectDB(conn *PostgresqlConnector) (*sqlx.DB, error) { } // MustConnectDB is like ConnectDB but panics on error -func MustConnectDB(conn *PostgresqlConnector) *sqlx.DB { +func MustConnectDB(conn driver.Connector) *sqlx.DB { db, err := ConnectDB(conn) if err != nil { panic(err) @@ -216,44 +147,54 @@ func MustConnectDB(conn *PostgresqlConnector) *sqlx.DB { return db } -// NewPostgresqlConnectorFromDSN constructs a PostgresqlConnector from either a normal -// Postgres DSN/connection string or the custom postgres+rds-iam DSN used for RDS IAM auth. -// -// IAM example 1: postgres+rds-iam://@[:]/ -// -// Optional query params (for cross-account IAM): -// - assume_role_arn: role ARN to assume. -// - assume_role_session_name: only used when assume_role_arn is set; defaults to "pgutils-rds-iam" if omitted. -// -// IAM example 2: postgres+rds-iam://@[:]/?assume_role_arn=...&assume_role_session_name=... -func NewPostgresqlConnectorFromDSN(ctx context.Context, dsn string) (*PostgresqlConnector, error) { - if dsn == "" { - return nil, errors.New("DSN cannot be empty") - } +type staticConnectionStringProvider struct { + connectionString string +} + +func (p *staticConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) { + return p.connectionString, nil +} + +type rdsIAMConnectionStringProvider struct { + RDSEndpoint string + Region string + User string + Database string + CredentialsProvider aws.CredentialsProvider +} - u, err := url.Parse(dsn) +func (p *rdsIAMConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) { + authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.Region, p.User, p.CredentialsProvider) if err != nil { - return nil, fmt.Errorf("failed to parse DSN: %w", err) + return "", fmt.Errorf("error building auth token: %w", err) } + log.Printf("Signing RDS IAM token for Endpoint: %s User: %s Database: %s", p.RDSEndpoint, p.User, p.Database) - if u.Scheme != "postgres+rds-iam" { // Not our custom scheme: hand off to existing DSN handling. - return NewPostgresqlConnectorFromConnectionString(dsn), nil + dsnURL := &url.URL{ + Scheme: "postgresql", + User: url.UserPassword(p.User, authToken), + Host: p.RDSEndpoint, + Path: "/" + p.Database, } + return dsnURL.String(), nil +} + +func newIAMConnectionStringProviderFromURL(ctx context.Context, u *url.URL) (ConnectionStringProvider, error) { user := "" if u.User != nil { user = u.User.Username() if _, hasPw := u.User.Password(); hasPw { - return nil, fmt.Errorf("postgres+rds-iam DSN must not include a password") + return nil, fmt.Errorf("postgres+rds-iam URL must not include a password") } } if user == "" { - return nil, fmt.Errorf("postgres+rds-iam DSN missing username") + return nil, fmt.Errorf("postgres+rds-iam URL missing username") } host := u.Hostname() if host == "" { - return nil, fmt.Errorf("postgres+rds-iam DSN missing host") + return nil, fmt.Errorf("postgres+rds-iam URL missing host") } port := u.Port() @@ -268,24 +209,45 @@ func NewPostgresqlConnectorFromDSN(ctx context.Context, dsn string) (*Postgresql } q := u.Query() - supportedParams := []string{"assume_role_arn", "assume_role_session_name"} + supportedParams := map[string]struct{}{ + "assume_role_arn": {}, + "assume_role_session_name": {}, + } for k := range q { - if !slices.Contains(supportedParams, k) { - return nil, fmt.Errorf("postgres+rds-iam DSN has unsupported query parameter: %s", k) + if _, ok := supportedParams[k]; !ok { + return nil, fmt.Errorf("postgres+rds-iam URL has unsupported query parameter: %s", k) } } - cfg := &IAMAuthConfig{ - RDSEndpoint: host + ":" + port, - User: user, - Database: dbName, + awsCfg, err := awsconfig.LoadDefaultConfig(ctx) + if err != nil { + return nil, fmt.Errorf("load AWS config: %w", err) + } + + if awsCfg.Region == "" { + return nil, fmt.Errorf("AWS region is not configured") } + creds := awsCfg.Credentials assumeRoleARN := q.Get("assume_role_arn") if assumeRoleARN != "" { - cfg.AssumeRoleARN = assumeRoleARN - cfg.AssumeRoleSessionName = q.Get("assume_role_session_name") + log.Printf("RDS IAM Assuming Role: %s for Host: %s User: %s Database: %s", assumeRoleARN, host, user, dbName) + stsClient := sts.NewFromConfig(awsCfg) + sessionName := q.Get("assume_role_session_name") + if sessionName == "" { + sessionName = "pgutils-rds-iam" + } + assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, assumeRoleARN, func(opts *stscreds.AssumeRoleOptions) { + opts.RoleSessionName = sessionName + }) + creds = aws.NewCredentialsCache(assumeProvider) } - return NewPostgresqlConnectorWithIAMAuth(ctx, cfg) + return &rdsIAMConnectionStringProvider{ + Region: awsCfg.Region, + RDSEndpoint: net.JoinHostPort(host, port), + User: user, + Database: dbName, + CredentialsProvider: creds, + }, nil } diff --git a/pgutils/listener.go b/pgutils/listener.go index 958462c..d1a7d06 100644 --- a/pgutils/listener.go +++ b/pgutils/listener.go @@ -69,7 +69,7 @@ func listenerEventToString(t pq.ListenerEventType) string { // The callback is invoked from the listener goroutine; it MUST NOT block // for long periods. If you need to do heavy work, offload it to another // goroutine. -func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string, callback func(*pq.Notification), onClose func()) error { +func Listen(ctx context.Context, provider ConnectionStringProvider, pgChannelName string, callback func(*pq.Notification), onClose func()) error { if callback == nil { return fmt.Errorf("listener callback cannot be nil") } @@ -77,9 +77,9 @@ func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string reconnectEventCh := make(chan struct{}, 1) // We just need a single reconnect event to trigger, so buffer size of 1 makeListener := func() (*pq.Listener, error) { - url, err := conn.GetConnectionString(ctx) + url, err := provider.ConnectionString(ctx) if err != nil { - return nil, fmt.Errorf("get url: %w", err) + return nil, fmt.Errorf("error getting connection string from provider: %w", err) } cb := func(t pq.ListenerEventType, e error) { @@ -174,4 +174,3 @@ func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string return nil } - From 5d0d0a3193a567c5dabf61b6fd9f20622ce50797 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Thu, 29 Jan 2026 20:03:54 -0800 Subject: [PATCH 6/6] Update add to search path --- pgutils/connector.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/pgutils/connector.go b/pgutils/connector.go index 89e8c15..30b314c 100644 --- a/pgutils/connector.go +++ b/pgutils/connector.go @@ -87,23 +87,31 @@ func NewConnectorFromURL(ctx context.Context, rawURL string) (driver.Connector, // AddSearchPathToURL returns a copy of u with search_path set in the query string. // It returns an error if search_path is already present. -func AddSearchPathToURL(u *url.URL, searchPath string) (*url.URL, error) { +func AddSearchPathToURL(rawURL string, searchPath string) (string, error) { + if strings.TrimSpace(rawURL) == "" { + return "", fmt.Errorf("rawURL cannot be empty") + } + u, err := url.Parse(rawURL) + if err != nil { + return "", fmt.Errorf("error parsing URL: %w", err) + } + if u == nil { - return nil, fmt.Errorf("URL cannot be nil") + return "", fmt.Errorf("URL cannot be nil") } if searchPath == "" { uCopy := *u - return &uCopy, nil + return uCopy.String(), nil } uCopy := *u q := uCopy.Query() if v := q.Get("search_path"); v != "" { - return nil, fmt.Errorf("search_path already set to %q", v) + return "", fmt.Errorf("search_path already set to %q", v) } q.Set("search_path", searchPath) uCopy.RawQuery = q.Encode() - return &uCopy, nil + return uCopy.String(), nil } type postgresqlConnector struct { @@ -231,12 +239,12 @@ func newIAMConnectionStringProviderFromURL(ctx context.Context, u *url.URL) (Con creds := awsCfg.Credentials assumeRoleARN := q.Get("assume_role_arn") if assumeRoleARN != "" { - log.Printf("RDS IAM Assuming Role: %s for Host: %s User: %s Database: %s", assumeRoleARN, host, user, dbName) stsClient := sts.NewFromConfig(awsCfg) sessionName := q.Get("assume_role_session_name") if sessionName == "" { sessionName = "pgutils-rds-iam" } + log.Printf("RDS IAM Assuming Role: %s with session name: %s for Host: %s User: %s Database: %s", assumeRoleARN, sessionName, host, user, dbName) assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, assumeRoleARN, func(opts *stscreds.AssumeRoleOptions) { opts.RoleSessionName = sessionName })