From 00280ac5a0a307d8d7f5416553f6c4a4ab227b56 Mon Sep 17 00:00:00 2001 From: Marcin Walas Date: Tue, 1 Apr 2025 19:22:42 +0200 Subject: [PATCH] MPA method-wide mode for sansshell --- auth/opa/rpcauth/input.go | 3 + cmd/sanssh/client/client.go | 13 ++-- cmd/sanssh/main.go | 24 ++++---- services/mpa/mpa.pb.go | 49 +++++++++++---- services/mpa/mpa.proto | 6 ++ services/mpa/mpa_grpc.pb.go | 2 +- services/mpa/mpahooks/mpahooks.go | 82 +++++++++++++++----------- services/mpa/mpahooks/mpahooks_test.go | 12 ++-- 8 files changed, 123 insertions(+), 68 deletions(-) diff --git a/auth/opa/rpcauth/input.go b/auth/opa/rpcauth/input.go index bfff6908..58a8bbeb 100644 --- a/auth/opa/rpcauth/input.go +++ b/auth/opa/rpcauth/input.go @@ -56,6 +56,9 @@ type RPCAuthInput struct { // Information about approvers when using multi-party authentication. Approvers []*PrincipalAuthInput `json:"approvers"` + // TODO: commentary. + ApprovedMethodWideMpa bool `json:"approved-method-wide-mpa"` + // Information about the environment in which the policy evaluation is // happening. Environment *EnvironmentInput `json:"environment"` diff --git a/cmd/sanssh/client/client.go b/cmd/sanssh/client/client.go index 8f17e9ed..7cef0b8e 100644 --- a/cmd/sanssh/client/client.go +++ b/cmd/sanssh/client/client.go @@ -23,12 +23,13 @@ import ( "context" "flag" "fmt" - "google.golang.org/grpc/credentials" "io" "os" "path/filepath" "time" + "google.golang.org/grpc/credentials" + "github.com/google/subcommands" "google.golang.org/grpc" @@ -78,6 +79,8 @@ type RunState struct { BatchSize int // If true, add an interceptor that performs the multi-party auth flow EnableMPA bool + // If true, configure MPA interceptor to request approval method-wide. + EnableMethodWideMPA bool // If true, the command is authz dry run and real action should not be executed AuthzDryRun bool @@ -287,8 +290,8 @@ func Run(ctx context.Context, rs RunState) { unaryInterceptors = append(unaryInterceptors, clientAuthz.AuthorizeClient) } if rs.EnableMPA { - unaryInterceptors = append(unaryInterceptors, mpahooks.UnaryClientIntercepter()) - streamInterceptors = append(streamInterceptors, mpahooks.StreamClientIntercepter()) + unaryInterceptors = append(unaryInterceptors, mpahooks.UnaryClientIntercepter(rs.EnableMethodWideMPA)) + streamInterceptors = append(streamInterceptors, mpahooks.StreamClientIntercepter(rs.EnableMethodWideMPA)) } // timeout interceptor should be the last item in ops so that it's executed first. streamInterceptors = append(streamInterceptors, StreamClientTimeoutInterceptor(rs.IdleTimeout)) @@ -369,8 +372,8 @@ func Run(ctx context.Context, rs RunState) { conn.AuthzDryRun = rs.AuthzDryRun if rs.EnableMPA { - conn.UnaryInterceptors = []proxy.UnaryInterceptor{mpahooks.ProxyClientUnaryInterceptor(state)} - conn.StreamInterceptors = []proxy.StreamInterceptor{mpahooks.ProxyClientStreamInterceptor(state)} + conn.UnaryInterceptors = []proxy.UnaryInterceptor{mpahooks.ProxyClientUnaryInterceptor(state, rs.EnableMethodWideMPA)} + conn.StreamInterceptors = []proxy.StreamInterceptor{mpahooks.ProxyClientStreamInterceptor(state, rs.EnableMethodWideMPA)} } state.Conn = conn state.Out = output[start:end] diff --git a/cmd/sanssh/main.go b/cmd/sanssh/main.go index e29be8ee..416fd7b4 100644 --- a/cmd/sanssh/main.go +++ b/cmd/sanssh/main.go @@ -88,6 +88,7 @@ If port is blank the default of %d will be used`, proxyEnv, defaultProxyPort)) prefixHeader = flag.Bool("h", false, "If true prefix each line of output with '-: '") batchSize = flag.Int("batch-size", 0, "If non-zero will perform the proxy->target work in batches of this size (with any remainder done at the end).") mpa = flag.Bool("mpa", false, "Request multi-party approval for commands. This will create an MPA request, wait for approval, and then execute the command.") + methodWideMpa = flag.Bool("method-wide-mpa", false, "Request multi-party approval for entire method, regardless of the payload.") authzDryRun = flag.Bool("authz-dry-run", false, "If true, the client will send a request to the server to check if the user has the permission to run the command. The server will respond with a success or failure message.") // targets will be bound to --targets for sending a single request to N nodes. @@ -189,17 +190,18 @@ func main() { stdr.SetVerbosity(*verbosity) rs := client.RunState{ - Proxy: *proxyAddr, - Targets: *targetsFlag.Target, - Outputs: *outputsFlag.Target, - AuthzDryRun: *authzDryRun, - OutputsDir: *outputsDir, - CredSource: *credSource, - IdleTimeout: *idleTimeout, - ClientPolicy: clientPolicy, - PrefixOutput: *prefixHeader, - BatchSize: *batchSize, - EnableMPA: *mpa, + Proxy: *proxyAddr, + Targets: *targetsFlag.Target, + Outputs: *outputsFlag.Target, + AuthzDryRun: *authzDryRun, + OutputsDir: *outputsDir, + CredSource: *credSource, + IdleTimeout: *idleTimeout, + ClientPolicy: clientPolicy, + PrefixOutput: *prefixHeader, + BatchSize: *batchSize, + EnableMPA: *mpa, + EnableMethodWideMPA: *methodWideMpa, } ctx := logr.NewContext(context.Background(), logger) diff --git a/services/mpa/mpa.pb.go b/services/mpa/mpa.pb.go index 377e7e96..b04fe5ed 100644 --- a/services/mpa/mpa.pb.go +++ b/services/mpa/mpa.pb.go @@ -16,7 +16,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.34.2 -// protoc v5.29.3 +// protoc v5.28.3 // source: mpa.proto package mpa @@ -49,6 +49,9 @@ type Action struct { Method string `protobuf:"bytes,3,opt,name=method,proto3" json:"method,omitempty"` // The request protocol buffer. Message *anypb.Any `protobuf:"bytes,4,opt,name=message,proto3" json:"message,omitempty"` + // Method-wide MPA requested, will ignore `message` and only match request on + // `method`. + MethodWideMpa bool `protobuf:"varint,5,opt,name=method_wide_mpa,json=methodWideMpa,proto3" json:"method_wide_mpa,omitempty"` } func (x *Action) Reset() { @@ -111,6 +114,13 @@ func (x *Action) GetMessage() *anypb.Any { return nil } +func (x *Action) GetMethodWideMpa() bool { + if x != nil { + return x.MethodWideMpa + } + return false +} + type Principal struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -177,6 +187,9 @@ type StoreRequest struct { Method string `protobuf:"bytes,1,opt,name=method,proto3" json:"method,omitempty"` // The request protocol buffer. Message *anypb.Any `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + // Method-wide MPA requested, will ignore `message` and only match request on + // `method`. + MethodWideMpa bool `protobuf:"varint,5,opt,name=method_wide_mpa,json=methodWideMpa,proto3" json:"method_wide_mpa,omitempty"` } func (x *StoreRequest) Reset() { @@ -225,6 +238,13 @@ func (x *StoreRequest) GetMessage() *anypb.Any { return nil } +func (x *StoreRequest) GetMethodWideMpa() bool { + if x != nil { + return x.MethodWideMpa + } + return false +} + type StoreResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -804,7 +824,7 @@ var File_mpa_proto protoreflect.FileDescriptor var file_mpa_proto_rawDesc = []byte{ 0x0a, 0x09, 0x6d, 0x70, 0x61, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x03, 0x4d, 0x70, 0x61, 0x1a, 0x19, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, - 0x66, 0x2f, 0x61, 0x6e, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x8a, 0x01, 0x0a, 0x06, + 0x66, 0x2f, 0x61, 0x6e, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xb2, 0x01, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x24, 0x0a, 0x0d, 0x6a, 0x75, 0x73, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, @@ -813,16 +833,21 @@ var file_mpa_proto_rawDesc = []byte{ 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x2e, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, - 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x33, 0x0a, 0x09, 0x50, 0x72, 0x69, 0x6e, - 0x63, 0x69, 0x70, 0x61, 0x6c, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x22, 0x56, 0x0a, - 0x0c, 0x53, 0x74, 0x6f, 0x72, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, - 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, - 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x2e, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x07, 0x6d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x70, 0x0a, 0x0d, 0x53, 0x74, 0x6f, 0x72, 0x65, 0x52, 0x65, + 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x26, 0x0a, 0x0f, 0x6d, 0x65, 0x74, 0x68, + 0x6f, 0x64, 0x5f, 0x77, 0x69, 0x64, 0x65, 0x5f, 0x6d, 0x70, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x0d, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x57, 0x69, 0x64, 0x65, 0x4d, 0x70, 0x61, + 0x22, 0x33, 0x0a, 0x09, 0x50, 0x72, 0x69, 0x6e, 0x63, 0x69, 0x70, 0x61, 0x6c, 0x12, 0x0e, 0x0a, + 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x16, 0x0a, + 0x06, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x67, + 0x72, 0x6f, 0x75, 0x70, 0x73, 0x22, 0x7e, 0x0a, 0x0c, 0x53, 0x74, 0x6f, 0x72, 0x65, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x2e, 0x0a, + 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x26, 0x0a, + 0x0f, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x5f, 0x77, 0x69, 0x64, 0x65, 0x5f, 0x6d, 0x70, 0x61, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x57, 0x69, + 0x64, 0x65, 0x4d, 0x70, 0x61, 0x22, 0x70, 0x0a, 0x0d, 0x53, 0x74, 0x6f, 0x72, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x23, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x4d, 0x70, 0x61, 0x2e, 0x41, 0x63, 0x74, diff --git a/services/mpa/mpa.proto b/services/mpa/mpa.proto index 2e140be9..e248235a 100644 --- a/services/mpa/mpa.proto +++ b/services/mpa/mpa.proto @@ -61,6 +61,9 @@ message Action { string method = 3; // The request protocol buffer. google.protobuf.Any message = 4; + // Method-wide MPA requested, will ignore `message` and only match request on + // `method`. + bool method_wide_mpa = 5; } message Principal { @@ -76,6 +79,9 @@ message StoreRequest { string method = 1; // The request protocol buffer. google.protobuf.Any message = 2; + // Method-wide MPA requested, will ignore `message` and only match request on + // `method`. + bool method_wide_mpa = 5; } message StoreResponse { diff --git a/services/mpa/mpa_grpc.pb.go b/services/mpa/mpa_grpc.pb.go index 8335bdec..f33c8300 100644 --- a/services/mpa/mpa_grpc.pb.go +++ b/services/mpa/mpa_grpc.pb.go @@ -16,7 +16,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.5.1 -// - protoc v5.29.3 +// - protoc v5.28.3 // source: mpa.proto package mpa diff --git a/services/mpa/mpahooks/mpahooks.go b/services/mpa/mpahooks/mpahooks.go index bc9bf15a..50eecc41 100644 --- a/services/mpa/mpahooks/mpahooks.go +++ b/services/mpa/mpahooks/mpahooks.go @@ -72,20 +72,6 @@ func ActionMatchesInput(ctx context.Context, action *mpa.Action, input *rpcauth. justification = j[0] } - // Transform the rpcauth input into the original proto - mt, err := protoregistry.GlobalTypes.FindMessageByURL(input.MessageType) - if err != nil { - return fmt.Errorf("unable to find proto type: %v", err) - } - m2 := mt.New().Interface() - if err := protojson.Unmarshal([]byte(input.Message), m2); err != nil { - return fmt.Errorf("could not marshal input into %v: %v", input.Message, err) - } - var msg anypb.Any - if err := msg.MarshalFrom(m2); err != nil { - return fmt.Errorf("unable to marshal into anyproto: %v", err) - } - // Prefer using a proxied identity if provided var user string if p := proxiedidentity.FromContext(ctx); p != nil { @@ -100,7 +86,24 @@ func ActionMatchesInput(ctx context.Context, action *mpa.Action, input *rpcauth. User: user, Method: input.Method, Justification: justification, - Message: &msg, + } + + if !action.MethodWideMpa { + // Transform the rpcauth input into the original proto + mt, err := protoregistry.GlobalTypes.FindMessageByURL(input.MessageType) + if err != nil { + return fmt.Errorf("unable to find proto type: %v", err) + } + m2 := mt.New().Interface() + if err := protojson.Unmarshal([]byte(input.Message), m2); err != nil { + return fmt.Errorf("could not marshal input into %v: %v", input.Message, err) + } + var msg anypb.Any + if err := msg.MarshalFrom(m2); err != nil { + return fmt.Errorf("unable to marshal into anyproto: %v", err) + } + + sentAct.Message = &msg } // Make sure to use an any-proto-aware comparison if !cmp.Equal(action, sentAct, protocmp.Transform()) { @@ -109,7 +112,7 @@ func ActionMatchesInput(ctx context.Context, action *mpa.Action, input *rpcauth. return nil } -func createAndBlockOnSingleTargetMPA(ctx context.Context, method string, req any, cc *grpc.ClientConn) (mpaID string, err error) { +func createAndBlockOnSingleTargetMPA(ctx context.Context, method string, req any, cc *grpc.ClientConn, methodWideMpa bool) (mpaID string, err error) { p, ok := req.(proto.Message) if !ok { return "", fmt.Errorf("unable to cast req to proto: %v", req) @@ -121,10 +124,15 @@ func createAndBlockOnSingleTargetMPA(ctx context.Context, method string, req any } mpaClient := mpa.NewMpaClient(cc) - result, err := mpaClient.Store(ctx, &mpa.StoreRequest{ - Method: method, - Message: &msg, - }) + storeReq := &mpa.StoreRequest{ + Method: method, + MethodWideMpa: methodWideMpa, + } + if !methodWideMpa { + storeReq.Message = &msg + } + result, err := mpaClient.Store(ctx, storeReq) + if err != nil { return "", err } @@ -140,7 +148,7 @@ func createAndBlockOnSingleTargetMPA(ctx context.Context, method string, req any } // UnaryClientIntercepter is a grpc.UnaryClientIntercepter that will perform the MPA flow. -func UnaryClientIntercepter() grpc.UnaryClientInterceptor { +func UnaryClientIntercepter(methodWideMpa bool) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { // Our interceptor will run for all gRPC calls, including ones used inside the interceptor. // We need to bail early on MPA-related ones to prevent infinite recursion. @@ -148,7 +156,7 @@ func UnaryClientIntercepter() grpc.UnaryClientInterceptor { return invoker(ctx, method, req, reply, cc, opts...) } - mpaID, err := createAndBlockOnSingleTargetMPA(ctx, method, req, cc) + mpaID, err := createAndBlockOnSingleTargetMPA(ctx, method, req, cc, methodWideMpa) if err != nil { return err } @@ -211,7 +219,7 @@ func (w *delayedStartStream) RecvMsg(m any) error { // StreamClientIntercepter is a grpc.StreamClientInterceptor that will perform // the MPA flow. -func StreamClientIntercepter() grpc.StreamClientInterceptor { +func StreamClientIntercepter(methodWideMpa bool) grpc.StreamClientInterceptor { return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { if method == "/Proxy.Proxy/Proxy" { // No need to intercept proxying, that's handled specially. @@ -220,7 +228,7 @@ func StreamClientIntercepter() grpc.StreamClientInterceptor { return newStreamAfterFirstSend(func(m any) (grpc.ClientStream, error) { // Figure out the MPA request - mpaID, err := createAndBlockOnSingleTargetMPA(ctx, method, m, cc) + mpaID, err := createAndBlockOnSingleTargetMPA(ctx, method, m, cc, methodWideMpa) if err != nil { return nil, err } @@ -233,7 +241,7 @@ func StreamClientIntercepter() grpc.StreamClientInterceptor { } } -func createAndBlockOnProxiedMPA(ctx context.Context, method string, args any, conn *proxy.Conn, state *util.ExecuteState) (mpaID string, err error) { +func createAndBlockOnProxiedMPA(ctx context.Context, method string, args any, conn *proxy.Conn, state *util.ExecuteState, methodWideMpa bool) (mpaID string, err error) { p, ok := args.(proto.Message) if !ok { return "", fmt.Errorf("unable to cast args to proto: %v", args) @@ -243,10 +251,13 @@ func createAndBlockOnProxiedMPA(ctx context.Context, method string, args any, co return "", fmt.Errorf("unable to marshal into anyproto: %v", err) } mpaClient := mpa.NewMpaClientProxy(conn) - ch, err := mpaClient.StoreOneMany(ctx, &mpa.StoreRequest{ - Method: method, - Message: &msg, - }) + storeReq := &mpa.StoreRequest{ + Method: method, + } + if !methodWideMpa { + storeReq.Message = &msg + } + ch, err := mpaClient.StoreOneMany(ctx, storeReq) if err != nil { return "", err } @@ -296,7 +307,7 @@ func createAndBlockOnProxiedMPA(ctx context.Context, method string, args any, co // ProxyClientUnaryInterceptor will perform the MPA flow prior to making the desired RPC // calls through the proxy. -func ProxyClientUnaryInterceptor(state *util.ExecuteState) proxy.UnaryInterceptor { +func ProxyClientUnaryInterceptor(state *util.ExecuteState, methodWideMpa bool) proxy.UnaryInterceptor { return func(ctx context.Context, conn *proxy.Conn, method string, args any, invoker proxy.UnaryInvoker, opts ...grpc.CallOption) (<-chan *proxy.Ret, error) { // Our hook will run for all gRPC calls, including ones used inside the interceptor. // We need to bail early on MPA-related ones to prevent infinite recursion. @@ -304,7 +315,7 @@ func ProxyClientUnaryInterceptor(state *util.ExecuteState) proxy.UnaryIntercepto return invoker(ctx, method, args, opts...) } - mpaID, err := createAndBlockOnProxiedMPA(ctx, method, args, conn, state) + mpaID, err := createAndBlockOnProxiedMPA(ctx, method, args, conn, state, methodWideMpa) if err != nil { return nil, err } @@ -317,11 +328,11 @@ func ProxyClientUnaryInterceptor(state *util.ExecuteState) proxy.UnaryIntercepto // ProxyClientStreamInterceptor will perform the MPA flow prior to making the desired streaming // RPC calls through the proxy. -func ProxyClientStreamInterceptor(state *util.ExecuteState) proxy.StreamInterceptor { +func ProxyClientStreamInterceptor(state *util.ExecuteState, methodWideMpa bool) proxy.StreamInterceptor { return func(ctx context.Context, desc *grpc.StreamDesc, cc *proxy.Conn, method string, streamer proxy.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { return newStreamAfterFirstSend(func(args any) (grpc.ClientStream, error) { // Figure out the MPA request - mpaID, err := createAndBlockOnProxiedMPA(ctx, method, args, cc, state) + mpaID, err := createAndBlockOnProxiedMPA(ctx, method, args, cc, state, methodWideMpa) if err != nil { return nil, err } @@ -363,6 +374,11 @@ func ProxyMPAAuthzHook() rpcauth.RPCAuthzHook { if err := ActionMatchesInput(ctx, resp.Action, input); err != nil { return err } + + if resp.Action.MethodWideMpa { + input.ApprovedMethodWideMpa = true + } + for _, a := range resp.Approver { input.Approvers = append(input.Approvers, &rpcauth.PrincipalAuthInput{ ID: a.Id, diff --git a/services/mpa/mpahooks/mpahooks_test.go b/services/mpa/mpahooks/mpahooks_test.go index f710a5db..7bd4f557 100644 --- a/services/mpa/mpahooks/mpahooks_test.go +++ b/services/mpa/mpahooks/mpahooks_test.go @@ -316,8 +316,8 @@ func TestClientInterceptors(t *testing.T) { // Make our calls conn, err := grpc.DialContext(ctx, srvAddr, grpc.WithTransportCredentials(clientCreds), - grpc.WithChainStreamInterceptor(mpahooks.StreamClientIntercepter()), - grpc.WithChainUnaryInterceptor(mpahooks.UnaryClientIntercepter()), + grpc.WithChainStreamInterceptor(mpahooks.StreamClientIntercepter(false)), + grpc.WithChainUnaryInterceptor(mpahooks.UnaryClientIntercepter(false)), ) if err != nil { t.Error(err) @@ -513,8 +513,8 @@ func TestProxiedClientInterceptors(t *testing.T) { // Make our calls conn, err := proxy.DialContext(ctx, proxyAddr, []string{srvAddr}, grpc.WithTransportCredentials(clientCreds), - grpc.WithChainStreamInterceptor(mpahooks.StreamClientIntercepter()), - grpc.WithChainUnaryInterceptor(mpahooks.UnaryClientIntercepter()), + grpc.WithChainStreamInterceptor(mpahooks.StreamClientIntercepter(false)), + grpc.WithChainUnaryInterceptor(mpahooks.UnaryClientIntercepter(false)), ) if err != nil { t.Error(err) @@ -532,10 +532,10 @@ func TestProxiedClientInterceptors(t *testing.T) { Err: []io.Writer{os.Stderr}, } conn.StreamInterceptors = []proxy.StreamInterceptor{ - mpahooks.ProxyClientStreamInterceptor(state), + mpahooks.ProxyClientStreamInterceptor(state, false), } conn.UnaryInterceptors = []proxy.UnaryInterceptor{ - mpahooks.ProxyClientUnaryInterceptor(state), + mpahooks.ProxyClientUnaryInterceptor(state, false), } if _, err := hc.Ok(ctx, &emptypb.Empty{}); err != nil { t.Error(err)