Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions auth/mtls/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ import (
)

// LoadClientCredentials returns transport credentials for SansShell clients,
// based on the provided `loaderName`
func LoadClientCredentials(ctx context.Context, loaderName string) (credentials.TransportCredentials, error) {
// based on the provided `loaderName`. If serverName is non-empty, it will be
// set on the TLS configuration and persisted across credential refreshes.
func LoadClientCredentials(ctx context.Context, loaderName string, serverName string) (credentials.TransportCredentials, error) {
logger := logr.FromContextOrDiscard(ctx)
recorder := metrics.RecorderFromContextOrNoop(ctx)
mtlsLoader, err := Loader(loaderName)
Expand All @@ -48,6 +49,11 @@ func LoadClientCredentials(ctx context.Context, loaderName string) (credentials.
logger: logger,
recorder: recorder,
}
if serverName != "" {
if err := wrapped.OverrideServerName(serverName); err != nil { //nolint:staticcheck
return nil, fmt.Errorf("could not set server name: %w", err)
}
}
return wrapped, nil
}

Expand All @@ -68,25 +74,30 @@ func internalLoadClientCredentials(ctx context.Context, loaderName string) (cred
return nil, err
}
logger.Info("loaded new client cert", "error", err)
return NewClientCredentials(cert, pool), nil
return NewClientCredentials(cert, pool, ""), nil
}

// NewClientCredentials returns transport credentials for SansShell clients.
func NewClientCredentials(cert tls.Certificate, CAPool *x509.CertPool) credentials.TransportCredentials {
return credentials.NewTLS(&tls.Config{
func NewClientCredentials(cert tls.Certificate, CAPool *x509.CertPool, serverName string) credentials.TransportCredentials {
config := &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: CAPool,
MinVersion: tls.VersionTLS12,
})
}

if serverName != "" {
config.ServerName = serverName
}

return credentials.NewTLS(config)
}

