Skip to content

Commit 539eb29

Browse files
Copilotberndverst
andcommitted
Enhanced entity implementation based on .NET reference
Co-authored-by: berndverst <4535280+berndverst@users.noreply.github.com>
1 parent 745237f commit 539eb29

File tree

6 files changed

+360
-31
lines changed

6 files changed

+360
-31
lines changed

durabletask/__init__.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,19 @@
44
"""Durable Task SDK for Python"""
55

66
from durabletask.worker import ConcurrencyOptions
7-
from durabletask.task import EntityContext, EntityState, EntityQuery, EntityQueryResult
7+
from durabletask.task import (
8+
EntityContext, EntityState, EntityQuery, EntityQueryResult,
9+
EntityInstanceId, EntityOperationFailedException
10+
)
811

9-
__all__ = ["ConcurrencyOptions", "EntityContext", "EntityState", "EntityQuery", "EntityQueryResult"]
12+
__all__ = [
13+
"ConcurrencyOptions",
14+
"EntityContext",
15+
"EntityState",
16+
"EntityQuery",
17+
"EntityQueryResult",
18+
"EntityInstanceId",
19+
"EntityOperationFailedException"
20+
]
1021

1122
PACKAGE_NAME = "durabletask"

durabletask/client.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -223,15 +223,15 @@ def purge_orchestration(self, instance_id: str, recursive: bool = True):
223223
self._logger.info(f"Purging instance '{instance_id}'.")
224224
self._stub.PurgeInstances(req)
225225

