From 95890d8fe5f4cf7b342f63f0ec6f886f2ddeff51 Mon Sep 17 00:00:00 2001 From: Mateusz Jankowski Date: Fri, 27 Feb 2026 11:17:19 +0100 Subject: [PATCH] Add possibility to specify additional credential source in sanssh to select client cert for mTLS at Proxy level --- cmd/proxy-server/server/server.go | 45 +++++- cmd/proxy-server/server/server_test.go | 191 +++++++++++++++++++++++++ cmd/sanssh/client/client.go | 5 + proxy/proxy.pb.go | 97 +++++++------ proxy/proxy.proto | 6 + proxy/proxy/proxy.go | 14 +- proxy/server/server.go | 16 ++- proxy/server/target.go | 24 +++- proxy/server/target_test.go | 29 ++-- 9 files changed, 357 insertions(+), 70 deletions(-) create mode 100644 cmd/proxy-server/server/server_test.go diff --git a/cmd/proxy-server/server/server.go b/cmd/proxy-server/server/server.go index e9424ba8..0357786b 100644 --- a/cmd/proxy-server/server/server.go +++ b/cmd/proxy-server/server/server.go @@ -73,6 +73,7 @@ type runState struct { authzHooks []rpcauth.RPCAuthzHook services []func(*grpc.Server) metricsRecorder metrics.MetricsRecorder + namedCredSources map[string]string // hint name -> mtls loader name } type Option interface { @@ -308,6 +309,20 @@ func WithOtelTracing(interceptorOpts ...otelgrpc.Option) Option { }) } +// WithNamedClientCredSource registers an additional client credential source +// that the proxy can use when a client sends a matching force_credential in +// StartStream. hintName is the value clients will send; credSource is the name +// registered with the mtls package for loading client credentials. +func WithNamedClientCredSource(hintName, credSource string) Option { + return optionFunc(func(_ context.Context, r *runState) error { + if r.namedCredSources == nil { + r.namedCredSources = make(map[string]string) + } + r.namedCredSources[hintName] = credSource + return nil + }) +} + // Run takes the given context and RunState along with any authz hooks and starts up a sansshell proxy server // using the flags above to provide credentials. An address hook (based on the remote host) with always be added. // As this is intended to be called from main() it doesn't return errors and will instead exit on any errors. @@ -383,21 +398,23 @@ func Run(ctx context.Context, opts ...Option) { unaryClient = append(unaryClient, clientAuthz.AuthorizeClient) streamClient = append(streamClient, clientAuthz.AuthorizeClientStream) } - dialOpts := []grpc.DialOption{ - grpc.WithTransportCredentials(clientCreds), + sharedDialOpts := []grpc.DialOption{ grpc.WithChainUnaryInterceptor(unaryClient...), grpc.WithChainStreamInterceptor(streamClient...), // Use 16MB instead of the default 4MB to allow larger responses grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(16 * 1024 * 1024)), } if rs.statsClientHandler != nil { - dialOpts = append(dialOpts, grpc.WithStatsHandler(rs.statsClientHandler)) + sharedDialOpts = append(sharedDialOpts, grpc.WithStatsHandler(rs.statsClientHandler)) } - targetDialer := server.NewDialer(dialOpts...) + defaultDialOpts := append([]grpc.DialOption{grpc.WithTransportCredentials(clientCreds)}, sharedDialOpts...) + targetDialer := server.NewDialer(defaultDialOpts...) + + dialers := buildDialers(ctx, rs, targetDialer, sharedDialOpts) svcMap := server.LoadGlobalServiceMap() rs.logger.Info("loaded service map", "serviceMap", svcMap) - server := server.New(targetDialer, authz) + server := server.NewWithDialersAndServiceMap(dialers, authz, svcMap) // Even though the proxy RPC is streaming we have unary RPCs (logging, reflection) we // also need to properly auth and log. @@ -456,6 +473,24 @@ func Run(ctx context.Context, opts ...Option) { } } +// buildDialers constructs the named dialers map from the default dialer and +// any additional credential sources registered via WithNamedClientCredSource. +// If a named source fails to load, it is logged and skipped; the default +// dialer is always present under key "". +func buildDialers(ctx context.Context, rs *runState, defaultDialer server.TargetDialer, sharedDialOpts []grpc.DialOption) map[string]server.TargetDialer { + dialers := map[string]server.TargetDialer{"": defaultDialer} + for hint, src := range rs.namedCredSources { + creds, err := mtls.LoadClientCredentials(ctx, src) + if err != nil { + rs.logger.Error(err, "failed to load named client cred source, skipping", "hint", hint, "source", src) + continue + } + hintDialOpts := append([]grpc.DialOption{grpc.WithTransportCredentials(creds)}, sharedDialOpts...) + dialers[hint] = server.NewDialer(hintDialOpts...) + } + return dialers +} + // extractClientTransportCredentialsFromRunState extracts transport credentials from runState. Will error if both credSource and tlsConfig are specified func extractClientTransportCredentialsFromRunState(ctx context.Context, rs *runState) (credentials.TransportCredentials, error) { var creds credentials.TransportCredentials diff --git a/cmd/proxy-server/server/server_test.go b/cmd/proxy-server/server/server_test.go new file mode 100644 index 00000000..484f8f92 --- /dev/null +++ b/cmd/proxy-server/server/server_test.go @@ -0,0 +1,191 @@ +/* Copyright (c) 2019 Snowflake Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +package server + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "os" + "testing" + + "github.com/go-logr/logr" + "google.golang.org/grpc" + + "github.com/Snowflake-Labs/sansshell/auth/mtls" + proxyserver "github.com/Snowflake-Labs/sansshell/proxy/server" +) + +const ( + failLoaderName = "test-buildDialers-fail" + okLoaderName = "test-buildDialers-ok" +) + +func TestMain(m *testing.M) { + if err := mtls.Register(failLoaderName, failingLoader{}); err != nil { + fmt.Fprintf(os.Stderr, "mtls.Register(%s): %v\n", failLoaderName, err) + os.Exit(1) + } + if err := mtls.Register(okLoaderName, successLoader{}); err != nil { + fmt.Fprintf(os.Stderr, "mtls.Register(%s): %v\n", okLoaderName, err) + os.Exit(1) + } + os.Exit(m.Run()) +} + +type fakeDialer struct{} + +func (fakeDialer) DialContext(_ context.Context, _ string, _ ...grpc.DialOption) (proxyserver.ClientConnCloser, error) { + return nil, errors.New("not implemented") +} + +func TestWithNamedClientCredSourceSingle(t *testing.T) { + rs := &runState{} + opt := WithNamedClientCredSource("pg", "some-loader") + if err := opt.apply(context.Background(), rs); err != nil { + t.Fatalf("apply: %v", err) + } + if got, ok := rs.namedCredSources["pg"]; !ok || got != "some-loader" { + t.Fatalf("expected namedCredSources[\"pg\"] = \"some-loader\", got %q (ok=%v)", got, ok) + } +} + +func TestWithNamedClientCredSourceMultiple(t *testing.T) { + rs := &runState{} + for _, pair := range []struct{ hint, src string }{ + {"pg", "loader-a"}, + {"redis", "loader-b"}, + } { + if err := WithNamedClientCredSource(pair.hint, pair.src).apply(context.Background(), rs); err != nil { + t.Fatalf("apply(%q): %v", pair.hint, err) + } + } + if len(rs.namedCredSources) != 2 { + t.Fatalf("expected 2 entries, got %d", len(rs.namedCredSources)) + } + if rs.namedCredSources["pg"] != "loader-a" { + t.Fatalf("pg: got %q", rs.namedCredSources["pg"]) + } + if rs.namedCredSources["redis"] != "loader-b" { + t.Fatalf("redis: got %q", rs.namedCredSources["redis"]) + } +} + +func TestWithNamedClientCredSourceOverwrite(t *testing.T) { + rs := &runState{} + if err := WithNamedClientCredSource("pg", "old").apply(context.Background(), rs); err != nil { + t.Fatal(err) + } + if err := WithNamedClientCredSource("pg", "new").apply(context.Background(), rs); err != nil { + t.Fatal(err) + } + if rs.namedCredSources["pg"] != "new" { + t.Fatalf("expected overwrite to \"new\", got %q", rs.namedCredSources["pg"]) + } +} + +// --- buildDialers tests --- + +type failingLoader struct{} + +func (failingLoader) LoadClientCA(context.Context) (*x509.CertPool, error) { + return nil, errors.New("no CA") +} +func (failingLoader) LoadRootCA(context.Context) (*x509.CertPool, error) { + return nil, errors.New("no root CA") +} +func (failingLoader) LoadClientCertificate(context.Context) (tls.Certificate, error) { + return tls.Certificate{}, errors.New("no client cert") +} +func (failingLoader) LoadServerCertificate(context.Context) (tls.Certificate, error) { + return tls.Certificate{}, errors.New("no server cert") +} +func (failingLoader) CertsRefreshed() bool { return false } +func (failingLoader) GetClientCertInfo(context.Context, string) (*mtls.ClientCertInfo, error) { + return nil, nil +} + +type successLoader struct{} + +func (successLoader) LoadClientCA(context.Context) (*x509.CertPool, error) { + return x509.NewCertPool(), nil +} +func (successLoader) LoadRootCA(context.Context) (*x509.CertPool, error) { + return x509.NewCertPool(), nil +} +func (successLoader) LoadClientCertificate(context.Context) (tls.Certificate, error) { + return tls.Certificate{}, nil +} +func (successLoader) LoadServerCertificate(context.Context) (tls.Certificate, error) { + return tls.Certificate{}, nil +} +func (successLoader) CertsRefreshed() bool { return false } +func (successLoader) GetClientCertInfo(context.Context, string) (*mtls.ClientCertInfo, error) { + return nil, nil +} + +func TestBuildDialersDefaultOnly(t *testing.T) { + rs := &runState{logger: logr.Discard()} + dialers := buildDialers(context.Background(), rs, fakeDialer{}, nil) + if _, ok := dialers[""]; !ok { + t.Fatal("expected default dialer under key \"\"") + } + if len(dialers) != 1 { + t.Fatalf("expected 1 dialer, got %d", len(dialers)) + } +} + +func TestBuildDialersSkipsFailedCredSources(t *testing.T) { + rs := &runState{ + logger: logr.Discard(), + namedCredSources: map[string]string{"bad-hint": failLoaderName}, + } + dialers := buildDialers(context.Background(), rs, fakeDialer{}, nil) + + if _, ok := dialers[""]; !ok { + t.Fatal("expected default dialer under key \"\"") + } + if _, ok := dialers["bad-hint"]; ok { + t.Fatal("expected failed hint to be absent from dialers map") + } + if len(dialers) != 1 { + t.Fatalf("expected 1 dialer (default only), got %d", len(dialers)) + } +} + +func TestBuildDialersRegistersSuccessfulHint(t *testing.T) { + rs := &runState{ + logger: logr.Discard(), + namedCredSources: map[string]string{"good-hint": okLoaderName}, + } + shared := []grpc.DialOption{ + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(4 * 1024 * 1024)), + } + dialers := buildDialers(context.Background(), rs, fakeDialer{}, shared) + + if _, ok := dialers[""]; !ok { + t.Fatal("expected default dialer under key \"\"") + } + if _, ok := dialers["good-hint"]; !ok { + t.Fatal("expected successful hint to be present in dialers map") + } + if len(dialers) != 2 { + t.Fatalf("expected 2 dialers, got %d", len(dialers)) + } +} diff --git a/cmd/sanssh/client/client.go b/cmd/sanssh/client/client.go index 3b49b4a9..2421d188 100644 --- a/cmd/sanssh/client/client.go +++ b/cmd/sanssh/client/client.go @@ -82,6 +82,10 @@ type RunState struct { EnableMPA bool // If true, the command is authz dry run and real action should not be executed AuthzDryRun bool + // ForceCredential is passed to the proxy to force a specific client + // credential when dialing targets. The proxy will fail with an error if + // the requested credential is not configured. Empty means default. + ForceCredential string // Interspectors for unary calls to the connection to the proxy ClientUnaryInterceptors []proxy.UnaryInterceptor @@ -376,6 +380,7 @@ func Run(ctx context.Context, rs RunState) { } conn.AuthzDryRun = rs.AuthzDryRun + conn.ForceCredential = rs.ForceCredential if rs.EnableMPA { conn.UnaryInterceptors = []proxy.UnaryInterceptor{mpahooks.ProxyClientUnaryInterceptor(state)} diff --git a/proxy/proxy.pb.go b/proxy/proxy.pb.go index 1dec35dc..90b546ae 100644 --- a/proxy/proxy.pb.go +++ b/proxy/proxy.pb.go @@ -276,6 +276,11 @@ type StartStream struct { DialTimeout *durationpb.Duration `protobuf:"bytes,4,opt,name=dial_timeout,json=dialTimeout,proto3" json:"dial_timeout,omitempty"` // Perform authz dry run instead actual execution. AuthzDryRun bool `protobuf:"varint,5,opt,name=authz_dry_run,json=authzDryRun,proto3" json:"authz_dry_run,omitempty"` + // Optional. Forces the proxy to use a specific client credential when + // dialing this target. If empty or unset, the proxy uses its default + // credential. The proxy will reject unrecognized non-empty values with + // InvalidArgument rather than falling back to the default. + ForceCredential string `protobuf:"bytes,6,opt,name=force_credential,json=forceCredential,proto3" json:"force_credential,omitempty"` } func (x *StartStream) Reset() { @@ -345,6 +350,13 @@ func (x *StartStream) GetAuthzDryRun() bool { return false } +func (x *StartStream) GetForceCredential() string { + if x != nil { + return x.ForceCredential + } + return "" +} + type StartStreamReply struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -778,7 +790,7 @@ var file_proxy_proto_rawDesc = []byte{ 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x48, 0x00, 0x52, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x42, 0x07, 0x0a, 0x05, 0x72, 0x65, 0x70, 0x6c, 0x79, 0x22, - 0xbe, 0x01, 0x0a, 0x0b, 0x53, 0x74, 0x61, 0x72, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, + 0xe9, 0x01, 0x0a, 0x0b, 0x53, 0x74, 0x61, 0x72, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x16, 0x0a, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x12, 0x1f, 0x0a, 0x0b, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6d, 0x65, @@ -790,47 +802,50 @@ var file_proxy_proto_rawDesc = []byte{ 0x0b, 0x64, 0x69, 0x61, 0x6c, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x22, 0x0a, 0x0d, 0x61, 0x75, 0x74, 0x68, 0x7a, 0x5f, 0x64, 0x72, 0x79, 0x5f, 0x72, 0x75, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x61, 0x75, 0x74, 0x68, 0x7a, 0x44, 0x72, 0x79, 0x52, 0x75, 0x6e, - 0x22, 0x9c, 0x01, 0x0a, 0x10, 0x53, 0x74, 0x61, 0x72, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, - 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x16, 0x0a, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x12, 0x14, 0x0a, - 0x05, 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x6e, 0x6f, - 0x6e, 0x63, 0x65, 0x12, 0x1d, 0x0a, 0x09, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x5f, 0x69, 0x64, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x04, 0x48, 0x00, 0x52, 0x08, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, - 0x49, 0x64, 0x12, 0x32, 0x0a, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x73, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, - 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x48, 0x00, 0x52, 0x0b, 0x65, 0x72, 0x72, 0x6f, 0x72, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x42, 0x07, 0x0a, 0x05, 0x72, 0x65, 0x70, 0x6c, 0x79, 0x22, - 0x2c, 0x0a, 0x0b, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x12, 0x1d, - 0x0a, 0x0a, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, - 0x28, 0x04, 0x52, 0x09, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x49, 0x64, 0x73, 0x22, 0x2d, 0x0a, - 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x12, 0x1d, 0x0a, - 0x0a, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x04, 0x52, 0x09, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x49, 0x64, 0x73, 0x22, 0x5b, 0x0a, 0x0a, - 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, - 0x72, 0x65, 0x61, 0x6d, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x04, 0x52, 0x09, - 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x49, 0x64, 0x73, 0x12, 0x2e, 0x0a, 0x07, 0x70, 0x61, 0x79, - 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, - 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, - 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x65, 0x72, - 0x76, 0x65, 0x72, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x72, 0x65, + 0x12, 0x29, 0x0a, 0x10, 0x66, 0x6f, 0x72, 0x63, 0x65, 0x5f, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, + 0x74, 0x69, 0x61, 0x6c, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x66, 0x6f, 0x72, 0x63, + 0x65, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x22, 0x9c, 0x01, 0x0a, 0x10, + 0x53, 0x74, 0x61, 0x72, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, 0x70, 0x6c, 0x79, + 0x12, 0x16, 0x0a, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x6f, 0x6e, 0x63, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x12, 0x1d, + 0x0a, 0x09, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x04, 0x48, 0x00, 0x52, 0x08, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x49, 0x64, 0x12, 0x32, 0x0a, + 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x48, 0x00, 0x52, 0x0b, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x42, 0x07, 0x0a, 0x05, 0x72, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x2c, 0x0a, 0x0b, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x72, + 0x65, 0x61, 0x6d, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x04, 0x52, 0x09, 0x73, + 0x74, 0x72, 0x65, 0x61, 0x6d, 0x49, 0x64, 0x73, 0x22, 0x2d, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, + 0x6e, 0x74, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x04, 0x52, 0x09, 0x73, 0x74, - 0x72, 0x65, 0x61, 0x6d, 0x49, 0x64, 0x73, 0x12, 0x25, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, - 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x66, - 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x64, 0x65, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x07, - 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x2e, 0x0a, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, - 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x07, 0x64, - 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x32, 0x3e, 0x0a, 0x05, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x12, - 0x35, 0x0a, 0x05, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x12, 0x13, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, - 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x11, 0x2e, - 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x52, 0x65, 0x70, 0x6c, 0x79, - 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x2b, 0x5a, 0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, - 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x53, 0x6e, 0x6f, 0x77, 0x66, 0x6c, 0x61, 0x6b, 0x65, 0x2d, 0x4c, - 0x61, 0x62, 0x73, 0x2f, 0x73, 0x61, 0x6e, 0x73, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x2f, 0x70, 0x72, - 0x6f, 0x78, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x72, 0x65, 0x61, 0x6d, 0x49, 0x64, 0x73, 0x22, 0x5b, 0x0a, 0x0a, 0x53, 0x74, 0x72, 0x65, 0x61, + 0x6d, 0x44, 0x61, 0x74, 0x61, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x5f, + 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x04, 0x52, 0x09, 0x73, 0x74, 0x72, 0x65, 0x61, + 0x6d, 0x49, 0x64, 0x73, 0x12, 0x2e, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x07, 0x70, 0x61, 0x79, + 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6c, + 0x6f, 0x73, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x5f, 0x69, 0x64, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x04, 0x52, 0x09, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x49, + 0x64, 0x73, 0x12, 0x25, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x66, 0x0a, 0x06, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x12, 0x2e, 0x0a, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x03, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, + 0x73, 0x32, 0x3e, 0x0a, 0x05, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x12, 0x35, 0x0a, 0x05, 0x50, 0x72, + 0x6f, 0x78, 0x79, 0x12, 0x13, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x50, 0x72, 0x6f, 0x78, + 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x11, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, + 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x28, 0x01, 0x30, + 0x01, 0x42, 0x2b, 0x5a, 0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, + 0x53, 0x6e, 0x6f, 0x77, 0x66, 0x6c, 0x61, 0x6b, 0x65, 0x2d, 0x4c, 0x61, 0x62, 0x73, 0x2f, 0x73, + 0x61, 0x6e, 0x73, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/proxy/proxy.proto b/proxy/proxy.proto index 470ba0b3..6189c7b7 100644 --- a/proxy/proxy.proto +++ b/proxy/proxy.proto @@ -87,6 +87,12 @@ message StartStream { // Perform authz dry run instead actual execution. bool authz_dry_run = 5; + + // Optional. Forces the proxy to use a specific client credential when + // dialing this target. If empty or unset, the proxy uses its default + // credential. The proxy will reject unrecognized non-empty values with + // InvalidArgument rather than falling back to the default. + string force_credential = 6; } message StartStreamReply { diff --git a/proxy/proxy/proxy.go b/proxy/proxy/proxy.go index 57c69141..9d141265 100644 --- a/proxy/proxy/proxy.go +++ b/proxy/proxy/proxy.go @@ -79,6 +79,11 @@ type Conn struct { // Perform authz dry run instead of actual execution AuthzDryRun bool + // ForceCredential is passed in each StartStream to tell the proxy which + // client credential to use when dialing the target. The proxy will fail + // if the requested credential is not configured. Empty means default. + ForceCredential string + // UnaryInterceptors allow intercepting Invoke and InvokeOneMany calls // that go through a proxy. // It is unsafe to modify Intercepters while calls are in progress. @@ -455,10 +460,11 @@ func (p *Conn) createStreams(ctx context.Context, method string) (proxypb.Proxy_ req := &proxypb.ProxyRequest{ Request: &proxypb.ProxyRequest_StartStream{ StartStream: &proxypb.StartStream{ - Target: t, - MethodName: method, - Nonce: uint32(i), - AuthzDryRun: p.AuthzDryRun, + Target: t, + MethodName: method, + Nonce: uint32(i), + AuthzDryRun: p.AuthzDryRun, + ForceCredential: p.ForceCredential, }, }, } diff --git a/proxy/server/server.go b/proxy/server/server.go index f9346648..296c562c 100644 --- a/proxy/server/server.go +++ b/proxy/server/server.go @@ -77,8 +77,9 @@ type Server struct { // A map of /Package.Service/Method => ServiceMethod serviceMap map[string]*ServiceMethod - // A dialer for making proxy -> target connections - dialer TargetDialer + // Named dialers for making proxy -> target connections. + // Key "" is the default dialer used when no force_credential is specified. + dialers map[string]TargetDialer // A policy authorizer, for authorizing proxy -> target requests authorizer rpcauth.RPCAuthorizer @@ -104,9 +105,16 @@ func New(dialer TargetDialer, authorizer rpcauth.RPCAuthorizer) *Server { // The supplied authorizer is used to authorize requests made // to targets. func NewWithServiceMap(dialer TargetDialer, authorizer rpcauth.RPCAuthorizer, serviceMap map[string]*ServiceMethod) *Server { + return NewWithDialersAndServiceMap(map[string]TargetDialer{"": dialer}, authorizer, serviceMap) +} + +// NewWithDialers creates a new Server with named dialers for credential-hint-based +// dialer selection and the global service map. The dialers map must contain a "" +// key for the default dialer. +func NewWithDialersAndServiceMap(dialers map[string]TargetDialer, authorizer rpcauth.RPCAuthorizer, serviceMap map[string]*ServiceMethod) *Server { return &Server{ serviceMap: serviceMap, - dialer: dialer, + dialers: dialers, authorizer: authorizer, } } @@ -122,7 +130,7 @@ func (s *Server) Proxy(stream pb.Proxy_ProxyServer) error { // create a new TargetStreamSet to manage the target streams // associated with this proxy connection - streamSet := NewTargetStreamSet(s.serviceMap, s.dialer, s.authorizer) + streamSet := NewTargetStreamSet(s.serviceMap, s.dialers, s.authorizer) // A single go-routine for handling all sends to the reply // channel diff --git a/proxy/server/target.go b/proxy/server/target.go index 5a25f798..7823df07 100644 --- a/proxy/server/target.go +++ b/proxy/server/target.go @@ -401,8 +401,10 @@ type TargetStreamSet struct { // A service method map used to resolve incoming stream requests to service methods serviceMethods map[string]*ServiceMethod - // A TargetDialer for initiating target connections - targetDialer TargetDialer + // Named dialers for initiating target connections. + // Key "" is the default dialer used when no force_credential is specified. + // Non-empty values that are not present in this map cause an error. + dialers map[string]TargetDialer // [rpcauthz.rpcAuthorizerImpl], for authorizing requests sent to targets. authorizer rpcauth.RPCAuthorizer @@ -421,11 +423,12 @@ type TargetStreamSet struct { noncePairs map[string]bool } -// NewTargetStreamSet creates a TargetStreamSet which manages a set of related TargetStreams -func NewTargetStreamSet(serviceMethods map[string]*ServiceMethod, dialer TargetDialer, authorizer rpcauth.RPCAuthorizer) *TargetStreamSet { +// NewTargetStreamSet creates a TargetStreamSet which manages a set of related TargetStreams. +// The dialers map must contain a "" key for the default dialer. +func NewTargetStreamSet(serviceMethods map[string]*ServiceMethod, dialers map[string]TargetDialer, authorizer rpcauth.RPCAuthorizer) *TargetStreamSet { return &TargetStreamSet{ serviceMethods: serviceMethods, - targetDialer: dialer, + dialers: dialers, authorizer: authorizer, streams: make(map[uint64]*TargetStream), closedStreams: make(map[uint64]bool), @@ -481,12 +484,21 @@ func (t *TargetStreamSet) Add(ctx context.Context, req *pb.StartStream, replyCha sendReply(reply) return nil } + hint := req.GetForceCredential() + dialer, ok := t.dialers[hint] + if !ok { + reply.GetStartStreamReply().Reply = &pb.StartStreamReply_ErrorStatus{ + ErrorStatus: convertStatus(status.Newf(codes.InvalidArgument, "unknown credential %q: not configured on this proxy", hint)), + } + sendReply(reply) + return nil + } var dialTimeout *time.Duration if req.DialTimeout != nil { d := req.DialTimeout.AsDuration() dialTimeout = &d } - stream, err := NewTargetStream(ctx, req.GetTarget(), t.targetDialer, dialTimeout, serviceMethod, t.authorizer, req.GetAuthzDryRun()) + stream, err := NewTargetStream(ctx, req.GetTarget(), dialer, dialTimeout, serviceMethod, t.authorizer, req.GetAuthzDryRun()) if err != nil { reply.GetStartStreamReply().Reply = &pb.StartStreamReply_ErrorStatus{ ErrorStatus: convertStatus(status.New(codes.Internal, err.Error())), diff --git a/proxy/server/target_test.go b/proxy/server/target_test.go index a5f0eb1e..fe981124 100644 --- a/proxy/server/target_test.go +++ b/proxy/server/target_test.go @@ -42,7 +42,7 @@ func (e dialErrTargetDialer) DialContext(ctx context.Context, target string, dia func TestEmptyStreamSet(t *testing.T) { ctx := context.Background() errDialer := dialErrTargetDialer(codes.Unimplemented) - ss := NewTargetStreamSet(map[string]*ServiceMethod{}, errDialer, nil) + ss := NewTargetStreamSet(map[string]*ServiceMethod{}, map[string]TargetDialer{"": errDialer}, nil) // wait does not block when no work is being done finishedWait := make(chan struct{}) @@ -81,13 +81,14 @@ func TestEmptyStreamSet(t *testing.T) { func TestStreamSetAddErrors(t *testing.T) { errDialer := dialErrTargetDialer(codes.Unimplemented) serviceMap := LoadGlobalServiceMap() - ss := NewTargetStreamSet(serviceMap, errDialer, nil) + ss := NewTargetStreamSet(serviceMap, map[string]TargetDialer{"": errDialer}, nil) for _, tc := range []struct { - name string - method string - nonce uint32 - errCode codes.Code + name string + method string + nonce uint32 + forceCredential string + errCode codes.Code }{ { name: "dial failure no error", @@ -100,6 +101,13 @@ func TestStreamSetAddErrors(t *testing.T) { method: "/Nosuch.Method/Foo", errCode: codes.InvalidArgument, }, + { + name: "unknown credential hint", + nonce: 3, + method: "/Testdata.TestService/TestUnary", + forceCredential: "nonexistent", + errCode: codes.InvalidArgument, + }, } { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -107,9 +115,10 @@ func TestStreamSetAddErrors(t *testing.T) { replyChan := make(chan *pb.ProxyReply, 1) req := &pb.StartStream{ - Target: "nosuchhost:000", - Nonce: tc.nonce, - MethodName: tc.method, + Target: "nosuchhost:000", + Nonce: tc.nonce, + MethodName: tc.method, + ForceCredential: tc.forceCredential, } err := ss.Add(context.Background(), req, replyChan, nil /*doneChan should not be called*/) testutil.FatalOnErr(fmt.Sprintf("StartStream(+%v)", req), err, t) @@ -166,7 +175,7 @@ func TestTargetStreamAddNonBlocking(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() serviceMap := LoadGlobalServiceMap() - ss := NewTargetStreamSet(serviceMap, blockingClientDialer{}, nil) + ss := NewTargetStreamSet(serviceMap, map[string]TargetDialer{"": blockingClientDialer{}}, nil) replyChan := make(chan *pb.ProxyReply, 1) doneChan := make(chan struct{}) req := &pb.StartStream{