From 91fc4a88316142a40b609de0092cb45376a5925e Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Mon, 3 Nov 2025 14:04:28 +0100 Subject: [PATCH 01/18] feat: Remove refresh_token grant type Signed-off-by: Jorge Turrado --- core/clients/key_flow.go | 74 ++++--------------- core/clients/key_flow_continuous_refresh.go | 2 +- .../key_flow_continuous_refresh_test.go | 41 +++------- core/clients/key_flow_test.go | 66 +---------------- 4 files changed, 27 insertions(+), 156 deletions(-) diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index 589774314..9a1b5d1e8 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -68,11 +68,10 @@ type KeyFlowConfig struct { // TokenResponseBody is the API response // when requesting a new token type TokenResponseBody struct { - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` - RefreshToken string `json:"refresh_token"` - Scope string `json:"scope"` - TokenType string `json:"token_type"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` } // ServiceAccountKeyResponse is the API response @@ -158,9 +157,9 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { return nil } -// SetToken can be used to set an access and refresh token manually in the client. +// SetToken can be used to set an access token manually in the client. // The other fields in the token field are determined by inspecting the token or setting default values. -func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { +func (c *KeyFlow) SetToken(accessToken string) error { // We can safely use ParseUnverified because we are not authenticating the user, // We are parsing the token just to get the expiration time claim parsedAccessToken, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) @@ -174,11 +173,10 @@ func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { c.tokenMutex.Lock() c.token = &TokenResponseBody{ - AccessToken: accessToken, - ExpiresIn: int(exp.Time.Unix()), - Scope: defaultScope, - RefreshToken: refreshToken, - TokenType: defaultTokenType, + AccessToken: accessToken, + ExpiresIn: int(exp.Time.Unix()), + Scope: defaultScope, + TokenType: defaultTokenType, } c.tokenMutex.Unlock() return nil @@ -198,7 +196,7 @@ func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) { return c.rt.RoundTrip(req) } -// GetAccessToken returns a short-lived access token and saves the access and refresh tokens in the token field +// GetAccessToken returns a short-lived access token and saves the access token in the token field func (c *KeyFlow) GetAccessToken() (string, error) { if c.rt == nil { return "", fmt.Errorf("nil http round tripper, please run Init()") @@ -219,7 +217,7 @@ func (c *KeyFlow) GetAccessToken() (string, error) { if !accessTokenExpired { return accessToken, nil } - if err = c.recreateAccessToken(); err != nil { + if err = c.createAccessToken(); err != nil { var oapiErr *oapierror.GenericOpenAPIError if ok := errors.As(err, &oapiErr); ok { reg := regexp.MustCompile("Key with kid .*? was not found") @@ -269,27 +267,6 @@ func (c *KeyFlow) validate() error { // Flow auth functions -// recreateAccessToken is used to create a new access token -// when the existing one isn't valid anymore -func (c *KeyFlow) recreateAccessToken() error { - var refreshToken string - - c.tokenMutex.RLock() - if c.token != nil { - refreshToken = c.token.RefreshToken - } - c.tokenMutex.RUnlock() - - refreshTokenExpired, err := tokenExpired(refreshToken, c.tokenExpirationLeeway) - if err != nil { - return err - } - if !refreshTokenExpired { - return c.createAccessTokenWithRefreshToken() - } - return c.createAccessToken() -} - // createAccessToken creates an access token using self signed JWT func (c *KeyFlow) createAccessToken() (err error) { grant := "urn:ietf:params:oauth:grant-type:jwt-bearer" @@ -310,26 +287,6 @@ func (c *KeyFlow) createAccessToken() (err error) { return c.parseTokenResponse(res) } -// createAccessTokenWithRefreshToken creates an access token using -// an existing pre-validated refresh token -func (c *KeyFlow) createAccessTokenWithRefreshToken() (err error) { - c.tokenMutex.RLock() - refreshToken := c.token.RefreshToken - c.tokenMutex.RUnlock() - - res, err := c.requestToken("refresh_token", refreshToken) - if err != nil { - return err - } - defer func() { - tempErr := res.Body.Close() - if tempErr != nil && err == nil { - err = fmt.Errorf("close request access token with refresh token response: %w", tempErr) - } - }() - return c.parseTokenResponse(res) -} - // generateSelfSignedJWT generates JWT token func (c *KeyFlow) generateSelfSignedJWT() (string, error) { claims := jwt.MapClaims{ @@ -353,11 +310,8 @@ func (c *KeyFlow) generateSelfSignedJWT() (string, error) { func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error) { body := url.Values{} body.Set("grant_type", grant) - if grant == "refresh_token" { - body.Set("refresh_token", assertion) - } else { - body.Set("assertion", assertion) - } + body.Set("assertion", assertion) + payload := strings.NewReader(body.Encode()) req, err := http.NewRequest(http.MethodPost, c.config.TokenUrl, payload) if err != nil { diff --git a/core/clients/key_flow_continuous_refresh.go b/core/clients/key_flow_continuous_refresh.go index f5129aa02..4b971c203 100644 --- a/core/clients/key_flow_continuous_refresh.go +++ b/core/clients/key_flow_continuous_refresh.go @@ -125,7 +125,7 @@ func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Tim // - (false, nil) if not successful but should be retried. // - (_, err) if not successful and shouldn't be retried. func (refresher *continuousTokenRefresher) refreshToken() (bool, error) { - err := refresher.keyFlow.recreateAccessToken() + err := refresher.keyFlow.createAccessToken() if err == nil { return true, nil } diff --git a/core/clients/key_flow_continuous_refresh_test.go b/core/clients/key_flow_continuous_refresh_test.go index 7c7ee9565..983a34f37 100644 --- a/core/clients/key_flow_continuous_refresh_test.go +++ b/core/clients/key_flow_continuous_refresh_test.go @@ -95,15 +95,8 @@ func TestContinuousRefreshToken(t *testing.T) { t.Fatalf("failed to create access token: %v", err) } - refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("test")) - if err != nil { - t.Fatalf("failed to create refresh token: %v", err) - } - numberDoCalls := 0 - mockDo := func(_ *http.Request) (resp *http.Response, err error) { + mockDo := func(r *http.Request) (resp *http.Response, err error) { numberDoCalls++ // count refresh attempts if tt.doError != nil { return nil, tt.doError @@ -115,8 +108,7 @@ func TestContinuousRefreshToken(t *testing.T) { t.Fatalf("Do call: failed to create access token: %v", err) } responseBodyStruct := TokenResponseBody{ - AccessToken: newAccessToken, - RefreshToken: refreshToken, + AccessToken: newAccessToken, } responseBody, err := json.Marshal(responseBodyStruct) if err != nil { @@ -153,7 +145,7 @@ func TestContinuousRefreshToken(t *testing.T) { } // Set the token after initialization - err = keyFlow.SetToken(accessToken, refreshToken) + err = keyFlow.SetToken(accessToken) if err != nil { t.Fatalf("failed to set token: %v", err) } @@ -186,7 +178,7 @@ func TestContinuousRefreshToken(t *testing.T) { } // Tests if -// - continuousRefreshToken() updates access token using the refresh token +// - continuousRefreshToken() updates access token // - The access token can be accessed while continuousRefreshToken() is trying to update it func TestContinuousRefreshTokenConcurrency(t *testing.T) { // The times here are in the order of miliseconds (so they run faster) @@ -234,14 +226,6 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { t.Fatalf("created tokens are equal") } - // The refresh token used to update the access token - refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("test")) - if err != nil { - t.Fatalf("failed to create refresh token: %v", err) - } - ctx := context.Background() ctx, cancel := context.WithCancel(ctx) defer cancel() // This cancels the refresher goroutine @@ -271,8 +255,7 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { t.Fatalf("Do call: failed to create additional access token: %v", err) } responseBodyStruct := TokenResponseBody{ - AccessToken: newAccessToken, - RefreshToken: refreshToken, + AccessToken: newAccessToken, } responseBody, err := json.Marshal(responseBodyStruct) if err != nil { @@ -308,18 +291,12 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { t.Fatalf("Do call: failed to parse body form: %v", err) } reqGrantType := req.Form.Get("grant_type") - if reqGrantType != "refresh_token" { - t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "refresh_token", reqGrantType) + if reqGrantType != "urn:ietf:params:oauth:grant-type:jwt-bearer" { + t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "urn:ietf:params:oauth:grant-type:jwt-bearer", reqGrantType) } - reqRefreshToken := req.Form.Get("refresh_token") - if reqRefreshToken != refreshToken { - t.Fatalf("Do call: failed request to refresh token: call to refresh access token did not have the expected refresh token set") - } - // Return response with accessTokenSecond responseBodyStruct := TokenResponseBody{ - AccessToken: accessTokenSecond, - RefreshToken: refreshToken, + AccessToken: accessTokenSecond, } responseBody, err := json.Marshal(responseBodyStruct) if err != nil { @@ -409,7 +386,7 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { } // Set the token after initialization - err = keyFlow.SetToken(accessTokenFirst, refreshToken) + err = keyFlow.SetToken(accessTokenFirst) if err != nil { t.Fatalf("failed to set token: %v", err) } diff --git a/core/clients/key_flow_test.go b/core/clients/key_flow_test.go index 9803f24ee..a64bee881 100644 --- a/core/clients/key_flow_test.go +++ b/core/clients/key_flow_test.go @@ -130,65 +130,6 @@ func TestKeyFlowInit(t *testing.T) { } } -func TestSetToken(t *testing.T) { - tests := []struct { - name string - tokenInvalid bool - refreshToken string - wantErr bool - }{ - { - name: "ok", - tokenInvalid: false, - refreshToken: "refresh_token", - wantErr: false, - }, - { - name: "invalid_token", - tokenInvalid: true, - refreshToken: "", - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var accessToken string - var err error - - timestamp := time.Now().Add(24 * time.Hour) - if tt.tokenInvalid { - accessToken = "foo" - } else { - accessTokenJWT := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(timestamp)}) - accessToken, err = accessTokenJWT.SignedString(testSigningKey) - if err != nil { - t.Fatalf("get test access token as string: %s", err) - } - } - - keyFlow := &KeyFlow{} - err = keyFlow.SetToken(accessToken, tt.refreshToken) - - if (err != nil) != tt.wantErr { - t.Errorf("KeyFlow.SetToken() error = %v, wantErr %v", err, tt.wantErr) - } - if err == nil { - expectedKeyFlowToken := &TokenResponseBody{ - AccessToken: accessToken, - ExpiresIn: int(timestamp.Unix()), - RefreshToken: tt.refreshToken, - Scope: defaultScope, - TokenType: defaultTokenType, - } - if !cmp.Equal(expectedKeyFlowToken, keyFlow.token) { - t.Errorf("The returned result is wrong. Expected %+v, got %+v", expectedKeyFlowToken, keyFlow.token) - } - } - }) - } -} - func TestTokenExpired(t *testing.T) { tokenExpirationLeeway := 5 * time.Second tests := []struct { @@ -442,10 +383,9 @@ func TestKeyFlow_Do(t *testing.T) { res.Header().Set("Content-Type", "application/json") token := &TokenResponseBody{ - AccessToken: testBearerToken, - ExpiresIn: 2147483647, - RefreshToken: testBearerToken, - TokenType: "Bearer", + AccessToken: testBearerToken, + ExpiresIn: 2147483647, + TokenType: "Bearer", } if err := json.NewEncoder(res.Body).Encode(token); err != nil { From 3a0b673787b0e035a1746d17b3a260506988a9c9 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Thu, 6 Nov 2025 16:26:06 +0100 Subject: [PATCH 02/18] Update changelogs Signed-off-by: Jorge Turrado --- CHANGELOG.md | 3 +++ core/CHANGELOG.md | 4 ++++ core/VERSION | 2 +- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e738723f..e0ddcc5e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ - `certificates`: [v1.2.0](services/certificates/CHANGELOG.md#v120) - **Feature:** Switch from `v2beta` API version to `v2` version. - **Breaking change:** Rename `CreateCertificateResponse` to `GetCertificateResponse` +- `core`: + - [v0.21.0](core/CHANGELOG.md#v0210) + - **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` - `sfs`: - [v0.2.0](services/sfs/CHANGELOG.md) - **Breaking change:** Remove region configuration in `APIClient` diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index 8b1d2fb86..47d06e806 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -1,3 +1,6 @@ +## v0.21.0 +- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` + ## v0.20.1 - **Improvement:** Improve error message when passing a PEM encoded file to as service account key @@ -9,6 +12,7 @@ ## v0.18.0 - **New:** Added duration utils +- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` ## v0.17.3 - **Dependencies:** Bump `github.com/golang-jwt/jwt/v5` from `v5.2.2` to `v5.2.3` diff --git a/core/VERSION b/core/VERSION index 2c80271d5..fcc9d59a4 100644 --- a/core/VERSION +++ b/core/VERSION @@ -1 +1 @@ -v0.20.1 +v0.21.0 \ No newline at end of file From 60a09f799b1362b245b5ab408b03f8b468e74765 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Thu, 13 Nov 2025 10:19:42 +0100 Subject: [PATCH 03/18] update exp time of assertion as access token CAN'T be longer that it Signed-off-by: Jorge Turrado --- core/CHANGELOG.md | 1 - core/clients/key_flow.go | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index 47d06e806..c2719e863 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -12,7 +12,6 @@ ## v0.18.0 - **New:** Added duration utils -- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` ## v0.17.3 - **Dependencies:** Bump `github.com/golang-jwt/jwt/v5` from `v5.2.2` to `v5.2.3` diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index 9a1b5d1e8..cedf5e937 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -295,7 +295,7 @@ func (c *KeyFlow) generateSelfSignedJWT() (string, error) { "jti": uuid.New(), "aud": c.key.Credentials.Aud, "iat": jwt.NewNumericDate(time.Now()), - "exp": jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), + "exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), } token := jwt.NewWithClaims(jwt.SigningMethodRS512, claims) token.Header["kid"] = c.key.Credentials.Kid From f8da09dbabf4bb2cd8aba825f81ef491575b61ac Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Thu, 6 Nov 2025 16:26:06 +0100 Subject: [PATCH 04/18] Update changelogs Signed-off-by: Jorge Turrado --- core/CHANGELOG.md | 1 + core/VERSION | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index c2719e863..47d06e806 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -12,6 +12,7 @@ ## v0.18.0 - **New:** Added duration utils +- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` ## v0.17.3 - **Dependencies:** Bump `github.com/golang-jwt/jwt/v5` from `v5.2.2` to `v5.2.3` diff --git a/core/VERSION b/core/VERSION index fcc9d59a4..759e855fb 100644 --- a/core/VERSION +++ b/core/VERSION @@ -1 +1 @@ -v0.21.0 \ No newline at end of file +v0.21.0 From 85275f2c4e41d41f7e4f8c533ad0d4a9e2d6cd4e Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Wed, 10 Dec 2025 11:20:03 +0100 Subject: [PATCH 05/18] feat: Support Workload Identity Federation flow Signed-off-by: Jorge Turrado --- core/auth/auth.go | 53 +- core/auth/auth_test.go | 101 +++- core/clients/auth_flow.go | 84 +++ core/clients/key_flow.go | 126 +--- core/clients/key_flow_continuous_refresh.go | 39 +- .../key_flow_continuous_refresh_test.go | 414 ++----------- core/clients/key_flow_test.go | 17 +- core/clients/workload_identity_flow.go | 249 ++++++++ core/clients/workload_identity_flow_test.go | 566 ++++++++++++++++++ core/config/config.go | 69 ++- examples/authentication/authentication.go | 59 +- 11 files changed, 1230 insertions(+), 547 deletions(-) create mode 100644 core/clients/auth_flow.go create mode 100644 core/clients/workload_identity_flow.go create mode 100644 core/clients/workload_identity_flow_test.go diff --git a/core/auth/auth.go b/core/auth/auth.go index 568847aea..88f002fe7 100644 --- a/core/auth/auth.go +++ b/core/auth/auth.go @@ -45,6 +45,12 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) { if cfg.CustomAuth != nil { return cfg.CustomAuth, nil + } else if useWorkloadIdentityFederation(cfg) { + wifRoundTripper, err := WorkloadIdentityFederationAuth(cfg) + if err != nil { + return nil, fmt.Errorf("configuring no auth client: %w", err) + } + return wifRoundTripper, nil } else if cfg.NoAuth { noAuthRoundTripper, err := NoAuth(cfg) if err != nil { @@ -84,14 +90,18 @@ func DefaultAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) { cfg = &config.Configuration{} } - // Key flow - rt, err = KeyAuth(cfg) + // WIF flow + rt, err = WorkloadIdentityFederationAuth(cfg) if err != nil { - keyFlowErr := err - // Token flow - rt, err = TokenAuth(cfg) + // Key flow + rt, err = KeyAuth(cfg) if err != nil { - return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err) + keyFlowErr := err + // Token flow + rt, err = TokenAuth(cfg) + if err != nil { + return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err) + } } } return rt, nil @@ -221,6 +231,29 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) { return client, nil } +// WorkloadIdentityFederationAuth configures the wif flow and returns an http.RoundTripper +// that can be used to make authenticated requests using an access token +func WorkloadIdentityFederationAuth(cfg *config.Configuration) (http.RoundTripper, error) { + wifConfig := clients.WorkloadIdentityFederationFlowConfig{ + TokenUrl: cfg.TokenCustomUrl, + BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext, + ClientID: cfg.ServiceAccountEmail, + FederatedTokenFilePath: cfg.WorkloadIdentityFederationFederatedTokenPath, + TokenExpiration: cfg.WorkloadIdentityFederationTokenExpiration, + } + + if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil { + wifConfig.HTTPTransport = cfg.HTTPClient.Transport + } + + client := &clients.WorkloadIdentityFederationFlow{} + if err := client.Init(&wifConfig); err != nil { + return nil, fmt.Errorf("error initializing client: %w", err) + } + + return client, nil +} + // readCredentialsFile reads the credentials file from the specified path and returns Credentials func readCredentialsFile(path string) (*Credentials, error) { if path == "" { @@ -361,3 +394,11 @@ func getServiceAccountKey(cfg *config.Configuration) error { func getPrivateKey(cfg *config.Configuration) error { return getKey(&cfg.PrivateKey, &cfg.PrivateKeyPath, "STACKIT_PRIVATE_KEY_PATH", "STACKIT_PRIVATE_KEY", privateKeyPathCredentialType, privateKeyCredentialType, cfg.CredentialsFilePath) } + +func useWorkloadIdentityFederation(cfg *config.Configuration) bool { + if cfg != nil && cfg.WorkloadIdentityFederation { + return true + } + val, exists := os.LookupEnv(clients.FederatedTokenFileEnv) + return exists && val != "" +} diff --git a/core/auth/auth_test.go b/core/auth/auth_test.go index a7c776946..5e8af7203 100644 --- a/core/auth/auth_test.go +++ b/core/auth/auth_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/stackitcloud/stackit-sdk-go/core/clients" "github.com/stackitcloud/stackit-sdk-go/core/config" @@ -121,6 +122,32 @@ func TestSetupAuth(t *testing.T) { } }() + // create a wif assertion file + wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt") + if errs != nil { + t.Fatalf("Creating temporary file: %s", err) + } + defer func() { + _ = wifAssertionFile.Close() + err := os.Remove(wifAssertionFile.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + Subject: "sub", + }).SignedString([]byte("test")) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + + _, errs = wifAssertionFile.WriteString(string(token)) + if errs != nil { + t.Fatalf("Writing wif assertion to temporary file: %s", err) + } + // create a credentials file with saKey and private key credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt") if errs != nil { @@ -147,25 +174,28 @@ func TestSetupAuth(t *testing.T) { desc string config *config.Configuration setToken bool + setWorkloadIdentity bool setKeys bool setKeyPaths bool setCredentialsFilePathToken bool setCredentialsFilePathKey bool - isValid bool }{ + { + desc: "wif_config", + config: nil, + setWorkloadIdentity: true, + }, { desc: "token_config", config: nil, setToken: true, setCredentialsFilePathToken: false, - isValid: true, }, { desc: "key_config", config: nil, setKeys: true, setCredentialsFilePathToken: false, - isValid: true, }, { desc: "key_config_path", @@ -173,7 +203,6 @@ func TestSetupAuth(t *testing.T) { setKeys: false, setKeyPaths: true, setCredentialsFilePathToken: false, - isValid: true, }, { desc: "key_config_credentials_path", @@ -181,14 +210,12 @@ func TestSetupAuth(t *testing.T) { setKeys: false, setKeyPaths: false, setCredentialsFilePathKey: true, - isValid: true, }, { desc: "valid_path_to_file", config: nil, setToken: false, setCredentialsFilePathToken: true, - isValid: true, }, { desc: "custom_config_token", @@ -197,7 +224,6 @@ func TestSetupAuth(t *testing.T) { }, setToken: false, setCredentialsFilePathToken: false, - isValid: true, }, { desc: "custom_config_path", @@ -206,7 +232,6 @@ func TestSetupAuth(t *testing.T) { }, setToken: false, setCredentialsFilePathToken: false, - isValid: true, }, } { t.Run(test.desc, func(t *testing.T) { @@ -241,19 +266,21 @@ func TestSetupAuth(t *testing.T) { t.Setenv("STACKIT_CREDENTIALS_PATH", "") } + if test.setWorkloadIdentity { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name()) + } else { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "") + } + t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email") authRoundTripper, err := SetupAuth(test.config) - if err != nil && test.isValid { + if err != nil { t.Fatalf("Test returned error on valid test case: %v", err) } - if err == nil && !test.isValid { - t.Fatalf("Test didn't return error on invalid test case") - } - - if test.isValid && authRoundTripper == nil { + if authRoundTripper == nil { t.Fatalf("Roundtripper returned is nil for valid test case") } }) @@ -381,6 +408,32 @@ func TestDefaultAuth(t *testing.T) { t.Fatalf("Writing private key to temporary file: %s", err) } + // create a wif assertion file + wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt") + if errs != nil { + t.Fatalf("Creating temporary file: %s", err) + } + defer func() { + _ = wifAssertionFile.Close() + err := os.Remove(wifAssertionFile.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + Subject: "sub", + }).SignedString([]byte("test")) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + + _, errs = wifAssertionFile.WriteString(string(token)) + if errs != nil { + t.Fatalf("Writing wif assertion to temporary file: %s", err) + } + // create a credentials file with saKey and private key credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt") if errs != nil { @@ -409,6 +462,7 @@ func TestDefaultAuth(t *testing.T) { setKeyPaths bool setKeys bool setCredentialsFilePathKey bool + setWorkloadIdentity bool isValid bool expectedFlow string }{ @@ -418,6 +472,14 @@ func TestDefaultAuth(t *testing.T) { isValid: true, expectedFlow: "token", }, + { + desc: "wif_precedes_key_precedes_token", + setToken: true, + setKeyPaths: true, + setWorkloadIdentity: true, + isValid: true, + expectedFlow: "wif", + }, { desc: "key_precedes_token", setToken: true, @@ -475,6 +537,13 @@ func TestDefaultAuth(t *testing.T) { } else { t.Setenv("STACKIT_SERVICE_ACCOUNT_TOKEN", "") } + + if test.setWorkloadIdentity { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name()) + } else { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "") + } + t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email") // Get the default authentication client and ensure that it's not nil @@ -501,6 +570,10 @@ func TestDefaultAuth(t *testing.T) { if _, ok := authClient.(*clients.KeyFlow); !ok { t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient)) } + case "wif": + if _, ok := authClient.(*clients.WorkloadIdentityFederationFlow); !ok { + t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient)) + } } } }) diff --git a/core/clients/auth_flow.go b/core/clients/auth_flow.go new file mode 100644 index 000000000..141d75489 --- /dev/null +++ b/core/clients/auth_flow.go @@ -0,0 +1,84 @@ +package clients + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stackitcloud/stackit-sdk-go/core/oapierror" +) + +const ( + defaultTokenExpirationLeeway = time.Second * 5 +) + +type AuthFlow interface { + RoundTrip(req *http.Request) (*http.Response, error) + GetAccessToken() (string, error) + GetBackgroundTokenRefreshContext() context.Context +} + +// TokenResponseBody is the API response +// when requesting a new token +type TokenResponseBody struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` +} + +func parseTokenResponse(res *http.Response) (*TokenResponseBody, error) { + if res == nil { + return nil, fmt.Errorf("received bad response from API") + } + if res.StatusCode != http.StatusOK { + body, err := io.ReadAll(res.Body) + if err != nil { + // Fail silently, omit body from error + // We're trying to show error details, so it's unnecessary to fail because of this err + body = []byte{} + } + return nil, &oapierror.GenericOpenAPIError{ + StatusCode: res.StatusCode, + Body: body, + } + } + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + + token := &TokenResponseBody{} + err = json.Unmarshal(body, token) + if err != nil { + return nil, fmt.Errorf("unmarshal token response: %w", err) + } + return token, nil +} + +func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) { + if token == "" { + return true, nil + } + + // We can safely use ParseUnverified because we are not authenticating the user at this point. + // We're just checking the expiration time + tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) + if err != nil { + return false, fmt.Errorf("parse token: %w", err) + } + + expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime() + if err != nil { + return false, fmt.Errorf("get expiration timestamp: %w", err) + } + + // Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring + // between retrieving the token and upstream systems validating it. + now := time.Now().Add(tokenExpirationLeeway) + return now.After(expirationTimestampNumeric.Time), nil +} diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index cedf5e937..83c82e778 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -4,11 +4,9 @@ import ( "context" "crypto/rsa" "crypto/x509" - "encoding/json" "encoding/pem" "errors" "fmt" - "io" "net/http" "net/url" "regexp" @@ -30,12 +28,10 @@ const ( ServiceAccountKeyPath = "STACKIT_SERVICE_ACCOUNT_KEY_PATH" PrivateKeyPath = "STACKIT_PRIVATE_KEY_PATH" tokenAPI = "https://service-account.api.stackit.cloud/token" //nolint:gosec // linter false positive - defaultTokenType = "Bearer" - defaultScope = "" - - defaultTokenExpirationLeeway = time.Second * 5 ) +var _ AuthFlow = &KeyFlow{} + // KeyFlow handles auth with SA key type KeyFlow struct { rt http.RoundTripper @@ -65,15 +61,6 @@ type KeyFlowConfig struct { AuthHTTPClient *http.Client } -// TokenResponseBody is the API response -// when requesting a new token -type TokenResponseBody struct { - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` - Scope string `json:"scope"` - TokenType string `json:"token_type"` -} - // ServiceAccountKeyResponse is the API response // when creating a new SA key type ServiceAccountKeyResponse struct { @@ -112,19 +99,6 @@ func (c *KeyFlow) GetServiceAccountEmail() string { return c.key.Credentials.Iss } -// GetToken returns the token field -func (c *KeyFlow) GetToken() TokenResponseBody { - c.tokenMutex.RLock() - defer c.tokenMutex.RUnlock() - - if c.token == nil { - return TokenResponseBody{} - } - // Returned struct is passed by value (because it's a struct) - // So no deepy copy needed - return *c.token -} - func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { // No concurrency at this point, so no mutex check needed c.token = &TokenResponseBody{} @@ -157,31 +131,6 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { return nil } -// SetToken can be used to set an access token manually in the client. -// The other fields in the token field are determined by inspecting the token or setting default values. -func (c *KeyFlow) SetToken(accessToken string) error { - // We can safely use ParseUnverified because we are not authenticating the user, - // We are parsing the token just to get the expiration time claim - parsedAccessToken, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) - if err != nil { - return fmt.Errorf("parse access token to read expiration time: %w", err) - } - exp, err := parsedAccessToken.Claims.GetExpirationTime() - if err != nil { - return fmt.Errorf("get expiration time from access token: %w", err) - } - - c.tokenMutex.Lock() - c.token = &TokenResponseBody{ - AccessToken: accessToken, - ExpiresIn: int(exp.Time.Unix()), - Scope: defaultScope, - TokenType: defaultTokenType, - } - c.tokenMutex.Unlock() - return nil -} - // Roundtrip performs the request func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) { if c.rt == nil { @@ -201,7 +150,6 @@ func (c *KeyFlow) GetAccessToken() (string, error) { if c.rt == nil { return "", fmt.Errorf("nil http round tripper, please run Init()") } - var accessToken string c.tokenMutex.RLock() @@ -235,6 +183,10 @@ func (c *KeyFlow) GetAccessToken() (string, error) { return accessToken, nil } +func (c *KeyFlow) GetBackgroundTokenRefreshContext() context.Context { + return c.config.BackgroundTokenRefreshContext +} + // validate the client is configured well func (c *KeyFlow) validate() error { if c.config.ServiceAccountKey == nil { @@ -284,7 +236,14 @@ func (c *KeyFlow) createAccessToken() (err error) { err = fmt.Errorf("close request access token response: %w", tempErr) } }() - return c.parseTokenResponse(res) + token, err := parseTokenResponse(res) + if err != nil { + return err + } + c.tokenMutex.Lock() + c.token = token + c.tokenMutex.Unlock() + return nil } // generateSelfSignedJWT generates JWT token @@ -321,60 +280,3 @@ func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error) return c.authClient.Do(req) } - -// parseTokenResponse parses the response from the server -func (c *KeyFlow) parseTokenResponse(res *http.Response) error { - if res == nil { - return fmt.Errorf("received bad response from API") - } - if res.StatusCode != http.StatusOK { - body, err := io.ReadAll(res.Body) - if err != nil { - // Fail silently, omit body from error - // We're trying to show error details, so it's unnecessary to fail because of this err - body = []byte{} - } - return &oapierror.GenericOpenAPIError{ - StatusCode: res.StatusCode, - Body: body, - } - } - body, err := io.ReadAll(res.Body) - if err != nil { - return err - } - - c.tokenMutex.Lock() - c.token = &TokenResponseBody{} - err = json.Unmarshal(body, c.token) - c.tokenMutex.Unlock() - if err != nil { - return fmt.Errorf("unmarshal token response: %w", err) - } - - return nil -} - -func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) { - if token == "" { - return true, nil - } - - // We can safely use ParseUnverified because we are not authenticating the user at this point. - // We're just checking the expiration time - tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) - if err != nil { - return false, fmt.Errorf("parse token: %w", err) - } - - expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime() - if err != nil { - return false, fmt.Errorf("get expiration timestamp: %w", err) - } - - // Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring - // between retrieving the token and upstream systems validating it. - now := time.Now().Add(tokenExpirationLeeway) - - return now.After(expirationTimestampNumeric.Time), nil -} diff --git a/core/clients/key_flow_continuous_refresh.go b/core/clients/key_flow_continuous_refresh.go index 4b971c203..702b3695c 100644 --- a/core/clients/key_flow_continuous_refresh.go +++ b/core/clients/key_flow_continuous_refresh.go @@ -20,9 +20,9 @@ var ( // Continuously refreshes the token of a key flow, retrying if the token API returns 5xx errrors. Writes to stderr when it terminates. // // To terminate this routine, close the context in keyFlow.config.BackgroundTokenRefreshContext. -func continuousRefreshToken(keyflow *KeyFlow) { +func continuousRefreshToken(flow AuthFlow) { refresher := &continuousTokenRefresher{ - keyFlow: keyflow, + flow: flow, timeStartBeforeTokenExpiration: defaultTimeStartBeforeTokenExpiration, timeBetweenContextCheck: defaultTimeBetweenContextCheck, timeBetweenTries: defaultTimeBetweenTries, @@ -32,7 +32,7 @@ func continuousRefreshToken(keyflow *KeyFlow) { } type continuousTokenRefresher struct { - keyFlow *KeyFlow + flow AuthFlow // Token refresh tries start at [Access token expiration timestamp] - [This duration] timeStartBeforeTokenExpiration time.Duration timeBetweenContextCheck time.Duration @@ -46,22 +46,12 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error { // Compute timestamp where we'll refresh token // Access token may be empty at this point, we have to check it var startRefreshTimestamp time.Time - var accessToken string - refresher.keyFlow.tokenMutex.RLock() - if refresher.keyFlow.token != nil { - accessToken = refresher.keyFlow.token.AccessToken - } - refresher.keyFlow.tokenMutex.RUnlock() - if accessToken == "" { - startRefreshTimestamp = time.Now() - } else { - expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp() - if err != nil { - return fmt.Errorf("get access token expiration timestamp: %w", err) - } - startRefreshTimestamp = expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration) + expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp() + if err != nil { + return fmt.Errorf("get access token expiration timestamp: %w", err) } + startRefreshTimestamp = expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration) for { err := refresher.waitUntilTimestamp(startRefreshTimestamp) @@ -69,7 +59,7 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error { return err } - err = refresher.keyFlow.config.BackgroundTokenRefreshContext.Err() + err = refresher.flow.GetBackgroundTokenRefreshContext().Err() if err != nil { return fmt.Errorf("check context: %w", err) } @@ -92,13 +82,14 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error { } func (refresher *continuousTokenRefresher) getAccessTokenExpirationTimestamp() (*time.Time, error) { - refresher.keyFlow.tokenMutex.RLock() - token := refresher.keyFlow.token.AccessToken - refresher.keyFlow.tokenMutex.RUnlock() + accessToken, err := refresher.flow.GetAccessToken() + if err != nil { + return nil, err + } // We can safely use ParseUnverified because we are not doing authentication of any kind // We're just checking the expiration time - tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) + tokenParsed, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) if err != nil { return nil, fmt.Errorf("parse token: %w", err) } @@ -111,7 +102,7 @@ func (refresher *continuousTokenRefresher) getAccessTokenExpirationTimestamp() ( func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Time) error { for time.Now().Before(timestamp) { - err := refresher.keyFlow.config.BackgroundTokenRefreshContext.Err() + err := refresher.flow.GetBackgroundTokenRefreshContext().Err() if err != nil { return fmt.Errorf("check context: %w", err) } @@ -125,7 +116,7 @@ func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Tim // - (false, nil) if not successful but should be retried. // - (_, err) if not successful and shouldn't be retried. func (refresher *continuousTokenRefresher) refreshToken() (bool, error) { - err := refresher.keyFlow.createAccessToken() + _, err := refresher.flow.GetAccessToken() if err == nil { return true, nil } diff --git a/core/clients/key_flow_continuous_refresh_test.go b/core/clients/key_flow_continuous_refresh_test.go index 983a34f37..cfd50e763 100644 --- a/core/clients/key_flow_continuous_refresh_test.go +++ b/core/clients/key_flow_continuous_refresh_test.go @@ -1,18 +1,13 @@ package clients import ( - "bytes" "context" - "encoding/json" "fmt" - "io" "net/http" - "net/url" "testing" "time" "github.com/golang-jwt/jwt/v5" - "github.com/stackitcloud/stackit-sdk-go/core/oapierror" ) @@ -22,9 +17,9 @@ func TestContinuousRefreshToken(t *testing.T) { jwt.TimePrecision = time.Millisecond // Refresher settings - timeStartBeforeTokenExpiration := 500 * time.Millisecond - timeBetweenContextCheck := 10 * time.Millisecond - timeBetweenTries := 100 * time.Millisecond + timeStartBeforeTokenExpiration := 0 * time.Second + timeBetweenContextCheck := 50 * time.Millisecond + timeBetweenTries := 500 * time.Millisecond // All generated acess tokens will have this time to live accessTokensTimeToLive := 1 * time.Second @@ -34,16 +29,20 @@ func TestContinuousRefreshToken(t *testing.T) { contextClosesIn time.Duration doError error expectedNumberDoCalls int - expectedCallRange []int // Optional: for tests that can have variable call counts }{ + { + desc: "update access token never", + contextClosesIn: 900 * time.Millisecond, // Should allow no refresh + expectedNumberDoCalls: 0, + }, { desc: "update access token once", - contextClosesIn: 700 * time.Millisecond, // Should allow one refresh + contextClosesIn: 1900 * time.Millisecond, // Should allow one refresh expectedNumberDoCalls: 1, }, { desc: "update access token twice", - contextClosesIn: 1300 * time.Millisecond, // Should allow two refreshes + contextClosesIn: 2900 * time.Millisecond, // Should allow two refreshes expectedNumberDoCalls: 2, }, { @@ -62,14 +61,14 @@ func TestContinuousRefreshToken(t *testing.T) { expectedNumberDoCalls: 0, }, { - desc: "refresh token fails - non-API error", - contextClosesIn: 700 * time.Millisecond, + desc: "refresh token fails - error", + contextClosesIn: 1900 * time.Millisecond, doError: fmt.Errorf("something went wrong"), expectedNumberDoCalls: 1, }, { desc: "refresh token fails - API non-5xx error", - contextClosesIn: 700 * time.Millisecond, + contextClosesIn: 1900 * time.Millisecond, doError: &oapierror.GenericOpenAPIError{ StatusCode: http.StatusBadRequest, }, @@ -77,84 +76,35 @@ func TestContinuousRefreshToken(t *testing.T) { }, { desc: "refresh token fails - API 5xx error", - contextClosesIn: 800 * time.Millisecond, + contextClosesIn: 2900 * time.Millisecond, doError: &oapierror.GenericOpenAPIError{ StatusCode: http.StatusInternalServerError, }, - expectedNumberDoCalls: 3, - expectedCallRange: []int{3, 4}, // Allow 3 or 4 calls due to timing race condition + expectedNumberDoCalls: 4, }, } for _, tt := range tests { + tt := tt t.Run(tt.desc, func(t *testing.T) { - accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)), - }).SignedString([]byte("test")) + t.Parallel() + accessToken, err := signToken(accessTokensTimeToLive) if err != nil { - t.Fatalf("failed to create access token: %v", err) - } - - numberDoCalls := 0 - mockDo := func(r *http.Request) (resp *http.Response, err error) { - numberDoCalls++ // count refresh attempts - if tt.doError != nil { - return nil, tt.doError - } - newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)), - }).SignedString([]byte("test")) - if err != nil { - t.Fatalf("Do call: failed to create access token: %v", err) - } - responseBodyStruct := TokenResponseBody{ - AccessToken: newAccessToken, - } - responseBody, err := json.Marshal(responseBodyStruct) - if err != nil { - t.Fatalf("Do call: failed to marshal response: %v", err) - } - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - } - return response, nil + t.Fatalf("failed to sign access token: %v", err) } - ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, tt.contextClosesIn) defer cancel() - keyFlow := &KeyFlow{} - privateKeyBytes, err := generatePrivateKey() - if err != nil { - t.Fatalf("Error generating private key: %s", err) - } - keyFlowConfig := &KeyFlowConfig{ - ServiceAccountKey: fixtureServiceAccountKey(), - PrivateKey: string(privateKeyBytes), - AuthHTTPClient: &http.Client{ - Transport: mockTransportFn{mockDo}, - }, - HTTPTransport: mockTransportFn{mockDo}, - BackgroundTokenRefreshContext: nil, - } - err = keyFlow.Init(keyFlowConfig) - if err != nil { - t.Fatalf("failed to initialize key flow: %v", err) + authFlow := &fakeAuthFlow{ + backgroundTokenRefreshContext: ctx, + doError: tt.doError, + accessTokensTimeToLive: accessTokensTimeToLive, + accessToken: accessToken, } - // Set the token after initialization - err = keyFlow.SetToken(accessToken) - if err != nil { - t.Fatalf("failed to set token: %v", err) - } - - // Set the context for continuous refresh - keyFlow.config.BackgroundTokenRefreshContext = ctx - refresher := &continuousTokenRefresher{ - keyFlow: keyFlow, + flow: authFlow, timeStartBeforeTokenExpiration: timeStartBeforeTokenExpiration, timeBetweenContextCheck: timeBetweenContextCheck, timeBetweenTries: timeBetweenTries, @@ -164,300 +114,56 @@ func TestContinuousRefreshToken(t *testing.T) { if err == nil { t.Fatalf("routine finished with non-nil error") } - - // Check if we have a range of expected calls (for timing-sensitive tests) - if tt.expectedCallRange != nil { - if !contains(tt.expectedCallRange, numberDoCalls) { - t.Fatalf("expected %v calls to API to refresh token, got %d", tt.expectedCallRange, numberDoCalls) - } - } else if numberDoCalls != tt.expectedNumberDoCalls { + numberDoCalls := authFlow.getTokenCalls() + if numberDoCalls != tt.expectedNumberDoCalls { t.Fatalf("expected %d calls to API to refresh token, got %d", tt.expectedNumberDoCalls, numberDoCalls) } }) } } -// Tests if -// - continuousRefreshToken() updates access token -// - The access token can be accessed while continuousRefreshToken() is trying to update it -func TestContinuousRefreshTokenConcurrency(t *testing.T) { - // The times here are in the order of miliseconds (so they run faster) - // For this to work, we need to increase precision of the expiration timestamps - jwt.TimePrecision = time.Millisecond - - // Test plan: - // 1) continuousRefreshToken() will trigger a token update. It will be blocked in the mockDo() routine (defined below) - // 2) After continuousRefreshToken() is blocked, a request will be made using the key flow. That request should carry the access token (shouldn't be blocked just because continuousRefreshToken() is trying to refresh the token) - // 3) After the request is successful, continuousRefreshToken() will be unblocked - // 4) After waiting a bit, a new request will be made using the key flow. That request should carry the new access token - - // Where we're at in the test plan: - // - Starts at 0 - // - Is set to 1 before continuousRefreshToken() is called - // - Is set to 2 once the continuousRefreshToken() is blocked - // - Is set to 3 once the first request goes through and is checked - // - Is set to 4 after a small wait after continuousRefreshToken() is unblocked - currentTestPhase := 0 - - // Used to signal continuousRefreshToken() has been blocked - chanBlockContinuousRefreshToken := make(chan bool) - - // Used to signal continuousRefreshToken() should be unblocked - chanUnblockContinuousRefreshToken := make(chan bool) - - // The access token at the start - accessTokenFirst, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), - }).SignedString([]byte("token-first")) - if err != nil { - t.Fatalf("failed to create first access token: %v", err) - } - - // The access token that will replace accessTokenFirst - // Has a much longer expiration timestamp - accessTokenSecond, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("token-second")) - if err != nil { - t.Fatalf("failed to create second access token: %v", err) - } - - if accessTokenFirst == accessTokenSecond { - t.Fatalf("created tokens are equal") - } - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() // This cancels the refresher goroutine - - // Extract host from tokenAPI constant for consistency - tokenURL, _ := url.Parse(tokenAPI) - tokenHost := tokenURL.Host - - // The Do() routine, that both the keyFlow and continuousRefreshToken() use to make their requests - // The bools are used to make sure only one request goes through on each test phase - doTestPhase1RequestDone := false - doTestPhase2RequestDone := false - doTestPhase4RequestDone := false - mockDo := func(req *http.Request) (resp *http.Response, err error) { - // Handle auth requests (token refresh) - if req.URL.Host == tokenHost { - switch currentTestPhase { - default: - // After phase 1, allow additional auth requests but don't fail the test - // This handles the continuous nature of the refresh routine - if currentTestPhase > 1 { - // Return a valid response for any additional auth requests - newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("additional-token")) - if err != nil { - t.Fatalf("Do call: failed to create additional access token: %v", err) - } - responseBodyStruct := TokenResponseBody{ - AccessToken: newAccessToken, - } - responseBody, err := json.Marshal(responseBodyStruct) - if err != nil { - t.Fatalf("Do call: failed to marshal additional response: %v", err) - } - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - } - return response, nil - } - t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase) - return nil, nil - case 1: // Call by continuousRefreshToken() - if doTestPhase1RequestDone { - t.Fatalf("Do call: multiple requests during test phase 1") - } - doTestPhase1RequestDone = true - - currentTestPhase = 2 - chanBlockContinuousRefreshToken <- true - - // Wait until continuousRefreshToken() is to be unblocked - <-chanUnblockContinuousRefreshToken - - if currentTestPhase != 3 { - t.Fatalf("Do call: after unlocking refreshToken(), expected test phase to be 3, got %d", currentTestPhase) - } - - // Check required fields are passed - err = req.ParseForm() - if err != nil { - t.Fatalf("Do call: failed to parse body form: %v", err) - } - reqGrantType := req.Form.Get("grant_type") - if reqGrantType != "urn:ietf:params:oauth:grant-type:jwt-bearer" { - t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "urn:ietf:params:oauth:grant-type:jwt-bearer", reqGrantType) - } - // Return response with accessTokenSecond - responseBodyStruct := TokenResponseBody{ - AccessToken: accessTokenSecond, - } - responseBody, err := json.Marshal(responseBodyStruct) - if err != nil { - t.Fatalf("Do call: failed request to refresh token: marshal access token response: %v", err) - } - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - } - return response, nil - } - } - - // Handle regular HTTP requests - switch currentTestPhase { - default: - t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase) - return nil, nil - case 2: // Call by tokenFlow, first request - if doTestPhase2RequestDone { - t.Fatalf("Do call: multiple requests during test phase 2") - } - doTestPhase2RequestDone = true - - // Check host and access token - host := req.Host - expectedHost := "first-request-url.com" - if host != expectedHost { - t.Fatalf("Do call: first request expected to have host %q, found %q", expectedHost, host) - } - authHeader := req.Header.Get("Authorization") - expectedAuthHeader := fmt.Sprintf("Bearer %s", accessTokenFirst) - if authHeader != expectedAuthHeader { - t.Fatalf("Do call: first request didn't carry first access token. Expected: %s, Got: %s", expectedAuthHeader, authHeader) - } - - // Return empty response - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte{})), - } - return response, nil - case 4: // Call by tokenFlow, second request - if doTestPhase4RequestDone { - t.Fatalf("Do call: multiple requests during test phase 4") - } - doTestPhase4RequestDone = true - - // Check host and access token - host := req.Host - expectedHost := "second-request-url.com" - if host != expectedHost { - t.Fatalf("Do call: second request expected to have host %q, found %q", expectedHost, host) - } - authHeader := req.Header.Get("Authorization") - if authHeader != fmt.Sprintf("Bearer %s", accessTokenSecond) { - t.Fatalf("Do call: second request didn't carry second access token") - } - - // Return empty response - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte{})), - } - return response, nil - } - } - - keyFlow := &KeyFlow{} - privateKeyBytes, err := generatePrivateKey() - if err != nil { - t.Fatalf("Error generating private key: %s", err) - } - keyFlowConfig := &KeyFlowConfig{ - ServiceAccountKey: fixtureServiceAccountKey(), - PrivateKey: string(privateKeyBytes), - AuthHTTPClient: &http.Client{ - Transport: mockTransportFn{mockDo}, - }, - HTTPTransport: mockTransportFn{mockDo}, // Use same mock for regular requests - // Don't start continuous refresh automatically - BackgroundTokenRefreshContext: nil, - } - err = keyFlow.Init(keyFlowConfig) - if err != nil { - t.Fatalf("failed to initialize key flow: %v", err) - } - - // Set the token after initialization - err = keyFlow.SetToken(accessTokenFirst) - if err != nil { - t.Fatalf("failed to set token: %v", err) - } - - // Set the context for continuous refresh - keyFlow.config.BackgroundTokenRefreshContext = ctx - - // Create a custom refresher with shorter timing for the test - refresher := &continuousTokenRefresher{ - keyFlow: keyFlow, - timeStartBeforeTokenExpiration: 9 * time.Second, // Start 9 seconds before expiration - timeBetweenContextCheck: 5 * time.Millisecond, - timeBetweenTries: 40 * time.Millisecond, - } - - // TEST START - currentTestPhase = 1 - // Ignore returned error as expected in test - go func() { - _ = refresher.continuousRefreshToken() - }() +func signToken(expiration time.Duration) (string, error) { + return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiration)), + }).SignedString([]byte("test")) +} - // Wait until continuousRefreshToken() is blocked - <-chanBlockContinuousRefreshToken +var _ AuthFlow = &fakeAuthFlow{} - if currentTestPhase != 2 { - t.Fatalf("Unexpected test phase %d after continuousRefreshToken() was blocked", currentTestPhase) - } +type fakeAuthFlow struct { + backgroundTokenRefreshContext context.Context + tokenCounter int + doError error + accessTokensTimeToLive time.Duration + accessToken string +} - // Perform first request - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://first-request-url.com", http.NoBody) - if err != nil { - t.Fatalf("Create first request failed: %v", err) - } - resp, err := keyFlow.RoundTrip(req) - if err != nil { - t.Fatalf("Perform first request failed: %v", err) - } - err = resp.Body.Close() +func (f *fakeAuthFlow) RoundTrip(req *http.Request) (*http.Response, error) { + return nil, nil +} +func (f *fakeAuthFlow) GetAccessToken() (string, error) { + expired, err := tokenExpired(f.accessToken, 0) if err != nil { - t.Fatalf("First request body failed to close: %v", err) + return "", err } - - // Unblock continuousRefreshToken() - currentTestPhase = 3 - chanUnblockContinuousRefreshToken <- true - - // Wait for a bit - time.Sleep(10 * time.Millisecond) - currentTestPhase = 4 - - // Perform second request - req, err = http.NewRequestWithContext(ctx, http.MethodGet, "http://second-request-url.com", http.NoBody) - if err != nil { - t.Fatalf("Create second request failed: %v", err) + if !expired { + return f.accessToken, nil } - resp, err = keyFlow.RoundTrip(req) - if err != nil { - t.Fatalf("Second request failed: %v", err) + f.tokenCounter++ + if f.doError != nil { + return "", f.doError } - err = resp.Body.Close() + accessToken, err := signToken(f.accessTokensTimeToLive) if err != nil { - t.Fatalf("Second request body failed to close: %v", err) + return "", f.doError } + f.accessToken = accessToken + return accessToken, nil +} +func (f *fakeAuthFlow) GetBackgroundTokenRefreshContext() context.Context { + return f.backgroundTokenRefreshContext } -func contains(arr []int, val int) bool { - for _, v := range arr { - if v == val { - return true - } - } - return false +func (f *fakeAuthFlow) getTokenCalls() int { + return f.tokenCounter } diff --git a/core/clients/key_flow_test.go b/core/clients/key_flow_test.go index a64bee881..8b8877673 100644 --- a/core/clients/key_flow_test.go +++ b/core/clients/key_flow_test.go @@ -17,15 +17,10 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v5" "github.com/google/go-cmp/cmp" "github.com/google/uuid" ) -var ( - testSigningKey = []byte(`Test`) -) - const testBearerToken = "eyJhbGciOiJub25lIn0.eyJleHAiOjIxNDc0ODM2NDd9." //nolint:gosec // linter false positive func fixtureServiceAccountKey(mods ...func(*ServiceAccountKeyResponse)) *ServiceAccountKeyResponse { @@ -135,25 +130,25 @@ func TestTokenExpired(t *testing.T) { tests := []struct { desc string tokenInvalid bool - tokenExpiresAt time.Time + tokenDuration time.Duration expectedErr bool expectedIsExpired bool }{ { desc: "token valid", - tokenExpiresAt: time.Now().Add(time.Hour), + tokenDuration: time.Hour, expectedErr: false, expectedIsExpired: false, }, { desc: "token expired", - tokenExpiresAt: time.Now().Add(-time.Hour), + tokenDuration: -time.Hour, expectedErr: false, expectedIsExpired: true, }, { desc: "token almost expired", - tokenExpiresAt: time.Now().Add(tokenExpirationLeeway), + tokenDuration: tokenExpirationLeeway, expectedErr: false, expectedIsExpired: true, }, @@ -169,9 +164,7 @@ func TestTokenExpired(t *testing.T) { var err error token := "foo" if !tt.tokenInvalid { - token, err = jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(tt.tokenExpiresAt), - }).SignedString([]byte("test")) + token, err = signToken(tt.tokenDuration) if err != nil { t.Fatalf("failed to create token: %v", err) } diff --git a/core/clients/workload_identity_flow.go b/core/clients/workload_identity_flow.go new file mode 100644 index 000000000..65b6fc461 --- /dev/null +++ b/core/clients/workload_identity_flow.go @@ -0,0 +1,249 @@ +package clients + +import ( + "context" + "fmt" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +const ( + clientIDEnv = "STACKIT_SERVICE_ACCOUNT_EMAIL" + FederatedTokenFileEnv = "STACKIT_FEDERATED_TOKEN_FILE" + wifTokenEndpointEnv = "STACKIT_IDP_ENDPOINT" + wifTokenExpirationEnv = "STACKIT_IDP_EXPIRATION_SECONDS" + + wifClientAssertionType = "urn:schwarz:params:oauth:client-assertion-type:workload-jwt" + wifGrantType = "client_credentials" + defaultWifTokenEndpoint = "https://accounts.stackit.cloud/oauth/v2/token" + defaultFederatedTokenPath = "/var/run/secrets/stackit.cloud/serviceaccount/token" + defaultWifExpirationToken = "1h" +) + +var ( + _ = getEnvOrDefault(wifTokenExpirationEnv, defaultWifExpirationToken) // Not used yet +) + +func getEnvOrDefault(envVar, defaultValue string) string { + if value := os.Getenv(envVar); value != "" { + return value + } + return defaultValue +} + +var _ AuthFlow = &WorkloadIdentityFederationFlow{} + +// WorkloadIdentityFlow handles auth with Workload Identity Federation +type WorkloadIdentityFederationFlow struct { + rt http.RoundTripper + authClient *http.Client + config *WorkloadIdentityFederationFlowConfig + + tokenMutex sync.RWMutex + token *TokenResponseBody + + parser *jwt.Parser + + // If the current access token would expire in less than TokenExpirationLeeway, + // the client will refresh it early to prevent clock skew or other timing issues. + tokenExpirationLeeway time.Duration +} + +// KeyFlowConfig is the flow config +type WorkloadIdentityFederationFlowConfig struct { + TokenUrl string + ClientID string + FederatedTokenFilePath string + TokenExpiration string // Not supported yet + BackgroundTokenRefreshContext context.Context // Functionality is enabled if this isn't nil + HTTPTransport http.RoundTripper + AuthHTTPClient *http.Client +} + +// GetConfig returns the flow configuration +func (c *WorkloadIdentityFederationFlow) GetConfig() WorkloadIdentityFederationFlowConfig { + if c.config == nil { + return WorkloadIdentityFederationFlowConfig{} + } + return *c.config +} + +// GetAccessToken implements AuthFlow. +func (c *WorkloadIdentityFederationFlow) GetAccessToken() (string, error) { + if c.rt == nil { + return "", fmt.Errorf("nil http round tripper, please run Init()") + } + var accessToken string + + c.tokenMutex.RLock() + if c.token != nil { + accessToken = c.token.AccessToken + } + c.tokenMutex.RUnlock() + + accessTokenExpired, err := tokenExpired(accessToken, c.tokenExpirationLeeway) + if err != nil { + return "", fmt.Errorf("check access token is expired: %w", err) + } + if !accessTokenExpired { + return accessToken, nil + } + if err = c.createAccessToken(); err != nil { + return "", fmt.Errorf("get new access token: %w", err) + } + + c.tokenMutex.RLock() + accessToken = c.token.AccessToken + c.tokenMutex.RUnlock() + + return accessToken, nil +} + +// RoundTrip implements the http.RoundTripper interface. +// It gets a token, adds it to the request's authorization header, and performs the request. +func (c *WorkloadIdentityFederationFlow) RoundTrip(req *http.Request) (*http.Response, error) { + if c.rt == nil { + return nil, fmt.Errorf("please run Init()") + } + + accessToken, err := c.GetAccessToken() + if err != nil { + return nil, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + return c.rt.RoundTrip(req) +} + +// GetBackgroundTokenRefreshContext implements AuthFlow. +func (c *WorkloadIdentityFederationFlow) GetBackgroundTokenRefreshContext() context.Context { + return c.config.BackgroundTokenRefreshContext +} + +func (c *WorkloadIdentityFederationFlow) Init(cfg *WorkloadIdentityFederationFlowConfig) error { + // No concurrency at this point, so no mutex check needed + c.token = &TokenResponseBody{} + c.config = cfg + c.parser = jwt.NewParser() + + if c.config.TokenUrl == "" { + c.config.TokenUrl = getEnvOrDefault(wifTokenEndpointEnv, defaultWifTokenEndpoint) + } + + if c.config.ClientID == "" { + c.config.ClientID = getEnvOrDefault(clientIDEnv, "") + } + + if c.config.FederatedTokenFilePath == "" { + c.config.FederatedTokenFilePath = getEnvOrDefault(FederatedTokenFileEnv, defaultFederatedTokenPath) + } + + c.tokenExpirationLeeway = defaultTokenExpirationLeeway + + if c.rt = cfg.HTTPTransport; c.rt == nil { + c.rt = http.DefaultTransport + } + + if c.authClient = cfg.AuthHTTPClient; cfg.AuthHTTPClient == nil { + c.authClient = &http.Client{ + Transport: c.rt, + Timeout: DefaultClientTimeout, + } + } + + err := c.validate() + if err != nil { + return err + } + + // // Init the token + // _, err = c.GetAccessToken() + // if err != nil { + // return err + // } + + if c.config.BackgroundTokenRefreshContext != nil { + go continuousRefreshToken(c) + } + return nil +} + +// validate the client is configured well +func (c *WorkloadIdentityFederationFlow) validate() error { + if c.config.ClientID == "" { + return fmt.Errorf("client ID cannot be empty") + } + if c.config.TokenUrl == "" { + return fmt.Errorf("token URL cannot be empty") + } + if _, err := c.readJWTFromFileSystem(c.config.FederatedTokenFilePath); err != nil { + return fmt.Errorf("error reading federated token file - %w", err) + } + if c.tokenExpirationLeeway < 0 { + return fmt.Errorf("token expiration leeway cannot be negative") + } + + return nil +} + +// createAccessToken creates an access token using self signed JWT +func (c *WorkloadIdentityFederationFlow) createAccessToken() (err error) { + clientAssertion, err := c.readJWTFromFileSystem(c.config.FederatedTokenFilePath) + if err != nil { + return fmt.Errorf("error reading service account assertion - %w", err) + } + + res, err := c.requestToken(c.config.ClientID, clientAssertion) + if err != nil { + return err + } + defer func() { + tempErr := res.Body.Close() + if tempErr != nil && err == nil { + err = fmt.Errorf("close request access token response: %w", tempErr) + } + }() + token, err := parseTokenResponse(res) + if err != nil { + return err + } + c.tokenMutex.Lock() + c.token = token + c.tokenMutex.Unlock() + return nil +} + +func (c *WorkloadIdentityFederationFlow) requestToken(clientID, assertion string) (*http.Response, error) { + body := url.Values{} + body.Set("grant_type", wifGrantType) + body.Set("client_assertion_type", wifClientAssertionType) + body.Set("client_assertion", assertion) + body.Set("client_id", clientID) + + payload := strings.NewReader(body.Encode()) + req, err := http.NewRequest(http.MethodPost, c.config.TokenUrl, payload) + if err != nil { + return nil, err + } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + return c.authClient.Do(req) +} + +func (c *WorkloadIdentityFederationFlow) readJWTFromFileSystem(tokenFilePath string) (string, error) { + token, err := os.ReadFile(tokenFilePath) + if err != nil { + return "", err + } + tokenStr := string(token) + _, _, err = c.parser.ParseUnverified(tokenStr, jwt.MapClaims{}) + if err != nil { + return "", err + } + return tokenStr, nil +} diff --git a/core/clients/workload_identity_flow_test.go b/core/clients/workload_identity_flow_test.go new file mode 100644 index 000000000..ef8f7a15f --- /dev/null +++ b/core/clients/workload_identity_flow_test.go @@ -0,0 +1,566 @@ +package clients + +import ( + "encoding/json" + "log" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +func TestWorkloadIdentityFlowInit(t *testing.T) { + tests := []struct { + name string + clientID string + clientIDAsEnv bool + customTokenUrl string + customTokenUrlEnv bool + tokenExpiration string + validAssertion bool + tokenFilePathAsEnv bool + missingTokenFilePath bool + wantErr bool + }{ + { + name: "ok setting all", + clientID: "test@stackit.cloud", + validAssertion: true, + wantErr: false, + }, + { + name: "ok using defaults from envs", + clientID: "test@stackit.cloud", + clientIDAsEnv: true, + tokenFilePathAsEnv: true, + customTokenUrlEnv: true, + validAssertion: true, + wantErr: false, + }, + { + name: "ok using defaults from envs", + clientID: "test@stackit.cloud", + clientIDAsEnv: true, + tokenFilePathAsEnv: true, + customTokenUrlEnv: true, + validAssertion: true, + wantErr: false, + }, + { + name: "missing client id", + validAssertion: true, + wantErr: true, + }, + { + name: "missing assertion", + clientID: "test@stackit.cloud", + missingTokenFilePath: true, + wantErr: true, + }, + { + name: "invalid assertion", + clientID: "test@stackit.cloud", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + flow := &WorkloadIdentityFederationFlow{} + flowConfig := &WorkloadIdentityFederationFlowConfig{} + if tt.customTokenUrl != "" { + if tt.customTokenUrlEnv { + t.Setenv("STACKIT_IDP_ENDPOINT", tt.customTokenUrl) + } else { + flowConfig.TokenUrl = tt.customTokenUrl + } + } + + if tt.clientID != "" { + if tt.clientIDAsEnv { + t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", tt.clientID) + } else { + flowConfig.ClientID = tt.clientID + } + } + if tt.tokenExpiration != "" { + flowConfig.TokenExpiration = tt.tokenExpiration + } + + if !tt.missingTokenFilePath { + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer os.Remove(file.Name()) + if tt.validAssertion { + token, err := signTokenWithSubject("subject", time.Minute) + if err != nil { + t.Fatalf("failed to create token: %v", err) + } + os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + } + if tt.tokenFilePathAsEnv { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", file.Name()) + } else { + flowConfig.FederatedTokenFilePath = file.Name() + } + } + + if err := flow.Init(flowConfig); (err != nil) != tt.wantErr { + t.Errorf("KeyFlow.Init() error = %v, wantErr %v", err, tt.wantErr) + } + if flow.config == nil { + t.Error("config is nil") + } + + if flow.config.ClientID != tt.clientID { + t.Errorf("clientID mismatch, want %s, got %s", tt.clientID, flow.config.ClientID) + } + + if tt.customTokenUrl != "" && flow.config.TokenUrl != tt.customTokenUrl { + t.Errorf("tokenUrl mismatch, want %s, got %s", tt.customTokenUrl, flow.config.TokenUrl) + } + + if tt.customTokenUrl == "" && flow.config.TokenUrl != "https://accounts.stackit.cloud/oauth/v2/token" { + t.Errorf("tokenUrl mismatch, want %s, got %s", "https://accounts.stackit.cloud/oauth/v2/token", flow.config.TokenUrl) + } + + if tt.missingTokenFilePath && flow.config.FederatedTokenFilePath != "/var/run/secrets/stackit.cloud/serviceaccount/token" { + t.Errorf("clientID mismatch, want %s, got %s", "/var/run/secrets/stackit.cloud/serviceaccount/token", flow.config.FederatedTokenFilePath) + } + + if !tt.missingTokenFilePath && flow.config.FederatedTokenFilePath == "/var/run/secrets/stackit.cloud/serviceaccount/token" { + t.Errorf("clientID mismatch, want different from %s", flow.config.FederatedTokenFilePath) + } + + if tt.tokenExpiration != "" && flow.config.TokenExpiration != tt.tokenExpiration { + t.Errorf("tokenExpiration mismatch, want %s, got %s", tt.tokenExpiration, flow.config.TokenExpiration) + } + }) + } +} + +func signTokenWithSubject(sub string, expiration time.Duration) (string, error) { + return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiration)), + Subject: sub, + }).SignedString([]byte("test")) +} + +func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { + validSub := "valid-sub" + serviceAccountSub := "sa-sub" + tests := []struct { + name string + clientID string + validAssertion bool + wantErr bool + }{ + { + name: "ok setting all", + clientID: "test@stackit.cloud", + validAssertion: true, + wantErr: false, + }, + { + name: "invalid assertion", + clientID: "test@stackit.cloud", + validAssertion: false, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + assertionType := r.PostForm.Get("client_assertion_type") + if assertionType != "urn:schwarz:params:oauth:client-assertion-type:workload-jwt" { + t.Fatalf("invalid assertion type: %s", assertionType) + } + grantType := r.PostForm.Get("grant_type") + if grantType != "client_credentials" { + t.Fatalf("invalid grant type: %s", assertionType) + } + context, _, err := jwt.NewParser().ParseUnverified(r.PostForm.Get("client_assertion"), jwt.MapClaims{}) + if err != nil { + t.Fatalf("failed to validate token: %v", err) + } + sub, err := context.Claims.GetSubject() + if err != nil { + t.Fatalf("failed to validate token sub: %v", err) + } + if sub != validSub { + w.WriteHeader(http.StatusUnauthorized) + return + } + + token, err := signTokenWithSubject(serviceAccountSub, time.Minute) + if err != nil { + t.Fatalf("failed to create token: %v", err) + } + + tokenResponse := &TokenResponseBody{ + AccessToken: token, + ExpiresIn: 60, + TokenType: "Bearer", + } + + payload, err := json.Marshal(tokenResponse) + if err != nil { + t.Fatalf("failed to create token payload: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(payload) + })) + t.Cleanup(authServer.Close) + + protectedResource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + context, _, err := jwt.NewParser().ParseUnverified(strings.Fields(r.Header.Get("Authorization"))[1], jwt.MapClaims{}) + if err != nil { + t.Fatalf("failed to validate token: %v", err) + } + + sub, err := context.Claims.GetSubject() + if err != nil { + t.Fatalf("failed to validate token sub: %v", err) + } + if sub != serviceAccountSub { + t.Fatalf("invalid token on protected resource: %v", err) + } + + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(protectedResource.Close) + + flow := &WorkloadIdentityFederationFlow{} + flowConfig := &WorkloadIdentityFederationFlowConfig{} + flowConfig.TokenUrl = authServer.URL + + flowConfig.ClientID = tt.clientID + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer os.Remove(file.Name()) + flowConfig.FederatedTokenFilePath = file.Name() + + subject := "wrong" + if tt.validAssertion { + subject = validSub + } + token, err := signTokenWithSubject(subject, time.Minute) + if err != nil { + t.Fatalf("failed to create token: %v", err) + } + os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + + if err := flow.Init(flowConfig); err != nil { + t.Errorf("KeyFlow.Init() error = %v", err) + } + if flow.config == nil { + t.Error("config is nil") + } + + client := http.Client{ + Transport: flow, + } + resp, err := client.Get(protectedResource.URL) + if (err != nil || resp.StatusCode != http.StatusOK) && !tt.wantErr { + t.Fatalf("failed request to protected resource: %v", err) + } + }) + } +} + +// func TestRequestToken(t *testing.T) { +// testCases := []struct { +// name string +// grant string +// assertion string +// mockResponse *http.Response +// mockError error +// expectedError error +// }{ +// { +// name: "Success", +// grant: "test_grant", +// assertion: "test_assertion", +// mockResponse: &http.Response{ +// StatusCode: 200, +// Body: io.NopCloser(strings.NewReader(`{"access_token": "test_token"}`)), +// }, +// mockError: nil, +// expectedError: nil, +// }, +// { +// name: "Error", +// grant: "test_grant", +// assertion: "test_assertion", +// mockResponse: nil, +// mockError: fmt.Errorf("request error"), +// expectedError: fmt.Errorf("request error"), +// }, +// } + +// for _, tt := range testCases { +// t.Run(tt.name, func(t *testing.T) { +// keyFlow := &KeyFlow{} +// privateKeyBytes, err := generatePrivateKey() +// if err != nil { +// t.Fatalf("Error generating private key: %s", err) +// } +// keyFlowConfig := &KeyFlowConfig{ +// AuthHTTPClient: &http.Client{ +// Transport: mockTransportFn{func(_ *http.Request) (*http.Response, error) { +// return tt.mockResponse, tt.mockError +// }}, +// }, +// ServiceAccountKey: fixtureServiceAccountKey(), +// PrivateKey: string(privateKeyBytes), +// HTTPTransport: http.DefaultTransport, +// } +// err = keyFlow.Init(keyFlowConfig) +// if err != nil { +// t.Fatalf("failed to initialize key flow: %v", err) +// } + +// res, err := keyFlow.requestToken(tt.grant, tt.assertion) +// defer func() { +// if res != nil { +// tempErr := res.Body.Close() +// if tempErr != nil { +// t.Errorf("closing request token response: %s", tempErr.Error()) +// } +// } +// }() +// if tt.expectedError != nil { +// if err == nil { +// t.Errorf("Expected error '%v' but no error was returned", tt.expectedError) +// } else if errors.Is(err, tt.expectedError) { +// t.Errorf("Error is not correct. Expected %v, got %v", tt.expectedError, err) +// } +// } else { +// if err != nil { +// t.Errorf("Expected no error but error was returned: %v", err) +// } +// if !cmp.Equal(tt.mockResponse, res, cmp.AllowUnexported(strings.Reader{})) { +// t.Errorf("The returned result is wrong. Expected %v, got %v", tt.mockResponse, res) +// } +// } +// }) +// } +// } + +// func TestKeyFlow_Do(t *testing.T) { +// t.Parallel() + +// tests := []struct { +// name string +// handlerFn func(tb testing.TB) http.HandlerFunc +// want int +// wantErr bool +// }{ +// { +// name: "success", +// handlerFn: func(tb testing.TB) http.HandlerFunc { +// tb.Helper() + +// return func(w http.ResponseWriter, r *http.Request) { +// if r.Header.Get("Authorization") != "Bearer "+testBearerToken { +// tb.Errorf("expected Authorization header to be 'Bearer %s', but got %s", testBearerToken, r.Header.Get("Authorization")) +// } + +// w.Header().Set("Content-Type", "application/json") +// w.WriteHeader(http.StatusOK) +// _, _ = fmt.Fprintln(w, `{"status":"ok"}`) +// } +// }, +// want: http.StatusOK, +// wantErr: false, +// }, +// { +// name: "success with code 500", +// handlerFn: func(_ testing.TB) http.HandlerFunc { +// return func(w http.ResponseWriter, _ *http.Request) { +// w.Header().Set("Content-Type", "text/html") +// w.WriteHeader(http.StatusInternalServerError) +// _, _ = fmt.Fprintln(w, `Internal Server Error`) +// } +// }, +// want: http.StatusInternalServerError, +// wantErr: false, +// }, +// { +// name: "success with custom transport", +// handlerFn: func(tb testing.TB) http.HandlerFunc { +// tb.Helper() + +// return func(w http.ResponseWriter, r *http.Request) { +// if r.Header.Get("User-Agent") != "custom_transport" { +// tb.Errorf("expected User-Agent header to be 'custom_transport', but got %s", r.Header.Get("User-Agent")) +// } + +// w.Header().Set("Content-Type", "application/json") +// w.WriteHeader(http.StatusOK) +// _, _ = fmt.Fprintln(w, `{"status":"ok"}`) +// } +// }, +// want: http.StatusOK, +// wantErr: false, +// }, +// { +// name: "fail with custom proxy", +// handlerFn: func(testing.TB) http.HandlerFunc { +// return func(w http.ResponseWriter, _ *http.Request) { +// w.Header().Set("Content-Type", "application/json") +// w.WriteHeader(http.StatusOK) +// _, _ = fmt.Fprintln(w, `{"status":"ok"}`) +// } +// }, +// want: 0, +// wantErr: true, +// }, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// ctx := context.Background() +// ctx, cancel := context.WithCancel(ctx) +// t.Cleanup(cancel) // This cancels the refresher goroutine + +// privateKeyBytes, err := generatePrivateKey() +// if err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } + +// keyFlow := &KeyFlow{} +// keyFlowConfig := &KeyFlowConfig{ +// ServiceAccountKey: fixtureServiceAccountKey(), +// PrivateKey: string(privateKeyBytes), +// BackgroundTokenRefreshContext: ctx, +// HTTPTransport: func() http.RoundTripper { +// switch tt.name { +// case "success with custom transport": +// return mockTransportFn{ +// fn: func(req *http.Request) (*http.Response, error) { +// req.Header.Set("User-Agent", "custom_transport") +// return http.DefaultTransport.RoundTrip(req) +// }, +// } +// case "fail with custom proxy": +// return &http.Transport{ +// Proxy: func(_ *http.Request) (*url.URL, error) { +// return nil, fmt.Errorf("proxy error") +// }, +// } +// default: +// return http.DefaultTransport +// } +// }(), +// AuthHTTPClient: &http.Client{ +// Transport: mockTransportFn{ +// fn: func(_ *http.Request) (*http.Response, error) { +// res := httptest.NewRecorder() +// res.WriteHeader(http.StatusOK) +// res.Header().Set("Content-Type", "application/json") + +// token := &TokenResponseBody{ +// AccessToken: testBearerToken, +// ExpiresIn: 2147483647, +// TokenType: "Bearer", +// } + +// if err := json.NewEncoder(res.Body).Encode(token); err != nil { +// t.Logf("no error is expected, but got %v", err) +// } + +// return res.Result(), nil +// }, +// }, +// }, +// } +// err = keyFlow.Init(keyFlowConfig) +// if err != nil { +// t.Fatalf("failed to initialize key flow: %v", err) +// } + +// go continuousRefreshToken(keyFlow) + +// tokenCtx, tokenCancel := context.WithTimeout(context.Background(), 1*time.Second) + +// token: +// for { +// select { +// case <-tokenCtx.Done(): +// t.Error(tokenCtx.Err()) +// case <-time.After(50 * time.Millisecond): +// keyFlow.tokenMutex.RLock() +// if keyFlow.token != nil { +// keyFlow.tokenMutex.RUnlock() +// tokenCancel() +// break token +// } + +// keyFlow.tokenMutex.RUnlock() +// } +// } + +// server := httptest.NewServer(tt.handlerFn(t)) +// t.Cleanup(server.Close) + +// u, err := url.Parse(server.URL) +// if err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } + +// req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) +// if err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } + +// httpClient := &http.Client{ +// Transport: keyFlow, +// } + +// res, err := httpClient.Do(req) + +// if tt.wantErr { +// if err == nil { +// t.Errorf("error is expected, but got %v", err) +// } +// } else { +// if err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } + +// if res.StatusCode != tt.want { +// t.Errorf("expected status code %d, but got %d", tt.want, res.StatusCode) +// } + +// // Defer discard and close the body +// t.Cleanup(func() { +// if _, err := io.Copy(io.Discard, res.Body); err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } + +// if err := res.Body.Close(); err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } +// }) +// } +// }) +// } +// } + +// type mockTransportFn struct { +// fn func(req *http.Request) (*http.Response, error) +// } + +// func (m mockTransportFn) RoundTrip(req *http.Request) (*http.Response, error) { +// return m.fn(req) +// } diff --git a/core/config/config.go b/core/config/config.go index 93002c02a..ae2d8c498 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -75,26 +75,29 @@ type Middleware func(http.RoundTripper) http.RoundTripper // Configuration stores the configuration of the API client type Configuration struct { - Host string `json:"host,omitempty"` - Scheme string `json:"scheme,omitempty"` - DefaultHeader map[string]string `json:"defaultHeader,omitempty"` - UserAgent string `json:"userAgent,omitempty"` - Debug bool `json:"debug,omitempty"` - NoAuth bool `json:"noAuth,omitempty"` - ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` // Deprecated: ServiceAccountEmail is not required and will be removed after 12th June 2025. - Token string `json:"token,omitempty"` - ServiceAccountKey string `json:"serviceAccountKey,omitempty"` - PrivateKey string `json:"privateKey,omitempty"` - ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` - PrivateKeyPath string `json:"privateKeyPath,omitempty"` - CredentialsFilePath string `json:"credentialsFilePath,omitempty"` - TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` - Region string `json:"region,omitempty"` - CustomAuth http.RoundTripper - Servers ServerConfigurations - OperationServers map[string]ServerConfigurations - HTTPClient *http.Client - Middleware []Middleware + Host string `json:"host,omitempty"` + Scheme string `json:"scheme,omitempty"` + DefaultHeader map[string]string `json:"defaultHeader,omitempty"` + UserAgent string `json:"userAgent,omitempty"` + Debug bool `json:"debug,omitempty"` + NoAuth bool `json:"noAuth,omitempty"` + WorkloadIdentityFederation bool `json:"workloadIdentityFederation,omitempty"` + WorkloadIdentityFederationTokenExpiration string `json:"workloadIdentityFederationTokenExpiration,omitempty"` + WorkloadIdentityFederationFederatedTokenPath string `json:"workloadIdentityFederationFederatedTokenPath,omitempty"` + ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` + Token string `json:"token,omitempty"` + ServiceAccountKey string `json:"serviceAccountKey,omitempty"` + PrivateKey string `json:"privateKey,omitempty"` + ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` + PrivateKeyPath string `json:"privateKeyPath,omitempty"` + CredentialsFilePath string `json:"credentialsFilePath,omitempty"` + TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` + Region string `json:"region,omitempty"` + CustomAuth http.RoundTripper + Servers ServerConfigurations + OperationServers map[string]ServerConfigurations + HTTPClient *http.Client + Middleware []Middleware // If != nil, a goroutine will be launched that will refresh the service account's access token when it's close to being expired. // The goroutine is killed whenever this context is canceled. @@ -176,8 +179,6 @@ func WithTokenEndpoint(url string) ConfigurationOption { } // WithServiceAccountEmail returns a ConfigurationOption that sets the service account email -// -// Deprecated: WithServiceAccountEmail is not required and will be removed after 12th June 2025. func WithServiceAccountEmail(serviceAccountEmail string) ConfigurationOption { return func(config *Configuration) error { config.ServiceAccountEmail = serviceAccountEmail @@ -237,6 +238,30 @@ func WithToken(token string) ConfigurationOption { } } +// WithWorkloadIdentityFederationAuth returns a ConfigurationOption that sets workload identity flow to be used for authentication in API calls +func WithWorkloadIdentityFederationAuth() ConfigurationOption { + return func(config *Configuration) error { + config.WorkloadIdentityFederation = true + return nil + } +} + +// WithWorkloadIdentityFederation returns a ConfigurationOption that sets workload identity flow to be used for authentication in API calls +func WithWorkloadIdentityFederationTokenPath(path string) ConfigurationOption { + return func(config *Configuration) error { + config.WorkloadIdentityFederationFederatedTokenPath = path + return nil + } +} + +// WithWorkloadIdentityFederationTokenExpiration returns a ConfigurationOption that sets the token expiration for workload identity federation flow +func WithWorkloadIdentityFederationTokenExpiration(expiration string) ConfigurationOption { + return func(config *Configuration) error { + config.WorkloadIdentityFederationTokenExpiration = expiration + return nil + } +} + // Deprecated: retry options were removed to reduce complexity of the client. If this functionality is needed, you can provide your own custom HTTP client. This option has no effect, and will be removed in a later update func WithMaxRetries(_ int) ConfigurationOption { return func(_ *Configuration) error { diff --git a/examples/authentication/authentication.go b/examples/authentication/authentication.go index 839999938..b398b19a9 100644 --- a/examples/authentication/authentication.go +++ b/examples/authentication/authentication.go @@ -35,18 +35,27 @@ func main() { // Create a new API client, that will authenticate using the provided bearer token token := "TOKEN" - _, err = dns.NewAPIClient(config.WithToken(token)) + dnsClient, err := dns.NewAPIClient(config.WithToken(token)) if err != nil { fmt.Fprintf(os.Stderr, "[DNS API] Creating API client: %v\n", err) os.Exit(1) } + // Check that you can make an authenticated request + getZoneResp, err := dnsClient.ListZones(context.Background(), projectId).Execute() + + if err != nil { + fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) + } else { + fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) + } + // Create a new API client, that will authenticate using the key flow // If you created a service account key and provided your own RSA key pair, // you need to add the path to a PEM encoded file including the private key // using config.WithPrivateKeyPath("path/to/private_key.pem") saKeyPath := "/path/to/service_account_key.json" - dnsClient, err := dns.NewAPIClient( + dnsClient, err = dns.NewAPIClient( config.WithServiceAccountKeyPath(saKeyPath), ) if err != nil { @@ -55,7 +64,51 @@ func main() { } // Check that you can make an authenticated request - getZoneResp, err := dnsClient.ListZones(context.Background(), projectId).Execute() + getZoneResp, err = dnsClient.ListZones(context.Background(), projectId).Execute() + + if err != nil { + fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) + } else { + fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) + } + + // Create a new API client, that will authenticate using the wif flow + // You need to create a service account key and configure the federate identity provider, + // then you can init the SDK using default env var + os.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "my-sa@sa-stackit.cloud") + os.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "/path/to/your/federated/token") // Default "/var/run/secrets/stackit.cloud/serviceaccount/token" + os.Setenv("STACKIT_IDP_ENDPOINT", "custom token endpoint") // Default "https://accounts.stackit.cloud/oauth/v2/token" + dnsClient, err = dns.NewAPIClient() + if err != nil { + fmt.Fprintf(os.Stderr, "[DNS API] Creating API client: %v\n", err) + os.Exit(1) + } + + // Check that you can make an authenticated request + getZoneResp, err = dnsClient.ListZones(context.Background(), projectId).Execute() + + if err != nil { + fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) + } else { + fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) + } + + // Create a new API client, that will authenticate using the wif flow + // You need to create a service account key and configure the federate identity provider, + // then you can init the SDK setting fields + dnsClient, err = dns.NewAPIClient( + config.WithWorkloadIdentityFederationAuth(), + config.WithTokenEndpoint("custom token endpoint"), + config.WithWorkloadIdentityFederationTokenPath("/path/to/your/federated/token"), + config.WithServiceAccountEmail("my-sa@sa-stackit.cloud"), + ) + if err != nil { + fmt.Fprintf(os.Stderr, "[DNS API] Creating API client: %v\n", err) + os.Exit(1) + } + + // Check that you can make an authenticated request + getZoneResp, err = dnsClient.ListZones(context.Background(), projectId).Execute() if err != nil { fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) From 711502ec9f89af3b6ab712346d68b896111c61c6 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Wed, 10 Dec 2025 12:16:08 +0100 Subject: [PATCH 06/18] update changelog Signed-off-by: Jorge Turrado --- core/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index 47d06e806..aaaa0636c 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -1,5 +1,6 @@ ## v0.21.0 - **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` +- **Feature:** Support Workload Identity Federation flow ## v0.20.1 - **Improvement:** Improve error message when passing a PEM encoded file to as service account key From 7dd73c1607bf45709f82071842742f0dad19b3ce Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Mon, 15 Dec 2025 18:38:12 +0100 Subject: [PATCH 07/18] apply feedback Signed-off-by: Jorge Turrado --- core/auth/auth.go | 20 ++++------- core/auth/auth_test.go | 17 ++++++++-- core/clients/key_flow.go | 40 ++++++++++++++++++++++ core/clients/key_flow_test.go | 63 +++++++++++++++++++++++++++++++++++ 4 files changed, 124 insertions(+), 16 deletions(-) diff --git a/core/auth/auth.go b/core/auth/auth.go index 88f002fe7..e3b10bc46 100644 --- a/core/auth/auth.go +++ b/core/auth/auth.go @@ -45,18 +45,18 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) { if cfg.CustomAuth != nil { return cfg.CustomAuth, nil - } else if useWorkloadIdentityFederation(cfg) { - wifRoundTripper, err := WorkloadIdentityFederationAuth(cfg) - if err != nil { - return nil, fmt.Errorf("configuring no auth client: %w", err) - } - return wifRoundTripper, nil } else if cfg.NoAuth { noAuthRoundTripper, err := NoAuth(cfg) if err != nil { return nil, fmt.Errorf("configuring no auth client: %w", err) } return noAuthRoundTripper, nil + } else if cfg.WorkloadIdentityFederation { + wifRoundTripper, err := WorkloadIdentityFederationAuth(cfg) + if err != nil { + return nil, fmt.Errorf("configuring no auth client: %w", err) + } + return wifRoundTripper, nil } else if cfg.ServiceAccountKey != "" || cfg.ServiceAccountKeyPath != "" { keyRoundTripper, err := KeyAuth(cfg) if err != nil { @@ -394,11 +394,3 @@ func getServiceAccountKey(cfg *config.Configuration) error { func getPrivateKey(cfg *config.Configuration) error { return getKey(&cfg.PrivateKey, &cfg.PrivateKeyPath, "STACKIT_PRIVATE_KEY_PATH", "STACKIT_PRIVATE_KEY", privateKeyPathCredentialType, privateKeyCredentialType, cfg.CredentialsFilePath) } - -func useWorkloadIdentityFederation(cfg *config.Configuration) bool { - if cfg != nil && cfg.WorkloadIdentityFederation { - return true - } - val, exists := os.LookupEnv(clients.FederatedTokenFileEnv) - return exists && val != "" -} diff --git a/core/auth/auth_test.go b/core/auth/auth_test.go index 5e8af7203..b861bf581 100644 --- a/core/auth/auth_test.go +++ b/core/auth/auth_test.go @@ -179,23 +179,27 @@ func TestSetupAuth(t *testing.T) { setKeyPaths bool setCredentialsFilePathToken bool setCredentialsFilePathKey bool + isValid bool }{ { desc: "wif_config", config: nil, setWorkloadIdentity: true, + isValid: true, }, { desc: "token_config", config: nil, setToken: true, setCredentialsFilePathToken: false, + isValid: true, }, { desc: "key_config", config: nil, setKeys: true, setCredentialsFilePathToken: false, + isValid: true, }, { desc: "key_config_path", @@ -203,6 +207,7 @@ func TestSetupAuth(t *testing.T) { setKeys: false, setKeyPaths: true, setCredentialsFilePathToken: false, + isValid: true, }, { desc: "key_config_credentials_path", @@ -210,12 +215,14 @@ func TestSetupAuth(t *testing.T) { setKeys: false, setKeyPaths: false, setCredentialsFilePathKey: true, + isValid: true, }, { desc: "valid_path_to_file", config: nil, setToken: false, setCredentialsFilePathToken: true, + isValid: true, }, { desc: "custom_config_token", @@ -224,6 +231,7 @@ func TestSetupAuth(t *testing.T) { }, setToken: false, setCredentialsFilePathToken: false, + isValid: true, }, { desc: "custom_config_path", @@ -232,6 +240,7 @@ func TestSetupAuth(t *testing.T) { }, setToken: false, setCredentialsFilePathToken: false, + isValid: true, }, } { t.Run(test.desc, func(t *testing.T) { @@ -276,11 +285,15 @@ func TestSetupAuth(t *testing.T) { authRoundTripper, err := SetupAuth(test.config) - if err != nil { + if err != nil && test.isValid { t.Fatalf("Test returned error on valid test case: %v", err) } - if authRoundTripper == nil { + if err == nil && !test.isValid { + t.Fatalf("Test didn't return error on invalid test case") + } + + if authRoundTripper == nil && test.isValid { t.Fatalf("Roundtripper returned is nil for valid test case") } }) diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index 83c82e778..46b5d91a0 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -99,6 +99,46 @@ func (c *KeyFlow) GetServiceAccountEmail() string { return c.key.Credentials.Iss } +// GetToken returns the token field +// Deprecated: Use GetAccessToken instead +func (c *KeyFlow) GetToken() TokenResponseBody { + c.tokenMutex.RLock() + defer c.tokenMutex.RUnlock() + + if c.token == nil { + return TokenResponseBody{} + } + // Returned struct is passed by value (because it's a struct) + // So no deepy copy needed + return *c.token +} + +// SetToken can be used to set an access and refresh token manually in the client. +// The other fields in the token field are determined by inspecting the token or setting default values. +// Deprecated +func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { + // We can safely use ParseUnverified because we are not authenticating the user, + // We are parsing the token just to get the expiration time claim + parsedAccessToken, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) + if err != nil { + return fmt.Errorf("parse access token to read expiration time: %w", err) + } + exp, err := parsedAccessToken.Claims.GetExpirationTime() + if err != nil { + return fmt.Errorf("get expiration time from access token: %w", err) + } + + c.tokenMutex.Lock() + c.token = &TokenResponseBody{ + AccessToken: accessToken, + ExpiresIn: int(exp.Time.Unix()), + Scope: "", + TokenType: "Bearer", + } + c.tokenMutex.Unlock() + return nil +} + func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { // No concurrency at this point, so no mutex check needed c.token = &TokenResponseBody{} diff --git a/core/clients/key_flow_test.go b/core/clients/key_flow_test.go index 8b8877673..7c094331e 100644 --- a/core/clients/key_flow_test.go +++ b/core/clients/key_flow_test.go @@ -17,10 +17,15 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v5" "github.com/google/go-cmp/cmp" "github.com/google/uuid" ) +var ( + testSigningKey = []byte(`Test`) +) + const testBearerToken = "eyJhbGciOiJub25lIn0.eyJleHAiOjIxNDc0ODM2NDd9." //nolint:gosec // linter false positive func fixtureServiceAccountKey(mods ...func(*ServiceAccountKeyResponse)) *ServiceAccountKeyResponse { @@ -125,6 +130,64 @@ func TestKeyFlowInit(t *testing.T) { } } +func TestSetToken(t *testing.T) { + tests := []struct { + name string + tokenInvalid bool + refreshToken string + wantErr bool + }{ + { + name: "ok", + tokenInvalid: false, + refreshToken: "refresh_token", + wantErr: false, + }, + { + name: "invalid_token", + tokenInvalid: true, + refreshToken: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var accessToken string + var err error + + timestamp := time.Now().Add(24 * time.Hour) + if tt.tokenInvalid { + accessToken = "foo" + } else { + accessTokenJWT := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(timestamp)}) + accessToken, err = accessTokenJWT.SignedString(testSigningKey) + if err != nil { + t.Fatalf("get test access token as string: %s", err) + } + } + + keyFlow := &KeyFlow{} + err = keyFlow.SetToken(accessToken, tt.refreshToken) + + if (err != nil) != tt.wantErr { + t.Errorf("KeyFlow.SetToken() error = %v, wantErr %v", err, tt.wantErr) + } + if err == nil { + expectedKeyFlowToken := &TokenResponseBody{ + AccessToken: accessToken, + ExpiresIn: int(timestamp.Unix()), + Scope: "", + TokenType: "Bearer", + } + if !cmp.Equal(expectedKeyFlowToken, keyFlow.token) { + t.Errorf("The returned result is wrong. Expected %+v, got %+v", expectedKeyFlowToken, keyFlow.token) + } + } + }) + } +} + func TestTokenExpired(t *testing.T) { tokenExpirationLeeway := 5 * time.Second tests := []struct { From ed9d1b43e2325600da5f0d651c9fca7595dbc992 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Mon, 15 Dec 2025 18:40:51 +0100 Subject: [PATCH 08/18] apply feedback Signed-off-by: Jorge Turrado --- core/clients/key_flow.go | 52 ++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index 46b5d91a0..d18d4f0bf 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -113,32 +113,6 @@ func (c *KeyFlow) GetToken() TokenResponseBody { return *c.token } -// SetToken can be used to set an access and refresh token manually in the client. -// The other fields in the token field are determined by inspecting the token or setting default values. -// Deprecated -func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { - // We can safely use ParseUnverified because we are not authenticating the user, - // We are parsing the token just to get the expiration time claim - parsedAccessToken, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) - if err != nil { - return fmt.Errorf("parse access token to read expiration time: %w", err) - } - exp, err := parsedAccessToken.Claims.GetExpirationTime() - if err != nil { - return fmt.Errorf("get expiration time from access token: %w", err) - } - - c.tokenMutex.Lock() - c.token = &TokenResponseBody{ - AccessToken: accessToken, - ExpiresIn: int(exp.Time.Unix()), - Scope: "", - TokenType: "Bearer", - } - c.tokenMutex.Unlock() - return nil -} - func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { // No concurrency at this point, so no mutex check needed c.token = &TokenResponseBody{} @@ -171,6 +145,32 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { return nil } +// SetToken can be used to set an access and refresh token manually in the client. +// The other fields in the token field are determined by inspecting the token or setting default values. +// Deprecated +func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { + // We can safely use ParseUnverified because we are not authenticating the user, + // We are parsing the token just to get the expiration time claim + parsedAccessToken, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) + if err != nil { + return fmt.Errorf("parse access token to read expiration time: %w", err) + } + exp, err := parsedAccessToken.Claims.GetExpirationTime() + if err != nil { + return fmt.Errorf("get expiration time from access token: %w", err) + } + + c.tokenMutex.Lock() + c.token = &TokenResponseBody{ + AccessToken: accessToken, + ExpiresIn: int(exp.Time.Unix()), + Scope: "", + TokenType: "Bearer", + } + c.tokenMutex.Unlock() + return nil +} + // Roundtrip performs the request func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) { if c.rt == nil { From 796fb6df46e7e17a1b9bae5a53657d9e17f05889 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Mon, 15 Dec 2025 19:18:07 +0100 Subject: [PATCH 09/18] apply feedback Signed-off-by: Jorge Turrado --- README.md | 47 +++++++++++++++++++++-- examples/authentication/authentication.go | 21 ---------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 69d23ae86..b2331eb88 100644 --- a/README.md +++ b/README.md @@ -105,13 +105,20 @@ To authenticate with the SDK, you need a [service account](https://docs.stackit. The SDK supports two authentication methods: -1. **Key Flow** (Recommended) +1. **Workload Identity Federation Flow** (Recommended) + + - Uses OIDC trusted tokens + - Provides best security through short-lived tokens without secrets + +> NOTE: This flow isn't publicly available yet. It'll be public during Q1 2026 + +2. **Key Flow** (Recommended) - Uses RSA key-pair based authentication - Provides better security through short-lived tokens - Supports both STACKIT-generated and custom key pairs -2. **Token Flow** +3. **Token Flow** - Uses long-lived service account tokens - Simpler but less secure @@ -120,10 +127,42 @@ The SDK supports two authentication methods: The SDK searches for credentials in the following order: 1. Explicit configuration in code -2. Environment variables (KEY_PATH for KEY) +2. Environment variables 3. Credentials file (`$HOME/.stackit/credentials.json`) -For each authentication method, the key flow is attempted first, followed by the token flow. +For each authentication method, the try order is: +1. Workload Identity Federation Flow +2. Key Flow +3. Token Flow + +### Using the Workload Identity Fedearion Flow + +1. Create a service account trusted relation in the STACKIT Portal: + + - Navigate to `Service Accounts` → Select account → `Federated Identity Providers` → Add a Federated Identity Provider + - Configure the trusted issuer and the required assertions to trust in. (Link to official docs here after GA) + +2. Configure authentication using any of these methods: + + **A. Code Configuration** + + ```go + // Using wokload identity federation flow + config.WithWorkloadIdentityFederationAuth() + // With the custom path for the external OIDC token + config.WithWorkloadIdentityFederationTokenPath("/path/to/your/federated/token") + // For the service account + config.WithServiceAccountEmail("my-sa@sa-stackit.cloud") + ``` + + **B. Environment Variables** + + ```bash + # With the custom path for the external OIDC token + STACKIT_FEDERATED_TOKEN_FILE=/path/to/your/federated/token + # For the service account + STACKIT_SERVICE_ACCOUNT_EMAIL=my-sa@sa-stackit.cloud + ``` ### Using the Key Flow diff --git a/examples/authentication/authentication.go b/examples/authentication/authentication.go index b398b19a9..8ec2a84db 100644 --- a/examples/authentication/authentication.go +++ b/examples/authentication/authentication.go @@ -72,27 +72,6 @@ func main() { fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) } - // Create a new API client, that will authenticate using the wif flow - // You need to create a service account key and configure the federate identity provider, - // then you can init the SDK using default env var - os.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "my-sa@sa-stackit.cloud") - os.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "/path/to/your/federated/token") // Default "/var/run/secrets/stackit.cloud/serviceaccount/token" - os.Setenv("STACKIT_IDP_ENDPOINT", "custom token endpoint") // Default "https://accounts.stackit.cloud/oauth/v2/token" - dnsClient, err = dns.NewAPIClient() - if err != nil { - fmt.Fprintf(os.Stderr, "[DNS API] Creating API client: %v\n", err) - os.Exit(1) - } - - // Check that you can make an authenticated request - getZoneResp, err = dnsClient.ListZones(context.Background(), projectId).Execute() - - if err != nil { - fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) - } else { - fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) - } - // Create a new API client, that will authenticate using the wif flow // You need to create a service account key and configure the federate identity provider, // then you can init the SDK setting fields From dc427ec9faa37137fbacc5626e34895261521d78 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Mon, 15 Dec 2025 19:40:02 +0100 Subject: [PATCH 10/18] apply feedback Signed-off-by: Jorge Turrado --- examples/authentication/authentication.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/authentication/authentication.go b/examples/authentication/authentication.go index 8ec2a84db..64758bd87 100644 --- a/examples/authentication/authentication.go +++ b/examples/authentication/authentication.go @@ -14,7 +14,8 @@ func main() { // When creating a new API client without providing any configuration, it will setup default authentication. // The SDK will search for a valid service account key or token in several locations. - // It will first try to use the key flow, by looking into the variables STACKIT_SERVICE_ACCOUNT_KEY, STACKIT_SERVICE_ACCOUNT_KEY_PATH, + // It will first try to use the workload identity federation flow by looking into the variables STACKIT_FEDERATED_TOKEN_FILE, STACKIT_SERVICE_ACCOUNT_EMAIL and their default values, + // Then, it will try key flow, by looking into the variables STACKIT_SERVICE_ACCOUNT_KEY, STACKIT_SERVICE_ACCOUNT_KEY_PATH, // STACKIT_PRIVATE_KEY and STACKIT_PRIVATE_KEY_PATH. If the keys cannot be retrieved, it will check the credentials file located in STACKIT_CREDENTIALS_PATH, if specified, or in // $HOME/.stackit/credentials.json as a fallback. If the key are found and are valid, the KeyAuth flow is used. // If the key flow cannot be used, it will try to find a token in the STACKIT_SERVICE_ACCOUNT_TOKEN. If not present, it will From 6e3169f613d226c95cc780b26780b7da50837f14 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Sun, 21 Dec 2025 22:41:42 +0100 Subject: [PATCH 11/18] remove docs from PR Signed-off-by: Jorge Turrado --- CHANGELOG.md | 1 + README.md | 49 +++-------------------- core/CHANGELOG.md | 1 - examples/authentication/authentication.go | 43 +++----------------- 4 files changed, 11 insertions(+), 83 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e0ddcc5e7..cda81ed70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - `core`: - [v0.21.0](core/CHANGELOG.md#v0210) - **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` + - **Feature:** Support Workload Identity Federation flow - `sfs`: - [v0.2.0](services/sfs/CHANGELOG.md) - **Breaking change:** Remove region configuration in `APIClient` diff --git a/README.md b/README.md index b2331eb88..9ca8dcace 100644 --- a/README.md +++ b/README.md @@ -105,20 +105,13 @@ To authenticate with the SDK, you need a [service account](https://docs.stackit. The SDK supports two authentication methods: -1. **Workload Identity Federation Flow** (Recommended) - - - Uses OIDC trusted tokens - - Provides best security through short-lived tokens without secrets - -> NOTE: This flow isn't publicly available yet. It'll be public during Q1 2026 - -2. **Key Flow** (Recommended) +1. **Key Flow** (Recommended) - Uses RSA key-pair based authentication - Provides better security through short-lived tokens - Supports both STACKIT-generated and custom key pairs -3. **Token Flow** +2. **Token Flow** - Uses long-lived service account tokens - Simpler but less secure @@ -127,42 +120,10 @@ The SDK supports two authentication methods: The SDK searches for credentials in the following order: 1. Explicit configuration in code -2. Environment variables +2. Environment variables (KEY_PATH for KEY) 3. Credentials file (`$HOME/.stackit/credentials.json`) -For each authentication method, the try order is: -1. Workload Identity Federation Flow -2. Key Flow -3. Token Flow - -### Using the Workload Identity Fedearion Flow - -1. Create a service account trusted relation in the STACKIT Portal: - - - Navigate to `Service Accounts` → Select account → `Federated Identity Providers` → Add a Federated Identity Provider - - Configure the trusted issuer and the required assertions to trust in. (Link to official docs here after GA) - -2. Configure authentication using any of these methods: - - **A. Code Configuration** - - ```go - // Using wokload identity federation flow - config.WithWorkloadIdentityFederationAuth() - // With the custom path for the external OIDC token - config.WithWorkloadIdentityFederationTokenPath("/path/to/your/federated/token") - // For the service account - config.WithServiceAccountEmail("my-sa@sa-stackit.cloud") - ``` - - **B. Environment Variables** - - ```bash - # With the custom path for the external OIDC token - STACKIT_FEDERATED_TOKEN_FILE=/path/to/your/federated/token - # For the service account - STACKIT_SERVICE_ACCOUNT_EMAIL=my-sa@sa-stackit.cloud - ``` +For each authentication method, the key flow is attempted first, followed by the token flow. ### Using the Key Flow @@ -273,4 +234,4 @@ See the [release documentation](./RELEASE.md) for further information. ## License -Apache 2.0 +Apache 2.0 \ No newline at end of file diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index aaaa0636c..1e8466cac 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -13,7 +13,6 @@ ## v0.18.0 - **New:** Added duration utils -- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` ## v0.17.3 - **Dependencies:** Bump `github.com/golang-jwt/jwt/v5` from `v5.2.2` to `v5.2.3` diff --git a/examples/authentication/authentication.go b/examples/authentication/authentication.go index 64758bd87..cb0357b19 100644 --- a/examples/authentication/authentication.go +++ b/examples/authentication/authentication.go @@ -14,8 +14,7 @@ func main() { // When creating a new API client without providing any configuration, it will setup default authentication. // The SDK will search for a valid service account key or token in several locations. - // It will first try to use the workload identity federation flow by looking into the variables STACKIT_FEDERATED_TOKEN_FILE, STACKIT_SERVICE_ACCOUNT_EMAIL and their default values, - // Then, it will try key flow, by looking into the variables STACKIT_SERVICE_ACCOUNT_KEY, STACKIT_SERVICE_ACCOUNT_KEY_PATH, + // It will first try to use the key flow, by looking into the variables STACKIT_SERVICE_ACCOUNT_KEY, STACKIT_SERVICE_ACCOUNT_KEY_PATH, // STACKIT_PRIVATE_KEY and STACKIT_PRIVATE_KEY_PATH. If the keys cannot be retrieved, it will check the credentials file located in STACKIT_CREDENTIALS_PATH, if specified, or in // $HOME/.stackit/credentials.json as a fallback. If the key are found and are valid, the KeyAuth flow is used. // If the key flow cannot be used, it will try to find a token in the STACKIT_SERVICE_ACCOUNT_TOKEN. If not present, it will @@ -36,27 +35,18 @@ func main() { // Create a new API client, that will authenticate using the provided bearer token token := "TOKEN" - dnsClient, err := dns.NewAPIClient(config.WithToken(token)) + _, err = dns.NewAPIClient(config.WithToken(token)) if err != nil { fmt.Fprintf(os.Stderr, "[DNS API] Creating API client: %v\n", err) os.Exit(1) } - // Check that you can make an authenticated request - getZoneResp, err := dnsClient.ListZones(context.Background(), projectId).Execute() - - if err != nil { - fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) - } else { - fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) - } - // Create a new API client, that will authenticate using the key flow // If you created a service account key and provided your own RSA key pair, // you need to add the path to a PEM encoded file including the private key // using config.WithPrivateKeyPath("path/to/private_key.pem") saKeyPath := "/path/to/service_account_key.json" - dnsClient, err = dns.NewAPIClient( + dnsClient, err := dns.NewAPIClient( config.WithServiceAccountKeyPath(saKeyPath), ) if err != nil { @@ -65,34 +55,11 @@ func main() { } // Check that you can make an authenticated request - getZoneResp, err = dnsClient.ListZones(context.Background(), projectId).Execute() - - if err != nil { - fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) - } else { - fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) - } - - // Create a new API client, that will authenticate using the wif flow - // You need to create a service account key and configure the federate identity provider, - // then you can init the SDK setting fields - dnsClient, err = dns.NewAPIClient( - config.WithWorkloadIdentityFederationAuth(), - config.WithTokenEndpoint("custom token endpoint"), - config.WithWorkloadIdentityFederationTokenPath("/path/to/your/federated/token"), - config.WithServiceAccountEmail("my-sa@sa-stackit.cloud"), - ) - if err != nil { - fmt.Fprintf(os.Stderr, "[DNS API] Creating API client: %v\n", err) - os.Exit(1) - } - - // Check that you can make an authenticated request - getZoneResp, err = dnsClient.ListZones(context.Background(), projectId).Execute() + getZoneResp, err := dnsClient.ListZones(context.Background(), projectId).Execute() if err != nil { fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) } else { fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) } -} +} \ No newline at end of file From 28d4a02cefa715650eb5d2a57e3583c6050b30ca Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Sun, 21 Dec 2025 23:40:49 +0100 Subject: [PATCH 12/18] remove docs from PR Signed-off-by: Jorge Turrado --- examples/authentication/authentication.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/authentication/authentication.go b/examples/authentication/authentication.go index cb0357b19..839999938 100644 --- a/examples/authentication/authentication.go +++ b/examples/authentication/authentication.go @@ -62,4 +62,4 @@ func main() { } else { fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) } -} \ No newline at end of file +} From 61b2cd4b153cfb4776f852be7d95bfec21ba0a82 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Sun, 21 Dec 2025 23:58:35 +0100 Subject: [PATCH 13/18] remove docs from PR Signed-off-by: Jorge Turrado --- examples/authentication/authentication.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/authentication/authentication.go b/examples/authentication/authentication.go index 839999938..cb0357b19 100644 --- a/examples/authentication/authentication.go +++ b/examples/authentication/authentication.go @@ -62,4 +62,4 @@ func main() { } else { fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) } -} +} \ No newline at end of file From 76a16106f03ab74a9fb724594672dd1539af8614 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Mon, 22 Dec 2025 00:03:43 +0100 Subject: [PATCH 14/18] remove docs from PR Signed-off-by: Jorge Turrado --- examples/authentication/authentication.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/authentication/authentication.go b/examples/authentication/authentication.go index cb0357b19..839999938 100644 --- a/examples/authentication/authentication.go +++ b/examples/authentication/authentication.go @@ -62,4 +62,4 @@ func main() { } else { fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) } -} \ No newline at end of file +} From 845a48f0e406108ed69a2ede2dbb548c846ead85 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Tue, 23 Dec 2025 15:28:03 +0100 Subject: [PATCH 15/18] add static token Signed-off-by: Jorge Turrado --- core/auth/auth.go | 5 +- core/clients/workload_identity_flow.go | 27 +- core/clients/workload_identity_flow_test.go | 314 ++------------------ core/config/config.go | 51 ++-- 4 files changed, 63 insertions(+), 334 deletions(-) diff --git a/core/auth/auth.go b/core/auth/auth.go index e3b10bc46..450361c60 100644 --- a/core/auth/auth.go +++ b/core/auth/auth.go @@ -238,8 +238,9 @@ func WorkloadIdentityFederationAuth(cfg *config.Configuration) (http.RoundTrippe TokenUrl: cfg.TokenCustomUrl, BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext, ClientID: cfg.ServiceAccountEmail, - FederatedTokenFilePath: cfg.WorkloadIdentityFederationFederatedTokenPath, - TokenExpiration: cfg.WorkloadIdentityFederationTokenExpiration, + FederatedTokenFilePath: cfg.ServiceAccountFederatedTokenPath, + TokenExpiration: cfg.ServiceAccountFederatedTokenExpiration, + FederatedToken: cfg.ServiceAccountFederatedToken, } if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil { diff --git a/core/clients/workload_identity_flow.go b/core/clients/workload_identity_flow.go index 65b6fc461..0046ec864 100644 --- a/core/clients/workload_identity_flow.go +++ b/core/clients/workload_identity_flow.go @@ -59,6 +59,7 @@ type WorkloadIdentityFederationFlow struct { type WorkloadIdentityFederationFlowConfig struct { TokenUrl string ClientID string + FederatedToken string // Static token string. This is optional, if not set the token will be read from file. FederatedTokenFilePath string TokenExpiration string // Not supported yet BackgroundTokenRefreshContext context.Context // Functionality is enabled if this isn't nil @@ -139,7 +140,7 @@ func (c *WorkloadIdentityFederationFlow) Init(cfg *WorkloadIdentityFederationFlo c.config.ClientID = getEnvOrDefault(clientIDEnv, "") } - if c.config.FederatedTokenFilePath == "" { + if c.config.FederatedToken == "" && c.config.FederatedTokenFilePath == "" { c.config.FederatedTokenFilePath = getEnvOrDefault(FederatedTokenFileEnv, defaultFederatedTokenPath) } @@ -161,12 +162,6 @@ func (c *WorkloadIdentityFederationFlow) Init(cfg *WorkloadIdentityFederationFlo return err } - // // Init the token - // _, err = c.GetAccessToken() - // if err != nil { - // return err - // } - if c.config.BackgroundTokenRefreshContext != nil { go continuousRefreshToken(c) } @@ -181,8 +176,10 @@ func (c *WorkloadIdentityFederationFlow) validate() error { if c.config.TokenUrl == "" { return fmt.Errorf("token URL cannot be empty") } - if _, err := c.readJWTFromFileSystem(c.config.FederatedTokenFilePath); err != nil { - return fmt.Errorf("error reading federated token file - %w", err) + if c.config.FederatedToken == "" { + if _, err := c.readJWTFromFileSystem(c.config.FederatedTokenFilePath); err != nil { + return fmt.Errorf("error reading federated token file - %w", err) + } } if c.tokenExpirationLeeway < 0 { return fmt.Errorf("token expiration leeway cannot be negative") @@ -192,10 +189,14 @@ func (c *WorkloadIdentityFederationFlow) validate() error { } // createAccessToken creates an access token using self signed JWT -func (c *WorkloadIdentityFederationFlow) createAccessToken() (err error) { - clientAssertion, err := c.readJWTFromFileSystem(c.config.FederatedTokenFilePath) - if err != nil { - return fmt.Errorf("error reading service account assertion - %w", err) +func (c *WorkloadIdentityFederationFlow) createAccessToken() error { + clientAssertion := c.config.FederatedToken + if clientAssertion == "" { + var err error + clientAssertion, err = c.readJWTFromFileSystem(c.config.FederatedTokenFilePath) + if err != nil { + return fmt.Errorf("error reading service account assertion - %w", err) + } } res, err := c.requestToken(c.config.ClientID, clientAssertion) diff --git a/core/clients/workload_identity_flow_test.go b/core/clients/workload_identity_flow_test.go index ef8f7a15f..4a9e07161 100644 --- a/core/clients/workload_identity_flow_test.go +++ b/core/clients/workload_identity_flow_test.go @@ -158,6 +158,7 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { name string clientID string validAssertion bool + injectToken bool wantErr bool }{ { @@ -166,6 +167,13 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { validAssertion: true, wantErr: false, }, + { + name: "injected token ok", + clientID: "test@stackit.cloud", + validAssertion: true, + injectToken: true, + wantErr: false, + }, { name: "invalid assertion", clientID: "test@stackit.cloud", @@ -243,12 +251,6 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { flowConfig.TokenUrl = authServer.URL flowConfig.ClientID = tt.clientID - file, err := os.CreateTemp("", "*.token") - if err != nil { - log.Fatal(err) - } - defer os.Remove(file.Name()) - flowConfig.FederatedTokenFilePath = file.Name() subject := "wrong" if tt.validAssertion { @@ -258,7 +260,18 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { if err != nil { t.Fatalf("failed to create token: %v", err) } - os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + + if tt.injectToken { + flowConfig.FederatedToken = token + } else { + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer os.Remove(file.Name()) + flowConfig.FederatedTokenFilePath = file.Name() + os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + } if err := flow.Init(flowConfig); err != nil { t.Errorf("KeyFlow.Init() error = %v", err) @@ -277,290 +290,3 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { }) } } - -// func TestRequestToken(t *testing.T) { -// testCases := []struct { -// name string -// grant string -// assertion string -// mockResponse *http.Response -// mockError error -// expectedError error -// }{ -// { -// name: "Success", -// grant: "test_grant", -// assertion: "test_assertion", -// mockResponse: &http.Response{ -// StatusCode: 200, -// Body: io.NopCloser(strings.NewReader(`{"access_token": "test_token"}`)), -// }, -// mockError: nil, -// expectedError: nil, -// }, -// { -// name: "Error", -// grant: "test_grant", -// assertion: "test_assertion", -// mockResponse: nil, -// mockError: fmt.Errorf("request error"), -// expectedError: fmt.Errorf("request error"), -// }, -// } - -// for _, tt := range testCases { -// t.Run(tt.name, func(t *testing.T) { -// keyFlow := &KeyFlow{} -// privateKeyBytes, err := generatePrivateKey() -// if err != nil { -// t.Fatalf("Error generating private key: %s", err) -// } -// keyFlowConfig := &KeyFlowConfig{ -// AuthHTTPClient: &http.Client{ -// Transport: mockTransportFn{func(_ *http.Request) (*http.Response, error) { -// return tt.mockResponse, tt.mockError -// }}, -// }, -// ServiceAccountKey: fixtureServiceAccountKey(), -// PrivateKey: string(privateKeyBytes), -// HTTPTransport: http.DefaultTransport, -// } -// err = keyFlow.Init(keyFlowConfig) -// if err != nil { -// t.Fatalf("failed to initialize key flow: %v", err) -// } - -// res, err := keyFlow.requestToken(tt.grant, tt.assertion) -// defer func() { -// if res != nil { -// tempErr := res.Body.Close() -// if tempErr != nil { -// t.Errorf("closing request token response: %s", tempErr.Error()) -// } -// } -// }() -// if tt.expectedError != nil { -// if err == nil { -// t.Errorf("Expected error '%v' but no error was returned", tt.expectedError) -// } else if errors.Is(err, tt.expectedError) { -// t.Errorf("Error is not correct. Expected %v, got %v", tt.expectedError, err) -// } -// } else { -// if err != nil { -// t.Errorf("Expected no error but error was returned: %v", err) -// } -// if !cmp.Equal(tt.mockResponse, res, cmp.AllowUnexported(strings.Reader{})) { -// t.Errorf("The returned result is wrong. Expected %v, got %v", tt.mockResponse, res) -// } -// } -// }) -// } -// } - -// func TestKeyFlow_Do(t *testing.T) { -// t.Parallel() - -// tests := []struct { -// name string -// handlerFn func(tb testing.TB) http.HandlerFunc -// want int -// wantErr bool -// }{ -// { -// name: "success", -// handlerFn: func(tb testing.TB) http.HandlerFunc { -// tb.Helper() - -// return func(w http.ResponseWriter, r *http.Request) { -// if r.Header.Get("Authorization") != "Bearer "+testBearerToken { -// tb.Errorf("expected Authorization header to be 'Bearer %s', but got %s", testBearerToken, r.Header.Get("Authorization")) -// } - -// w.Header().Set("Content-Type", "application/json") -// w.WriteHeader(http.StatusOK) -// _, _ = fmt.Fprintln(w, `{"status":"ok"}`) -// } -// }, -// want: http.StatusOK, -// wantErr: false, -// }, -// { -// name: "success with code 500", -// handlerFn: func(_ testing.TB) http.HandlerFunc { -// return func(w http.ResponseWriter, _ *http.Request) { -// w.Header().Set("Content-Type", "text/html") -// w.WriteHeader(http.StatusInternalServerError) -// _, _ = fmt.Fprintln(w, `Internal Server Error`) -// } -// }, -// want: http.StatusInternalServerError, -// wantErr: false, -// }, -// { -// name: "success with custom transport", -// handlerFn: func(tb testing.TB) http.HandlerFunc { -// tb.Helper() - -// return func(w http.ResponseWriter, r *http.Request) { -// if r.Header.Get("User-Agent") != "custom_transport" { -// tb.Errorf("expected User-Agent header to be 'custom_transport', but got %s", r.Header.Get("User-Agent")) -// } - -// w.Header().Set("Content-Type", "application/json") -// w.WriteHeader(http.StatusOK) -// _, _ = fmt.Fprintln(w, `{"status":"ok"}`) -// } -// }, -// want: http.StatusOK, -// wantErr: false, -// }, -// { -// name: "fail with custom proxy", -// handlerFn: func(testing.TB) http.HandlerFunc { -// return func(w http.ResponseWriter, _ *http.Request) { -// w.Header().Set("Content-Type", "application/json") -// w.WriteHeader(http.StatusOK) -// _, _ = fmt.Fprintln(w, `{"status":"ok"}`) -// } -// }, -// want: 0, -// wantErr: true, -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// ctx := context.Background() -// ctx, cancel := context.WithCancel(ctx) -// t.Cleanup(cancel) // This cancels the refresher goroutine - -// privateKeyBytes, err := generatePrivateKey() -// if err != nil { -// t.Errorf("no error is expected, but got %v", err) -// } - -// keyFlow := &KeyFlow{} -// keyFlowConfig := &KeyFlowConfig{ -// ServiceAccountKey: fixtureServiceAccountKey(), -// PrivateKey: string(privateKeyBytes), -// BackgroundTokenRefreshContext: ctx, -// HTTPTransport: func() http.RoundTripper { -// switch tt.name { -// case "success with custom transport": -// return mockTransportFn{ -// fn: func(req *http.Request) (*http.Response, error) { -// req.Header.Set("User-Agent", "custom_transport") -// return http.DefaultTransport.RoundTrip(req) -// }, -// } -// case "fail with custom proxy": -// return &http.Transport{ -// Proxy: func(_ *http.Request) (*url.URL, error) { -// return nil, fmt.Errorf("proxy error") -// }, -// } -// default: -// return http.DefaultTransport -// } -// }(), -// AuthHTTPClient: &http.Client{ -// Transport: mockTransportFn{ -// fn: func(_ *http.Request) (*http.Response, error) { -// res := httptest.NewRecorder() -// res.WriteHeader(http.StatusOK) -// res.Header().Set("Content-Type", "application/json") - -// token := &TokenResponseBody{ -// AccessToken: testBearerToken, -// ExpiresIn: 2147483647, -// TokenType: "Bearer", -// } - -// if err := json.NewEncoder(res.Body).Encode(token); err != nil { -// t.Logf("no error is expected, but got %v", err) -// } - -// return res.Result(), nil -// }, -// }, -// }, -// } -// err = keyFlow.Init(keyFlowConfig) -// if err != nil { -// t.Fatalf("failed to initialize key flow: %v", err) -// } - -// go continuousRefreshToken(keyFlow) - -// tokenCtx, tokenCancel := context.WithTimeout(context.Background(), 1*time.Second) - -// token: -// for { -// select { -// case <-tokenCtx.Done(): -// t.Error(tokenCtx.Err()) -// case <-time.After(50 * time.Millisecond): -// keyFlow.tokenMutex.RLock() -// if keyFlow.token != nil { -// keyFlow.tokenMutex.RUnlock() -// tokenCancel() -// break token -// } - -// keyFlow.tokenMutex.RUnlock() -// } -// } - -// server := httptest.NewServer(tt.handlerFn(t)) -// t.Cleanup(server.Close) - -// u, err := url.Parse(server.URL) -// if err != nil { -// t.Errorf("no error is expected, but got %v", err) -// } - -// req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) -// if err != nil { -// t.Errorf("no error is expected, but got %v", err) -// } - -// httpClient := &http.Client{ -// Transport: keyFlow, -// } - -// res, err := httpClient.Do(req) - -// if tt.wantErr { -// if err == nil { -// t.Errorf("error is expected, but got %v", err) -// } -// } else { -// if err != nil { -// t.Errorf("no error is expected, but got %v", err) -// } - -// if res.StatusCode != tt.want { -// t.Errorf("expected status code %d, but got %d", tt.want, res.StatusCode) -// } - -// // Defer discard and close the body -// t.Cleanup(func() { -// if _, err := io.Copy(io.Discard, res.Body); err != nil { -// t.Errorf("no error is expected, but got %v", err) -// } - -// if err := res.Body.Close(); err != nil { -// t.Errorf("no error is expected, but got %v", err) -// } -// }) -// } -// }) -// } -// } - -// type mockTransportFn struct { -// fn func(req *http.Request) (*http.Response, error) -// } - -// func (m mockTransportFn) RoundTrip(req *http.Request) (*http.Response, error) { -// return m.fn(req) -// } diff --git a/core/config/config.go b/core/config/config.go index ae2d8c498..dd9dd98f4 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -75,29 +75,30 @@ type Middleware func(http.RoundTripper) http.RoundTripper // Configuration stores the configuration of the API client type Configuration struct { - Host string `json:"host,omitempty"` - Scheme string `json:"scheme,omitempty"` - DefaultHeader map[string]string `json:"defaultHeader,omitempty"` - UserAgent string `json:"userAgent,omitempty"` - Debug bool `json:"debug,omitempty"` - NoAuth bool `json:"noAuth,omitempty"` - WorkloadIdentityFederation bool `json:"workloadIdentityFederation,omitempty"` - WorkloadIdentityFederationTokenExpiration string `json:"workloadIdentityFederationTokenExpiration,omitempty"` - WorkloadIdentityFederationFederatedTokenPath string `json:"workloadIdentityFederationFederatedTokenPath,omitempty"` - ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` - Token string `json:"token,omitempty"` - ServiceAccountKey string `json:"serviceAccountKey,omitempty"` - PrivateKey string `json:"privateKey,omitempty"` - ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` - PrivateKeyPath string `json:"privateKeyPath,omitempty"` - CredentialsFilePath string `json:"credentialsFilePath,omitempty"` - TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` - Region string `json:"region,omitempty"` - CustomAuth http.RoundTripper - Servers ServerConfigurations - OperationServers map[string]ServerConfigurations - HTTPClient *http.Client - Middleware []Middleware + Host string `json:"host,omitempty"` + Scheme string `json:"scheme,omitempty"` + DefaultHeader map[string]string `json:"defaultHeader,omitempty"` + UserAgent string `json:"userAgent,omitempty"` + Debug bool `json:"debug,omitempty"` + NoAuth bool `json:"noAuth,omitempty"` + WorkloadIdentityFederation bool `json:"workloadIdentityFederation,omitempty"` + ServiceAccountFederatedTokenExpiration string `json:"serviceAccountFederatedTokenExpiration,omitempty"` + ServiceAccountFederatedToken string `json:"serviceAccountFederatedToken,omitempty"` + ServiceAccountFederatedTokenPath string `json:"serviceAccountFederatedTokenPath,omitempty"` + ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` + Token string `json:"token,omitempty"` + ServiceAccountKey string `json:"serviceAccountKey,omitempty"` + PrivateKey string `json:"privateKey,omitempty"` + ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` + PrivateKeyPath string `json:"privateKeyPath,omitempty"` + CredentialsFilePath string `json:"credentialsFilePath,omitempty"` + TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` + Region string `json:"region,omitempty"` + CustomAuth http.RoundTripper + Servers ServerConfigurations + OperationServers map[string]ServerConfigurations + HTTPClient *http.Client + Middleware []Middleware // If != nil, a goroutine will be launched that will refresh the service account's access token when it's close to being expired. // The goroutine is killed whenever this context is canceled. @@ -249,7 +250,7 @@ func WithWorkloadIdentityFederationAuth() ConfigurationOption { // WithWorkloadIdentityFederation returns a ConfigurationOption that sets workload identity flow to be used for authentication in API calls func WithWorkloadIdentityFederationTokenPath(path string) ConfigurationOption { return func(config *Configuration) error { - config.WorkloadIdentityFederationFederatedTokenPath = path + config.ServiceAccountFederatedTokenPath = path return nil } } @@ -257,7 +258,7 @@ func WithWorkloadIdentityFederationTokenPath(path string) ConfigurationOption { // WithWorkloadIdentityFederationTokenExpiration returns a ConfigurationOption that sets the token expiration for workload identity federation flow func WithWorkloadIdentityFederationTokenExpiration(expiration string) ConfigurationOption { return func(config *Configuration) error { - config.WorkloadIdentityFederationTokenExpiration = expiration + config.ServiceAccountFederatedTokenExpiration = expiration return nil } } From f540b0b9abd6995ce9b7fd4c55672a2b946c3237 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Wed, 7 Jan 2026 15:39:52 +0100 Subject: [PATCH 16/18] fix linting issues Signed-off-by: Jorge Turrado --- core/clients/key_flow.go | 4 +- .../key_flow_continuous_refresh_test.go | 2 +- core/clients/workload_identity_flow.go | 10 ++--- core/clients/workload_identity_flow_test.go | 37 ++++++++++++++++--- 4 files changed, 39 insertions(+), 14 deletions(-) diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index d18d4f0bf..6de8f4009 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -147,8 +147,8 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { // SetToken can be used to set an access and refresh token manually in the client. // The other fields in the token field are determined by inspecting the token or setting default values. -// Deprecated -func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { +// Deprecated This method will be removed in future versions. Access tokens are now automatically managed by the client. +func (c *KeyFlow) SetToken(accessToken, _ string) error { // We can safely use ParseUnverified because we are not authenticating the user, // We are parsing the token just to get the expiration time claim parsedAccessToken, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) diff --git a/core/clients/key_flow_continuous_refresh_test.go b/core/clients/key_flow_continuous_refresh_test.go index cfd50e763..8dce296a8 100644 --- a/core/clients/key_flow_continuous_refresh_test.go +++ b/core/clients/key_flow_continuous_refresh_test.go @@ -138,7 +138,7 @@ type fakeAuthFlow struct { accessToken string } -func (f *fakeAuthFlow) RoundTrip(req *http.Request) (*http.Response, error) { +func (f *fakeAuthFlow) RoundTrip(_ *http.Request) (*http.Response, error) { return nil, nil } func (f *fakeAuthFlow) GetAccessToken() (string, error) { diff --git a/core/clients/workload_identity_flow.go b/core/clients/workload_identity_flow.go index 0046ec864..8dc13c3f1 100644 --- a/core/clients/workload_identity_flow.go +++ b/core/clients/workload_identity_flow.go @@ -15,14 +15,14 @@ import ( const ( clientIDEnv = "STACKIT_SERVICE_ACCOUNT_EMAIL" - FederatedTokenFileEnv = "STACKIT_FEDERATED_TOKEN_FILE" - wifTokenEndpointEnv = "STACKIT_IDP_ENDPOINT" - wifTokenExpirationEnv = "STACKIT_IDP_EXPIRATION_SECONDS" + FederatedTokenFileEnv = "STACKIT_FEDERATED_TOKEN_FILE" //nolint:gosec // This is not a secret, just the env variable name + wifTokenEndpointEnv = "STACKIT_IDP_ENDPOINT" //nolint:gosec // This is not a secret, just the env variable name + wifTokenExpirationEnv = "STACKIT_IDP_EXPIRATION_SECONDS" //nolint:gosec // This is not a secret, just the env variable name wifClientAssertionType = "urn:schwarz:params:oauth:client-assertion-type:workload-jwt" wifGrantType = "client_credentials" - defaultWifTokenEndpoint = "https://accounts.stackit.cloud/oauth/v2/token" - defaultFederatedTokenPath = "/var/run/secrets/stackit.cloud/serviceaccount/token" + defaultWifTokenEndpoint = "https://accounts.stackit.cloud/oauth/v2/token" //nolint:gosec // This is not a secret, just the public endpoint for default value + defaultFederatedTokenPath = "/var/run/secrets/stackit.cloud/serviceaccount/token" //nolint:gosec // This is not a secret, just the default path for workload identity token defaultWifExpirationToken = "1h" ) diff --git a/core/clients/workload_identity_flow_test.go b/core/clients/workload_identity_flow_test.go index 4a9e07161..4078fa96f 100644 --- a/core/clients/workload_identity_flow_test.go +++ b/core/clients/workload_identity_flow_test.go @@ -95,13 +95,21 @@ func TestWorkloadIdentityFlowInit(t *testing.T) { if err != nil { log.Fatal(err) } - defer os.Remove(file.Name()) + defer func() { + err := os.Remove(file.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() if tt.validAssertion { token, err := signTokenWithSubject("subject", time.Minute) if err != nil { t.Fatalf("failed to create token: %v", err) } - os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + err = os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + if err != nil { + t.Fatalf("writing temporary file: %s", err) + } } if tt.tokenFilePathAsEnv { t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", file.Name()) @@ -184,7 +192,10 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + err := r.ParseForm() + if err != nil { + t.Fatalf("failed to parse form: %v", err) + } assertionType := r.PostForm.Get("client_assertion_type") if assertionType != "urn:schwarz:params:oauth:client-assertion-type:workload-jwt" { t.Fatalf("invalid assertion type: %s", assertionType) @@ -224,7 +235,10 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(payload) + _, err = w.Write(payload) + if err != nil { + t.Fatalf("writing response: %s", err) + } })) t.Cleanup(authServer.Close) @@ -268,9 +282,17 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { if err != nil { log.Fatal(err) } - defer os.Remove(file.Name()) + defer func() { + err := os.Remove(file.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() flowConfig.FederatedTokenFilePath = file.Name() - os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + err = os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + if err != nil { + t.Fatalf("writing temporary file: %s", err) + } } if err := flow.Init(flowConfig); err != nil { @@ -287,6 +309,9 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { if (err != nil || resp.StatusCode != http.StatusOK) && !tt.wantErr { t.Fatalf("failed request to protected resource: %v", err) } + if err := resp.Body.Close(); err != nil { + t.Errorf("resp.Body.Close() error = %v", err) + } }) } } From 675a8b56c2721d9e601409e33e50fe694ee6fd35 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Wed, 7 Jan 2026 16:16:58 +0100 Subject: [PATCH 17/18] fix panic Signed-off-by: Jorge Turrado --- core/clients/workload_identity_flow_test.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/clients/workload_identity_flow_test.go b/core/clients/workload_identity_flow_test.go index 4078fa96f..539553dea 100644 --- a/core/clients/workload_identity_flow_test.go +++ b/core/clients/workload_identity_flow_test.go @@ -309,8 +309,10 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { if (err != nil || resp.StatusCode != http.StatusOK) && !tt.wantErr { t.Fatalf("failed request to protected resource: %v", err) } - if err := resp.Body.Close(); err != nil { - t.Errorf("resp.Body.Close() error = %v", err) + if resp != nil && resp.Body != nil { + if err := resp.Body.Close(); err != nil { + t.Errorf("resp.Body.Close() error = %v", err) + } } }) } From 5e41dc32d226b57d262ea79a55055262ce7f8747 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Wed, 7 Jan 2026 18:29:57 +0100 Subject: [PATCH 18/18] replace wif assertion options with a func Signed-off-by: Jorge Turrado --- core/auth/auth.go | 3 +- core/clients/workload_identity_flow.go | 44 ++++---------- core/clients/workload_identity_flow_test.go | 27 ++++----- core/config/config.go | 64 ++++++++++++++------- core/utils/filesystem.go | 24 ++++++++ core/utils/filesystem_test.go | 62 ++++++++++++++++++++ 6 files changed, 150 insertions(+), 74 deletions(-) create mode 100644 core/utils/filesystem.go create mode 100644 core/utils/filesystem_test.go diff --git a/core/auth/auth.go b/core/auth/auth.go index 450361c60..b393afbb7 100644 --- a/core/auth/auth.go +++ b/core/auth/auth.go @@ -238,9 +238,8 @@ func WorkloadIdentityFederationAuth(cfg *config.Configuration) (http.RoundTrippe TokenUrl: cfg.TokenCustomUrl, BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext, ClientID: cfg.ServiceAccountEmail, - FederatedTokenFilePath: cfg.ServiceAccountFederatedTokenPath, TokenExpiration: cfg.ServiceAccountFederatedTokenExpiration, - FederatedToken: cfg.ServiceAccountFederatedToken, + FederatedTokenFunction: cfg.ServiceAccountFederatedTokenFunc, } if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil { diff --git a/core/clients/workload_identity_flow.go b/core/clients/workload_identity_flow.go index 8dc13c3f1..2491cb775 100644 --- a/core/clients/workload_identity_flow.go +++ b/core/clients/workload_identity_flow.go @@ -10,7 +10,7 @@ import ( "sync" "time" - "github.com/golang-jwt/jwt/v5" + "github.com/stackitcloud/stackit-sdk-go/core/utils" ) const ( @@ -48,8 +48,6 @@ type WorkloadIdentityFederationFlow struct { tokenMutex sync.RWMutex token *TokenResponseBody - parser *jwt.Parser - // If the current access token would expire in less than TokenExpirationLeeway, // the client will refresh it early to prevent clock skew or other timing issues. tokenExpirationLeeway time.Duration @@ -59,12 +57,11 @@ type WorkloadIdentityFederationFlow struct { type WorkloadIdentityFederationFlowConfig struct { TokenUrl string ClientID string - FederatedToken string // Static token string. This is optional, if not set the token will be read from file. - FederatedTokenFilePath string TokenExpiration string // Not supported yet BackgroundTokenRefreshContext context.Context // Functionality is enabled if this isn't nil HTTPTransport http.RoundTripper AuthHTTPClient *http.Client + FederatedTokenFunction func() (string, error) // Function to get the federated token } // GetConfig returns the flow configuration @@ -130,7 +127,6 @@ func (c *WorkloadIdentityFederationFlow) Init(cfg *WorkloadIdentityFederationFlo // No concurrency at this point, so no mutex check needed c.token = &TokenResponseBody{} c.config = cfg - c.parser = jwt.NewParser() if c.config.TokenUrl == "" { c.config.TokenUrl = getEnvOrDefault(wifTokenEndpointEnv, defaultWifTokenEndpoint) @@ -140,8 +136,10 @@ func (c *WorkloadIdentityFederationFlow) Init(cfg *WorkloadIdentityFederationFlo c.config.ClientID = getEnvOrDefault(clientIDEnv, "") } - if c.config.FederatedToken == "" && c.config.FederatedTokenFilePath == "" { - c.config.FederatedTokenFilePath = getEnvOrDefault(FederatedTokenFileEnv, defaultFederatedTokenPath) + if c.config.FederatedTokenFunction == nil { + c.config.FederatedTokenFunction = func() (string, error) { + return utils.ReadJWTFromFileSystem(getEnvOrDefault(FederatedTokenFileEnv, defaultFederatedTokenPath)) + } } c.tokenExpirationLeeway = defaultTokenExpirationLeeway @@ -176,10 +174,8 @@ func (c *WorkloadIdentityFederationFlow) validate() error { if c.config.TokenUrl == "" { return fmt.Errorf("token URL cannot be empty") } - if c.config.FederatedToken == "" { - if _, err := c.readJWTFromFileSystem(c.config.FederatedTokenFilePath); err != nil { - return fmt.Errorf("error reading federated token file - %w", err) - } + if _, err := c.config.FederatedTokenFunction(); err != nil { + return fmt.Errorf("error reading federated token file - %w", err) } if c.tokenExpirationLeeway < 0 { return fmt.Errorf("token expiration leeway cannot be negative") @@ -190,15 +186,10 @@ func (c *WorkloadIdentityFederationFlow) validate() error { // createAccessToken creates an access token using self signed JWT func (c *WorkloadIdentityFederationFlow) createAccessToken() error { - clientAssertion := c.config.FederatedToken - if clientAssertion == "" { - var err error - clientAssertion, err = c.readJWTFromFileSystem(c.config.FederatedTokenFilePath) - if err != nil { - return fmt.Errorf("error reading service account assertion - %w", err) - } + clientAssertion, err := c.config.FederatedTokenFunction() + if err != nil { + return err } - res, err := c.requestToken(c.config.ClientID, clientAssertion) if err != nil { return err @@ -235,16 +226,3 @@ func (c *WorkloadIdentityFederationFlow) requestToken(clientID, assertion string return c.authClient.Do(req) } - -func (c *WorkloadIdentityFederationFlow) readJWTFromFileSystem(tokenFilePath string) (string, error) { - token, err := os.ReadFile(tokenFilePath) - if err != nil { - return "", err - } - tokenStr := string(token) - _, _, err = c.parser.ParseUnverified(tokenStr, jwt.MapClaims{}) - if err != nil { - return "", err - } - return tokenStr, nil -} diff --git a/core/clients/workload_identity_flow_test.go b/core/clients/workload_identity_flow_test.go index 539553dea..d6d43f8e6 100644 --- a/core/clients/workload_identity_flow_test.go +++ b/core/clients/workload_identity_flow_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "github.com/stackitcloud/stackit-sdk-go/core/utils" ) func TestWorkloadIdentityFlowInit(t *testing.T) { @@ -55,12 +56,6 @@ func TestWorkloadIdentityFlowInit(t *testing.T) { validAssertion: true, wantErr: true, }, - { - name: "missing assertion", - clientID: "test@stackit.cloud", - missingTokenFilePath: true, - wantErr: true, - }, { name: "invalid assertion", clientID: "test@stackit.cloud", @@ -114,7 +109,9 @@ func TestWorkloadIdentityFlowInit(t *testing.T) { if tt.tokenFilePathAsEnv { t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", file.Name()) } else { - flowConfig.FederatedTokenFilePath = file.Name() + flowConfig.FederatedTokenFunction = func() (string, error) { + return utils.ReadJWTFromFileSystem(file.Name()) + } } } @@ -137,14 +134,6 @@ func TestWorkloadIdentityFlowInit(t *testing.T) { t.Errorf("tokenUrl mismatch, want %s, got %s", "https://accounts.stackit.cloud/oauth/v2/token", flow.config.TokenUrl) } - if tt.missingTokenFilePath && flow.config.FederatedTokenFilePath != "/var/run/secrets/stackit.cloud/serviceaccount/token" { - t.Errorf("clientID mismatch, want %s, got %s", "/var/run/secrets/stackit.cloud/serviceaccount/token", flow.config.FederatedTokenFilePath) - } - - if !tt.missingTokenFilePath && flow.config.FederatedTokenFilePath == "/var/run/secrets/stackit.cloud/serviceaccount/token" { - t.Errorf("clientID mismatch, want different from %s", flow.config.FederatedTokenFilePath) - } - if tt.tokenExpiration != "" && flow.config.TokenExpiration != tt.tokenExpiration { t.Errorf("tokenExpiration mismatch, want %s, got %s", tt.tokenExpiration, flow.config.TokenExpiration) } @@ -276,7 +265,9 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { } if tt.injectToken { - flowConfig.FederatedToken = token + flowConfig.FederatedTokenFunction = func() (string, error) { + return token, nil + } } else { file, err := os.CreateTemp("", "*.token") if err != nil { @@ -288,7 +279,9 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { t.Fatalf("Removing temporary file: %s", err) } }() - flowConfig.FederatedTokenFilePath = file.Name() + flowConfig.FederatedTokenFunction = func() (string, error) { + return utils.ReadJWTFromFileSystem(file.Name()) + } err = os.WriteFile(file.Name(), []byte(token), os.ModeAppend) if err != nil { t.Fatalf("writing temporary file: %s", err) diff --git a/core/config/config.go b/core/config/config.go index dd9dd98f4..d0d66893f 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -11,6 +11,7 @@ import ( "time" "github.com/stackitcloud/stackit-sdk-go/core/clients" + "github.com/stackitcloud/stackit-sdk-go/core/utils" ) const ( @@ -75,25 +76,24 @@ type Middleware func(http.RoundTripper) http.RoundTripper // Configuration stores the configuration of the API client type Configuration struct { - Host string `json:"host,omitempty"` - Scheme string `json:"scheme,omitempty"` - DefaultHeader map[string]string `json:"defaultHeader,omitempty"` - UserAgent string `json:"userAgent,omitempty"` - Debug bool `json:"debug,omitempty"` - NoAuth bool `json:"noAuth,omitempty"` - WorkloadIdentityFederation bool `json:"workloadIdentityFederation,omitempty"` - ServiceAccountFederatedTokenExpiration string `json:"serviceAccountFederatedTokenExpiration,omitempty"` - ServiceAccountFederatedToken string `json:"serviceAccountFederatedToken,omitempty"` - ServiceAccountFederatedTokenPath string `json:"serviceAccountFederatedTokenPath,omitempty"` - ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` - Token string `json:"token,omitempty"` - ServiceAccountKey string `json:"serviceAccountKey,omitempty"` - PrivateKey string `json:"privateKey,omitempty"` - ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` - PrivateKeyPath string `json:"privateKeyPath,omitempty"` - CredentialsFilePath string `json:"credentialsFilePath,omitempty"` - TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` - Region string `json:"region,omitempty"` + Host string `json:"host,omitempty"` + Scheme string `json:"scheme,omitempty"` + DefaultHeader map[string]string `json:"defaultHeader,omitempty"` + UserAgent string `json:"userAgent,omitempty"` + Debug bool `json:"debug,omitempty"` + NoAuth bool `json:"noAuth,omitempty"` + WorkloadIdentityFederation bool `json:"workloadIdentityFederation,omitempty"` + ServiceAccountFederatedTokenExpiration string `json:"serviceAccountFederatedTokenExpiration,omitempty"` + ServiceAccountFederatedTokenFunc func() (string, error) `json:"serviceAccountFederatedTokenFunc,omitempty"` + ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` + Token string `json:"token,omitempty"` + ServiceAccountKey string `json:"serviceAccountKey,omitempty"` + PrivateKey string `json:"privateKey,omitempty"` + ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` + PrivateKeyPath string `json:"privateKeyPath,omitempty"` + CredentialsFilePath string `json:"credentialsFilePath,omitempty"` + TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` + Region string `json:"region,omitempty"` CustomAuth http.RoundTripper Servers ServerConfigurations OperationServers map[string]ServerConfigurations @@ -247,10 +247,30 @@ func WithWorkloadIdentityFederationAuth() ConfigurationOption { } } -// WithWorkloadIdentityFederation returns a ConfigurationOption that sets workload identity flow to be used for authentication in API calls -func WithWorkloadIdentityFederationTokenPath(path string) ConfigurationOption { +// WithWorkloadIdentityFederationFunc returns a ConfigurationOption that sets the function to get the federated token for workload identity federation flow +func WithWorkloadIdentityFederationFunc(function func() (string, error)) ConfigurationOption { return func(config *Configuration) error { - config.ServiceAccountFederatedTokenPath = path + config.ServiceAccountFederatedTokenFunc = function + return nil + } +} + +// WithWorkloadIdentityFederationPath returns a ConfigurationOption that sets the custom path to the federated token file for workload identity federation flow +func WithWorkloadIdentityFederationPath(path string) ConfigurationOption { + return func(config *Configuration) error { + config.ServiceAccountFederatedTokenFunc = func() (string, error) { + return utils.ReadJWTFromFileSystem(path) + } + return nil + } +} + +// WithWorkloadIdentityFederationFunc returns a ConfigurationOption that sets the id token for workload identity federation flow +func WithWorkloadIdentityFederationToken(token string) ConfigurationOption { + return func(config *Configuration) error { + config.ServiceAccountFederatedTokenFunc = func() (string, error) { + return token, nil + } return nil } } diff --git a/core/utils/filesystem.go b/core/utils/filesystem.go new file mode 100644 index 000000000..44de8326b --- /dev/null +++ b/core/utils/filesystem.go @@ -0,0 +1,24 @@ +package utils + +import ( + "os" + + "github.com/golang-jwt/jwt/v5" +) + +var ( + parser *jwt.Parser = jwt.NewParser() +) + +func ReadJWTFromFileSystem(tokenFilePath string) (string, error) { + token, err := os.ReadFile(tokenFilePath) + if err != nil { + return "", err + } + tokenStr := string(token) + _, _, err = parser.ParseUnverified(tokenStr, jwt.MapClaims{}) + if err != nil { + return "", err + } + return tokenStr, nil +} diff --git a/core/utils/filesystem_test.go b/core/utils/filesystem_test.go new file mode 100644 index 000000000..e98f30bdc --- /dev/null +++ b/core/utils/filesystem_test.go @@ -0,0 +1,62 @@ +package utils + +import ( + "log" + "os" + "testing" +) + +func TestReadJWTFromFileSystem(t *testing.T) { + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer func() { + err := os.Remove(file.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + + token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.KMUFsIDTnFmyG3nMiGM6H9FNFUROf3wh7SmqJp-QV30" // nolint:gosec // This is a fake token for testing purposes only + err = os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + if err != nil { + t.Fatalf("Writing temporary file: %s", err) + } + + _, err = ReadJWTFromFileSystem(file.Name()) + if err != nil { + t.Fatalf("Reading JWT from file system: %s", err) + } +} + +func TestReadRandomContentFromFileSystem(t *testing.T) { + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer func() { + err := os.Remove(file.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + + token := "invalid random content" + err = os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + if err != nil { + t.Fatalf("Writing temporary file: %s", err) + } + + _, err = ReadJWTFromFileSystem(file.Name()) + if err == nil { + t.Fatalf("Reading JWT from file system must fail") + } +} + +func TestReadMissingFileFromFileSystem(t *testing.T) { + _, err := ReadJWTFromFileSystem("/path/to/nonexistent/file.token") + if err == nil { + t.Fatalf("Reading JWT from file system must fail") + } +}