Skip to content

Commit 5b9644d

Browse files
committed
Merge branch 'main' into andystaples/add-functions-support
2 parents fde02c5 + 3eaf42c commit 5b9644d

File tree

11 files changed

+271
-72
lines changed

11 files changed

+271
-72
lines changed

durabletask/entities/entity_instance_id.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from typing import Optional
2-
3-
41
class EntityInstanceId:
52
def __init__(self, entity: str, key: str):
63
self.entity = entity
@@ -30,8 +27,13 @@ def parse(entity_id: str) -> "EntityInstanceId":
3027
3128
Returns
3229
-------
33-
Optional[EntityInstanceId]
34-
The parsed EntityInstanceId object, or None if the input is None.
30+
EntityInstanceId
31+
The parsed EntityInstanceId object.
32+
33+
Raises
34+
------
35+
ValueError
36+
If the input string is not in the correct format.
3537
"""
3638
try:
3739
_, entity, key = entity_id.split("@", 2)

durabletask/entities/entity_metadata.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ def __init__(self,
4444

4545
@staticmethod
4646
def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool):
47-
entity_id = EntityInstanceId.parse(entity_response.entity.instanceId)
48-
if not entity_id:
47+
try:
48+
entity_id = EntityInstanceId.parse(entity_response.entity.instanceId)
49+
except ValueError:
4950
raise ValueError("Invalid entity instance ID in entity response.")
5051
entity_state = None
5152
if includes_state:

durabletask/internal/helpers.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ def new_orchestrator_started_event(timestamp: Optional[datetime] = None) -> pb.H
2020
return pb.HistoryEvent(eventId=-1, timestamp=ts, orchestratorStarted=pb.OrchestratorStartedEvent())
2121

2222

23+
def new_orchestrator_completed_event() -> pb.HistoryEvent:
24+
return pb.HistoryEvent(eventId=-1, timestamp=timestamp_pb2.Timestamp(),
25+
orchestratorCompleted=pb.OrchestratorCompletedEvent())
26+
27+
2328
def new_execution_started_event(name: str, instance_id: str, encoded_input: Optional[str] = None,
2429
tags: Optional[dict[str, str]] = None) -> pb.HistoryEvent:
2530
return pb.HistoryEvent(
@@ -119,6 +124,18 @@ def new_failure_details(ex: Exception) -> pb.TaskFailureDetails:
119124
)
120125

121126

127+
def new_event_sent_event(event_id: int, instance_id: str, input: str):
128+
return pb.HistoryEvent(
129+
eventId=event_id,
130+
timestamp=timestamp_pb2.Timestamp(),
131+
eventSent=pb.EventSentEvent(
132+
name="",
133+
input=get_string_value(input),
134+
instanceId=instance_id
135+
)
136+
)
137+
138+
122139
def new_event_raised_event(name: str, encoded_input: Optional[str] = None) -> pb.HistoryEvent:
123140
return pb.HistoryEvent(
124141
eventId=-1,
@@ -199,8 +216,9 @@ def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str],
199216
def new_call_entity_action(id: int,
200217
parent_instance_id: str,
201218
entity_id: EntityInstanceId,
202-
operation: str, encoded_input: Optional[str],
203-
request_id: str):
219+
operation: str,
220+
encoded_input: Optional[str],
221+
request_id: str) -> pb.OrchestratorAction:
204222
return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationCalled=pb.EntityOperationCalledEvent(
205223
requestId=request_id,
206224
operation=operation,
@@ -216,7 +234,7 @@ def new_signal_entity_action(id: int,
216234
entity_id: EntityInstanceId,
217235
operation: str,
218236
encoded_input: Optional[str],
219-
request_id: str):
237+
request_id: str) -> pb.OrchestratorAction:
220238
return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationSignaled=pb.EntityOperationSignaledEvent(
221239
requestId=request_id,
222240
operation=operation,

durabletask/internal/proto_task_hub_sidecar_service_stub.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33

44
class ProtoTaskHubSidecarServiceStub(Protocol):
5-
"""A stub class roughly matching the TaskHubSidecarServiceStub generated from the .proto file.
6-
Used by Azure Functions during orchestration and entity executions to inject custom behavior,
7-
as no real sidecar stub is available.
5+
"""A stub class matching the TaskHubSidecarServiceStub generated from the .proto file.
6+
Allows the use of TaskHubGrpcWorker methods when a real sidecar stub is not available.
87
"""
98
Hello: Callable[..., Any]
109
StartInstance: Callable[..., Any]

durabletask/py.typed

Whitespace-only changes.

durabletask/task.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
139139
pass
140140

