2626TInput = TypeVar ("TInput" )
2727TOutput = TypeVar ("TOutput" )
2828
29+ class VersionNotRegisteredException (Exception ):
30+ pass
2931
3032def _log_all_threads (logger : logging .Logger , context : str = "" ):
3133 """Helper function to log all currently active threads for debugging."""
@@ -88,30 +90,58 @@ def __init__(
8890
8991class _Registry :
9092 orchestrators : dict [str , task .Orchestrator ]
93+ versioned_orchestrators : dict [str , dict [str , task .Orchestrator ]]
94+ latest_versioned_orchestrators_version_name : dict [str , str ]
9195 activities : dict [str , task .Activity ]
9296
9397 def __init__ (self ):
9498 self .orchestrators = {}
99+ self .versioned_orchestrators = {}
100+ self .latest_versioned_orchestrators_version_name = {}
95101 self .activities = {}
96102
97- def add_orchestrator (self , fn : task .Orchestrator ) -> str :
103+ def add_orchestrator (self , fn : task .Orchestrator , version_name : Optional [ str ] = None , is_latest : bool = False ) -> str :
98104 if fn is None :
99105 raise ValueError ("An orchestrator function argument is required." )
100106
101107 name = task .get_name (fn )
102- self .add_named_orchestrator (name , fn )
108+ self .add_named_orchestrator (name , fn , version_name , is_latest )
103109 return name
104110
105- def add_named_orchestrator (self , name : str , fn : task .Orchestrator ) -> None :
111+ def add_named_orchestrator (self , name : str , fn : task .Orchestrator , version_name : Optional [ str ] = None , is_latest : bool = False ) -> None :
106112 if not name :
107113 raise ValueError ("A non-empty orchestrator name is required." )
114+
115+ if version_name is None :
116+ if name in self .orchestrators :
117+ raise ValueError (f"A '{ name } ' orchestrator already exists." )
118+ self .orchestrators [name ] = fn
119+ else :
120+ if name not in self .versioned_orchestrators :
121+ self .versioned_orchestrators [name ] = {}
122+ if version_name in self .versioned_orchestrators [name ]:
123+ raise ValueError (f"The version '{ version_name } ' of '{ name } ' orchestrator already exists." )
124+ self .versioned_orchestrators [name ][version_name ] = fn
125+ if is_latest :
126+ self .latest_versioned_orchestrators_version_name [name ] = version_name
127+
128+ def get_orchestrator (self , name : str , version_name : Optional [str ] = None ) -> Optional [tuple [task .Orchestrator , str ]]:
108129 if name in self .orchestrators :
109- raise ValueError (f"A '{ name } ' orchestrator already exists." )
130+ return self .orchestrators .get (name ), None
131+
132+ if name in self .versioned_orchestrators :
133+ if version_name :
134+ version_to_use = version_name
135+ elif name in self .latest_versioned_orchestrators_version_name :
136+ version_to_use = self .latest_versioned_orchestrators_version_name [name ]
137+ else :
138+ return None , None
110139
111- self .orchestrators [name ] = fn
140+ if version_to_use not in self .versioned_orchestrators [name ]:
141+ raise VersionNotRegisteredException
142+ return self .versioned_orchestrators [name ].get (version_to_use ), version_to_use
112143
113- def get_orchestrator (self , name : str ) -> Optional [task .Orchestrator ]:
114- return self .orchestrators .get (name )
144+ return None , None
115145
116146 def add_activity (self , fn : task .Activity ) -> str :
117147 if fn is None :
@@ -721,11 +751,22 @@ def _execute_orchestrator(
721751 try :
722752 executor = _OrchestrationExecutor (self ._registry , self ._logger )
723753 result = executor .execute (req .instanceId , req .pastEvents , req .newEvents )
754+
755+ version = None
756+ if result .version_name :
757+ version = version or pb .OrchestrationVersion ()
758+ version .name = result .version_name
759+ if result .patches :
760+ version = version or pb .OrchestrationVersion ()
761+ version .patches .extend (result .patches )
762+
763+
724764 res = pb .OrchestratorResponse (
725765 instanceId = req .instanceId ,
726766 actions = result .actions ,
727767 customStatus = ph .get_string_value (result .encoded_custom_status ),
728768 completionToken = completionToken ,
769+ version = version ,
729770 )
730771 except Exception as ex :
731772 self ._logger .exception (
@@ -810,6 +851,11 @@ def __init__(self, instance_id: str):
810851 self ._new_input : Optional [Any ] = None
811852 self ._save_events = False
812853 self ._encoded_custom_status : Optional [str ] = None
854+ self ._orchestrator_version_name : Optional [str ] = None
855+ self ._version_name : Optional [str ] = None
856+ self ._history_patches : dict [str , bool ] = {}
857+ self ._applied_patches : dict [str , bool ] = {}
858+ self ._encountered_patches : list [str ] = []
813859
814860 def run (self , generator : Generator [task .Task , Any , Any ]):
815861 self ._generator = generator
@@ -886,6 +932,14 @@ def set_failed(self, ex: Exception):
886932 )
887933 self ._pending_actions [action .id ] = action
888934
935+
936+ def set_version_not_registered (self ):
937+ self ._pending_actions .clear ()
938+ self ._completion_status = pb .ORCHESTRATION_STATUS_STALLED
939+ action = ph .new_orchestrator_version_not_available_action (self .next_sequence_number ())
940+ self ._pending_actions [action .id ] = action
941+
942+
889943 def set_continued_as_new (self , new_input : Any , save_events : bool ):
890944 if self ._is_complete :
891945 return
@@ -1097,13 +1151,38 @@ def continue_as_new(self, new_input, *, save_events: bool = False) -> None:
10971151 self .set_continued_as_new (new_input , save_events )
10981152
10991153
1154+ def is_patched (self , patch_name : str ) -> bool :
1155+ is_patched = self ._is_patched (patch_name )
1156+ if is_patched :
1157+ self ._encountered_patches .append (patch_name )
1158+ return is_patched
1159+
1160+ def _is_patched (self , patch_name : str ) -> bool :
1161+ if patch_name in self ._applied_patches :
1162+ return self ._applied_patches [patch_name ]
1163+ if patch_name in self ._history_patches :
1164+ self ._applied_patches [patch_name ] = True
1165+ return True
1166+
1167+ if self ._is_replaying :
1168+ self ._applied_patches [patch_name ] = False
1169+ return False
1170+
1171+ self ._applied_patches [patch_name ] = True
1172+ return True
1173+
1174+
11001175class ExecutionResults :
11011176 actions : list [pb .OrchestratorAction ]
11021177 encoded_custom_status : Optional [str ]
1178+ version_name : Optional [str ]
1179+ patches : Optional [list [str ]]
11031180
1104- def __init__ (self , actions : list [pb .OrchestratorAction ], encoded_custom_status : Optional [str ]):
1181+ def __init__ (self , actions : list [pb .OrchestratorAction ], encoded_custom_status : Optional [str ], version_name : Optional [ str ] = None , patches : Optional [ list [ str ]] = None ):
11051182 self .actions = actions
11061183 self .encoded_custom_status = encoded_custom_status
1184+ self .version_name = version_name
1185+ self .patches = patches
11071186
11081187
11091188class _OrchestrationExecutor :
@@ -1146,6 +1225,8 @@ def execute(
11461225 for new_event in new_events :
11471226 self .process_event (ctx , new_event )
11481227
1228+ except VersionNotRegisteredException :
1229+ ctx .set_version_not_registered ()
11491230 except Exception as ex :
11501231 # Unhandled exceptions fail the orchestration
11511232 ctx .set_failed (ex )
@@ -1170,7 +1251,12 @@ def execute(
11701251 self ._logger .debug (
11711252 f"{ instance_id } : Returning { len (actions )} action(s): { _get_action_summary (actions )} "
11721253 )
1173- return ExecutionResults (actions = actions , encoded_custom_status = ctx ._encoded_custom_status )
1254+ return ExecutionResults (
1255+ actions = actions ,
1256+ encoded_custom_status = ctx ._encoded_custom_status ,
1257+ version_name = getattr (ctx , '_version_name' , None ),
1258+ patches = ctx ._encountered_patches
1259+ )
11741260
11751261 def process_event (self , ctx : _RuntimeOrchestrationContext , event : pb .HistoryEvent ) -> None :
11761262 if self ._is_suspended and _is_suspendable (event ):
@@ -1182,19 +1268,33 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
11821268 try :
11831269 if event .HasField ("orchestratorStarted" ):
11841270 ctx .current_utc_datetime = event .timestamp .ToDatetime ()
1271+ if event .orchestratorStarted .version :
1272+ if event .orchestratorStarted .version .name :
1273+ ctx ._orchestrator_version_name = event .orchestratorStarted .version .name
1274+ for patch in event .orchestratorStarted .version .patches :
1275+ ctx ._history_patches [patch ] = True
11851276 elif event .HasField ("executionStarted" ):
11861277 if event .router .targetAppID :
11871278 ctx ._app_id = event .router .targetAppID
11881279 else :
11891280 ctx ._app_id = event .router .sourceAppID
11901281
1282+ version_name = None
1283+ if ctx ._orchestrator_version_name :
1284+ version_name = ctx ._orchestrator_version_name
1285+
1286+
11911287 # TODO: Check if we already started the orchestration
1192- fn = self ._registry .get_orchestrator (event .executionStarted .name )
1288+ fn , version_used = self ._registry .get_orchestrator (event .executionStarted .name , version_name = version_name )
1289+
11931290 if fn is None :
11941291 raise OrchestratorNotRegisteredError (
11951292 f"A '{ event .executionStarted .name } ' orchestrator was not registered."
11961293 )
11971294
1295+ if version_used is not None :
1296+ ctx ._version_name = version_used
1297+
11981298 # deserialize the input, if any
11991299 input = None
12001300 if (
@@ -1461,6 +1561,9 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
14611561 pb .ORCHESTRATION_STATUS_TERMINATED ,
14621562 is_result_encoded = True ,
14631563 )
1564+ elif event .HasField ("executionStalled" ):
1565+ # Nothing to do
1566+ pass
14641567 else :
14651568 eventType = event .WhichOneof ("eventType" )
14661569 raise task .OrchestrationStateError (
0 commit comments