Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 40 additions & 5 deletions cmd/proxy-server/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ type runState struct {
authzHooks []rpcauth.RPCAuthzHook
services []func(*grpc.Server)
metricsRecorder metrics.MetricsRecorder
namedCredSources map[string]string // hint name -> mtls loader name
}

type Option interface {
Expand Down Expand Up @@ -308,6 +309,20 @@ func WithOtelTracing(interceptorOpts ...otelgrpc.Option) Option {
})
}

// WithNamedClientCredSource registers an additional client credential source
// that the proxy can use when a client sends a matching force_credential in
// StartStream. hintName is the value clients will send; credSource is the name
// registered with the mtls package for loading client credentials.
func WithNamedClientCredSource(hintName, credSource string) Option {
return optionFunc(func(_ context.Context, r *runState) error {
if r.namedCredSources == nil {
r.namedCredSources = make(map[string]string)
}
r.namedCredSources[hintName] = credSource
return nil
})
}

// Run takes the given context and RunState along with any authz hooks and starts up a sansshell proxy server
// using the flags above to provide credentials. An address hook (based on the remote host) with always be added.
// As this is intended to be called from main() it doesn't return errors and will instead exit on any errors.
Expand Down Expand Up @@ -383,21 +398,23 @@ func Run(ctx context.Context, opts ...Option) {
unaryClient = append(unaryClient, clientAuthz.AuthorizeClient)
streamClient = append(streamClient, clientAuthz.AuthorizeClientStream)
}
dialOpts := []grpc.DialOption{
grpc.WithTransportCredentials(clientCreds),
sharedDialOpts := []grpc.DialOption{
grpc.WithChainUnaryInterceptor(unaryClient...),
grpc.WithChainStreamInterceptor(streamClient...),
// Use 16MB instead of the default 4MB to allow larger responses
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(16 * 1024 * 1024)),
}
if rs.statsClientHandler != nil {
dialOpts = append(dialOpts, grpc.WithStatsHandler(rs.statsClientHandler))
sharedDialOpts = append(sharedDialOpts, grpc.WithStatsHandler(rs.statsClientHandler))
}
targetDialer := server.NewDialer(dialOpts...)
defaultDialOpts := append([]grpc.DialOption{grpc.WithTransportCredentials(clientCreds)}, sharedDialOpts...)
targetDialer := server.NewDialer(defaultDialOpts...)

dialers := buildDialers(ctx, rs, targetDialer, sharedDialOpts)

svcMap := server.LoadGlobalServiceMap()
rs.logger.Info("loaded service map", "serviceMap", svcMap)
server := server.New(targetDialer, authz)
server := server.NewWithDialersAndServiceMap(dialers, authz, svcMap)

// Even though the proxy RPC is streaming we have unary RPCs (logging, reflection) we
// also need to properly auth and log.
Expand Down Expand Up @@ -456,6 +473,24 @@ func Run(ctx context.Context, opts ...Option) {
}
}

// buildDialers constructs the named dialers map from the default dialer and
// any additional credential sources registered via WithNamedClientCredSource.
// If a named source fails to load, it is logged and skipped; the default
// dialer is always present under key "".
func buildDialers(ctx context.Context, rs *runState, defaultDialer server.TargetDialer, sharedDialOpts []grpc.DialOption) map[string]server.TargetDialer {
dialers := map[string]server.TargetDialer{"": defaultDialer}
for hint, src := range rs.namedCredSources {
creds, err := mtls.LoadClientCredentials(ctx, src)
if err != nil {
rs.logger.Error(err, "failed to load named client cred source, skipping", "hint", hint, "source", src)
continue
}
hintDialOpts := append([]grpc.DialOption{grpc.WithTransportCredentials(creds)}, sharedDialOpts...)
dialers[hint] = server.NewDialer(hintDialOpts...)
}
return dialers
}

// extractClientTransportCredentialsFromRunState extracts transport credentials from runState. Will error if both credSource and tlsConfig are specified
func extractClientTransportCredentialsFromRunState(ctx context.Context, rs *runState) (credentials.TransportCredentials, error) {
var creds credentials.TransportCredentials
Expand Down
191 changes: 191 additions & 0 deletions cmd/proxy-server/server/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/* Copyright (c) 2019 Snowflake Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
*/

package server

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"os"
"testing"

"github.com/go-logr/logr"
"google.golang.org/grpc"

"github.com/Snowflake-Labs/sansshell/auth/mtls"
proxyserver "github.com/Snowflake-Labs/sansshell/proxy/server"
)

const (
failLoaderName = "test-buildDialers-fail"
okLoaderName = "test-buildDialers-ok"
)

func TestMain(m *testing.M) {
if err := mtls.Register(failLoaderName, failingLoader{}); err != nil {
fmt.Fprintf(os.Stderr, "mtls.Register(%s): %v\n", failLoaderName, err)
os.Exit(1)
}
if err := mtls.Register(okLoaderName, successLoader{}); err != nil {
fmt.Fprintf(os.Stderr, "mtls.Register(%s): %v\n", okLoaderName, err)
os.Exit(1)
}
os.Exit(m.Run())
}

type fakeDialer struct{}

func (fakeDialer) DialContext(_ context.Context, _ string, _ ...grpc.DialOption) (proxyserver.ClientConnCloser, error) {
return nil, errors.New("not implemented")
}

func TestWithNamedClientCredSourceSingle(t *testing.T) {
rs := &runState{}
opt := WithNamedClientCredSource("pg", "some-loader")
if err := opt.apply(context.Background(), rs); err != nil {
t.Fatalf("apply: %v", err)
}
if got, ok := rs.namedCredSources["pg"]; !ok || got != "some-loader" {
t.Fatalf("expected namedCredSources[\"pg\"] = \"some-loader\", got %q (ok=%v)", got, ok)
}
}

