Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
VERSION=9.0.0
1 change: 1 addition & 0 deletions .vscode/dryrun.log
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
make.exe --dry-run --always-make --keep-going --print-directory
'make.exe' is not recognized as an internal or external command,
operable program or batch file.

2 changes: 1 addition & 1 deletion plugin/metashield/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func getClient(endpoint plugin.ShieldEndpoint) (*shield.Client, error) {
return nil, err
}

ca, err := endpoint.StringValue("core_ca_cert")
ca, err := endpoint.StringValueDefault("core_ca_cert", "")
if err != nil {
return nil, err
}
Expand Down
100 changes: 82 additions & 18 deletions plugin/postgres/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"os/exec"
"regexp"
"strings"

fmt "github.com/jhunt/go-ansi"

Expand Down Expand Up @@ -111,6 +112,14 @@ func main() {
Help: "The absolute path to the bin/ directory that contains the `psql` command.",
Default: "/var/vcap/packages/postgres-9.4/bin",
},
plugin.Field{
Mode: "target",
Name: "pg_skip_permission_check",
Type: "bool",
Title: "Skip permission validation",
Help: "Skip upfront permission checking. WARNING: Use only if you understand the risks. Restore may fail with confusing errors if privileges are insufficient.",
Default: "false",
},
},
}

