Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package azureoidc
package azuremanagedidentity

import (
"context"
Expand All @@ -14,17 +14,17 @@ import (
authplugins "github.com/winhowes/AuthTranslator/app/auth"
)

// azureOIDCParams configures the Azure OIDC plugin.
type azureOIDCParams struct {
// azureManagedIdentityParams configures the Azure Managed Identity plugin.
type azureManagedIdentityParams struct {
Resource string `json:"resource"`
ClientID string `json:"client_id"`
Header string `json:"header"`
Prefix string `json:"prefix"`
}

// AzureOIDC obtains an access token from the Azure Instance Metadata Service and
// attaches it to outgoing requests.
type AzureOIDC struct{}
// AzureManagedIdentity obtains an access token from the Azure Instance Metadata
// Service and attaches it to outgoing requests.
type AzureManagedIdentity struct{}

// MetadataHost is the base URL for the Azure metadata service. It can be
// overridden in tests.
Expand All @@ -43,14 +43,16 @@ type cachedToken struct {
exp time.Time
}

func (a *AzureOIDC) Name() string { return "azure_oidc" }
func (a *AzureManagedIdentity) Name() string { return "azure_managed_identity" }

func (a *AzureOIDC) RequiredParams() []string { return []string{"resource"} }
func (a *AzureManagedIdentity) RequiredParams() []string { return []string{"resource"} }

func (a *AzureOIDC) OptionalParams() []string { return []string{"client_id", "header", "prefix"} }
func (a *AzureManagedIdentity) OptionalParams() []string {
return []string{"client_id", "header", "prefix"}
}

func (a *AzureOIDC) ParseParams(m map[string]interface{}) (interface{}, error) {
p, err := authplugins.ParseParams[azureOIDCParams](m)
func (a *AzureManagedIdentity) ParseParams(m map[string]interface{}) (interface{}, error) {
p, err := authplugins.ParseParams[azureManagedIdentityParams](m)
if err != nil {
return nil, err
}
Expand All @@ -66,8 +68,8 @@ func (a *AzureOIDC) ParseParams(m map[string]interface{}) (interface{}, error) {
return p, nil
}

func (a *AzureOIDC) AddAuth(ctx context.Context, r *http.Request, params interface{}) error {
cfg, ok := params.(*azureOIDCParams)
func (a *AzureManagedIdentity) AddAuth(ctx context.Context, r *http.Request, params interface{}) error {
cfg, ok := params.(*azureManagedIdentityParams)
if !ok {
return fmt.Errorf("invalid config")
}
Expand Down Expand Up @@ -155,4 +157,4 @@ func setCachedToken(key, tok string, exp time.Time) {
tokenCache.Unlock()
}

func init() { authplugins.RegisterOutgoing(&AzureOIDC{}) }
func init() { authplugins.RegisterOutgoing(&AzureManagedIdentity{}) }
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package azureoidc
package azuremanagedidentity

import (
"context"
Expand All @@ -16,7 +16,7 @@ func resetCache() {
tokenCache.Unlock()
}

func TestAzureOIDCAddAuth(t *testing.T) {
func TestAzureManagedIdentityAddAuth(t *testing.T) {
resetCache()

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -38,7 +38,7 @@ func TestAzureOIDCAddAuth(t *testing.T) {
HTTPClient = ts.Client()
defer func() { HTTPClient = oldClient }()

p := AzureOIDC{}
p := AzureManagedIdentity{}
cfg, err := p.ParseParams(map[string]interface{}{"resource": "api://res"})
if err != nil {
t.Fatal(err)
Expand All @@ -53,7 +53,7 @@ func TestAzureOIDCAddAuth(t *testing.T) {
}
}

func TestAzureOIDCCustomHeaderAndPrefix(t *testing.T) {
func TestAzureManagedIdentityCustomHeaderAndPrefix(t *testing.T) {
resetCache()

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -69,7 +69,7 @@ func TestAzureOIDCCustomHeaderAndPrefix(t *testing.T) {
HTTPClient = ts.Client()
defer func() { HTTPClient = oldClient }()

p := AzureOIDC{}
p := AzureManagedIdentity{}
cfg, err := p.ParseParams(map[string]interface{}{
"resource": "api://res",
"header": "X-Auth",
Expand All @@ -89,7 +89,7 @@ func TestAzureOIDCCustomHeaderAndPrefix(t *testing.T) {
}
}

func TestAzureOIDCAddAuthFailure(t *testing.T) {
func TestAzureManagedIdentityAddAuthFailure(t *testing.T) {
resetCache()

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -106,7 +106,7 @@ func TestAzureOIDCAddAuthFailure(t *testing.T) {
HTTPClient = ts.Client()
defer func() { HTTPClient = oldClient }()

p := AzureOIDC{}
p := AzureManagedIdentity{}
cfg, err := p.ParseParams(map[string]interface{}{"resource": "api://fail"})
if err != nil {
t.Fatal(err)
Expand All @@ -121,7 +121,7 @@ func TestAzureOIDCAddAuthFailure(t *testing.T) {
}
}

func TestAzureOIDCCache(t *testing.T) {
func TestAzureManagedIdentityCache(t *testing.T) {
resetCache()

var hits int32
Expand All @@ -139,7 +139,7 @@ func TestAzureOIDCCache(t *testing.T) {
HTTPClient = ts.Client()
defer func() { HTTPClient = oldClient }()

p := AzureOIDC{}
p := AzureManagedIdentity{}
cfg, err := p.ParseParams(map[string]interface{}{"resource": "api://res"})
if err != nil {
t.Fatal(err)
Expand All @@ -159,26 +159,26 @@ func TestAzureOIDCCache(t *testing.T) {
}
}

func TestAzureOIDCParseParamsMissingResource(t *testing.T) {
func TestAzureManagedIdentityParseParamsMissingResource(t *testing.T) {
resetCache()

p := AzureOIDC{}
p := AzureManagedIdentity{}
if _, err := p.ParseParams(map[string]interface{}{}); err == nil {
t.Fatal("expected error")
}
}

func TestAzureOIDCAddAuthWrongParams(t *testing.T) {
func TestAzureManagedIdentityAddAuthWrongParams(t *testing.T) {
resetCache()

p := AzureOIDC{}
p := AzureManagedIdentity{}
r := &http.Request{Header: http.Header{}}
if err := p.AddAuth(context.Background(), r, 5); err == nil {
t.Fatal("expected error")
}
}

func TestAzureOIDCUsesExpiresOn(t *testing.T) {
func TestAzureManagedIdentityUsesExpiresOn(t *testing.T) {
resetCache()

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -195,7 +195,7 @@ func TestAzureOIDCUsesExpiresOn(t *testing.T) {
HTTPClient = ts.Client()
defer func() { HTTPClient = oldClient }()

p := AzureOIDC{}
p := AzureManagedIdentity{}
cfg, err := p.ParseParams(map[string]interface{}{"resource": "api://res"})
if err != nil {
t.Fatal(err)
Expand All @@ -216,8 +216,8 @@ func TestAzureOIDCUsesExpiresOn(t *testing.T) {
}
}

func TestAzureOIDCParamLists(t *testing.T) {
p := AzureOIDC{}
func TestAzureManagedIdentityParamLists(t *testing.T) {
p := AzureManagedIdentity{}
if got := p.RequiredParams(); len(got) != 1 || got[0] != "resource" {
t.Fatalf("unexpected required params: %v", got)
}
Expand All @@ -233,8 +233,8 @@ func TestAzureOIDCParamLists(t *testing.T) {
}
}

func TestAzureOIDCParseParamsInvalidType(t *testing.T) {
p := AzureOIDC{}
func TestAzureManagedIdentityParseParamsInvalidType(t *testing.T) {
p := AzureManagedIdentity{}
if _, err := p.ParseParams(map[string]interface{}{"resource": 5}); err == nil {
t.Fatal("expected parse error for invalid type")
}
Expand Down
2 changes: 1 addition & 1 deletion app/auth/plugins/plugins.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package plugins

import (
_ "github.com/winhowes/AuthTranslator/app/auth/plugins/azure_oidc"
_ "github.com/winhowes/AuthTranslator/app/auth/plugins/azure_managed_identity"
_ "github.com/winhowes/AuthTranslator/app/auth/plugins/basic"
_ "github.com/winhowes/AuthTranslator/app/auth/plugins/findreplace"
_ "github.com/winhowes/AuthTranslator/app/auth/plugins/gcp_token"
Expand Down
6 changes: 3 additions & 3 deletions docs/auth-plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ AuthTranslator’s behaviour is extended by **plugins** – small Go packages th
| Outbound | `basic` | Adds HTTP Basic credentials to the upstream request. |
| Outbound | `google_oidc` | Attaches a Google identity token from the metadata service. |
| Outbound | `gcp_token` | Uses a metadata service access token. |
| Outbound | `azure_oidc` | Retrieves an Azure access token from the Instance Metadata Service. |
| Outbound | `azure_managed_identity` | Retrieves an Azure access token from the Instance Metadata Service. |
| Outbound | `hmac_signature` | Computes an HMAC for the request. |
| Outbound | `jwt` | Adds a signed JWT to the request. |
| Outbound | `mtls` | Sends a client certificate and exposes the CN via header. |
Expand Down Expand Up @@ -84,11 +84,11 @@ outgoing_auth:
Replaces every occurrence of the secret referenced by `find_secret` with
the value from `replace_secret` across the URL, headers and body.

### Outbound `azure_oidc`
### Outbound `azure_managed_identity`

```yaml
outgoing_auth:
- type: azure_oidc
- type: azure_managed_identity
params:
resource: api://my-api-app-id
client_id: 00000000-0000-0000-0000-000000000000 # optional
Expand Down
Loading