226-
def signal_entity(self, entity_id: str, operation_name: str, *,
226+
def signal_entity(self, entity_id: Union[str, 'task.EntityInstanceId'], operation_name: str, *,
227227
input: Optional[Any] = None,
228228
request_id: Optional[str] = None,
229229
scheduled_time: Optional[datetime] = None):
230230
"""Signal an entity with an operation.
231231
232232
Parameters
233233
----------
234-
entity_id : str
234+
entity_id : Union[str, task.EntityInstanceId]
235235
The ID of the entity to signal.
236236
operation_name : str
237237
The name of the operation to perform.
@@ -242,22 +242,24 @@ def signal_entity(self, entity_id: str, operation_name: str, *,
242242
scheduled_time : Optional[datetime]
243243
The time to schedule the operation. If not provided, the operation is scheduled immediately.
244244
"""
245+
entity_id_str = str(entity_id) if hasattr(entity_id, '__str__') else entity_id
246+
245247
req = pb.SignalEntityRequest(
246-
instanceId=entity_id,
248+
instanceId=entity_id_str,
247249
name=operation_name,
248250
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
249251
requestId=request_id if request_id else uuid.uuid4().hex,
250252
scheduledTime=helpers.new_timestamp(scheduled_time) if scheduled_time else None)
251253

252-
self._logger.info(f"Signaling entity '{entity_id}' with operation '{operation_name}'.")
254+
self._logger.info(f"Signaling entity '{entity_id_str}' with operation '{operation_name}'.")
253255
self._stub.SignalEntity(req)
254256

255-
def get_entity(self, entity_id: str, *, include_state: bool = True) -> Optional[task.EntityState]:
257+
def get_entity(self, entity_id: Union[str, 'task.EntityInstanceId'], *, include_state: bool = True) -> Optional[task.EntityState]:
256258
"""Get the state of an entity.
257259
258260
Parameters
259261
----------
260-
entity_id : str
262+
entity_id : Union[str, task.EntityInstanceId]
261263
The ID of the entity to query.
262264
include_state : bool
263265
Whether to include the entity's state in the response.
@@ -267,7 +269,9 @@ def get_entity(self, entity_id: str, *, include_state: bool = True) -> Optional[
267269
Optional[EntityState]
268270
The entity state if it exists, None otherwise.
269271
"""
270-
req = pb.GetEntityRequest(instanceId=entity_id, includeState=include_state)
272+
entity_id_str = str(entity_id) if hasattr(entity_id, '__str__') else entity_id
273+
274+
req = pb.GetEntityRequest(instanceId=entity_id_str, includeState=include_state)
271275
res: pb.GetEntityResponse = self._stub.GetEntity(req)
272276

273277
if not res.exists:

durabletask/task.py

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import math
8+
import uuid
89
from abc import ABC, abstractmethod
910
from datetime import datetime, timedelta
1011
from typing import Any, Callable, Generator, Generic, Optional, TypeVar, Union
@@ -178,13 +179,13 @@ def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None:
178179
pass
179180

180181
@abstractmethod
181-
def signal_entity(self, entity_id: str, operation_name: str, *,
182+
def signal_entity(self, entity_id: Union[str, 'EntityInstanceId'], operation_name: str, *,
182183
input: Optional[Any] = None) -> Task:
183184
"""Signal an entity with an operation.
184185
185186
Parameters
186187
----------
187-
entity_id : str
188+
entity_id : Union[str, EntityInstanceId]
188189
The ID of the entity to signal.
189190
operation_name : str
190191
The name of the operation to perform.
@@ -199,14 +200,14 @@ def signal_entity(self, entity_id: str, operation_name: str, *,
199200
pass
200201

201202
@abstractmethod
202-
def call_entity(self, entity_id: str, operation_name: str, *,
203+
def call_entity(self, entity_id: Union[str, 'EntityInstanceId'], operation_name: str, *,
203204
input: Optional[TInput] = None,
204205
retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]:
205206
"""Call an entity operation and wait for the result.
206207
207208
Parameters
208209
----------
209-
entity_id : str
210+
entity_id : Union[str, EntityInstanceId]
210211
The ID of the entity to call.
211212
operation_name : str
212213
The name of the operation to perform.
@@ -513,12 +514,48 @@ def task_id(self) -> int:
513514
return self._task_id
514515

515516

517+
@dataclass
518+
class EntityInstanceId:
519+
"""Represents the ID of a durable entity instance."""
520+
name: str
521+
key: str
522+
523+
def __str__(self) -> str:
524+
"""Return the string representation in the format: name@key"""
525+
return f"{self.name}@{self.key}"
526+
527+
@classmethod
528+
def from_string(cls, instance_id: str) -> 'EntityInstanceId':
529+
"""Parse an entity instance ID from string format (name@key)."""
530+
if '@' not in instance_id:
531+
raise ValueError(f"Invalid entity instance ID format: {instance_id}. Expected format: name@key")
532+
533+
parts = instance_id.split('@', 1)
534+
if len(parts) != 2 or not parts[0] or not parts[1]:
535+
raise ValueError(f"Invalid entity instance ID format: {instance_id}. Expected format: name@key")
536+
537+
return cls(name=parts[0], key=parts[1])
538+
539+
540+
class EntityOperationFailedException(Exception):
541+
"""Exception raised when an entity operation fails."""
542+
543+
def __init__(self, entity_id: EntityInstanceId, operation_name: str, failure_details: FailureDetails):
544+
self.entity_id = entity_id
545+
self.operation_name = operation_name
546+
self.failure_details = failure_details
547+
super().__init__(f"Operation '{operation_name}' on entity '{entity_id}' failed: {failure_details.message}")
548+
549+
516550
class EntityContext:
551+
"""Context for entity operations, providing access to state and scheduling capabilities."""
552+
517553
def __init__(self, instance_id: str, operation_name: str, is_new_entity: bool = False):
518554
self._instance_id = instance_id
519555
self._operation_name = operation_name
520556
self._is_new_entity = is_new_entity
521557
self._state: Optional[Any] = None
558+
self._entity_instance_id = EntityInstanceId.from_string(instance_id)
522559

523560
@property
524561
def instance_id(self) -> str:
@@ -531,6 +568,17 @@ def instance_id(self) -> str:
531568
"""
532569
return self._instance_id
533570

571+
@property
572+
def entity_id(self) -> EntityInstanceId:
573+
"""Get the structured entity instance ID.
574+
575+
Returns
576+
-------
577+
EntityInstanceId
578+
The structured entity instance ID.
579+
"""
580+
return self._entity_instance_id
581+
534582
@property
535583
def operation_name(self) -> str:
536584
"""Get the name of the operation being performed on the entity.
@@ -578,6 +626,64 @@ def set_state(self, state: Any) -> None:
578626
"""
579627
self._state = state
580628

629+
def signal_entity(self, entity_id: Union[str, EntityInstanceId], operation_name: str, *,
630+
input: Optional[Any] = None) -> None:
631+
"""Signal another entity with an operation (fire-and-forget).
632+
633+
Parameters
634+
----------
635+
entity_id : Union[str, EntityInstanceId]
636+
The ID of the entity to signal.
637+
operation_name : str
638+
The name of the operation to perform.
639+
input : Optional[Any]
640+
The JSON-serializable input to pass to the entity operation.
641+
"""
642+
# Store the signal for later processing during entity execution
643+
if not hasattr(self, '_signals'):
644+
self._signals = []
645+
646+
entity_id_str = str(entity_id) if isinstance(entity_id, EntityInstanceId) else entity_id
647+
self._signals.append({
648+
'entity_id': entity_id_str,
649+
'operation_name': operation_name,
650+
'input': input
651+
})
652+
653+
def start_new_orchestration(self, orchestrator: Union[Orchestrator[TInput, TOutput], str], *,
654+
input: Optional[TInput] = None,
655+
instance_id: Optional[str] = None) -> str:
656+
"""Start a new orchestration from within an entity operation.
657+
658+
Parameters
659+
----------
660+
orchestrator : Union[Orchestrator[TInput, TOutput], str]
661+
The orchestrator function or name to start.
662+
input : Optional[TInput]
663+
The JSON-serializable input to pass to the orchestration.
664+
instance_id : Optional[str]
665+
The instance ID for the new orchestration. If not provided, a random UUID will be used.
666+
667+
Returns
668+
-------
669+
str
670+
The instance ID of the new orchestration.
671+
"""
672+
# Store the orchestration start request for later processing
673+
if not hasattr(self, '_orchestrations'):
674+
self._orchestrations = []
675+
676+
orchestrator_name = orchestrator if isinstance(orchestrator, str) else get_name(orchestrator)
677+
new_instance_id = instance_id or str(uuid.uuid4())
678+
679+
self._orchestrations.append({
680+
'name': orchestrator_name,
681+
'input': input,
682+
'instance_id': new_instance_id
683+
})
684+
685+
return new_instance_id
686+
581687

582688
# Orchestrators are generators that yield tasks and receive/return any type
583689
Orchestrator = Callable[[OrchestrationContext, TInput], Union[Generator[Task, Any, Any], TOutput]]

durabletask/worker.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -928,12 +928,14 @@ def continue_as_new(self, new_input, *, save_events: bool = False) -> None:
928928

929929
self.set_continued_as_new(new_input, save_events)
930930

931-
def signal_entity(self, entity_id: str, operation_name: str, *,
931+
def signal_entity(self, entity_id: Union[str, task.EntityInstanceId], operation_name: str, *,
932932
input: Optional[Any] = None) -> task.Task:
933933
# Create a signal entity action
934+
entity_id_str = str(entity_id) if hasattr(entity_id, '__str__') else entity_id
935+
934936
action = pb.OrchestratorAction()
935937
action.sendEntitySignal.CopyFrom(pb.SendSignalAction(
936-
instanceId=entity_id,
938+
instanceId=entity_id_str,
937939
name=operation_name,
938940
input=ph.get_string_value(shared.to_json(input)) if input is not None else None
939941
))
@@ -951,7 +953,7 @@ def signal_entity(self, entity_id: str, operation_name: str, *,
951953

952954
return signal_task
953955

954-
def call_entity(self, entity_id: str, operation_name: str, *,
956+
def call_entity(self, entity_id: Union[str, task.EntityInstanceId], operation_name: str, *,
955957
input: Optional[Any] = None,
956958
retry_policy: Optional[task.RetryPolicy] = None) -> task.Task:
957959
# For now, entity calls are not directly supported in orchestrations
@@ -1405,6 +1407,28 @@ def execute(self, req: pb.EntityBatchRequest) -> pb.EntityBatchResult:
14051407
# Update state for next operation
14061408
current_state = ctx.get_state()
14071409

1410+
# Process entity signals from context
1411+
if hasattr(ctx, '_signals'):
1412+
for signal in ctx._signals:
1413+
signal_action = pb.OrchestratorAction()
1414+
signal_action.sendEntitySignal.CopyFrom(pb.SendSignalAction(
1415+
instanceId=signal['entity_id'],
1416+
name=signal['operation_name'],
1417+
input=ph.get_string_value(shared.to_json(signal['input'])) if signal['input'] is not None else None
1418+
))
1419+
actions.append(signal_action)
1420+
1421+
# Process orchestration starts from context
1422+
if hasattr(ctx, '_orchestrations'):
1423+
for orch in ctx._orchestrations:
1424+
orch_action = pb.OrchestratorAction()
1425+
orch_action.callOrchestrator.CopyFrom(pb.CallOrchestratorAction(
1426+
name=orch['name'],
1427+
instanceId=orch['instance_id'],
1428+
input=ph.get_string_value(shared.to_json(orch['input'])) if orch['input'] is not None else None
1429+
))
1430+
actions.append(orch_action)
1431+
14081432
# Create operation result
14091433
result = pb.OperationResult()
14101434
if operation_output is not None:

0 commit comments

Comments
 (0)