Skip to content

Commit 92af27e

Browse files
authored
Merge pull request #40 from acroca/workflow-versioning
Workflow versioning
2 parents 6554a4f + c53b244 commit 92af27e

File tree

9 files changed

+615
-260
lines changed

9 files changed

+615
-260
lines changed

durabletask/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class OrchestrationStatus(Enum):
3333
PENDING = pb.ORCHESTRATION_STATUS_PENDING
3434
SUSPENDED = pb.ORCHESTRATION_STATUS_SUSPENDED
3535
CANCELED = pb.ORCHESTRATION_STATUS_CANCELED
36+
STALLED = pb.ORCHESTRATION_STATUS_STALLED
3637

3738
def __str__(self):
3839
return helpers.get_orchestration_status_str(self.value)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
4b86756497d875b97f9a91051781b5711c1e4fa6
1+
889781bbe90e6ec84ebe169978c4f2fd0df74ff0

durabletask/internal/helpers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,15 @@ def new_complete_orchestration_action(
188188
)
189189

190190

191+
def new_orchestrator_version_not_available_action(
192+
id: int,
193+
) -> pb.OrchestratorAction:
194+
return pb.OrchestratorAction(
195+
id=id,
196+
orchestratorVersionNotAvailable=pb.OrchestratorVersionNotAvailableAction(),
197+
)
198+
199+
191200
def new_create_timer_action(id: int, fire_at: datetime) -> pb.OrchestratorAction:
192201
timestamp = timestamp_pb2.Timestamp()
193202
timestamp.FromDatetime(fire_at)

durabletask/internal/orchestrator_service_pb2.py

Lines changed: 245 additions & 227 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

durabletask/internal/orchestrator_service_pb2.pyi

Lines changed: 99 additions & 18 deletions
Large diffs are not rendered by default.

durabletask/internal/orchestrator_service_pb2_grpc.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,16 @@ def __init__(self, channel):
170170
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RerunWorkflowFromEventRequest.SerializeToString,
171171
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RerunWorkflowFromEventResponse.FromString,
172172
_registered_method=True)
173+
self.ListInstanceIDs = channel.unary_unary(
174+
'/TaskHubSidecarService/ListInstanceIDs',
175+
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIDsRequest.SerializeToString,
176+
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIDsResponse.FromString,
177+
_registered_method=True)
178+
self.GetInstanceHistory = channel.unary_unary(
179+
'/TaskHubSidecarService/GetInstanceHistory',
180+
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceHistoryRequest.SerializeToString,
181+
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceHistoryResponse.FromString,
182+
_registered_method=True)
173183

174184

175185
class TaskHubSidecarServiceServicer(object):
@@ -360,6 +370,18 @@ def RerunWorkflowFromEvent(self, request, context):
360370
context.set_details('Method not implemented!')
361371
raise NotImplementedError('Method not implemented!')
362372

373+
def ListInstanceIDs(self, request, context):
374+
"""Missing associated documentation comment in .proto file."""
375+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
376+
context.set_details('Method not implemented!')
377+
raise NotImplementedError('Method not implemented!')
378+
379+
def GetInstanceHistory(self, request, context):
380+
"""Missing associated documentation comment in .proto file."""
381+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
382+
context.set_details('Method not implemented!')
383+
raise NotImplementedError('Method not implemented!')
384+
363385

