diff --git a/pgutils/connector.go b/pgutils/connector.go index 21dce91..30b314c 100644 --- a/pgutils/connector.go +++ b/pgutils/connector.go @@ -2,11 +2,11 @@ package pgutils import ( "context" - "errors" "fmt" "log" + "net" "net/url" - "time" + "strings" "database/sql" "database/sql/driver" @@ -20,109 +20,163 @@ import ( "github.com/lib/pq" ) -type baseConnectionStringProvider interface { - getBaseConnectionString(ctx context.Context) (string, error) -} - -type PostgresqlConnector struct { - baseConnectionStringProvider - searchPath string -} +const defaultPostgresPort = "5432" -func (conn *PostgresqlConnector) WithSearchPath(searchPath string) *PostgresqlConnector { - return &PostgresqlConnector{ - baseConnectionStringProvider: conn.baseConnectionStringProvider, - searchPath: searchPath, - } +// ConnectionStringProvider returns a Postgres connection string for use by clients +// that need a DSN (e.g., pq.Listener) or to build a connector. +type ConnectionStringProvider interface { + ConnectionString(ctx context.Context) (string, error) } -func (conn *PostgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) { - dsn, err := conn.GetConnectionString(ctx) - if err != nil { - return nil, fmt.Errorf("get connection string: %w", err) +// NewConnectionStringProviderFromURL constructs a ConnectionStringProvider from a URL-form DSN. +// +// Standard Postgres example: +// +// postgres://user:pass@host:5432/dbname?sslmode=require +// +// IAM example 1: +// +// postgres+rds-iam://user@host:5432/dbname +// +// IAM example 2 (cross-account): +// +// postgres+rds-iam://user@host:5432/dbname?assume_role_arn=...&assume_role_session_name=... +// +// For postgres+rds-iam, the provider generates a fresh IAM auth token on each ConnectionString(ctx) call. +func NewConnectionStringProviderFromURL(ctx context.Context, rawURL string) (ConnectionStringProvider, error) { + if strings.TrimSpace(rawURL) == "" { + return nil, fmt.Errorf("rawURL cannot be empty") } - pqConnector, err := pq.NewConnector(dsn) + u, err := url.Parse(rawURL) if err != nil { - return nil, fmt.Errorf("create pq connector: %w", err) + return nil, fmt.Errorf("error parsing URL: %w", err) } - return pqConnector.Connect(ctx) + switch u.Scheme { + case "postgres", "postgresql": + return &staticConnectionStringProvider{connectionString: u.String()}, nil + case "postgres+rds-iam": + return newIAMConnectionStringProviderFromURL(ctx, u) + default: + return nil, fmt.Errorf("unsupported URL scheme: %s", u.Scheme) + } } -func (conn *PostgresqlConnector) GetConnectionString(ctx context.Context) (string, error) { - dsn, err := conn.getBaseConnectionString(ctx) +// NewConnectorFromURL constructs a driver.Connector from a URL-form DSN. +// +// Standard Postgres example: +// +// postgres://user:pass@host:5432/dbname +// +// IAM example 1: +// +// postgres+rds-iam://user@host:5432/dbname +// +// IAM example 2 (cross-account): +// +// postgres+rds-iam://user@host:5432/dbname?assume_role_arn=...&assume_role_session_name=... +// +// For postgres+rds-iam, each Connect(ctx) call uses a fresh IAM auth token. +func NewConnectorFromURL(ctx context.Context, rawURL string) (driver.Connector, error) { + provider, err := NewConnectionStringProviderFromURL(ctx, rawURL) if err != nil { - return "", fmt.Errorf("get base connection string: %w", err) - } - if conn.searchPath == "" { - return dsn, nil + return nil, err } + return &postgresqlConnector{connectionStringProvider: provider}, nil +} - // Add search path - u, err := url.Parse(dsn) +// AddSearchPathToURL returns a copy of u with search_path set in the query string. +// It returns an error if search_path is already present. +func AddSearchPathToURL(rawURL string, searchPath string) (string, error) { + if strings.TrimSpace(rawURL) == "" { + return "", fmt.Errorf("rawURL cannot be empty") + } + u, err := url.Parse(rawURL) if err != nil { - return "", fmt.Errorf("parse DSN URL: %w", err) + return "", fmt.Errorf("error parsing URL: %w", err) } - q := u.Query() + + if u == nil { + return "", fmt.Errorf("URL cannot be nil") + } + if searchPath == "" { + uCopy := *u + return uCopy.String(), nil + } + + uCopy := *u + q := uCopy.Query() if v := q.Get("search_path"); v != "" { return "", fmt.Errorf("search_path already set to %q", v) } - q.Set("search_path", conn.searchPath) // url.Values will percent-encode commas as needed - u.RawQuery = q.Encode() - return u.String(), nil + q.Set("search_path", searchPath) + uCopy.RawQuery = q.Encode() + return uCopy.String(), nil } -func (c *PostgresqlConnector) Driver() driver.Driver { - return &pq.Driver{} +type postgresqlConnector struct { + connectionStringProvider ConnectionStringProvider } -type staticConnectionStringProvider struct { - connectionString string +func (c *postgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) { + dsn, err := c.connectionStringProvider.ConnectionString(ctx) + if err != nil { + return nil, fmt.Errorf("error getting connection string from provider: %w", err) + } + pqConnector, err := pq.NewConnector(dsn) + if err != nil { + return nil, fmt.Errorf("error creating pq connector: %w", err) + } + + return pqConnector.Connect(ctx) } -func (p *staticConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) { - return p.connectionString, nil +func (c *postgresqlConnector) Driver() driver.Driver { + return &pq.Driver{} } -func NewPostgresqlConnectorFromConnectionString(connectionString string) *PostgresqlConnector { - return &PostgresqlConnector{ - baseConnectionStringProvider: &staticConnectionStringProvider{connectionString}, +// ConnectDB opens a connection using the connector and verifies it with a ping +func ConnectDB(conn driver.Connector) (*sqlx.DB, error) { + sqlDB := sql.OpenDB(conn) + db := sqlx.NewDb(sqlDB, "postgres") + if err := db.Ping(); err != nil { + db.Close() + return nil, err } + return db, nil } -type IAMAuthConfig struct { - RDSEndpoint string - User string - Database string - - // Optional: cross-account role assumption. - // Set this to a role ARN in the RDS account (Account A) that has rds-db:connect. - AssumeRoleARN string - - // Optional: if your trust policy requires an external ID. - AssumeRoleExternalID string - - // Optional: override the default session name. - AssumeRoleSessionName string +// MustConnectDB is like ConnectDB but panics on error +func MustConnectDB(conn driver.Connector) *sqlx.DB { + db, err := ConnectDB(conn) + if err != nil { + panic(err) + } + return db +} - // Optional: override STS assume role duration. - // If zero, SDK default is used. - AssumeRoleDuration time.Duration +type staticConnectionStringProvider struct { + connectionString string } -type iamAuthConnectionStringProvider struct { - IAMAuthConfig +func (p *staticConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) { + return p.connectionString, nil +} - region string - creds aws.CredentialsProvider +type rdsIAMConnectionStringProvider struct { + RDSEndpoint string + Region string + User string + Database string + CredentialsProvider aws.CredentialsProvider } -func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) { - authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.region, p.User, p.creds) +func (p *rdsIAMConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) { + authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.Region, p.User, p.CredentialsProvider) if err != nil { - return "", fmt.Errorf("building auth token: %w", err) + return "", fmt.Errorf("error building auth token: %w", err) } - log.Printf("Signing RDS IAM token for \n Endpoint: %s \n User: %s \n Database: %s", p.RDSEndpoint, p.User, p.Database) + log.Printf("Signing RDS IAM token for Endpoint: %s User: %s Database: %s", p.RDSEndpoint, p.User, p.Database) dsnURL := &url.URL{ Scheme: "postgresql", @@ -134,9 +188,43 @@ func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Co return dsnURL.String(), nil } -func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) (*PostgresqlConnector, error) { - if cfg.RDSEndpoint == "" || cfg.User == "" || cfg.Database == "" { - return nil, errors.New("RDS endpoint, user, and database are required") +func newIAMConnectionStringProviderFromURL(ctx context.Context, u *url.URL) (ConnectionStringProvider, error) { + user := "" + if u.User != nil { + user = u.User.Username() + if _, hasPw := u.User.Password(); hasPw { + return nil, fmt.Errorf("postgres+rds-iam URL must not include a password") + } + } + if user == "" { + return nil, fmt.Errorf("postgres+rds-iam URL missing username") + } + + host := u.Hostname() + if host == "" { + return nil, fmt.Errorf("postgres+rds-iam URL missing host") + } + + port := u.Port() + if port == "" { + port = defaultPostgresPort + } + + // Match libpq/psql defaulting: if dbname isn't specified, dbname defaults to username. + dbName := strings.TrimPrefix(u.Path, "/") + if dbName == "" { + dbName = user + } + + q := u.Query() + supportedParams := map[string]struct{}{ + "assume_role_arn": {}, + "assume_role_session_name": {}, + } + for k := range q { + if _, ok := supportedParams[k]; !ok { + return nil, fmt.Errorf("postgres+rds-iam URL has unsupported query parameter: %s", k) + } } awsCfg, err := awsconfig.LoadDefaultConfig(ctx) @@ -145,70 +233,29 @@ func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) } if awsCfg.Region == "" { - return nil, errors.New("AWS region is not configured") + return nil, fmt.Errorf("AWS region is not configured") } creds := awsCfg.Credentials - - // Cross-account support: - // If AssumeRoleARN is set, assume a role in the RDS account (Account A) - // using the ECS task role creds from Account B as the source credentials. - if cfg.AssumeRoleARN != "" { - log.Printf("RDS IAM Assuming Role: %s for \n Endpoint: %s \n User: %s \n Database: %s", cfg.AssumeRoleARN, cfg.RDSEndpoint, cfg.User, cfg.Database) + assumeRoleARN := q.Get("assume_role_arn") + if assumeRoleARN != "" { stsClient := sts.NewFromConfig(awsCfg) - - sessionName := cfg.AssumeRoleSessionName + sessionName := q.Get("assume_role_session_name") if sessionName == "" { sessionName = "pgutils-rds-iam" } - - assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, cfg.AssumeRoleARN, func(assumeRoleOpts *stscreds.AssumeRoleOptions) { - assumeRoleOpts.RoleSessionName = sessionName - - if cfg.AssumeRoleExternalID != "" { - assumeRoleOpts.ExternalID = aws.String(cfg.AssumeRoleExternalID) - } - - if cfg.AssumeRoleDuration != 0 { - assumeRoleOpts.Duration = cfg.AssumeRoleDuration - } + log.Printf("RDS IAM Assuming Role: %s with session name: %s for Host: %s User: %s Database: %s", assumeRoleARN, sessionName, host, user, dbName) + assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, assumeRoleARN, func(opts *stscreds.AssumeRoleOptions) { + opts.RoleSessionName = sessionName }) - - // Cache to avoid calling STS too frequently. creds = aws.NewCredentialsCache(assumeProvider) } - return &PostgresqlConnector{ - baseConnectionStringProvider: &iamAuthConnectionStringProvider{ - IAMAuthConfig: *cfg, - region: awsCfg.Region, - creds: creds, - }, + return &rdsIAMConnectionStringProvider{ + Region: awsCfg.Region, + RDSEndpoint: net.JoinHostPort(host, port), + User: user, + Database: dbName, + CredentialsProvider: creds, }, nil } - -// Provides missing sqlx.OpenDB -func OpenDB(conn *PostgresqlConnector) *sqlx.DB { - sqlDB := sql.OpenDB(conn) - return sqlx.NewDb(sqlDB, "postgres") -} - -// ConnectDB opens a connection using the connector and verifies it with a ping -func ConnectDB(conn *PostgresqlConnector) (*sqlx.DB, error) { - db := OpenDB(conn) - if err := db.Ping(); err != nil { - db.Close() - return nil, err - } - return db, nil -} - -// MustConnectDB is like ConnectDB but panics on error -func MustConnectDB(conn *PostgresqlConnector) *sqlx.DB { - db, err := ConnectDB(conn) - if err != nil { - panic(err) - } - return db -} - diff --git a/pgutils/listener.go b/pgutils/listener.go index 958462c..d1a7d06 100644 --- a/pgutils/listener.go +++ b/pgutils/listener.go @@ -69,7 +69,7 @@ func listenerEventToString(t pq.ListenerEventType) string { // The callback is invoked from the listener goroutine; it MUST NOT block // for long periods. If you need to do heavy work, offload it to another // goroutine. -func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string, callback func(*pq.Notification), onClose func()) error { +func Listen(ctx context.Context, provider ConnectionStringProvider, pgChannelName string, callback func(*pq.Notification), onClose func()) error { if callback == nil { return fmt.Errorf("listener callback cannot be nil") } @@ -77,9 +77,9 @@ func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string reconnectEventCh := make(chan struct{}, 1) // We just need a single reconnect event to trigger, so buffer size of 1 makeListener := func() (*pq.Listener, error) { - url, err := conn.GetConnectionString(ctx) + url, err := provider.ConnectionString(ctx) if err != nil { - return nil, fmt.Errorf("get url: %w", err) + return nil, fmt.Errorf("error getting connection string from provider: %w", err) } cb := func(t pq.ListenerEventType, e error) { @@ -174,4 +174,3 @@ func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string return nil } -