// LoadClientTLS reads the certificates and keys from disk at the supplied paths,
// and assembles them into a set of TransportCredentials for the gRPC client.
func LoadClientTLS(clientCertFile, clientKeyFile string, CAPool *x509.CertPool) (credentials.TransportCredentials, error) {
// Read in client credentials
func LoadClientTLS(clientCertFile, clientKeyFile string, CAPool *x509.CertPool, serverName string) (credentials.TransportCredentials, error) {
cert, err := tls.LoadX509KeyPair(clientCertFile, clientKeyFile)
if err != nil {
return nil, fmt.Errorf("could not read client credentials: %w", err)
}
return NewClientCredentials(cert, CAPool), nil
return NewClientCredentials(cert, CAPool, serverName), nil
}
8 changes: 4 additions & 4 deletions auth/mtls/mtls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,12 @@ func TestLoadClientTLS(t *testing.T) {
testutil.FatalOnErr("Failed to load root CA", err, t)

// Make sure this errors if we pass bad data like reversing things.
_, err = LoadClientTLS("testdata/client.key", "testdata/client.pem", CAPool)
_, err = LoadClientTLS("testdata/client.key", "testdata/client.pem", CAPool, "")
t.Log(err)
testutil.FatalOnNoErr("bad TLS client data", err, t)

// Also that it works on correct input.
_, err = LoadClientTLS("testdata/client.pem", "testdata/client.key", CAPool)
_, err = LoadClientTLS("testdata/client.pem", "testdata/client.key", CAPool, "")
testutil.FatalOnErr("tls client data", err, t)
}

Expand Down Expand Up @@ -279,7 +279,7 @@ func TestLoadClientServerCredentials(t *testing.T) {
err = server.OverrideServerName("server") //nolint:staticcheck
testutil.FatalOnErr("OverrideServerName", err, t)
}
client, err := LoadClientCredentials(context.Background(), tc.loader)
client, err := LoadClientCredentials(context.Background(), tc.loader, "")
testutil.WantErr("client", err, tc.wantErr, t)
if !tc.wantErr {
err = client.OverrideServerName("server") //nolint:staticcheck
Expand Down Expand Up @@ -317,7 +317,7 @@ func TestHealthCheck(t *testing.T) {
unregisterAll()
err = Register("refresh", &simpleLoader{name: "refresh"})
testutil.FatalOnErr("Register", err, t)
creds, err := LoadClientCredentials(ctx, "refresh")
creds, err := LoadClientCredentials(ctx, "refresh", "")
testutil.FatalOnErr("Failed to load client cert", err, t)
err = creds.OverrideServerName("bufnet") //nolint:staticcheck
testutil.FatalOnErr("OverrideServerName", err, t)
Expand Down
11 changes: 10 additions & 1 deletion cmd/proxy-server/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type runState struct {
policy rpcauth.AuthzPolicy
clientPolicy rpcauth.AuthzPolicy
credSource string
serverName string
tlsConfig *tls.Config
hostport string
debugport string
Expand Down Expand Up @@ -126,6 +127,14 @@ func WithCredSource(credSource string) Option {
})
}

// WithServerName overrides the TLS server name used for client certificate verification.
func WithServerName(serverName string) Option {
return optionFunc(func(_ context.Context, r *runState) error {
r.serverName = serverName
return nil
})
}

// WithHostport applies the host:port to run the server.
func WithHostPort(hostport string) Option {
return optionFunc(func(_ context.Context, r *runState) error {
Expand Down Expand Up @@ -464,7 +473,7 @@ func extractClientTransportCredentialsFromRunState(ctx context.Context, rs *runS
return nil, fmt.Errorf("both credSource and tlsConfig are defined for the client")
}
if rs.credSource != "" {
creds, err = mtls.LoadClientCredentials(ctx, rs.credSource)
creds, err = mtls.LoadClientCredentials(ctx, rs.credSource, rs.serverName)
if err != nil {
return nil, err
}
Expand Down
4 changes: 3 additions & 1 deletion cmd/sanssh/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ type RunState struct {
OutputsDir string
// CredSource is a registered credential source with the mtls package.
CredSource string
// ServerName overrides the TLS server name used for certificate verification.
ServerName string
// IdleTimeout is the time duration to wait before closing an idle connection.
// If no messages are sent/received within this timeframe, connection will be terminated.
IdleTimeout time.Duration
Expand Down Expand Up @@ -267,7 +269,7 @@ func Run(ctx context.Context, rs RunState) {
os.Exit(1)
}
}
creds, err := mtls.LoadClientCredentials(ctx, rs.CredSource)
creds, err := mtls.LoadClientCredentials(ctx, rs.CredSource, rs.ServerName)
if err != nil {
fmt.Fprintf(os.Stderr, "Could not load creds from %s - %v\n", rs.CredSource, err)
os.Exit(1)
Expand Down
2 changes: 2 additions & 0 deletions cmd/sanssh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ If port is blank the default of %d will be used`, proxyEnv, defaultProxyPort))
batchSize = flag.Int("batch-size", 0, "If non-zero will perform the proxy->target work in batches of this size (with any remainder done at the end).")
mpa = flag.Bool("mpa", false, "Request multi-party approval for commands. This will create an MPA request, wait for approval, and then execute the command.")
authzDryRun = flag.Bool("authz-dry-run", false, "If true, the client will send a request to the server to check if the user has the permission to run the command. The server will respond with a success or failure message.")
serverName = flag.String("server-name", "", "If non-empty, overrides the TLS server name used for certificate verification.")

// targets will be bound to --targets for sending a single request to N nodes.
targetsFlag util.StringSliceCommaOrWhitespaceFlag
Expand Down Expand Up @@ -215,6 +216,7 @@ func main() {
AuthzDryRun: *authzDryRun,
OutputsDir: *outputsDir,
CredSource: *credSource,
ServerName: *serverName,
IdleTimeout: *idleTimeout,
ClientAuthzPolicy: clientPolicy,
PrefixOutput: *prefixHeader,
Expand Down
10 changes: 5 additions & 5 deletions services/mpa/mpahooks/mpahooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,11 @@ func TestClientInterceptors(t *testing.T) {
}()
defer s.GracefulStop()

clientCreds, err := mtls.LoadClientTLS("../../../auth/mtls/testdata/client.pem", "../../../auth/mtls/testdata/client.key", rot)
clientCreds, err := mtls.LoadClientTLS("../../../auth/mtls/testdata/client.pem", "../../../auth/mtls/testdata/client.key", rot, "")
if err != nil {
t.Fatal(err)
}
approverCreds, err := mtls.LoadClientTLS("../testdata/approver.pem", "../testdata/approver.key", rot)
approverCreds, err := mtls.LoadClientTLS("../testdata/approver.pem", "../testdata/approver.key", rot, "")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -446,7 +446,7 @@ func TestProxiedClientInterceptors(t *testing.T) {
if err != nil {
t.Fatal(err)
}
proxyClientCreds, err := mtls.LoadClientTLS("../testdata/proxy.pem", "../testdata/proxy.key", rot)
proxyClientCreds, err := mtls.LoadClientTLS("../testdata/proxy.pem", "../testdata/proxy.key", rot, "")
if err != nil {
t.Fatal(err)
}
Expand All @@ -472,11 +472,11 @@ func TestProxiedClientInterceptors(t *testing.T) {
}()
defer proxySrv.GracefulStop()

clientCreds, err := mtls.LoadClientTLS("../../../auth/mtls/testdata/client.pem", "../../../auth/mtls/testdata/client.key", rot)
clientCreds, err := mtls.LoadClientTLS("../../../auth/mtls/testdata/client.pem", "../../../auth/mtls/testdata/client.key", rot, "")
if err != nil {
t.Fatal(err)
}
approverCreds, err := mtls.LoadClientTLS("../testdata/approver.pem", "../testdata/approver.key", rot)
approverCreds, err := mtls.LoadClientTLS("../testdata/approver.pem", "../testdata/approver.key", rot, "")
if err != nil {
t.Fatal(err)
}
Expand Down
Loading