364386
def add_TaskHubSidecarServiceServicer_to_server(servicer, server):
365387
rpc_method_handlers = {
@@ -498,6 +520,16 @@ def add_TaskHubSidecarServiceServicer_to_server(servicer, server):
498520
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RerunWorkflowFromEventRequest.FromString,
499521
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RerunWorkflowFromEventResponse.SerializeToString,
500522
),
523+
'ListInstanceIDs': grpc.unary_unary_rpc_method_handler(
524+
servicer.ListInstanceIDs,
525+
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIDsRequest.FromString,
526+
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIDsResponse.SerializeToString,
527+
),
528+
'GetInstanceHistory': grpc.unary_unary_rpc_method_handler(
529+
servicer.GetInstanceHistory,
530+
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceHistoryRequest.FromString,
531+
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceHistoryResponse.SerializeToString,
532+
),
501533
}
502534
generic_handler = grpc.method_handlers_generic_handler(
503535
'TaskHubSidecarService', rpc_method_handlers)
@@ -1237,3 +1269,57 @@ def RerunWorkflowFromEvent(request,
12371269
timeout,
12381270
metadata,
12391271
_registered_method=True)
1272+
1273+
@staticmethod
1274+
def ListInstanceIDs(request,
1275+
target,
1276+
options=(),
1277+
channel_credentials=None,
1278+
call_credentials=None,
1279+
insecure=False,
1280+
compression=None,
1281+
wait_for_ready=None,
1282+
timeout=None,
1283+
metadata=None):
1284+
return grpc.experimental.unary_unary(
1285+
request,
1286+
target,
1287+
'/TaskHubSidecarService/ListInstanceIDs',
1288+
durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIDsRequest.SerializeToString,
1289+
durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIDsResponse.FromString,
1290+
options,
1291+
channel_credentials,
1292+
insecure,
1293+
call_credentials,
1294+
compression,
1295+
wait_for_ready,
1296+
timeout,
1297+
metadata,
1298+
_registered_method=True)
1299+
1300+
@staticmethod
1301+
def GetInstanceHistory(request,
1302+
target,
1303+
options=(),
1304+
channel_credentials=None,
1305+
call_credentials=None,
1306+
insecure=False,
1307+
compression=None,
1308+
wait_for_ready=None,
1309+
timeout=None,
1310+
metadata=None):
1311+
return grpc.experimental.unary_unary(
1312+
request,
1313+
target,
1314+
'/TaskHubSidecarService/GetInstanceHistory',
1315+
durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceHistoryRequest.SerializeToString,
1316+
durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceHistoryResponse.FromString,
1317+
options,
1318+
channel_credentials,
1319+
insecure,
1320+
call_credentials,
1321+
compression,
1322+
wait_for_ready,
1323+
timeout,
1324+
metadata,
1325+
_registered_method=True)