141141
@abstractmethod
142-
def call_entity(self, entity: EntityInstanceId,
142+
def call_entity(self,
143+
entity: EntityInstanceId,
143144
operation: str,
144145
input: Optional[TInput] = None) -> Task:
145146
"""Schedule entity function for execution.
@@ -264,8 +265,8 @@ def new_uuid(self) -> str:
264265
265266
The default implementation of this method creates a name-based UUID
266267
using the algorithm from RFC 4122 §4.3. The name input used to generate
267-
this value is a combination of the orchestration instance ID and an
268-
internally managed sequence number.
268+
this value is a combination of the orchestration instance ID, the current UTC datetime,
269+
and an internally managed counter.
269270
270271
Returns
271272
-------

durabletask/worker.py

Lines changed: 65 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from threading import Event, Thread
1313
from types import GeneratorType
1414
from enum import Enum
15-
from typing import Any, Generator, Optional, Sequence, TypeVar, Union
15+
from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union
1616
import uuid
1717
from packaging.version import InvalidVersion, parse
1818

@@ -25,6 +25,7 @@
2525
from durabletask.internal.helpers import new_timestamp
2626
from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId, EntityContext
2727
from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext
28+
from durabletask.internal.proto_task_hub_sidecar_service_stub import ProtoTaskHubSidecarServiceStub
2829
import durabletask.internal.helpers as ph
2930
import durabletask.internal.exceptions as pe
3031
import durabletask.internal.orchestrator_service_pb2 as pb
@@ -680,7 +681,7 @@ def _execute_orchestrator(
680681
def _cancel_orchestrator(
681682
self,
682683
req: pb.OrchestratorRequest,
683-
stub: stubs.TaskHubSidecarServiceStub,
684+
stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub],
684685
completionToken,
685686
):
686687
stub.AbandonTaskOrchestratorWorkItem(
@@ -693,7 +694,7 @@ def _cancel_orchestrator(
693694
def _execute_activity(
694695
self,
695696
req: pb.ActivityRequest,
696-
stub: stubs.TaskHubSidecarServiceStub,
697+
stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub],
697698
completionToken,
698699
):
699700
instance_id = req.orchestrationInstance.instanceId
@@ -726,7 +727,7 @@ def _execute_activity(
726727
def _cancel_activity(
727728
self,
728729
req: pb.ActivityRequest,
729-
stub: stubs.TaskHubSidecarServiceStub,
730+
stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub],
730731
completionToken,
731732
):
732733
stub.AbandonTaskActivityWorkItem(
@@ -754,9 +755,10 @@ def _execute_entity_batch(
754755
for operation in req.operations:
755756
start_time = datetime.now(timezone.utc)
756757
executor = _EntityExecutor(self._registry, self._logger)
757-
entity_instance_id = EntityInstanceId.parse(instance_id)
758-
if not entity_instance_id:
759-
raise RuntimeError(f"Invalid entity instance ID '{operation.requestId}' in entity operation request.")
758+
try:
759+
entity_instance_id = EntityInstanceId.parse(instance_id)
760+
except ValueError:
761+
raise RuntimeError(f"Invalid entity instance ID '{instance_id}' in entity operation request.")
760762

761763
operation_result = None
762764

@@ -808,7 +810,7 @@ def _execute_entity_batch(
808810
def _cancel_entity_batch(
809811
self,
810812
req: Union[pb.EntityBatchRequest, pb.EntityRequest],
811-
stub: stubs.TaskHubSidecarServiceStub,
813+
stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub],
812814
completionToken,
813815
):
814816
stub.AbandonTaskEntityWorkItem(
@@ -831,9 +833,8 @@ def __init__(self, instance_id: str, registry: _Registry):
831833
self._pending_actions: dict[int, pb.OrchestratorAction] = {}
832834
self._pending_tasks: dict[int, task.CompletableTask] = {}
833835
# Maps entity ID to task ID
834-
self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int, Optional[str]]] = {}
835-
# Maps criticalSectionId to task ID
836-
self._entity_lock_id_map: dict[str, int] = {}
836+
self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
837+
self._entity_lock_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
837838
self._sequence_number = 0
838839
self._new_uuid_counter = 0
839840
self._current_utc_datetime = datetime(1000, 1, 1)
@@ -1170,12 +1171,7 @@ def call_entity_function_helper(
11701171
raise RuntimeError(error_message)
11711172

11721173
encoded_input = shared.to_json(input) if input is not None else None
1173-
action = ph.new_call_entity_action(id,
1174-
self.instance_id,
1175-
entity_id,
1176-
operation,
1177-
encoded_input,
1178-
self.new_uuid())
1174+
action = ph.new_call_entity_action(id, self.instance_id, entity_id, operation, encoded_input, self.new_uuid())
11791175
self._pending_actions[id] = action
11801176

11811177
fn_task = task.CompletableTask()
@@ -1262,14 +1258,14 @@ def continue_as_new(self, new_input, *, save_events: bool = False) -> None:
12621258
self.set_continued_as_new(new_input, save_events)
12631259

12641260
def new_uuid(self) -> str:
1265-
URL_NAMESPACE: str = "9e952958-5e33-4daf-827f-2fa12937b875"
1261+
NAMESPACE_UUID: str = "9e952958-5e33-4daf-827f-2fa12937b875"
12661262

12671263
uuid_name_value = \
12681264
f"{self._instance_id}" \
12691265
f"_{self.current_utc_datetime.strftime(DATETIME_STRING_FORMAT)}" \
12701266
f"_{self._new_uuid_counter}"
12711267
self._new_uuid_counter += 1
1272-
namespace_uuid = uuid.uuid5(uuid.NAMESPACE_OID, URL_NAMESPACE)
1268+
namespace_uuid = uuid.uuid5(uuid.NAMESPACE_OID, NAMESPACE_UUID)
12731269
return str(uuid.uuid5(namespace_uuid, uuid_name_value))
12741270

12751271

@@ -1612,32 +1608,11 @@ def process_event(
16121608
raise TypeError("Unexpected sub-orchestration task type")
16131609
elif event.HasField("eventRaised"):
16141610
if event.eventRaised.name in ctx._entity_task_id_map:
1615-
# This eventRaised represents the result of an entity operation after being translated to the old
1616-
# entity protocol by the Durable WebJobs extension
1617-
entity_id, task_id, action_type = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None, None))
1618-
if entity_id is None:
1619-
raise RuntimeError(f"Could not retrieve entity ID for entity-related eventRaised with ID '{event.eventId}'")
1620-
if task_id is None:
1621-
raise RuntimeError(f"Could not retrieve task ID for entity-related eventRaised with ID '{event.eventId}'")
1622-
entity_task = ctx._pending_tasks.pop(task_id, None)
1623-
if not entity_task:
1624-
raise RuntimeError(f"Could not retrieve entity task for entity-related eventRaised with ID '{event.eventId}'")
1625-
result = None
1626-
if not ph.is_empty(event.eventRaised.input):
1627-
# TODO: Investigate why the event result is wrapped in a dict with "result" key
1628-
result = shared.from_json(event.eventRaised.input.value)["result"]
1629-
if action_type == "entityOperationCalled":
1630-
ctx._entity_context.recover_lock_after_call(entity_id)
1631-
entity_task.complete(result)
1632-
ctx.resume()
1633-
elif action_type == "entityLockRequested":
1634-
ctx._entity_context.complete_acquire(event.eventRaised.name)
1635-
entity_task.complete(EntityLock(ctx))
1636-
ctx.resume()
1637-
else:
1638-
raise RuntimeError(f"Unknown action type '{action_type}' for entity-related eventRaised "
1639-
f"with ID '{event.eventId}'")
1640-
1611+
entity_id, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None))
1612+
self._handle_entity_event_raised(ctx, event, entity_id, task_id, False)
1613+
elif event.eventRaised.name in ctx._entity_lock_task_id_map:
1614+
entity_id, task_id = ctx._entity_lock_task_id_map.get(event.eventRaised.name, (None, None))
1615+
self._handle_entity_event_raised(ctx, event, entity_id, task_id, True)
16411616
else:
16421617
# event names are case-insensitive
16431618
event_name = event.eventRaised.name.casefold()
@@ -1705,8 +1680,9 @@ def process_event(
17051680
raise _get_wrong_action_type_error(
17061681
entity_call_id, expected_method_name, action
17071682
)
1708-
entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value)
1709-
if not entity_id:
1683+
try:
1684+
entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value)
1685+
except ValueError:
17101686
raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'")
17111687
ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, entity_call_id, None)
17121688
elif event.HasField("entityOperationSignaled"):
@@ -1802,15 +1778,11 @@ def process_event(
18021778
action = ctx._pending_actions.pop(event.eventId, None)
18031779
if action and action.HasField("sendEntityMessage"):
18041780
if action.sendEntityMessage.HasField("entityOperationCalled"):
1805-
action_type = "entityOperationCalled"
1781+
entity_id, event_id = self._parse_entity_event_sent_input(event)
1782+
ctx._entity_task_id_map[event_id] = (entity_id, event.eventId)
18061783
elif action.sendEntityMessage.HasField("entityLockRequested"):
1807-
action_type = "entityLockRequested"
1808-
else:
1809-
return
1810-
1811-
entity_id = EntityInstanceId.parse(event.eventSent.instanceId)
1812-
event_id = json.loads(event.eventSent.input.value)["id"]
1813-
ctx._entity_task_id_map[event_id] = (entity_id, event.eventId, action_type)
1784+
entity_id, event_id = self._parse_entity_event_sent_input(event)
1785+
ctx._entity_lock_task_id_map[event_id] = (entity_id, event.eventId)
18141786
else:
18151787
eventType = event.WhichOneof("eventType")
18161788
raise task.OrchestrationStateError(
@@ -1820,6 +1792,44 @@ def process_event(
18201792
# The orchestrator generator function completed
18211793
ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED)
18221794

1795+
def _parse_entity_event_sent_input(self, event: pb.HistoryEvent) -> Tuple[EntityInstanceId, str]:
1796+
try:
1797+
entity_id = EntityInstanceId.parse(event.eventSent.instanceId)
1798+
except ValueError:
1799+
raise RuntimeError(f"Could not parse entity ID from instanceId '{event.eventSent.instanceId}'")
1800+
try:
1801+
event_id = json.loads(event.eventSent.input.value)["id"]
1802+
except (json.JSONDecodeError, KeyError, TypeError) as ex:
1803+
raise RuntimeError(f"Could not parse event ID from eventSent input '{event.eventSent.input.value}'") from ex
1804+
return entity_id, event_id
1805+
1806+
def _handle_entity_event_raised(self,
1807+
ctx: _RuntimeOrchestrationContext,
1808+
event: pb.HistoryEvent,
1809+
entity_id: Optional[EntityInstanceId],
1810+
task_id: Optional[int],
1811+
is_lock_event: bool):
1812+
# This eventRaised represents the result of an entity operation after being translated to the old
1813+
# entity protocol by the Durable WebJobs extension
1814+
if entity_id is None:
1815+
raise RuntimeError(f"Could not retrieve entity ID for entity-related eventRaised with ID '{event.eventId}'")
1816+
if task_id is None:
1817+
raise RuntimeError(f"Could not retrieve task ID for entity-related eventRaised with ID '{event.eventId}'")
1818+
entity_task = ctx._pending_tasks.pop(task_id, None)
1819+
if not entity_task:
1820+
raise RuntimeError(f"Could not retrieve entity task for entity-related eventRaised with ID '{event.eventId}'")
1821+
result = None
1822+
if not ph.is_empty(event.eventRaised.input):
1823+
# TODO: Investigate why the event result is wrapped in a dict with "result" key
1824+
result = shared.from_json(event.eventRaised.input.value)["result"]
1825+
if is_lock_event:
1826+
ctx._entity_context.complete_acquire(event.eventRaised.name)
1827+
entity_task.complete(EntityLock(ctx))
1828+
else:
1829+
ctx._entity_context.recover_lock_after_call(entity_id)
1830+
entity_task.complete(result)
1831+
ctx.resume()
1832+
18231833
def evaluate_orchestration_versioning(self, versioning: Optional[VersioningOptions], orchestration_version: Optional[str]) -> Optional[pb.TaskFailureDetails]:
18241834
if versioning is None:
18251835
return None

tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import threading
77
from datetime import timedelta
8+
import uuid
89

910
import pytest
1011

@@ -532,3 +533,39 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _):
532533
assert state.serialized_input is None
533534
assert state.serialized_output is None
534535
assert state.serialized_custom_status == "\"foobaz\""
536+
537+
538+
def test_new_uuid():
539+
def noop(_: task.ActivityContext, _1):
540+
pass
541+
542+
def empty_orchestrator(ctx: task.OrchestrationContext, _):
543+
# Assert that two new_uuid calls return different values
544+
results = [ctx.new_uuid(), ctx.new_uuid()]
545+
yield ctx.call_activity("noop")
546+
# Assert that new_uuid still returns a unique value after replay
547+
results.append(ctx.new_uuid())
548+
return results
549+
550+
# Start a worker, which will connect to the sidecar in a background thread
551+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
552+
taskhub=taskhub_name, token_credential=None) as w:
553+
w.add_orchestrator(empty_orchestrator)
554+
w.add_activity(noop)
555+
w.start()
556+
557+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
558+
taskhub=taskhub_name, token_credential=None)
559+
id = c.schedule_new_orchestration(empty_orchestrator)
560+
state = c.wait_for_orchestration_completion(id, timeout=30)
561+
562+
assert state is not None
563+
assert state.name == task.get_name(empty_orchestrator)
564+
assert state.instance_id == id
565+
assert state.failure_details is None
566+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
567+
results = json.loads(state.serialized_output or "\"\"")
568+
assert isinstance(results, list) and len(results) == 3
569+
assert uuid.UUID(results[0]) != uuid.UUID(results[1])
570+
assert uuid.UUID(results[0]) != uuid.UUID(results[2])
571+
assert uuid.UUID(results[1]) != uuid.UUID(results[2])

0 commit comments

Comments
 (0)