Skip to content

Commit b97434d

Browse files
committed
Address various entity-related bugs
1 parent 424ffa9 commit b97434d

File tree

10 files changed

+312
-19
lines changed

10 files changed

+312
-19
lines changed

durabletask/entities/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from durabletask.entities.entity_lock import EntityLock
99
from durabletask.entities.entity_context import EntityContext
1010
from durabletask.entities.entity_metadata import EntityMetadata
11+
from durabletask.entities.entity_operation_failed_exception import EntityOperationFailedException
1112

12-
__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext", "EntityMetadata"]
13+
__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext", "EntityMetadata",
14+
"EntityOperationFailedException"]
1315

1416
PACKAGE_NAME = "durabletask.entities"

durabletask/entities/entity_instance_id.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
class EntityInstanceId:
22
def __init__(self, entity: str, key: str):
3-
self.entity = entity
3+
self.entity = entity.lower()
44
self.key = key
55

66
def __str__(self) -> str:
@@ -36,7 +36,13 @@ def parse(entity_id: str) -> "EntityInstanceId":
3636
If the input string is not in the correct format.
3737
"""
3838
try:
39+
if not entity_id.startswith("@"):
40+
raise ValueError("Entity ID must start with '@'.")
3941
_, entity, key = entity_id.split("@", 2)
42+
if not entity or not key:
43+
raise ValueError("Entity name and key cannot be empty.")
44+
if "@" in key:
45+
raise ValueError("Entity instance ID string should not contain more than two '@' symbols.")
4046
return EntityInstanceId(entity=entity, key=key)
4147
except ValueError as ex:
4248
raise ValueError(f"Invalid entity ID format: {entity_id}", ex)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from durabletask.internal.orchestrator_service_pb2 import TaskFailureDetails
2+
from durabletask.entities.entity_instance_id import EntityInstanceId
3+
4+
5+
class EntityOperationFailedException(Exception):
6+
"""Exception raised when an operation on an Entity Function fails."""
7+
8+
def __init__(self, entity_instance_id: EntityInstanceId, operation_name: str, failure_details: TaskFailureDetails) -> None:
9+
super().__init__()
10+
self.entity_instance_id = entity_instance_id
11+
self.operation_name = operation_name
12+
self.failure_details = failure_details
13+
14+
def __str__(self) -> str:
15+
return f"Operation '{self.operation_name}' on entity '{self.entity_instance_id}' failed with error: {self.failure_details.errorMessage}"
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import Any
2+
3+
4+
class JsonEncodeOutputException(Exception):
5+
"""Custom exception type used to indicate that an orchestration result could not be JSON-encoded."""
6+
7+
def __init__(self, problem_object: Any):
8+
super().__init__()
9+
self.problem_object = problem_object
10+
11+
def __str__(self) -> str:
12+
return f"The orchestration result could not be encoded. Object details: {self.problem_object}"

durabletask/worker.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
import grpc
2020
from google.protobuf import empty_pb2
2121

22+
from durabletask.entities.entity_operation_failed_exception import EntityOperationFailedException
2223
from durabletask.internal import helpers
2324
from durabletask.internal.entity_state_shim import StateShim
2425
from durabletask.internal.helpers import new_timestamp
2526
from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId, EntityContext
27+
from durabletask.internal.json_encode_output_exception import JsonEncodeOutputException
2628
from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext
2729
from durabletask.internal.proto_task_hub_sidecar_service_stub import ProtoTaskHubSidecarServiceStub
2830
import durabletask.internal.helpers as ph
@@ -141,14 +143,12 @@ class _Registry:
141143
orchestrators: dict[str, task.Orchestrator]
142144
activities: dict[str, task.Activity]
143145
entities: dict[str, task.Entity]
144-
entity_instances: dict[str, DurableEntity]
145146
versioning: Optional[VersioningOptions] = None
146147

147148
def __init__(self):
148149
self.orchestrators = {}
149150
self.activities = {}
150151
self.entities = {}
151-
self.entity_instances = {}
152152

153153
def add_orchestrator(self, fn: task.Orchestrator[TInput, TOutput]) -> str:
154154
if fn is None:
@@ -201,6 +201,7 @@ def add_entity(self, fn: task.Entity, name: Optional[str] = None) -> str:
201201
def add_named_entity(self, name: str, fn: task.Entity) -> None:
202202
if not name:
203203
raise ValueError("A non-empty entity name is required.")
204+
name = name.lower()
204205
if name in self.entities:
205206
raise ValueError(f"A '{name}' entity already exists.")
206207

@@ -829,7 +830,7 @@ def __init__(self, instance_id: str, registry: _Registry):
829830
self._pending_actions: dict[int, pb.OrchestratorAction] = {}
830831
self._pending_tasks: dict[int, task.CompletableTask] = {}
831832
# Maps entity ID to task ID
832-
self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
833+
self._entity_task_id_map: dict[str, tuple[EntityInstanceId, str, int]] = {}
833834
self._entity_lock_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
834835
# Maps criticalSectionId to task ID
835836
self._entity_lock_id_map: dict[str, int] = {}
@@ -902,7 +903,10 @@ def set_complete(
902903
self._result = result
903904
result_json: Optional[str] = None
904905
if result is not None:
905-
result_json = result if is_result_encoded else shared.to_json(result)
906+
try:
907+
result_json = result if is_result_encoded else shared.to_json(result)
908+
except TypeError:
909+
result_json = shared.to_json(str(JsonEncodeOutputException(result)))
906910
action = ph.new_complete_orchestration_action(
907911
self.next_sequence_number(), status, result_json
908912
)
@@ -1606,7 +1610,7 @@ def process_event(
16061610
raise TypeError("Unexpected sub-orchestration task type")
16071611
elif event.HasField("eventRaised"):
16081612
if event.eventRaised.name in ctx._entity_task_id_map:
1609-
entity_id, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None))
1613+
entity_id, operation, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None, None))
16101614
self._handle_entity_event_raised(ctx, event, entity_id, task_id, False)
16111615
elif event.eventRaised.name in ctx._entity_lock_task_id_map:
16121616
entity_id, task_id = ctx._entity_lock_task_id_map.get(event.eventRaised.name, (None, None))
@@ -1680,9 +1684,10 @@ def process_event(
16801684
)
16811685
try:
16821686
entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value)
1687+
operation = event.entityOperationCalled.operation
16831688
except ValueError:
16841689
raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'")
1685-
ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, entity_call_id)
1690+
ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, operation, entity_call_id)
16861691
elif event.HasField("entityOperationSignaled"):
16871692
# This history event confirms that the entity signal was successfully scheduled.
16881693
# Remove the entityOperationSignaled event from the pending action list so we don't schedule it
@@ -1743,7 +1748,7 @@ def process_event(
17431748
ctx.resume()
17441749
elif event.HasField("entityOperationCompleted"):
17451750
request_id = event.entityOperationCompleted.requestId
1746-
entity_id, task_id = ctx._entity_task_id_map.pop(request_id, (None, None))
1751+
entity_id, operation, task_id = ctx._entity_task_id_map.pop(request_id, (None, None, None))
17471752
if not entity_id:
17481753
raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'")
17491754
if not task_id:
@@ -1762,10 +1767,29 @@ def process_event(
17621767
entity_task.complete(result)
17631768
ctx.resume()
17641769
elif event.HasField("entityOperationFailed"):
1765-
if not ctx.is_replaying:
1766-
self._logger.info(f"{ctx.instance_id}: Entity operation failed.")
1767-
self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}")
1768-
pass
1770+
request_id = event.entityOperationFailed.requestId
1771+
entity_id, operation, task_id = ctx._entity_task_id_map.pop(request_id, (None, None, None))
1772+
if not entity_id:
1773+
raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'")
1774+
if operation is None:
1775+
raise RuntimeError(f"Could not parse operation name from request ID '{request_id}'")
1776+
if not task_id:
1777+
raise RuntimeError(f"Could not find matching task ID for entity operation with request ID '{request_id}'")
1778+
entity_task = ctx._pending_tasks.pop(task_id, None)
1779+
if not entity_task:
1780+
if not ctx.is_replaying:
1781+
self._logger.warning(
1782+
f"{ctx.instance_id}: Ignoring unexpected entityOperationCompleted event with request ID = {request_id}."
1783+
)
1784+
return
1785+
failure = EntityOperationFailedException(
1786+
entity_id,
1787+
operation,
1788+
event.entityOperationFailed.failureDetails
1789+
)
1790+
ctx._entity_context.recover_lock_after_call(entity_id)
1791+
entity_task.fail(str(failure), failure)
1792+
ctx.resume()
17691793
elif event.HasField("orchestratorCompleted"):
17701794
# Added in Functions only (for some reason) and does not affect orchestrator flow
17711795
pass
@@ -1777,7 +1801,7 @@ def process_event(
17771801
if action and action.HasField("sendEntityMessage"):
17781802
if action.sendEntityMessage.HasField("entityOperationCalled"):
17791803
entity_id, event_id = self._parse_entity_event_sent_input(event)
1780-
ctx._entity_task_id_map[event_id] = (entity_id, event.eventId)
1804+
ctx._entity_task_id_map[event_id] = (entity_id, event.entityOperationCalled.operation, event.eventId)
17811805
elif action.sendEntityMessage.HasField("entityLockRequested"):
17821806
entity_id, event_id = self._parse_entity_event_sent_input(event)
17831807
ctx._entity_lock_task_id_map[event_id] = (entity_id, event.eventId)
@@ -1936,11 +1960,7 @@ def execute(
19361960
ctx = EntityContext(orchestration_id, operation, state, entity_id)
19371961

19381962
if isinstance(fn, type) and issubclass(fn, DurableEntity):
1939-
if self._registry.entity_instances.get(str(entity_id), None):
1940-
entity_instance = self._registry.entity_instances[str(entity_id)]
1941-
else:
1942-
entity_instance = fn()
1943-
self._registry.entity_instances[str(entity_id)] = entity_instance
1963+
entity_instance = fn()
19441964
if not hasattr(entity_instance, operation):
19451965
raise AttributeError(f"Entity '{entity_id}' does not have operation '{operation}'")
19461966
method = getattr(entity_instance, operation)

tests/durabletask-azuremanaged/entities/__init__.py

Whitespace-only changes.

tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py renamed to tests/durabletask-azuremanaged/entities/test_dts_class_based_entities_e2e.py

File renamed without changes.

0 commit comments

Comments
 (0)