Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions auth/aws/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ var registryRegex = regexp.MustCompile(registryPattern)
// ParseArtifactRepository implements auth.Provider.
// ParseArtifactRepository returns the ECR region, unless the registry
// is public.ecr.aws, in which case it returns public.ecr.aws.
// When skip validation is enabled and the registry doesn't match,
// it returns the provider name.
func (Provider) ParseArtifactRepository(artifactRepository string) (string, error) {
registry, err := auth.GetRegistryFromArtifactRepository(artifactRepository)
if err != nil {
Expand All @@ -206,6 +208,10 @@ func (Provider) ParseArtifactRepository(artifactRepository string) (string, erro

parts := registryRegex.FindAllStringSubmatch(registry, -1)
if len(parts) < 1 || len(parts[0]) < 3 {
// Skip validation if configured (allows custom registry proxies)
if auth.GetOCISkipRegistryValidation() {
return ProviderName, nil
}
return "", fmt.Errorf("invalid AWS registry: '%s'. must match %s",
registry, registryPattern)
}
Expand All @@ -220,6 +226,14 @@ func getECRRegionFromRegistryInput(registryInput string) string {
// https://docs.aws.amazon.com/AmazonECR/latest/public/public-registry-auth.html#public-registry-auth-token
return "us-east-1"
}
if registryInput == ProviderName {
// When using a proxy with skip validation, fall back to AWS_REGION.
// If AWS_REGION is not set, use us-east-1 as a default.
if region := os.Getenv("AWS_REGION"); region != "" {
return region
}
return "us-east-1"
}
return registryInput
}

Expand Down
33 changes: 33 additions & 0 deletions auth/aws/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,39 @@ func TestProvider_ParseArtifactRepository(t *testing.T) {
}
}

func TestProvider_ParseArtifactRepository_SkipValidation(t *testing.T) {
g := NewWithT(t)

auth.SetOCISkipRegistryValidation(true)
t.Cleanup(func() { auth.SetOCISkipRegistryValidation(false) })

// Test that invalid registries are accepted when skip validation is enabled
for _, tt := range []struct {
name string
artifactRepository string
}{
{
name: "custom proxy",
artifactRepository: "oci-gateway.example.org/oci/charts/",
},
{
name: "non-AWS registry",
artifactRepository: "gcr.io/foo/bar:baz",
},
{
name: "private registry",
artifactRepository: "registry.internal.company.com/images",
},
} {
t.Run(tt.name, func(t *testing.T) {
region, err := aws.Provider{}.ParseArtifactRepository(tt.artifactRepository)

g.Expect(err).NotTo(HaveOccurred())
g.Expect(region).To(Equal("aws"))
})
}
}

func TestProvider_NewRESTConfig(t *testing.T) {
for _, tt := range []struct {
name string
Expand Down
9 changes: 9 additions & 0 deletions auth/azure/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,19 @@ func (Provider) ParseArtifactRepository(artifactRepository string) (string, erro
if strings.HasSuffix(registry, registrySuffix) {
return registry, nil
}
// Skip validation if configured (allows custom registry proxies)
if auth.GetOCISkipRegistryValidation() {
return registry, nil
}
return "", fmt.Errorf("invalid Azure registry: '%s'. must end with %s",
registry, registrySuffix)
}

// Skip validation if configured (allows custom registry proxies)
if auth.GetOCISkipRegistryValidation() {
return registry, nil
}

return "", fmt.Errorf("invalid Azure registry: '%s'. must match %s",
registry, registryPattern)
}
Expand Down
37 changes: 37 additions & 0 deletions auth/azure/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,43 @@ func TestProvider_ParseArtifactRegistry(t *testing.T) {
}
}

func TestProvider_ParseArtifactRegistry_SkipValidation(t *testing.T) {
g := NewWithT(t)

auth.SetOCISkipRegistryValidation(true)
t.Cleanup(func() { auth.SetOCISkipRegistryValidation(false) })

// Test that invalid registries are accepted when skip validation is enabled
for _, tt := range []struct {
name string
artifactRepository string
expectedRegistry string
}{
{
name: "custom proxy",
artifactRepository: "oci-gateway.example.org/oci/charts/",
expectedRegistry: "oci-gateway.example.org",
},
{
name: "non-Azure registry",
artifactRepository: "gcr.io/foo/bar:baz",
expectedRegistry: "gcr.io",
},
{
name: "private registry",
artifactRepository: "registry.internal.company.com/images",
expectedRegistry: "registry.internal.company.com",
},
} {
t.Run(tt.name, func(t *testing.T) {
registry, err := azure.Provider{}.ParseArtifactRepository(tt.artifactRepository)

g.Expect(err).NotTo(HaveOccurred())
g.Expect(registry).To(Equal(tt.expectedRegistry))
})
}
}

func TestProvider_GetAccessTokenOptionsForArtifactRepository(t *testing.T) {
for _, tt := range []struct {
name string
Expand Down
19 changes: 19 additions & 0 deletions auth/controller_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ const (
// service account name to be used when .spec.decryption.serviceAccountName is
// not specified in the object.
ControllerFlagDefaultDecryptionServiceAccount = "default-decryption-service-account"

// ControllerFlagOCISkipRegistryValidation defines the flag for skipping OCI registry
// domain validation for cloud provider authentication. This allows using custom
// registry proxies/gateways with workload identity authentication.
ControllerFlagOCISkipRegistryValidation = "oci-skip-registry-validation"
)

var (
Expand All @@ -48,6 +53,10 @@ var (
// defaultDecryptionServiceAccount stores the default decryption
// service account name.
defaultDecryptionServiceAccount string

// ociSkipRegistryValidation stores whether to skip OCI registry
// domain validation for cloud provider authentication.
ociSkipRegistryValidation bool
)

// ErrDefaultServiceAccountNotFound is returned when a default service account
Expand Down Expand Up @@ -84,6 +93,16 @@ func GetDefaultDecryptionServiceAccount() string {
return defaultDecryptionServiceAccount
}

// SetOCISkipRegistryValidation sets whether to skip OCI registry domain validation.
func SetOCISkipRegistryValidation(skip bool) {
ociSkipRegistryValidation = skip
}

// GetOCISkipRegistryValidation returns whether to skip OCI registry domain validation.
func GetOCISkipRegistryValidation() bool {
return ociSkipRegistryValidation
}

func getDefaultServiceAccount() string {
// Here we can detect a default service account by checking either the default
// service account or the default kubeconfig service account because these two
Expand Down
26 changes: 26 additions & 0 deletions auth/controller_flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,29 @@ func TestGetDefaultDecryptionServiceAccount(t *testing.T) {
g.Expect(auth.GetDefaultDecryptionServiceAccount()).To(Equal(""))
})
}

func TestSetOCISkipRegistryValidation(t *testing.T) {
g := NewWithT(t)

auth.SetOCISkipRegistryValidation(true)
t.Cleanup(func() { auth.SetOCISkipRegistryValidation(false) })

g.Expect(auth.GetOCISkipRegistryValidation()).To(BeTrue())
}

func TestGetOCISkipRegistryValidation(t *testing.T) {
t.Run("returns true when set", func(t *testing.T) {
g := NewWithT(t)

auth.SetOCISkipRegistryValidation(true)
t.Cleanup(func() { auth.SetOCISkipRegistryValidation(false) })

g.Expect(auth.GetOCISkipRegistryValidation()).To(BeTrue())
})

t.Run("returns false when not set", func(t *testing.T) {
g := NewWithT(t)

g.Expect(auth.GetOCISkipRegistryValidation()).To(BeFalse())
})
}
5 changes: 5 additions & 0 deletions auth/gcp/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ func (Provider) ParseArtifactRepository(artifactRepository string) (string, erro
return "", err
}

// Skip registry validation if configured (allows custom registry proxies)
if auth.GetOCISkipRegistryValidation() {
return ProviderName, nil
}

if !registryRegex.MatchString(registry) {
return "", fmt.Errorf("invalid GCP registry: '%s'. must match %s",
registry, registryPattern)
Expand Down
33 changes: 33 additions & 0 deletions auth/gcp/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,39 @@ func TestProvider_ParseArtifactRegistry(t *testing.T) {
}
}

func TestProvider_ParseArtifactRegistry_SkipValidation(t *testing.T) {
g := NewWithT(t)

auth.SetOCISkipRegistryValidation(true)
t.Cleanup(func() { auth.SetOCISkipRegistryValidation(false) })

// Test that invalid registries are accepted when skip validation is enabled
for _, tt := range []struct {
name string
artifactRepository string
}{
{
name: "custom proxy",
artifactRepository: "oci-gateway.example.org/oci/charts/",
},
{
name: "non-GCP registry",
artifactRepository: "012345678901.dkr.ecr.us-east-1.amazonaws.com",
},
{
name: "private registry",
artifactRepository: "registry.internal.company.com/images",
},
} {
t.Run(tt.name, func(t *testing.T) {
cacheKey, err := gcp.Provider{}.ParseArtifactRepository(tt.artifactRepository)

g.Expect(err).NotTo(HaveOccurred())
g.Expect(cacheKey).To(Equal("gcp"))
})
}
}

func TestProvider_NewRESTConfig(t *testing.T) {
for _, tt := range []struct {
name string
Expand Down