Skip to content

Commit ba1ac4f

Browse files
committed
More review feedback
Signed-off-by: Ryan Lettieri <ryanLettieri@microsoft.com>
1 parent f9d55ab commit ba1ac4f

File tree

7 files changed

+228
-259
lines changed

7 files changed

+228
-259
lines changed
Lines changed: 46 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,46 @@
1-
# Copyright (c) Microsoft Corporation.
2-
# Licensed under the MIT License.
3-
4-
from typing import Optional
5-
from durabletask.client import TaskHubGrpcClient
6-
from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager
7-
from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl
8-
9-
# Client class used for Durable Task Scheduler (DTS)
10-
class DurableTaskSchedulerClient(TaskHubGrpcClient):
11-
def __init__(self, *,
12-
host_address: str,
13-
taskhub: str,
14-
secure_channel: Optional[bool] = True,
15-
metadata: Optional[list[tuple[str, str]]] = None,
16-
use_managed_identity: Optional[bool] = False,
17-
client_id: Optional[str] = None):
18-
19-
if taskhub == None:
20-
raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub")
21-
22-
# Ensure metadata is a list
23-
metadata = metadata or []
24-
self._metadata = metadata.copy() # Use a copy to avoid modifying original
25-
26-
# Append DurableTask-specific metadata
27-
self._metadata.append(("taskhub", taskhub))
28-
self._metadata.append(("dts", "True"))
29-
self._metadata.append(("use_managed_identity", str(use_managed_identity)))
30-
self._metadata.append(("client_id", str(client_id or "None")))
31-
32-
self._access_token_manager = AccessTokenManager(metadata=self._metadata)
33-
self.__update_metadata_with_token()
34-
self._interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)]
35-
36-
# We pass in None for the metadata so we don't construct an additional interceptor in the parent class
37-
# Since the parent class doesn't use anything metadata for anything else, we can set it as None
38-
super().__init__(
39-
host_address=host_address,
40-
secure_channel=secure_channel,
41-
metadata=None,
42-
interceptors=self._interceptors)
43-
44-
def __update_metadata_with_token(self):
45-
"""
46-
Add or update the `authorization` key in the metadata with the current access token.
47-
"""
48-
token = self._access_token_manager.get_access_token()
49-
50-
# Ensure that self._metadata is initialized
51-
if self._metadata is None:
52-
self._metadata = [] # Initialize it if it's still None
53-
54-
# Check if "authorization" already exists in the metadata
55-
updated = False
56-
for i, (key, _) in enumerate(self._metadata):
57-
if key == "authorization":
58-
self._metadata[i] = ("authorization", token)
59-
updated = True
60-
break
61-
62-
# If not updated, add a new entry
63-
if not updated:
64-
self._metadata.append(("authorization", token))
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from typing import Optional
5+
from durabletask.client import TaskHubGrpcClient, OrchestrationStatus
6+
from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager
7+
from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl
8+
from azure.identity import DefaultAzureCredential
9+
10+
# Client class used for Durable Task Scheduler (DTS)
11+
class DurableTaskSchedulerClient(TaskHubGrpcClient):
12+
def __init__(self, *,
13+
host_address: str,
14+
taskhub: str,
15+
secure_channel: Optional[bool] = True,
16+
metadata: Optional[list[tuple[str, str]]] = None,
17+
use_managed_identity: Optional[bool] = False,
18+
client_id: Optional[str] = None):
19+
20+
if taskhub == None:
21+
raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub")
22+
23+
# Ensure metadata is a list
24+
metadata = metadata or []
25+
self._metadata = metadata.copy() # Use a copy to avoid modifying original
26+
27+
# Append DurableTask-specific metadata
28+
self._metadata.append(("taskhub", taskhub))
29+
self._metadata.append(("dts", "True"))
30+
self._metadata.append(("use_managed_identity", str(use_managed_identity)))
31+
self._metadata.append(("client_id", str(client_id or "None")))
32+
33+
self._access_token_manager = AccessTokenManager(use_managed_identity=use_managed_identity,
34+
client_id=client_id)
35+
token = self._access_token_manager.get_access_token()
36+
self._metadata.append(("authorization", token))
37+
38+
self._interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)]
39+
40+
# We pass in None for the metadata so we don't construct an additional interceptor in the parent class
41+
# Since the parent class doesn't use anything metadata for anything else, we can set it as None
42+
super().__init__(
43+
host_address=host_address,
44+
secure_channel=secure_channel,
45+
metadata=None,
46+
interceptors=self._interceptors)
Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,42 @@
1-
# Copyright (c) Microsoft Corporation.
2-
# Licensed under the MIT License.
3-
4-
from durabletask.internal.grpc_interceptor import _ClientCallDetails, DefaultClientInterceptorImpl
5-
from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager
6-
7-
import grpc
8-
9-
class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl):
10-
"""The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
11-
StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an
12-
interceptor to add additional headers to all calls as needed."""
13-
14-
def __init__(self, metadata: list[tuple[str, str]]):
15-
super().__init__(metadata)
16-
self._token_manager = AccessTokenManager(metadata=self._metadata)
17-
18-
def _intercept_call(
19-
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
20-
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
21-
call details."""
22-
# Refresh the auth token if it is present and needed
23-
if self._metadata is not None:
24-
for i, (key, _) in enumerate(self._metadata):
25-
if key.lower() == "authorization": # Ensure case-insensitive comparison
26-
new_token = self._token_manager.get_access_token() # Get the new token
27-
self._metadata[i] = ("authorization", new_token) # Update the token
28-
29-
return super()._intercept_call(client_call_details)
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from durabletask.internal.grpc_interceptor import _ClientCallDetails, DefaultClientInterceptorImpl
5+
from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager
6+
7+
import grpc
8+
9+
class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl):
10+
"""The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
11+
StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an
12+
interceptor to add additional headers to all calls as needed."""
13+
14+
def __init__(self, metadata: list[tuple[str, str]]):
15+
super().__init__(metadata)
16+
17+
use_managed_identity = False
18+
client_id = None
19+
20+
# Check what authentication we are using
21+
if metadata:
22+
for key, value in metadata:
23+
if key.lower() == "use_managed_identity":
24+
self.use_managed_identity = value.strip().lower() == "true" # Convert to boolean
25+
elif key.lower() == "client_id":
26+
self.client_id = value
27+
28+
self._token_manager = AccessTokenManager(use_managed_identity=use_managed_identity,
29+
client_id=client_id)
30+
31+
def _intercept_call(
32+
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
33+
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
34+
call details."""
35+
# Refresh the auth token if it is present and needed
36+
if self._metadata is not None:
37+
for i, (key, _) in enumerate(self._metadata):
38+
if key.lower() == "authorization": # Ensure case-insensitive comparison
39+
new_token = self._token_manager.get_access_token() # Get the new token
40+
self._metadata[i] = ("authorization", new_token) # Update the token
41+
42+
return super()._intercept_call(client_call_details)
Lines changed: 50 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,51 @@
1-
# Copyright (c) Microsoft Corporation.
2-
# Licensed under the MIT License.
3-
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
4-
from datetime import datetime, timedelta, timezone
5-
from typing import Optional
6-
import durabletask.internal.shared as shared
7-
8-
# By default, when there's 10minutes left before the token expires, refresh the token
9-
class AccessTokenManager:
10-
def __init__(self, refresh_interval_seconds: int = 600, metadata: Optional[list[tuple[str, str]]] = None):
11-
self.scope = "https://durabletask.io/.default"
12-
self.refresh_interval_seconds = refresh_interval_seconds
13-
self._use_managed_identity = False
14-
self._metadata = metadata
15-
self._client_id = None
16-
self._logger = shared.get_logger("token_manager")
17-
18-
if metadata: # Ensure metadata is not None
19-
for key, value in metadata:
20-
if key == "use_managed_identity":
21-
self._use_managed_identity = value.lower() == "true" # Properly convert string to bool
22-
elif key == "client_id":
23-
self._client_id = value # Directly assign string
24-
25-
# Choose the appropriate credential based on use_managed_identity
26-
if self._use_managed_identity:
27-
if not self._client_id:
28-
self._logger.debug("Using System Assigned Managed Identity for authentication.")
29-
self.credential = ManagedIdentityCredential()
30-
else:
31-
self._logger.debug("Using User Assigned Managed Identity for authentication.")
32-
self.credential = ManagedIdentityCredential(client_id=self._client_id)
33-
else:
34-
self.credential = DefaultAzureCredential()
35-
self._logger.debug("Using Default Azure Credentials for authentication.")
36-
37-
self.token = None
38-
self.expiry_time = None
39-
40-
def get_access_token(self) -> str:
41-
if self.token is None or self.is_token_expired():
42-
self.refresh_token()
43-
return self.token
44-
45-
# Checks if the token is expired, or if it will expire in the next "refresh_interval_seconds" seconds.
46-
# For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes,
47-
# We will grab a new token when there're 30minutes left on the lifespan of the token
48-
def is_token_expired(self) -> bool:
49-
if self.expiry_time is None:
50-
return True
51-
return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self.refresh_interval_seconds))
52-
53-
def refresh_token(self):
54-
new_token = self.credential.get_token(self.scope)
55-
self.token = f"Bearer {new_token.token}"
56-
57-
# Convert UNIX timestamp to timezone-aware datetime
58-
self.expiry_time = datetime.fromtimestamp(new_token.expires_on, tz=timezone.utc)
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
4+
from datetime import datetime, timedelta, timezone
5+
from typing import Optional
6+
import durabletask.internal.shared as shared
7+
8+
# By default, when there's 10minutes left before the token expires, refresh the token
9+
class AccessTokenManager:
10+
def __init__(self, refresh_interval_seconds: int = 600, use_managed_identity: bool = False, client_id: str = None):
11+
self.scope = "https://durabletask.io/.default"
12+
self.refresh_interval_seconds = refresh_interval_seconds
13+
self._use_managed_identity = use_managed_identity
14+
self._client_id = client_id
15+
self._logger = shared.get_logger("token_manager")
16+
17+
# Choose the appropriate credential based on use_managed_identity
18+
if self._use_managed_identity:
19+
if not self._client_id:
20+
self._logger.debug("Using System Assigned Managed Identity for authentication.")
21+
self.credential = ManagedIdentityCredential()
22+
else:
23+
self._logger.debug("Using User Assigned Managed Identity for authentication.")
24+
self.credential = ManagedIdentityCredential(client_id=self._client_id)
25+
else:
26+
self.credential = DefaultAzureCredential()
27+
self._logger.debug("Using Default Azure Credentials for authentication.")
28+
29+
self.token = None
30+
self.expiry_time = None
31+
32+
def get_access_token(self) -> str:
33+
if self.token is None or self.is_token_expired():
34+
self.refresh_token()
35+
return self.token
36+
37+
# Checks if the token is expired, or if it will expire in the next "refresh_interval_seconds" seconds.
38+
# For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes,
39+
# We will grab a new token when there're 30minutes left on the lifespan of the token
40+
def is_token_expired(self) -> bool:
41+
if self.expiry_time is None:
42+
return True
43+
return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self.refresh_interval_seconds))
44+
45+
def refresh_token(self):
46+
new_token = self.credential.get_token(self.scope)
47+
self.token = f"Bearer {new_token.token}"
48+
49+
# Convert UNIX timestamp to timezone-aware datetime
50+
self.expiry_time = datetime.fromtimestamp(new_token.expires_on, tz=timezone.utc)
5951
self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}")

0 commit comments

Comments
 (0)