diff --git a/auth/mtls/client.go b/auth/mtls/client.go index e5d5aade..c339fbae 100644 --- a/auth/mtls/client.go +++ b/auth/mtls/client.go @@ -28,8 +28,9 @@ import ( ) // LoadClientCredentials returns transport credentials for SansShell clients, -// based on the provided `loaderName` -func LoadClientCredentials(ctx context.Context, loaderName string) (credentials.TransportCredentials, error) { +// based on the provided `loaderName`. If serverName is non-empty, it will be +// set on the TLS configuration and persisted across credential refreshes. +func LoadClientCredentials(ctx context.Context, loaderName string, serverName string) (credentials.TransportCredentials, error) { logger := logr.FromContextOrDiscard(ctx) recorder := metrics.RecorderFromContextOrNoop(ctx) mtlsLoader, err := Loader(loaderName) @@ -48,6 +49,11 @@ func LoadClientCredentials(ctx context.Context, loaderName string) (credentials. logger: logger, recorder: recorder, } + if serverName != "" { + if err := wrapped.OverrideServerName(serverName); err != nil { //nolint:staticcheck + return nil, fmt.Errorf("could not set server name: %w", err) + } + } return wrapped, nil } @@ -68,25 +74,30 @@ func internalLoadClientCredentials(ctx context.Context, loaderName string) (cred return nil, err } logger.Info("loaded new client cert", "error", err) - return NewClientCredentials(cert, pool), nil + return NewClientCredentials(cert, pool, ""), nil } // NewClientCredentials returns transport credentials for SansShell clients. -func NewClientCredentials(cert tls.Certificate, CAPool *x509.CertPool) credentials.TransportCredentials { - return credentials.NewTLS(&tls.Config{ +func NewClientCredentials(cert tls.Certificate, CAPool *x509.CertPool, serverName string) credentials.TransportCredentials { + config := &tls.Config{ Certificates: []tls.Certificate{cert}, RootCAs: CAPool, MinVersion: tls.VersionTLS12, - }) + } + + if serverName != "" { + config.ServerName = serverName + } + + return credentials.NewTLS(config) } // LoadClientTLS reads the certificates and keys from disk at the supplied paths, // and assembles them into a set of TransportCredentials for the gRPC client. -func LoadClientTLS(clientCertFile, clientKeyFile string, CAPool *x509.CertPool) (credentials.TransportCredentials, error) { - // Read in client credentials +func LoadClientTLS(clientCertFile, clientKeyFile string, CAPool *x509.CertPool, serverName string) (credentials.TransportCredentials, error) { cert, err := tls.LoadX509KeyPair(clientCertFile, clientKeyFile) if err != nil { return nil, fmt.Errorf("could not read client credentials: %w", err) } - return NewClientCredentials(cert, CAPool), nil + return NewClientCredentials(cert, CAPool, serverName), nil } diff --git a/auth/mtls/mtls_test.go b/auth/mtls/mtls_test.go index c42b08ef..ce3a4a25 100644 --- a/auth/mtls/mtls_test.go +++ b/auth/mtls/mtls_test.go @@ -211,12 +211,12 @@ func TestLoadClientTLS(t *testing.T) { testutil.FatalOnErr("Failed to load root CA", err, t) // Make sure this errors if we pass bad data like reversing things. - _, err = LoadClientTLS("testdata/client.key", "testdata/client.pem", CAPool) + _, err = LoadClientTLS("testdata/client.key", "testdata/client.pem", CAPool, "") t.Log(err) testutil.FatalOnNoErr("bad TLS client data", err, t) // Also that it works on correct input. - _, err = LoadClientTLS("testdata/client.pem", "testdata/client.key", CAPool) + _, err = LoadClientTLS("testdata/client.pem", "testdata/client.key", CAPool, "") testutil.FatalOnErr("tls client data", err, t) } @@ -279,7 +279,7 @@ func TestLoadClientServerCredentials(t *testing.T) { err = server.OverrideServerName("server") //nolint:staticcheck testutil.FatalOnErr("OverrideServerName", err, t) } - client, err := LoadClientCredentials(context.Background(), tc.loader) + client, err := LoadClientCredentials(context.Background(), tc.loader, "") testutil.WantErr("client", err, tc.wantErr, t) if !tc.wantErr { err = client.OverrideServerName("server") //nolint:staticcheck @@ -317,7 +317,7 @@ func TestHealthCheck(t *testing.T) { unregisterAll() err = Register("refresh", &simpleLoader{name: "refresh"}) testutil.FatalOnErr("Register", err, t) - creds, err := LoadClientCredentials(ctx, "refresh") + creds, err := LoadClientCredentials(ctx, "refresh", "") testutil.FatalOnErr("Failed to load client cert", err, t) err = creds.OverrideServerName("bufnet") //nolint:staticcheck testutil.FatalOnErr("OverrideServerName", err, t) diff --git a/cmd/proxy-server/server/server.go b/cmd/proxy-server/server/server.go index e9424ba8..2d6c259f 100644 --- a/cmd/proxy-server/server/server.go +++ b/cmd/proxy-server/server/server.go @@ -56,6 +56,7 @@ type runState struct { policy rpcauth.AuthzPolicy clientPolicy rpcauth.AuthzPolicy credSource string + serverName string tlsConfig *tls.Config hostport string debugport string @@ -126,6 +127,14 @@ func WithCredSource(credSource string) Option { }) } +// WithServerName overrides the TLS server name used for client certificate verification. +func WithServerName(serverName string) Option { + return optionFunc(func(_ context.Context, r *runState) error { + r.serverName = serverName + return nil + }) +} + // WithHostport applies the host:port to run the server. func WithHostPort(hostport string) Option { return optionFunc(func(_ context.Context, r *runState) error { @@ -464,7 +473,7 @@ func extractClientTransportCredentialsFromRunState(ctx context.Context, rs *runS return nil, fmt.Errorf("both credSource and tlsConfig are defined for the client") } if rs.credSource != "" { - creds, err = mtls.LoadClientCredentials(ctx, rs.credSource) + creds, err = mtls.LoadClientCredentials(ctx, rs.credSource, rs.serverName) if err != nil { return nil, err } diff --git a/cmd/sanssh/client/client.go b/cmd/sanssh/client/client.go index 3b49b4a9..11507bdd 100644 --- a/cmd/sanssh/client/client.go +++ b/cmd/sanssh/client/client.go @@ -68,6 +68,8 @@ type RunState struct { OutputsDir string // CredSource is a registered credential source with the mtls package. CredSource string + // ServerName overrides the TLS server name used for certificate verification. + ServerName string // IdleTimeout is the time duration to wait before closing an idle connection. // If no messages are sent/received within this timeframe, connection will be terminated. IdleTimeout time.Duration @@ -267,7 +269,7 @@ func Run(ctx context.Context, rs RunState) { os.Exit(1) } } - creds, err := mtls.LoadClientCredentials(ctx, rs.CredSource) + creds, err := mtls.LoadClientCredentials(ctx, rs.CredSource, rs.ServerName) if err != nil { fmt.Fprintf(os.Stderr, "Could not load creds from %s - %v\n", rs.CredSource, err) os.Exit(1) diff --git a/cmd/sanssh/main.go b/cmd/sanssh/main.go index a209174e..74c0d2a5 100644 --- a/cmd/sanssh/main.go +++ b/cmd/sanssh/main.go @@ -90,6 +90,7 @@ If port is blank the default of %d will be used`, proxyEnv, defaultProxyPort)) batchSize = flag.Int("batch-size", 0, "If non-zero will perform the proxy->target work in batches of this size (with any remainder done at the end).") mpa = flag.Bool("mpa", false, "Request multi-party approval for commands. This will create an MPA request, wait for approval, and then execute the command.") authzDryRun = flag.Bool("authz-dry-run", false, "If true, the client will send a request to the server to check if the user has the permission to run the command. The server will respond with a success or failure message.") + serverName = flag.String("server-name", "", "If non-empty, overrides the TLS server name used for certificate verification.") // targets will be bound to --targets for sending a single request to N nodes. targetsFlag util.StringSliceCommaOrWhitespaceFlag @@ -215,6 +216,7 @@ func main() { AuthzDryRun: *authzDryRun, OutputsDir: *outputsDir, CredSource: *credSource, + ServerName: *serverName, IdleTimeout: *idleTimeout, ClientAuthzPolicy: clientPolicy, PrefixOutput: *prefixHeader, diff --git a/services/mpa/mpahooks/mpahooks_test.go b/services/mpa/mpahooks/mpahooks_test.go index f9f49294..4142d772 100644 --- a/services/mpa/mpahooks/mpahooks_test.go +++ b/services/mpa/mpahooks/mpahooks_test.go @@ -258,11 +258,11 @@ func TestClientInterceptors(t *testing.T) { }() defer s.GracefulStop() - clientCreds, err := mtls.LoadClientTLS("../../../auth/mtls/testdata/client.pem", "../../../auth/mtls/testdata/client.key", rot) + clientCreds, err := mtls.LoadClientTLS("../../../auth/mtls/testdata/client.pem", "../../../auth/mtls/testdata/client.key", rot, "") if err != nil { t.Fatal(err) } - approverCreds, err := mtls.LoadClientTLS("../testdata/approver.pem", "../testdata/approver.key", rot) + approverCreds, err := mtls.LoadClientTLS("../testdata/approver.pem", "../testdata/approver.key", rot, "") if err != nil { t.Fatal(err) } @@ -446,7 +446,7 @@ func TestProxiedClientInterceptors(t *testing.T) { if err != nil { t.Fatal(err) } - proxyClientCreds, err := mtls.LoadClientTLS("../testdata/proxy.pem", "../testdata/proxy.key", rot) + proxyClientCreds, err := mtls.LoadClientTLS("../testdata/proxy.pem", "../testdata/proxy.key", rot, "") if err != nil { t.Fatal(err) } @@ -472,11 +472,11 @@ func TestProxiedClientInterceptors(t *testing.T) { }() defer proxySrv.GracefulStop() - clientCreds, err := mtls.LoadClientTLS("../../../auth/mtls/testdata/client.pem", "../../../auth/mtls/testdata/client.key", rot) + clientCreds, err := mtls.LoadClientTLS("../../../auth/mtls/testdata/client.pem", "../../../auth/mtls/testdata/client.key", rot, "") if err != nil { t.Fatal(err) } - approverCreds, err := mtls.LoadClientTLS("../testdata/approver.pem", "../testdata/approver.key", rot) + approverCreds, err := mtls.LoadClientTLS("../testdata/approver.pem", "../testdata/approver.key", rot, "") if err != nil { t.Fatal(err) }