1212from threading import Event , Thread
1313from types import GeneratorType
1414from enum import Enum
15- from typing import Any , Generator , Optional , Sequence , TypeVar , Union
15+ from typing import Any , Generator , Optional , Sequence , Tuple , TypeVar , Union
1616import uuid
1717from packaging .version import InvalidVersion , parse
1818
2525from durabletask .internal .helpers import new_timestamp
2626from durabletask .entities import DurableEntity , EntityLock , EntityInstanceId , EntityContext
2727from durabletask .internal .orchestration_entity_context import OrchestrationEntityContext
28+ from durabletask .internal .proto_task_hub_sidecar_service_stub import ProtoTaskHubSidecarServiceStub
2829import durabletask .internal .helpers as ph
2930import durabletask .internal .exceptions as pe
3031import durabletask .internal .orchestrator_service_pb2 as pb
@@ -680,7 +681,7 @@ def _execute_orchestrator(
680681 def _cancel_orchestrator (
681682 self ,
682683 req : pb .OrchestratorRequest ,
683- stub : stubs .TaskHubSidecarServiceStub ,
684+ stub : Union [ stubs .TaskHubSidecarServiceStub , ProtoTaskHubSidecarServiceStub ] ,
684685 completionToken ,
685686 ):
686687 stub .AbandonTaskOrchestratorWorkItem (
@@ -693,7 +694,7 @@ def _cancel_orchestrator(
693694 def _execute_activity (
694695 self ,
695696 req : pb .ActivityRequest ,
696- stub : stubs .TaskHubSidecarServiceStub ,
697+ stub : Union [ stubs .TaskHubSidecarServiceStub , ProtoTaskHubSidecarServiceStub ] ,
697698 completionToken ,
698699 ):
699700 instance_id = req .orchestrationInstance .instanceId
@@ -726,7 +727,7 @@ def _execute_activity(
726727 def _cancel_activity (
727728 self ,
728729 req : pb .ActivityRequest ,
729- stub : stubs .TaskHubSidecarServiceStub ,
730+ stub : Union [ stubs .TaskHubSidecarServiceStub , ProtoTaskHubSidecarServiceStub ] ,
730731 completionToken ,
731732 ):
732733 stub .AbandonTaskActivityWorkItem (
@@ -754,9 +755,10 @@ def _execute_entity_batch(
754755 for operation in req .operations :
755756 start_time = datetime .now (timezone .utc )
756757 executor = _EntityExecutor (self ._registry , self ._logger )
757- entity_instance_id = EntityInstanceId .parse (instance_id )
758- if not entity_instance_id :
759- raise RuntimeError (f"Invalid entity instance ID '{ operation .requestId } ' in entity operation request." )
758+ try :
759+ entity_instance_id = EntityInstanceId .parse (instance_id )
760+ except ValueError :
761+ raise RuntimeError (f"Invalid entity instance ID '{ instance_id } ' in entity operation request." )
760762
761763 operation_result = None
762764
@@ -808,7 +810,7 @@ def _execute_entity_batch(
808810 def _cancel_entity_batch (
809811 self ,
810812 req : Union [pb .EntityBatchRequest , pb .EntityRequest ],
811- stub : stubs .TaskHubSidecarServiceStub ,
813+ stub : Union [ stubs .TaskHubSidecarServiceStub , ProtoTaskHubSidecarServiceStub ] ,
812814 completionToken ,
813815 ):
814816 stub .AbandonTaskEntityWorkItem (
@@ -831,9 +833,8 @@ def __init__(self, instance_id: str, registry: _Registry):
831833 self ._pending_actions : dict [int , pb .OrchestratorAction ] = {}
832834 self ._pending_tasks : dict [int , task .CompletableTask ] = {}
833835 # Maps entity ID to task ID
834- self ._entity_task_id_map : dict [str , tuple [EntityInstanceId , int , Optional [str ]]] = {}
835- # Maps criticalSectionId to task ID
836- self ._entity_lock_id_map : dict [str , int ] = {}
836+ self ._entity_task_id_map : dict [str , tuple [EntityInstanceId , int ]] = {}
837+ self ._entity_lock_task_id_map : dict [str , tuple [EntityInstanceId , int ]] = {}
837838 self ._sequence_number = 0
838839 self ._new_uuid_counter = 0
839840 self ._current_utc_datetime = datetime (1000 , 1 , 1 )
@@ -1170,12 +1171,7 @@ def call_entity_function_helper(
11701171 raise RuntimeError (error_message )
11711172
11721173 encoded_input = shared .to_json (input ) if input is not None else None
1173- action = ph .new_call_entity_action (id ,
1174- self .instance_id ,
1175- entity_id ,
1176- operation ,
1177- encoded_input ,
1178- self .new_uuid ())
1174+ action = ph .new_call_entity_action (id , self .instance_id , entity_id , operation , encoded_input , self .new_uuid ())
11791175 self ._pending_actions [id ] = action
11801176
11811177 fn_task = task .CompletableTask ()
@@ -1262,14 +1258,14 @@ def continue_as_new(self, new_input, *, save_events: bool = False) -> None:
12621258 self .set_continued_as_new (new_input , save_events )
12631259
12641260 def new_uuid (self ) -> str :
1265- URL_NAMESPACE : str = "9e952958-5e33-4daf-827f-2fa12937b875"
1261+ NAMESPACE_UUID : str = "9e952958-5e33-4daf-827f-2fa12937b875"
12661262
12671263 uuid_name_value = \
12681264 f"{ self ._instance_id } " \
12691265 f"_{ self .current_utc_datetime .strftime (DATETIME_STRING_FORMAT )} " \
12701266 f"_{ self ._new_uuid_counter } "
12711267 self ._new_uuid_counter += 1
1272- namespace_uuid = uuid .uuid5 (uuid .NAMESPACE_OID , URL_NAMESPACE )
1268+ namespace_uuid = uuid .uuid5 (uuid .NAMESPACE_OID , NAMESPACE_UUID )
12731269 return str (uuid .uuid5 (namespace_uuid , uuid_name_value ))
12741270
12751271
@@ -1612,32 +1608,11 @@ def process_event(
16121608 raise TypeError ("Unexpected sub-orchestration task type" )
16131609 elif event .HasField ("eventRaised" ):
16141610 if event .eventRaised .name in ctx ._entity_task_id_map :
1615- # This eventRaised represents the result of an entity operation after being translated to the old
1616- # entity protocol by the Durable WebJobs extension
1617- entity_id , task_id , action_type = ctx ._entity_task_id_map .get (event .eventRaised .name , (None , None , None ))
1618- if entity_id is None :
1619- raise RuntimeError (f"Could not retrieve entity ID for entity-related eventRaised with ID '{ event .eventId } '" )
1620- if task_id is None :
1621- raise RuntimeError (f"Could not retrieve task ID for entity-related eventRaised with ID '{ event .eventId } '" )
1622- entity_task = ctx ._pending_tasks .pop (task_id , None )
1623- if not entity_task :
1624- raise RuntimeError (f"Could not retrieve entity task for entity-related eventRaised with ID '{ event .eventId } '" )
1625- result = None
1626- if not ph .is_empty (event .eventRaised .input ):
1627- # TODO: Investigate why the event result is wrapped in a dict with "result" key
1628- result = shared .from_json (event .eventRaised .input .value )["result" ]
1629- if action_type == "entityOperationCalled" :
1630- ctx ._entity_context .recover_lock_after_call (entity_id )
1631- entity_task .complete (result )
1632- ctx .resume ()
1633- elif action_type == "entityLockRequested" :
1634- ctx ._entity_context .complete_acquire (event .eventRaised .name )
1635- entity_task .complete (EntityLock (ctx ))
1636- ctx .resume ()
1637- else :
1638- raise RuntimeError (f"Unknown action type '{ action_type } ' for entity-related eventRaised "
1639- f"with ID '{ event .eventId } '" )
1640-
1611+ entity_id , task_id = ctx ._entity_task_id_map .get (event .eventRaised .name , (None , None ))
1612+ self ._handle_entity_event_raised (ctx , event , entity_id , task_id , False )
1613+ elif event .eventRaised .name in ctx ._entity_lock_task_id_map :
1614+ entity_id , task_id = ctx ._entity_lock_task_id_map .get (event .eventRaised .name , (None , None ))
1615+ self ._handle_entity_event_raised (ctx , event , entity_id , task_id , True )
16411616 else :
16421617 # event names are case-insensitive
16431618 event_name = event .eventRaised .name .casefold ()
@@ -1705,8 +1680,9 @@ def process_event(
17051680 raise _get_wrong_action_type_error (
17061681 entity_call_id , expected_method_name , action
17071682 )
1708- entity_id = EntityInstanceId .parse (event .entityOperationCalled .targetInstanceId .value )
1709- if not entity_id :
1683+ try :
1684+ entity_id = EntityInstanceId .parse (event .entityOperationCalled .targetInstanceId .value )
1685+ except ValueError :
17101686 raise RuntimeError (f"Could not parse entity ID from targetInstanceId '{ event .entityOperationCalled .targetInstanceId .value } '" )
17111687 ctx ._entity_task_id_map [event .entityOperationCalled .requestId ] = (entity_id , entity_call_id , None )
17121688 elif event .HasField ("entityOperationSignaled" ):
@@ -1802,15 +1778,11 @@ def process_event(
18021778 action = ctx ._pending_actions .pop (event .eventId , None )
18031779 if action and action .HasField ("sendEntityMessage" ):
18041780 if action .sendEntityMessage .HasField ("entityOperationCalled" ):
1805- action_type = "entityOperationCalled"
1781+ entity_id , event_id = self ._parse_entity_event_sent_input (event )
1782+ ctx ._entity_task_id_map [event_id ] = (entity_id , event .eventId )
18061783 elif action .sendEntityMessage .HasField ("entityLockRequested" ):
1807- action_type = "entityLockRequested"
1808- else :
1809- return
1810-
1811- entity_id = EntityInstanceId .parse (event .eventSent .instanceId )
1812- event_id = json .loads (event .eventSent .input .value )["id" ]
1813- ctx ._entity_task_id_map [event_id ] = (entity_id , event .eventId , action_type )
1784+ entity_id , event_id = self ._parse_entity_event_sent_input (event )
1785+ ctx ._entity_lock_task_id_map [event_id ] = (entity_id , event .eventId )
18141786 else :
18151787 eventType = event .WhichOneof ("eventType" )
18161788 raise task .OrchestrationStateError (
@@ -1820,6 +1792,44 @@ def process_event(
18201792 # The orchestrator generator function completed
18211793 ctx .set_complete (generatorStopped .value , pb .ORCHESTRATION_STATUS_COMPLETED )
18221794
1795+ def _parse_entity_event_sent_input (self , event : pb .HistoryEvent ) -> Tuple [EntityInstanceId , str ]:
1796+ try :
1797+ entity_id = EntityInstanceId .parse (event .eventSent .instanceId )
1798+ except ValueError :
1799+ raise RuntimeError (f"Could not parse entity ID from instanceId '{ event .eventSent .instanceId } '" )
1800+ try :
1801+ event_id = json .loads (event .eventSent .input .value )["id" ]
1802+ except (json .JSONDecodeError , KeyError , TypeError ) as ex :
1803+ raise RuntimeError (f"Could not parse event ID from eventSent input '{ event .eventSent .input .value } '" ) from ex
1804+ return entity_id , event_id
1805+
1806+ def _handle_entity_event_raised (self ,
1807+ ctx : _RuntimeOrchestrationContext ,
1808+ event : pb .HistoryEvent ,
1809+ entity_id : Optional [EntityInstanceId ],
1810+ task_id : Optional [int ],
1811+ is_lock_event : bool ):
1812+ # This eventRaised represents the result of an entity operation after being translated to the old
1813+ # entity protocol by the Durable WebJobs extension
1814+ if entity_id is None :
1815+ raise RuntimeError (f"Could not retrieve entity ID for entity-related eventRaised with ID '{ event .eventId } '" )
1816+ if task_id is None :
1817+ raise RuntimeError (f"Could not retrieve task ID for entity-related eventRaised with ID '{ event .eventId } '" )
1818+ entity_task = ctx ._pending_tasks .pop (task_id , None )
1819+ if not entity_task :
1820+ raise RuntimeError (f"Could not retrieve entity task for entity-related eventRaised with ID '{ event .eventId } '" )
1821+ result = None
1822+ if not ph .is_empty (event .eventRaised .input ):
1823+ # TODO: Investigate why the event result is wrapped in a dict with "result" key
1824+ result = shared .from_json (event .eventRaised .input .value )["result" ]
1825+ if is_lock_event :
1826+ ctx ._entity_context .complete_acquire (event .eventRaised .name )
1827+ entity_task .complete (EntityLock (ctx ))
1828+ else :
1829+ ctx ._entity_context .recover_lock_after_call (entity_id )
1830+ entity_task .complete (result )
1831+ ctx .resume ()
1832+
18231833 def evaluate_orchestration_versioning (self , versioning : Optional [VersioningOptions ], orchestration_version : Optional [str ]) -> Optional [pb .TaskFailureDetails ]:
18241834 if versioning is None :
18251835 return None
0 commit comments