durabletask/task.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,22 @@ def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None:
189189
"""
190190
pass
191191

192+
@abstractmethod
193+
def is_patched(self, patch_name: str) -> bool:
194+
"""Check if the given patch name can be applied to the orchestration.
195+
196+
Parameters
197+
----------
198+
patch_name : str
199+
The name of the patch to check.
200+
201+
Returns
202+
-------
203+
bool
204+
True if the given patch name can be applied to the orchestration, False otherwise.
205+
"""
206+
pass
207+
192208

193209
class FailureDetails:
194210
def __init__(self, message: str, error_type: str, stack_trace: Optional[str]):

durabletask/worker.py

Lines changed: 113 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
TInput = TypeVar("TInput")
2727
TOutput = TypeVar("TOutput")
2828

29+
class VersionNotRegisteredException(Exception):
30+
pass
2931

3032
def _log_all_threads(logger: logging.Logger, context: str = ""):
3133
"""Helper function to log all currently active threads for debugging."""
@@ -88,30 +90,58 @@ def __init__(
8890

8991
class _Registry:
9092
orchestrators: dict[str, task.Orchestrator]
93+
versioned_orchestrators: dict[str, dict[str, task.Orchestrator]]
94+
latest_versioned_orchestrators_version_name: dict[str, str]
9195
activities: dict[str, task.Activity]
9296

9397
def __init__(self):
9498
self.orchestrators = {}
99+
self.versioned_orchestrators = {}
100+
self.latest_versioned_orchestrators_version_name = {}
95101
self.activities = {}
96102

97-
def add_orchestrator(self, fn: task.Orchestrator) -> str:
103+
def add_orchestrator(self, fn: task.Orchestrator, version_name: Optional[str] = None, is_latest: bool = False) -> str:
98104
if fn is None:
99105
raise ValueError("An orchestrator function argument is required.")
100106

101107
name = task.get_name(fn)
102-
self.add_named_orchestrator(name, fn)
108+
self.add_named_orchestrator(name, fn, version_name, is_latest)
103109
return name
104110

105-
def add_named_orchestrator(self, name: str, fn: task.Orchestrator) -> None:
111+
def add_named_orchestrator(self, name: str, fn: task.Orchestrator, version_name: Optional[str] = None, is_latest: bool = False) -> None:
106112
if not name:
107113
raise ValueError("A non-empty orchestrator name is required.")
114+
115+
if version_name is None:
116+
if name in self.orchestrators:
117+
raise ValueError(f"A '{name}' orchestrator already exists.")
118+
self.orchestrators[name] = fn
119+
else:
120+
if name not in self.versioned_orchestrators:
121+
self.versioned_orchestrators[name] = {}
122+
if version_name in self.versioned_orchestrators[name]:
123+
raise ValueError(f"The version '{version_name}' of '{name}' orchestrator already exists.")
124+
self.versioned_orchestrators[name][version_name] = fn
125+
if is_latest:
126+
self.latest_versioned_orchestrators_version_name[name] = version_name
127+
128+
def get_orchestrator(self, name: str, version_name: Optional[str] = None) -> Optional[tuple[task.Orchestrator, str]]:
108129
if name in self.orchestrators:
109-
raise ValueError(f"A '{name}' orchestrator already exists.")
130+
return self.orchestrators.get(name), None
131+
132+
if name in self.versioned_orchestrators:
133+
if version_name:
134+
version_to_use = version_name
135+
elif name in self.latest_versioned_orchestrators_version_name:
136+
version_to_use = self.latest_versioned_orchestrators_version_name[name]
137+
else:
138+
return None, None
110139

111-
self.orchestrators[name] = fn
140+
if version_to_use not in self.versioned_orchestrators[name]:
141+
raise VersionNotRegisteredException
142+
return self.versioned_orchestrators[name].get(version_to_use), version_to_use
112143

113-
def get_orchestrator(self, name: str) -> Optional[task.Orchestrator]:
114-
return self.orchestrators.get(name)
144+
return None, None
115145

116146
def add_activity(self, fn: task.Activity) -> str:
117147
if fn is None:
@@ -721,11 +751,22 @@ def _execute_orchestrator(
721751
try:
722752
executor = _OrchestrationExecutor(self._registry, self._logger)
723753
result = executor.execute(req.instanceId, req.pastEvents, req.newEvents)
754+
755+
version = None
756+
if result.version_name:
757+
version = version or pb.OrchestrationVersion()
758+
version.name = result.version_name
759+
if result.patches:
760+
version = version or pb.OrchestrationVersion()
761+
version.patches.extend(result.patches)
762+
763+
724764
res = pb.OrchestratorResponse(
725765
instanceId=req.instanceId,
726766
actions=result.actions,
727767
customStatus=ph.get_string_value(result.encoded_custom_status),
728768
completionToken=completionToken,
769+
version=version,
729770
)
730771
except Exception as ex:
731772
self._logger.exception(
@@ -810,6 +851,11 @@ def __init__(self, instance_id: str):
810851
self._new_input: Optional[Any] = None
811852
self._save_events = False
812853
self._encoded_custom_status: Optional[str] = None
854+
self._orchestrator_version_name: Optional[str] = None
855+
self._version_name: Optional[str] = None
856+
self._history_patches: dict[str, bool] = {}
857+
self._applied_patches: dict[str, bool] = {}
858+
self._encountered_patches: list[str] = []
813859

814860
def run(self, generator: Generator[task.Task, Any, Any]):
815861
self._generator = generator
@@ -886,6 +932,14 @@ def set_failed(self, ex: Exception):
886932
)
887933
self._pending_actions[action.id] = action
888934

935+
936+
def set_version_not_registered(self):
937+
self._pending_actions.clear()
938+
self._completion_status = pb.ORCHESTRATION_STATUS_STALLED
939+
action = ph.new_orchestrator_version_not_available_action(self.next_sequence_number())
940+
self._pending_actions[action.id] = action
941+
942+
889943
def set_continued_as_new(self, new_input: Any, save_events: bool):
890944
if self._is_complete:
891945
return
@@ -1097,13 +1151,38 @@ def continue_as_new(self, new_input, *, save_events: bool = False) -> None:
10971151
self.set_continued_as_new(new_input, save_events)
10981152

10991153

1154+
def is_patched(self, patch_name: str) -> bool:
1155+
is_patched = self._is_patched(patch_name)
1156+
if is_patched:
1157+
self._encountered_patches.append(patch_name)
1158+
return is_patched
1159+
1160+
def _is_patched(self, patch_name: str) -> bool:
1161+
if patch_name in self._applied_patches:
1162+
return self._applied_patches[patch_name]
1163+
if patch_name in self._history_patches:
1164+
self._applied_patches[patch_name] = True
1165+
return True
1166+
1167+
if self._is_replaying:
1168+
self._applied_patches[patch_name] = False
1169+
return False
1170+
1171+
self._applied_patches[patch_name] = True
1172+
return True
1173+
1174+
11001175
class ExecutionResults:
11011176
actions: list[pb.OrchestratorAction]
11021177
encoded_custom_status: Optional[str]
1178+
version_name: Optional[str]
1179+
patches: Optional[list[str]]
11031180

1104-
def __init__(self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]):
1181+
def __init__(self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str], version_name: Optional[str] = None, patches: Optional[list[str]] = None):
11051182
self.actions = actions
11061183
self.encoded_custom_status = encoded_custom_status
1184+
self.version_name = version_name
1185+
self.patches = patches
11071186

11081187

11091188
class _OrchestrationExecutor:
@@ -1146,6 +1225,8 @@ def execute(
11461225
for new_event in new_events:
11471226
self.process_event(ctx, new_event)
11481227

1228+
except VersionNotRegisteredException:
1229+
ctx.set_version_not_registered()
11491230
except Exception as ex:
11501231
# Unhandled exceptions fail the orchestration
11511232
ctx.set_failed(ex)
@@ -1170,7 +1251,12 @@ def execute(
11701251
self._logger.debug(
11711252
f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}"
11721253
)
1173-
return ExecutionResults(actions=actions, encoded_custom_status=ctx._encoded_custom_status)
1254+
return ExecutionResults(
1255+
actions=actions,
1256+
encoded_custom_status=ctx._encoded_custom_status,
1257+
version_name=getattr(ctx, '_version_name', None),
1258+
patches=ctx._encountered_patches
1259+
)
11741260

11751261
def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent) -> None:
11761262
if self._is_suspended and _is_suspendable(event):
@@ -1182,19 +1268,33 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
11821268
try:
11831269
if event.HasField("orchestratorStarted"):
11841270
ctx.current_utc_datetime = event.timestamp.ToDatetime()
1271+
if event.orchestratorStarted.version:
1272+
if event.orchestratorStarted.version.name:
1273+
ctx._orchestrator_version_name = event.orchestratorStarted.version.name
1274+
for patch in event.orchestratorStarted.version.patches:
1275+
ctx._history_patches[patch] = True
11851276
elif event.HasField("executionStarted"):
11861277
if event.router.targetAppID:
11871278
ctx._app_id = event.router.targetAppID
11881279
else:
11891280
ctx._app_id = event.router.sourceAppID
11901281

1282+
version_name = None
1283+
if ctx._orchestrator_version_name:
1284+
version_name = ctx._orchestrator_version_name
1285+
1286+
11911287
# TODO: Check if we already started the orchestration
1192-
fn = self._registry.get_orchestrator(event.executionStarted.name)
1288+
fn, version_used = self._registry.get_orchestrator(event.executionStarted.name, version_name=version_name)
1289+
11931290
if fn is None:
11941291
raise OrchestratorNotRegisteredError(
11951292
f"A '{event.executionStarted.name}' orchestrator was not registered."
11961293
)
11971294

1295+
if version_used is not None:
1296+
ctx._version_name = version_used
1297+
11981298
# deserialize the input, if any
11991299
input = None
12001300
if (
@@ -1461,6 +1561,9 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
14611561
pb.ORCHESTRATION_STATUS_TERMINATED,
14621562
is_result_encoded=True,
14631563
)
1564+
elif event.HasField("executionStalled"):
1565+
# Nothing to do
1566+
pass
14641567
else:
14651568
eventType = event.WhichOneof("eventType")
14661569
raise task.OrchestrationStateError(

0 commit comments

Comments
 (0)