func TestWithNamedClientCredSourceMultiple(t *testing.T) {
rs := &runState{}
for _, pair := range []struct{ hint, src string }{
{"pg", "loader-a"},
{"redis", "loader-b"},
} {
if err := WithNamedClientCredSource(pair.hint, pair.src).apply(context.Background(), rs); err != nil {
t.Fatalf("apply(%q): %v", pair.hint, err)
}
}
if len(rs.namedCredSources) != 2 {
t.Fatalf("expected 2 entries, got %d", len(rs.namedCredSources))
}
if rs.namedCredSources["pg"] != "loader-a" {
t.Fatalf("pg: got %q", rs.namedCredSources["pg"])
}
if rs.namedCredSources["redis"] != "loader-b" {
t.Fatalf("redis: got %q", rs.namedCredSources["redis"])
}
}

func TestWithNamedClientCredSourceOverwrite(t *testing.T) {
rs := &runState{}
if err := WithNamedClientCredSource("pg", "old").apply(context.Background(), rs); err != nil {
t.Fatal(err)
}
if err := WithNamedClientCredSource("pg", "new").apply(context.Background(), rs); err != nil {
t.Fatal(err)
}
if rs.namedCredSources["pg"] != "new" {
t.Fatalf("expected overwrite to \"new\", got %q", rs.namedCredSources["pg"])
}
}

// --- buildDialers tests ---

type failingLoader struct{}

func (failingLoader) LoadClientCA(context.Context) (*x509.CertPool, error) {
return nil, errors.New("no CA")
}
func (failingLoader) LoadRootCA(context.Context) (*x509.CertPool, error) {
return nil, errors.New("no root CA")
}
func (failingLoader) LoadClientCertificate(context.Context) (tls.Certificate, error) {
return tls.Certificate{}, errors.New("no client cert")
}
func (failingLoader) LoadServerCertificate(context.Context) (tls.Certificate, error) {
return tls.Certificate{}, errors.New("no server cert")
}
func (failingLoader) CertsRefreshed() bool { return false }
func (failingLoader) GetClientCertInfo(context.Context, string) (*mtls.ClientCertInfo, error) {
return nil, nil
}

type successLoader struct{}

func (successLoader) LoadClientCA(context.Context) (*x509.CertPool, error) {
return x509.NewCertPool(), nil
}
func (successLoader) LoadRootCA(context.Context) (*x509.CertPool, error) {
return x509.NewCertPool(), nil
}
func (successLoader) LoadClientCertificate(context.Context) (tls.Certificate, error) {
return tls.Certificate{}, nil
}
func (successLoader) LoadServerCertificate(context.Context) (tls.Certificate, error) {
return tls.Certificate{}, nil
}
func (successLoader) CertsRefreshed() bool { return false }
func (successLoader) GetClientCertInfo(context.Context, string) (*mtls.ClientCertInfo, error) {
return nil, nil
}

func TestBuildDialersDefaultOnly(t *testing.T) {
rs := &runState{logger: logr.Discard()}
dialers := buildDialers(context.Background(), rs, fakeDialer{}, nil)
if _, ok := dialers[""]; !ok {
t.Fatal("expected default dialer under key \"\"")
}
if len(dialers) != 1 {
t.Fatalf("expected 1 dialer, got %d", len(dialers))
}
}

func TestBuildDialersSkipsFailedCredSources(t *testing.T) {
rs := &runState{
logger: logr.Discard(),
namedCredSources: map[string]string{"bad-hint": failLoaderName},
}
dialers := buildDialers(context.Background(), rs, fakeDialer{}, nil)

if _, ok := dialers[""]; !ok {
t.Fatal("expected default dialer under key \"\"")
}
if _, ok := dialers["bad-hint"]; ok {
t.Fatal("expected failed hint to be absent from dialers map")
}
if len(dialers) != 1 {
t.Fatalf("expected 1 dialer (default only), got %d", len(dialers))
}
}

func TestBuildDialersRegistersSuccessfulHint(t *testing.T) {
rs := &runState{
logger: logr.Discard(),
namedCredSources: map[string]string{"good-hint": okLoaderName},
}
shared := []grpc.DialOption{
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(4 * 1024 * 1024)),
}
dialers := buildDialers(context.Background(), rs, fakeDialer{}, shared)

if _, ok := dialers[""]; !ok {
t.Fatal("expected default dialer under key \"\"")
}
if _, ok := dialers["good-hint"]; !ok {
t.Fatal("expected successful hint to be present in dialers map")
}
if len(dialers) != 2 {
t.Fatalf("expected 2 dialers, got %d", len(dialers))
}
}
5 changes: 5 additions & 0 deletions cmd/sanssh/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ type RunState struct {
EnableMPA bool
// If true, the command is authz dry run and real action should not be executed
AuthzDryRun bool
// ForceCredential is passed to the proxy to force a specific client
// credential when dialing targets. The proxy will fail with an error if
// the requested credential is not configured. Empty means default.
ForceCredential string

// Interspectors for unary calls to the connection to the proxy
ClientUnaryInterceptors []proxy.UnaryInterceptor
Expand Down Expand Up @@ -376,6 +380,7 @@ func Run(ctx context.Context, rs RunState) {
}

conn.AuthzDryRun = rs.AuthzDryRun
conn.ForceCredential = rs.ForceCredential

if rs.EnableMPA {
conn.UnaryInterceptors = []proxy.UnaryInterceptor{mpahooks.ProxyClientUnaryInterceptor(state)}
Expand Down
Loading
Loading