From 9c46be8c7b379ec7e9ff13252102f0e02300c128 Mon Sep 17 00:00:00 2001 From: Mateusz Jankowski Date: Tue, 3 Mar 2026 17:46:40 +0100 Subject: [PATCH] Implement multi identity caller that looks at the server cert while selecting the client one --- auth/mtls/client.go | 19 +++ auth/mtls/multi_identity.go | 209 ++++++++++++++++++++++++++++++ cmd/proxy-server/server/server.go | 55 ++++---- 3 files changed, 261 insertions(+), 22 deletions(-) create mode 100644 auth/mtls/multi_identity.go diff --git a/auth/mtls/client.go b/auth/mtls/client.go index e5d5aade..f75d67f1 100644 --- a/auth/mtls/client.go +++ b/auth/mtls/client.go @@ -80,6 +80,25 @@ func NewClientCredentials(cert tls.Certificate, CAPool *x509.CertPool) credentia }) } +// LoadClientIdentity loads a client certificate and root CA pool from the +// named CredentialsLoader and returns them as an Identity suitable for use +// with NewMultiIdentityCredentials. +func LoadClientIdentity(ctx context.Context, loaderName string) (Identity, error) { + loader, err := Loader(loaderName) + if err != nil { + return Identity{}, err + } + pool, err := loader.LoadRootCA(ctx) + if err != nil { + return Identity{}, err + } + cert, err := loader.LoadClientCertificate(ctx) + if err != nil { + return Identity{}, err + } + return Identity{Name: loaderName, Cert: cert, RootCAs: pool}, nil +} + // LoadClientTLS reads the certificates and keys from disk at the supplied paths, // and assembles them into a set of TransportCredentials for the gRPC client. func LoadClientTLS(clientCertFile, clientKeyFile string, CAPool *x509.CertPool) (credentials.TransportCredentials, error) { diff --git a/auth/mtls/multi_identity.go b/auth/mtls/multi_identity.go new file mode 100644 index 00000000..1fdcc0f7 --- /dev/null +++ b/auth/mtls/multi_identity.go @@ -0,0 +1,209 @@ +/* Copyright (c) 2019 Snowflake Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +package mtls + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "net" + "sync" + + "google.golang.org/grpc/credentials" +) + +// Identity represents a single client identity (certificate + trusted CA pool) +// that can be presented during a TLS handshake. +type Identity struct { + Name string + Cert tls.Certificate + RootCAs *x509.CertPool +} + +// MultiIdentityCredentials implements credentials.TransportCredentials by +// dynamically selecting a client certificate during the TLS handshake based +// on the server's CertificateRequest.AcceptableCAs. This allows a single +// gRPC dial to present the correct identity without retry or fallback. +// +// All identities are assumed to share a RootCAs pool that already contains +// every CA needed to verify any target server. The pools from all identities +// are merged at construction time and used for standard Go TLS server +// verification. +type MultiIdentityCredentials struct { + mu sync.RWMutex + identities []Identity + loaderNames []string + rootCAs *x509.CertPool + serverName string +} + +// NewMultiIdentityCredentials creates a TransportCredentials that dynamically +// selects among the provided identities during each TLS handshake. +// At least one identity must be provided. When only one identity is given, +// behavior is equivalent to standard single-identity TLS credentials. +func NewMultiIdentityCredentials(identities []Identity) (*MultiIdentityCredentials, error) { + if len(identities) == 0 { + return nil, errors.New("at least one identity is required") + } + names := make([]string, len(identities)) + for i := range identities { + names[i] = identities[i].Name + } + m := &MultiIdentityCredentials{ + identities: make([]Identity, len(identities)), + loaderNames: names, + rootCAs: identities[0].RootCAs, + } + copy(m.identities, identities) + return m, nil +} + +func (m *MultiIdentityCredentials) refreshIfNeeded(ctx context.Context) { + for i, name := range m.loaderNames { + loader, err := Loader(name) + if err != nil { + continue + } + if !loader.CertsRefreshed() { + continue + } + id, err := LoadClientIdentity(ctx, name) + if err != nil { + continue + } + m.mu.Lock() + m.identities[i] = id + m.rootCAs = id.RootCAs + m.mu.Unlock() + } +} + +// ClientHandshake performs the client-side TLS handshake. It uses +// GetClientCertificate to pick the right client certificate based on the +// server's CertificateRequest.AcceptableCAs. Server certificate verification +// uses the merged RootCAs pool via standard Go TLS verification. +func (m *MultiIdentityCredentials) ClientHandshake(ctx context.Context, serverName string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + m.refreshIfNeeded(ctx) + + m.mu.RLock() + overrideName := m.serverName + identities := m.identities + rootCAs := m.rootCAs + m.mu.RUnlock() + + if overrideName != "" { + serverName = overrideName + } + + tlsCfg := &tls.Config{ + ServerName: serverName, + RootCAs: rootCAs, + MinVersion: tls.VersionTLS12, + GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { + idx := selectIdentity(cri, identities) + cert := identities[idx].Cert + return &cert, nil + }, + } + + conn := tls.Client(rawConn, tlsCfg) + if err := conn.HandshakeContext(ctx); err != nil { + conn.Close() + return nil, nil, err + } + + info := credentials.TLSInfo{ + State: conn.ConnectionState(), + CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}, + } + return conn, info, nil +} + +// selectIdentity picks the first identity whose issuing CA appears in the +// server's AcceptableCAs list. Falls back to index 0 if the server sends +// an empty list or no match is found. +func selectIdentity(cri *tls.CertificateRequestInfo, identities []Identity) int { + if len(cri.AcceptableCAs) == 0 || len(identities) <= 1 { + return 0 + } + acceptableSet := make(map[string]struct{}, len(cri.AcceptableCAs)) + for _, ca := range cri.AcceptableCAs { + acceptableSet[string(ca)] = struct{}{} + } + for i := range identities { + issuer := leafIssuer(&identities[i]) + if issuer == nil { + continue + } + if _, ok := acceptableSet[string(issuer)]; ok { + return i + } + } + return 0 +} + +func leafIssuer(id *Identity) []byte { + if id.Cert.Leaf != nil { + return id.Cert.Leaf.RawIssuer + } + if len(id.Cert.Certificate) > 0 { + leaf, err := x509.ParseCertificate(id.Cert.Certificate[0]) + if err == nil { + return leaf.RawIssuer + } + } + return nil +} + +// ServerHandshake is not supported — MultiIdentityCredentials is client-side only. +func (m *MultiIdentityCredentials) ServerHandshake(net.Conn) (net.Conn, credentials.AuthInfo, error) { + return nil, nil, errors.New("MultiIdentityCredentials does not support server-side handshake") +} + +// Info returns protocol info. +func (m *MultiIdentityCredentials) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{ + SecurityProtocol: "tls", + SecurityVersion: "1.2", + ServerName: m.serverName, + } +} + +// Clone returns a deep copy of the credentials. +func (m *MultiIdentityCredentials) Clone() credentials.TransportCredentials { + m.mu.RLock() + defer m.mu.RUnlock() + identitiesCopy := make([]Identity, len(m.identities)) + copy(identitiesCopy, m.identities) + namesCopy := make([]string, len(m.loaderNames)) + copy(namesCopy, m.loaderNames) + return &MultiIdentityCredentials{ + identities: identitiesCopy, + loaderNames: namesCopy, + rootCAs: m.rootCAs, + serverName: m.serverName, + } +} + +// OverrideServerName sets the server name used for TLS verification. +func (m *MultiIdentityCredentials) OverrideServerName(name string) error { //nolint:staticcheck + m.mu.Lock() + defer m.mu.Unlock() + m.serverName = name + return nil +} diff --git a/cmd/proxy-server/server/server.go b/cmd/proxy-server/server/server.go index e9424ba8..a59099a4 100644 --- a/cmd/proxy-server/server/server.go +++ b/cmd/proxy-server/server/server.go @@ -56,6 +56,7 @@ type runState struct { policy rpcauth.AuthzPolicy clientPolicy rpcauth.AuthzPolicy credSource string + clientCredSources []string tlsConfig *tls.Config hostport string debugport string @@ -119,6 +120,8 @@ func WithTlsConfig(tlsConfig *tls.Config) Option { } // WithCredSource applies a registered credential source with the mtls package. +// This is used for the server-side (incoming) identity. For outbound client +// identities, use WithClientCredSource. func WithCredSource(credSource string) Option { return optionFunc(func(_ context.Context, r *runState) error { r.credSource = credSource @@ -126,6 +129,16 @@ func WithCredSource(credSource string) Option { }) } +// WithClientCredSource adds a credential source for outbound (proxy→target) +// connections. Can be called multiple times to register multiple identities. +// If no client cred sources are specified, the server's credSource is used. +func WithClientCredSource(credSource string) Option { + return optionFunc(func(_ context.Context, r *runState) error { + r.clientCredSources = append(r.clientCredSources, credSource) + return nil + }) +} + // WithHostport applies the host:port to run the server. func WithHostPort(hostport string) Option { return optionFunc(func(_ context.Context, r *runState) error { @@ -346,7 +359,7 @@ func Run(ctx context.Context, opts ...Option) { clientCreds, err := extractClientTransportCredentialsFromRunState(ctx, rs) if err != nil { - rs.logger.Error(err, "unable to extract transport credentials from runstate for the client", "credsource", rs.credSource) + rs.logger.Error(err, "unable to extract transport credentials from runstate for the client", "credsources", rs.clientCredSources) os.Exit(1) } @@ -447,7 +460,7 @@ func Run(ctx context.Context, opts ...Option) { g.GracefulStop() }() - rs.logger.Info("initialized proxy service", "credsource", rs.credSource) + rs.logger.Info("initialized proxy service", "credsource", rs.credSource, "clientCredsources", rs.clientCredSources) rs.logger.Info("serving..") if err := g.Serve(lis); err != nil { @@ -458,36 +471,34 @@ func Run(ctx context.Context, opts ...Option) { // extractClientTransportCredentialsFromRunState extracts transport credentials from runState. Will error if both credSource and tlsConfig are specified func extractClientTransportCredentialsFromRunState(ctx context.Context, rs *runState) (credentials.TransportCredentials, error) { - var creds credentials.TransportCredentials - var err error - if rs.credSource != "" && rs.tlsConfig != nil { - return nil, fmt.Errorf("both credSource and tlsConfig are defined for the client") + sources := rs.clientCredSources + if len(sources) == 0 && rs.credSource != "" { + sources = []string{rs.credSource} + } + if len(sources) > 0 && rs.tlsConfig != nil { + return nil, fmt.Errorf("both credSources and tlsConfig are defined for the client") + } + if len(sources) == 0 { + return credentials.NewTLS(rs.tlsConfig), nil } - if rs.credSource != "" { - creds, err = mtls.LoadClientCredentials(ctx, rs.credSource) + identities := make([]mtls.Identity, 0, len(sources)) + for _, src := range sources { + id, err := mtls.LoadClientIdentity(ctx, src) if err != nil { - return nil, err + return nil, fmt.Errorf("loading client identity %q: %w", src, err) } - } else { - creds = credentials.NewTLS(rs.tlsConfig) + identities = append(identities, id) } - return creds, nil + return mtls.NewMultiIdentityCredentials(identities) } // extractServerTransportCredentialsFromRunState extracts transport credentials from runState. Will error if both credSource and tlsConfig are specified func extractServerTransportCredentialsFromRunState(ctx context.Context, rs *runState) (credentials.TransportCredentials, error) { - var creds credentials.TransportCredentials - var err error if rs.credSource != "" && rs.tlsConfig != nil { return nil, fmt.Errorf("both credSource and tlsConfig are defined for the server") } - if rs.credSource != "" { - creds, err = mtls.LoadServerCredentials(ctx, rs.credSource) - if err != nil { - return nil, err - } - } else { - creds = credentials.NewTLS(rs.tlsConfig) + if rs.credSource == "" { + return credentials.NewTLS(rs.tlsConfig), nil } - return creds, nil + return mtls.LoadServerCredentials(ctx, rs.credSource) }