Expand All @@ -120,15 +129,16 @@ func main() {
type PostgresPlugin plugin.PluginInfo

type PostgresConnectionInfo struct {
Host string
Port string
User string
Password string
Bin string
ReplicaHost string
ReplicaPort string
Database string
Options string
Host string
Port string
User string
Password string
Bin string
ReplicaHost string
ReplicaPort string
Database string
Options string
SkipPermissionCheck bool
}

func (p PostgresPlugin) Meta() plugin.PluginInfo {
Expand Down Expand Up @@ -259,6 +269,15 @@ func (p PostgresPlugin) Restore(endpoint plugin.ShieldEndpoint) error {

setupEnvironmentVariables(pg)

// First, check if we have permission issues before starting the restore
if !pg.SkipPermissionCheck {
if err := checkRestorePermissions(pg); err != nil {
return err
}
} else {
plugin.DEBUG("Skipping permission check as requested")
}

cmd := exec.Command(fmt.Sprintf("%s/psql", pg.Bin), "-d", "postgres")
plugin.DEBUG("Exec: %s/psql -d postgres", pg.Bin)
plugin.DEBUG("Redirecting stdout and stderr to stderr")
Expand Down Expand Up @@ -316,6 +335,44 @@ func (p PostgresPlugin) Restore(endpoint plugin.ShieldEndpoint) error {
return <-scanErr
}

// checkRestorePermissions performs upfront permission checks before starting restore
func checkRestorePermissions(pg *PostgresConnectionInfo) error {
plugin.DEBUG("Checking restore permissions...")

// Create a temporary connection to check permissions
// Check if user is superuser or has specific database privileges
cmd := exec.Command(fmt.Sprintf("%s/psql", pg.Bin), "-d", "postgres", "-t", "-A", "-c",
"SELECT CASE WHEN "+
"(SELECT COALESCE(usesuper, false) FROM pg_user WHERE usename = current_user) OR "+
"pg_has_role(current_user, 'rds_superuser', 'MEMBER') OR "+
"(pg_has_role(current_user, 'pg_database_owner', 'MEMBER') AND has_database_privilege(current_user, 'postgres', 'CREATE')) "+
"THEN 'SUFFICIENT' ELSE 'INSUFFICIENT' END;")

cmd.Env = os.Environ()
cmd.Env = append(cmd.Env,
fmt.Sprintf("PGUSER=%s", pg.User),
fmt.Sprintf("PGPASSWORD=%s", pg.Password),
fmt.Sprintf("PGHOST=%s", pg.Host),
fmt.Sprintf("PGPORT=%s", pg.Port),
)

output, err := cmd.Output()
if err != nil {
plugin.DEBUG("Failed to check permissions: %s", err)
return fmt.Errorf("postgres: failed to verify user privileges: %s", err)
}

result := strings.TrimSpace(string(output))
plugin.DEBUG("Permission check result: '%s'", result)

if result != "SUFFICIENT" {
return fmt.Errorf("postgres: insufficient privileges for restore operation. User '%s' needs superuser privileges or database creation rights to safely restore databases", pg.User)
}

plugin.DEBUG("User has sufficient privileges for restore")
return nil
}

func (p PostgresPlugin) Store(endpoint plugin.ShieldEndpoint) (string, int64, error) {
return "", 0, plugin.UNIMPLEMENTED
}
Expand Down Expand Up @@ -392,15 +449,22 @@ func pgConnectionInfo(endpoint plugin.ShieldEndpoint) (*PostgresConnectionInfo,
}
plugin.DEBUG("PGBINDIR: '%s'", bin)

skipCheck, err := endpoint.BooleanValueDefault("pg_skip_permission_check", false)
if err != nil {
return nil, err
}
plugin.DEBUG("PG_SKIP_PERMISSION_CHECK: %t", skipCheck)

return &PostgresConnectionInfo{
Host: host,
Port: port,
User: user,
Password: password,
ReplicaHost: replicahost,
ReplicaPort: replicaport,
Bin: bin,
Database: database,
Options: options,
Host: host,
Port: port,
User: user,
Password: password,
ReplicaHost: replicahost,
ReplicaPort: replicaport,
Bin: bin,
Database: database,
Options: options,
SkipPermissionCheck: skipCheck,
}, nil
}
96 changes: 70 additions & 26 deletions plugin/s3/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ package main

import (
"encoding/json"
"errors"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
Expand All @@ -27,6 +25,9 @@ const (
DefaultSkipSSLValidation = false
DefaultUseInstanceProfiles = false
credentialsEndpoint = "http://169.254.169.254/latest/meta-data/iam/security-credentials"
// IMDSv2 endpoints
imdsTokenEndpoint = "http://169.254.169.254/latest/api/token"
imdsTokenTTL = "21600" // 6 hours in seconds
)

func validSigVersion(v string) bool {
Expand Down Expand Up @@ -63,7 +64,7 @@ func validBucketName(v string) bool {
}

func clientUsesPathBuckets(err error) bool {
return !strings.Contains(err.Error(), "301 response missing Location header")
return !(strings.Contains(err.Error(), "301 response missing Location header") || strings.Contains(err.Error(), "Please send all future requests to this endpoint"))
}

func main() {
Expand Down Expand Up @@ -636,37 +637,80 @@ func (e s3Endpoint) Connect() (*s3.Client, error) {
}

func getInstanceProfileCredentials() (instanceProfileCredentials, error) {
response, connectErr := http.Get(fmt.Sprintf("%s/", credentialsEndpoint))
if connectErr != nil {
return instanceProfileCredentials{}, connectErr
} else if response.StatusCode != 200 {
return instanceProfileCredentials{}, errors.New(fmt.Sprintf("Connection request to %s/ failed with code %d", credentialsEndpoint, response.StatusCode))
var creds instanceProfileCredentials

// Step 1: Get IMDSv2 token
tokenReq, err := http.NewRequest("PUT", imdsTokenEndpoint, nil)
if err != nil {
return creds, fmt.Errorf("failed to create token request: %v", err)
}
tokenReq.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", imdsTokenTTL)

body, readErr := ioutil.ReadAll(response.Body)
if readErr != nil {
return instanceProfileCredentials{}, readErr
client := &http.Client{Timeout: 10 * time.Second}
tokenResp, err := client.Do(tokenReq)
if err != nil {
return creds, fmt.Errorf("failed to get IMDSv2 token: %v", err)
}
role := string(body)
response.Body.Close()
defer tokenResp.Body.Close()

var creds instanceProfileCredentials
response, connectErr = http.Get(fmt.Sprintf("%s/%s", credentialsEndpoint, role))
if connectErr != nil {
return instanceProfileCredentials{}, connectErr
} else if response.StatusCode != 200 {
return instanceProfileCredentials{}, errors.New(fmt.Sprintf("Connection request to %s/%s failed with code %d", credentialsEndpoint, role, response.StatusCode))
if tokenResp.StatusCode != 200 {
return creds, fmt.Errorf("failed to get IMDSv2 token, status: %d", tokenResp.StatusCode)
}
defer response.Body.Close()

body, readErr = ioutil.ReadAll(response.Body)
if readErr != nil {
return instanceProfileCredentials{}, readErr
tokenBytes, err := io.ReadAll(tokenResp.Body)
if err != nil {
return creds, fmt.Errorf("failed to read IMDSv2 token: %v", err)
}
token := strings.TrimSpace(string(tokenBytes))

unmarshallErr := json.Unmarshal(body, &creds)
if unmarshallErr != nil {
return instanceProfileCredentials{}, unmarshallErr
// Step 2: Get IAM role name using IMDSv2 token
roleReq, err := http.NewRequest("GET", credentialsEndpoint, nil)
if err != nil {
return creds, fmt.Errorf("failed to create role request: %v", err)
}
roleReq.Header.Set("X-aws-ec2-metadata-token", token)

roleResp, err := client.Do(roleReq)
if err != nil {
return creds, fmt.Errorf("failed to get IAM role: %v", err)
}
defer roleResp.Body.Close()

if roleResp.StatusCode != 200 {
return creds, fmt.Errorf("failed to get IAM role, status: %d", roleResp.StatusCode)
}

roleBytes, err := io.ReadAll(roleResp.Body)
if err != nil {
return creds, fmt.Errorf("failed to read IAM role: %v", err)
}
roleName := strings.TrimSpace(string(roleBytes))

// Step 3: Get credentials using the role name and IMDSv2 token
credReq, err := http.NewRequest("GET", credentialsEndpoint+"/"+roleName, nil)
if err != nil {
return creds, fmt.Errorf("failed to create credentials request: %v", err)
}
credReq.Header.Set("X-aws-ec2-metadata-token", token)

credResp, err := client.Do(credReq)
if err != nil {
return creds, fmt.Errorf("failed to get credentials: %v", err)
}
defer credResp.Body.Close()

if credResp.StatusCode != 200 {
return creds, fmt.Errorf("failed to get credentials, status: %d", credResp.StatusCode)
}

credBytes, err := io.ReadAll(credResp.Body)
if err != nil {
return creds, fmt.Errorf("failed to read credentials: %v", err)
}

err = json.Unmarshal(credBytes, &creds)
if err != nil {
return creds, fmt.Errorf("failed to parse credentials JSON: %v", err)
}

return creds, nil
Expand Down
Loading