1- # Copyright (c) Microsoft Corporation.
2- # Licensed under the MIT License.
3- from azure .identity import DefaultAzureCredential , ManagedIdentityCredential
4- from datetime import datetime , timedelta , timezone
5- from typing import Optional
6- import durabletask .internal .shared as shared
7-
8- # By default, when there's 10minutes left before the token expires, refresh the token
9- class AccessTokenManager :
10- def __init__ (self , refresh_interval_seconds : int = 600 , metadata : Optional [list [tuple [str , str ]]] = None ):
11- self .scope = "https://durabletask.io/.default"
12- self .refresh_interval_seconds = refresh_interval_seconds
13- self ._use_managed_identity = False
14- self ._metadata = metadata
15- self ._client_id = None
16- self ._logger = shared .get_logger ("token_manager" )
17-
18- if metadata : # Ensure metadata is not None
19- for key , value in metadata :
20- if key == "use_managed_identity" :
21- self ._use_managed_identity = value .lower () == "true" # Properly convert string to bool
22- elif key == "client_id" :
23- self ._client_id = value # Directly assign string
24-
25- # Choose the appropriate credential based on use_managed_identity
26- if self ._use_managed_identity :
27- if not self ._client_id :
28- self ._logger .debug ("Using System Assigned Managed Identity for authentication." )
29- self .credential = ManagedIdentityCredential ()
30- else :
31- self ._logger .debug ("Using User Assigned Managed Identity for authentication." )
32- self .credential = ManagedIdentityCredential (client_id = self ._client_id )
33- else :
34- self .credential = DefaultAzureCredential ()
35- self ._logger .debug ("Using Default Azure Credentials for authentication." )
36-
37- self .token = None
38- self .expiry_time = None
39-
40- def get_access_token (self ) -> str :
41- if self .token is None or self .is_token_expired ():
42- self .refresh_token ()
43- return self .token
44-
45- # Checks if the token is expired, or if it will expire in the next "refresh_interval_seconds" seconds.
46- # For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes,
47- # We will grab a new token when there're 30minutes left on the lifespan of the token
48- def is_token_expired (self ) -> bool :
49- if self .expiry_time is None :
50- return True
51- return datetime .now (timezone .utc ) >= (self .expiry_time - timedelta (seconds = self .refresh_interval_seconds ))
52-
53- def refresh_token (self ):
54- new_token = self .credential .get_token (self .scope )
55- self .token = f"Bearer { new_token .token } "
56-
57- # Convert UNIX timestamp to timezone-aware datetime
58- self .expiry_time = datetime .fromtimestamp (new_token .expires_on , tz = timezone .utc )
1+ # Copyright (c) Microsoft Corporation.
2+ # Licensed under the MIT License.
3+ from azure .identity import DefaultAzureCredential , ManagedIdentityCredential
4+ from datetime import datetime , timedelta , timezone
5+ from typing import Optional
6+ import durabletask .internal .shared as shared
7+
8+ # By default, when there's 10minutes left before the token expires, refresh the token
9+ class AccessTokenManager :
10+ def __init__ (self , refresh_interval_seconds : int = 600 , use_managed_identity : bool = False , client_id : str = None ):
11+ self .scope = "https://durabletask.io/.default"
12+ self .refresh_interval_seconds = refresh_interval_seconds
13+ self ._use_managed_identity = use_managed_identity
14+ self ._client_id = client_id
15+ self ._logger = shared .get_logger ("token_manager" )
16+
17+ # Choose the appropriate credential based on use_managed_identity
18+ if self ._use_managed_identity :
19+ if not self ._client_id :
20+ self ._logger .debug ("Using System Assigned Managed Identity for authentication." )
21+ self .credential = ManagedIdentityCredential ()
22+ else :
23+ self ._logger .debug ("Using User Assigned Managed Identity for authentication." )
24+ self .credential = ManagedIdentityCredential (client_id = self ._client_id )
25+ else :
26+ self .credential = DefaultAzureCredential ()
27+ self ._logger .debug ("Using Default Azure Credentials for authentication." )
28+
29+ self .token = None
30+ self .expiry_time = None
31+
32+ def get_access_token (self ) -> str :
33+ if self .token is None or self .is_token_expired ():
34+ self .refresh_token ()
35+ return self .token
36+
37+ # Checks if the token is expired, or if it will expire in the next "refresh_interval_seconds" seconds.
38+ # For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes,
39+ # We will grab a new token when there're 30minutes left on the lifespan of the token
40+ def is_token_expired (self ) -> bool :
41+ if self .expiry_time is None :
42+ return True
43+ return datetime .now (timezone .utc ) >= (self .expiry_time - timedelta (seconds = self .refresh_interval_seconds ))
44+
45+ def refresh_token (self ):
46+ new_token = self .credential .get_token (self .scope )
47+ self .token = f"Bearer { new_token .token } "
48+
49+ # Convert UNIX timestamp to timezone-aware datetime
50+ self .expiry_time = datetime .fromtimestamp (new_token .expires_on , tz = timezone .utc )
5951 self ._logger .debug (f"Token refreshed. Expires at: { self .expiry_time } " )
0 commit comments