Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 173 additions & 126 deletions pgutils/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package pgutils

import (
"context"
"errors"
"fmt"
"log"
"net"
"net/url"
"time"
"strings"

"database/sql"
"database/sql/driver"
Expand All @@ -20,109 +20,163 @@ import (
"github.com/lib/pq"
)

type baseConnectionStringProvider interface {
getBaseConnectionString(ctx context.Context) (string, error)
}

type PostgresqlConnector struct {
baseConnectionStringProvider
searchPath string
}
const defaultPostgresPort = "5432"

func (conn *PostgresqlConnector) WithSearchPath(searchPath string) *PostgresqlConnector {
return &PostgresqlConnector{
baseConnectionStringProvider: conn.baseConnectionStringProvider,
searchPath: searchPath,
}
// ConnectionStringProvider returns a Postgres connection string for use by clients
// that need a DSN (e.g., pq.Listener) or to build a connector.
type ConnectionStringProvider interface {
ConnectionString(ctx context.Context) (string, error)
}

func (conn *PostgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) {
dsn, err := conn.GetConnectionString(ctx)
if err != nil {
return nil, fmt.Errorf("get connection string: %w", err)
// NewConnectionStringProviderFromURL constructs a ConnectionStringProvider from a URL-form DSN.
//
// Standard Postgres example:
//
// postgres://user:pass@host:5432/dbname?sslmode=require
//
// IAM example 1:
//
// postgres+rds-iam://user@host:5432/dbname
//
// IAM example 2 (cross-account):
//
// postgres+rds-iam://user@host:5432/dbname?assume_role_arn=...&assume_role_session_name=...
//
// For postgres+rds-iam, the provider generates a fresh IAM auth token on each ConnectionString(ctx) call.
func NewConnectionStringProviderFromURL(ctx context.Context, rawURL string) (ConnectionStringProvider, error) {
if strings.TrimSpace(rawURL) == "" {
return nil, fmt.Errorf("rawURL cannot be empty")
}
pqConnector, err := pq.NewConnector(dsn)
u, err := url.Parse(rawURL)
if err != nil {
return nil, fmt.Errorf("create pq connector: %w", err)
return nil, fmt.Errorf("error parsing URL: %w", err)
}

return pqConnector.Connect(ctx)
switch u.Scheme {
case "postgres", "postgresql":
return &staticConnectionStringProvider{connectionString: u.String()}, nil
case "postgres+rds-iam":
return newIAMConnectionStringProviderFromURL(ctx, u)
default:
return nil, fmt.Errorf("unsupported URL scheme: %s", u.Scheme)
}
}

func (conn *PostgresqlConnector) GetConnectionString(ctx context.Context) (string, error) {
dsn, err := conn.getBaseConnectionString(ctx)
// NewConnectorFromURL constructs a driver.Connector from a URL-form DSN.
//
// Standard Postgres example:
//
// postgres://user:pass@host:5432/dbname
//
// IAM example 1:
//
// postgres+rds-iam://user@host:5432/dbname
//
// IAM example 2 (cross-account):
//
// postgres+rds-iam://user@host:5432/dbname?assume_role_arn=...&assume_role_session_name=...
//
// For postgres+rds-iam, each Connect(ctx) call uses a fresh IAM auth token.
func NewConnectorFromURL(ctx context.Context, rawURL string) (driver.Connector, error) {
provider, err := NewConnectionStringProviderFromURL(ctx, rawURL)
if err != nil {
return "", fmt.Errorf("get base connection string: %w", err)
}
if conn.searchPath == "" {
return dsn, nil
return nil, err
}
return &postgresqlConnector{connectionStringProvider: provider}, nil
}

// Add search path
u, err := url.Parse(dsn)
// AddSearchPathToURL returns a copy of u with search_path set in the query string.
// It returns an error if search_path is already present.
func AddSearchPathToURL(rawURL string, searchPath string) (string, error) {
if strings.TrimSpace(rawURL) == "" {
return "", fmt.Errorf("rawURL cannot be empty")
}
u, err := url.Parse(rawURL)
if err != nil {
return "", fmt.Errorf("parse DSN URL: %w", err)
return "", fmt.Errorf("error parsing URL: %w", err)
}
q := u.Query()

if u == nil {
return "", fmt.Errorf("URL cannot be nil")
}
if searchPath == "" {
uCopy := *u
return uCopy.String(), nil
}

uCopy := *u
q := uCopy.Query()
if v := q.Get("search_path"); v != "" {
return "", fmt.Errorf("search_path already set to %q", v)
}
q.Set("search_path", conn.searchPath) // url.Values will percent-encode commas as needed
u.RawQuery = q.Encode()
return u.String(), nil
q.Set("search_path", searchPath)
uCopy.RawQuery = q.Encode()
return uCopy.String(), nil
}

func (c *PostgresqlConnector) Driver() driver.Driver {
return &pq.Driver{}
type postgresqlConnector struct {
connectionStringProvider ConnectionStringProvider
}

type staticConnectionStringProvider struct {
connectionString string
func (c *postgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) {
dsn, err := c.connectionStringProvider.ConnectionString(ctx)
if err != nil {
return nil, fmt.Errorf("error getting connection string from provider: %w", err)
}
pqConnector, err := pq.NewConnector(dsn)
if err != nil {
return nil, fmt.Errorf("error creating pq connector: %w", err)
}

return pqConnector.Connect(ctx)
}

func (p *staticConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) {
return p.connectionString, nil
func (c *postgresqlConnector) Driver() driver.Driver {
return &pq.Driver{}
}

func NewPostgresqlConnectorFromConnectionString(connectionString string) *PostgresqlConnector {
return &PostgresqlConnector{
baseConnectionStringProvider: &staticConnectionStringProvider{connectionString},
// ConnectDB opens a connection using the connector and verifies it with a ping
func ConnectDB(conn driver.Connector) (*sqlx.DB, error) {
sqlDB := sql.OpenDB(conn)
db := sqlx.NewDb(sqlDB, "postgres")
if err := db.Ping(); err != nil {
db.Close()
return nil, err
}
return db, nil
}

type IAMAuthConfig struct {
RDSEndpoint string
User string
Database string

// Optional: cross-account role assumption.
// Set this to a role ARN in the RDS account (Account A) that has rds-db:connect.
AssumeRoleARN string

// Optional: if your trust policy requires an external ID.
AssumeRoleExternalID string

// Optional: override the default session name.
AssumeRoleSessionName string
// MustConnectDB is like ConnectDB but panics on error
func MustConnectDB(conn driver.Connector) *sqlx.DB {
db, err := ConnectDB(conn)
if err != nil {
panic(err)
}
return db
}

// Optional: override STS assume role duration.
// If zero, SDK default is used.
AssumeRoleDuration time.Duration
type staticConnectionStringProvider struct {
connectionString string
}

type iamAuthConnectionStringProvider struct {
IAMAuthConfig
func (p *staticConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) {
return p.connectionString, nil
}

region string
creds aws.CredentialsProvider
type rdsIAMConnectionStringProvider struct {
RDSEndpoint string
Region string
User string
Database string
CredentialsProvider aws.CredentialsProvider
}

func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) {
authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.region, p.User, p.creds)
func (p *rdsIAMConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) {
authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.Region, p.User, p.CredentialsProvider)
if err != nil {
return "", fmt.Errorf("building auth token: %w", err)
return "", fmt.Errorf("error building auth token: %w", err)
}
log.Printf("Signing RDS IAM token for \n Endpoint: %s \n User: %s \n Database: %s", p.RDSEndpoint, p.User, p.Database)
log.Printf("Signing RDS IAM token for Endpoint: %s User: %s Database: %s", p.RDSEndpoint, p.User, p.Database)

dsnURL := &url.URL{
Scheme: "postgresql",
Expand All @@ -134,9 +188,43 @@ func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Co
return dsnURL.String(), nil
}

func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) (*PostgresqlConnector, error) {
if cfg.RDSEndpoint == "" || cfg.User == "" || cfg.Database == "" {
return nil, errors.New("RDS endpoint, user, and database are required")
func newIAMConnectionStringProviderFromURL(ctx context.Context, u *url.URL) (ConnectionStringProvider, error) {
user := ""
if u.User != nil {
user = u.User.Username()
if _, hasPw := u.User.Password(); hasPw {
return nil, fmt.Errorf("postgres+rds-iam URL must not include a password")
}
}
if user == "" {
return nil, fmt.Errorf("postgres+rds-iam URL missing username")
}

host := u.Hostname()
if host == "" {
return nil, fmt.Errorf("postgres+rds-iam URL missing host")
}

port := u.Port()
if port == "" {
port = defaultPostgresPort
}

// Match libpq/psql defaulting: if dbname isn't specified, dbname defaults to username.
dbName := strings.TrimPrefix(u.Path, "/")
if dbName == "" {
dbName = user
}

q := u.Query()
supportedParams := map[string]struct{}{
"assume_role_arn": {},
"assume_role_session_name": {},
}
for k := range q {
if _, ok := supportedParams[k]; !ok {
return nil, fmt.Errorf("postgres+rds-iam URL has unsupported query parameter: %s", k)
}
}

awsCfg, err := awsconfig.LoadDefaultConfig(ctx)
Expand All @@ -145,70 +233,29 @@ func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig)
}

if awsCfg.Region == "" {
return nil, errors.New("AWS region is not configured")
return nil, fmt.Errorf("AWS region is not configured")
}

creds := awsCfg.Credentials

// Cross-account support:
// If AssumeRoleARN is set, assume a role in the RDS account (Account A)
// using the ECS task role creds from Account B as the source credentials.
if cfg.AssumeRoleARN != "" {
log.Printf("RDS IAM Assuming Role: %s for \n Endpoint: %s \n User: %s \n Database: %s", cfg.AssumeRoleARN, cfg.RDSEndpoint, cfg.User, cfg.Database)
assumeRoleARN := q.Get("assume_role_arn")
if assumeRoleARN != "" {
stsClient := sts.NewFromConfig(awsCfg)

sessionName := cfg.AssumeRoleSessionName
sessionName := q.Get("assume_role_session_name")
if sessionName == "" {
sessionName = "pgutils-rds-iam"
}

assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, cfg.AssumeRoleARN, func(assumeRoleOpts *stscreds.AssumeRoleOptions) {
assumeRoleOpts.RoleSessionName = sessionName

if cfg.AssumeRoleExternalID != "" {
assumeRoleOpts.ExternalID = aws.String(cfg.AssumeRoleExternalID)
}

if cfg.AssumeRoleDuration != 0 {
assumeRoleOpts.Duration = cfg.AssumeRoleDuration
}
log.Printf("RDS IAM Assuming Role: %s with session name: %s for Host: %s User: %s Database: %s", assumeRoleARN, sessionName, host, user, dbName)
assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, assumeRoleARN, func(opts *stscreds.AssumeRoleOptions) {
opts.RoleSessionName = sessionName
})

// Cache to avoid calling STS too frequently.
creds = aws.NewCredentialsCache(assumeProvider)
}

return &PostgresqlConnector{
baseConnectionStringProvider: &iamAuthConnectionStringProvider{
IAMAuthConfig: *cfg,
region: awsCfg.Region,
creds: creds,
},
return &rdsIAMConnectionStringProvider{
Region: awsCfg.Region,
RDSEndpoint: net.JoinHostPort(host, port),
User: user,
Database: dbName,
CredentialsProvider: creds,
}, nil
}

// Provides missing sqlx.OpenDB
func OpenDB(conn *PostgresqlConnector) *sqlx.DB {
sqlDB := sql.OpenDB(conn)
return sqlx.NewDb(sqlDB, "postgres")
}

// ConnectDB opens a connection using the connector and verifies it with a ping
func ConnectDB(conn *PostgresqlConnector) (*sqlx.DB, error) {
db := OpenDB(conn)
if err := db.Ping(); err != nil {
db.Close()
return nil, err
}
return db, nil
}

// MustConnectDB is like ConnectDB but panics on error
func MustConnectDB(conn *PostgresqlConnector) *sqlx.DB {
db, err := ConnectDB(conn)
if err != nil {
panic(err)
}
return db
}

Loading