Skip to content
Merged
182 changes: 171 additions & 11 deletions durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from dataclasses import dataclass
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Optional, Sequence, TypeVar, Union
from typing import Any, List, Optional, Sequence, TypeVar, Union

import grpc
from google.protobuf import wrappers_pb2

from durabletask.entities import EntityInstanceId
from durabletask.entities.entity_metadata import EntityMetadata
Expand Down Expand Up @@ -57,6 +56,39 @@ def raise_if_failed(self):
self.failure_details)


@dataclass
class OrchestrationQuery:
created_time_from: Optional[datetime] = None
created_time_to: Optional[datetime] = None
runtime_status: Optional[List[OrchestrationStatus]] = None
# Some backends don't respond well with max_instance_count = None, so we use the integer limit for non-paginated
# results instead.
max_instance_count: Optional[int] = (1 << 31) - 1
fetch_inputs_and_outputs: bool = False


@dataclass
class EntityQuery:
instance_id_starts_with: Optional[str] = None
last_modified_from: Optional[datetime] = None
last_modified_to: Optional[datetime] = None
include_state: bool = True
include_transient: bool = False
page_size: Optional[int] = None


@dataclass
class PurgeInstancesResult:
deleted_instance_count: int
is_complete: bool


@dataclass
class CleanEntityStorageResult:
empty_entities_removed: int
orphaned_locks_released: int


class OrchestrationFailedError(Exception):
def __init__(self, message: str, failure_details: task.FailureDetails):
super().__init__(message)
Expand All @@ -73,6 +105,12 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op

state = res.orchestrationState

new_state = parse_orchestration_state(state)
new_state.instance_id = instance_id # Override instance_id with the one from the request, to match old behavior
return new_state


def parse_orchestration_state(state: pb.OrchestrationState) -> OrchestrationState:
failure_details = None
if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '':
failure_details = task.FailureDetails(
Expand All @@ -81,7 +119,7 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
state.failureDetails.stackTrace.value if not helpers.is_empty(state.failureDetails.stackTrace) else None)

return OrchestrationState(
instance_id,
state.instanceId,
state.name,
OrchestrationStatus(state.orchestrationStatus),
state.createdTimestamp.ToDatetime(),
Expand All @@ -93,7 +131,6 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op


class TaskHubGrpcClient:

def __init__(self, *,
host_address: Optional[str] = None,
metadata: Optional[list[tuple[str, str]]] = None,
Expand Down Expand Up @@ -136,7 +173,7 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
req = pb.CreateInstanceRequest(
name=name,
instanceId=instance_id if instance_id else uuid.uuid4().hex,
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
version=helpers.get_string_value(version if version else self.default_version),
orchestrationIdReusePolicy=reuse_id_policy,
Expand All @@ -152,6 +189,42 @@ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = Tr
res: pb.GetInstanceResponse = self._stub.GetInstance(req)
return new_orchestration_state(req.instanceId, res)

def get_all_orchestration_states(self,
orchestration_query: Optional[OrchestrationQuery] = None
) -> List[OrchestrationState]:
if orchestration_query is None:
orchestration_query = OrchestrationQuery()
_continuation_token = None

self._logger.info(f"Querying orchestration instances with query: {orchestration_query}")

states = []

while True:
req = pb.QueryInstancesRequest(
query=pb.InstanceQuery(
runtimeStatus=[status.value for status in orchestration_query.runtime_status] if orchestration_query.runtime_status else None,
createdTimeFrom=helpers.new_timestamp(orchestration_query.created_time_from) if orchestration_query.created_time_from else None,
createdTimeTo=helpers.new_timestamp(orchestration_query.created_time_to) if orchestration_query.created_time_to else None,
maxInstanceCount=orchestration_query.max_instance_count,
fetchInputsAndOutputs=orchestration_query.fetch_inputs_and_outputs,
continuationToken=_continuation_token
)
)
resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req)
states += [parse_orchestration_state(res) for res in resp.orchestrationState]
# Check the value for continuationToken - none or "0" indicates that there are no more results.
if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next list of instances...")
if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
break
_continuation_token = resp.continuationToken
else:
break

return states

def wait_for_orchestration_start(self, instance_id: str, *,
fetch_payloads: bool = False,
timeout: int = 60) -> Optional[OrchestrationState]:
Expand Down Expand Up @@ -199,7 +272,8 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *,
req = pb.RaiseEventRequest(
instanceId=instance_id,
name=event_name,
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None)
input=helpers.get_string_value(shared.to_json(data) if data is not None else None)
)

self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
self._stub.RaiseEvent(req)
Expand All @@ -209,7 +283,7 @@ def terminate_orchestration(self, instance_id: str, *,
recursive: bool = True):
req = pb.TerminateRequest(
instanceId=instance_id,
output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None,
output=helpers.get_string_value(shared.to_json(output) if output is not None else None),
recursive=recursive)

self._logger.info(f"Terminating instance '{instance_id}'.")
Expand All @@ -225,10 +299,31 @@ def resume_orchestration(self, instance_id: str):
self._logger.info(f"Resuming instance '{instance_id}'.")
self._stub.ResumeInstance(req)

def purge_orchestration(self, instance_id: str, recursive: bool = True):
def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
self._logger.info(f"Purging instance '{instance_id}'.")
self._stub.PurgeInstances(req)
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req)
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)

def purge_orchestrations_by(self,
created_time_from: Optional[datetime] = None,
created_time_to: Optional[datetime] = None,
runtime_status: Optional[List[OrchestrationStatus]] = None,
recursive: bool = False) -> PurgeInstancesResult:
self._logger.info("Purging orchestrations by filter: "
f"created_time_from={created_time_from}, "
f"created_time_to={created_time_to}, "
f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
f"recursive={recursive}")
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(pb.PurgeInstancesRequest(
purgeInstanceFilter=pb.PurgeInstanceFilter(
createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
runtimeStatus=[status.value for status in runtime_status] if runtime_status else None
),
recursive=recursive
))
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)

def signal_entity(self,
entity_instance_id: EntityInstanceId,
Expand All @@ -237,7 +332,7 @@ def signal_entity(self,
req = pb.SignalEntityRequest(
instanceId=str(entity_instance_id),
name=operation_name,
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None,
input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
requestId=str(uuid.uuid4()),
scheduledTime=None,
parentTraceContext=None,
Expand All @@ -256,4 +351,69 @@ def get_entity(self,
if not res.exists:
return None

return EntityMetadata.from_entity_response(res, include_state)
return EntityMetadata.from_entity_metadata(res.entity, include_state)

def get_all_entities(self,
entity_query: Optional[EntityQuery] = None) -> List[EntityMetadata]:
if entity_query is None:
entity_query = EntityQuery()
_continuation_token = None

self._logger.info(f"Retrieving entities by filter: {entity_query}")

entities = []

while True:
query_request = pb.QueryEntitiesRequest(
query=pb.EntityQuery(
instanceIdStartsWith=helpers.get_string_value(entity_query.instance_id_starts_with),
lastModifiedFrom=helpers.new_timestamp(entity_query.last_modified_from) if entity_query.last_modified_from else None,
lastModifiedTo=helpers.new_timestamp(entity_query.last_modified_to) if entity_query.last_modified_to else None,
includeState=entity_query.include_state,
includeTransient=entity_query.include_transient,
pageSize=helpers.get_int_value(entity_query.page_size),
continuationToken=_continuation_token
)
)
resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request)
entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next page of entities...")
if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
break
_continuation_token = resp.continuationToken
else:
break
return entities

def clean_entity_storage(self,
remove_empty_entities: bool = True,
release_orphaned_locks: bool = True
) -> CleanEntityStorageResult:
self._logger.info("Cleaning entity storage")

empty_entities_removed = 0
orphaned_locks_released = 0
_continuation_token = None

while True:
req = pb.CleanEntityStorageRequest(
removeEmptyEntities=remove_empty_entities,
releaseOrphanedLocks=release_orphaned_locks,
continuationToken=_continuation_token
)
resp: pb.CleanEntityStorageResponse = self._stub.CleanEntityStorage(req)
empty_entities_removed += resp.emptyEntitiesRemoved
orphaned_locks_released += resp.orphanedLocksReleased

if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, cleaning next page...")
if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
break
_continuation_token = resp.continuationToken
else:
break

return CleanEntityStorageResult(empty_entities_removed, orphaned_locks_released)
14 changes: 9 additions & 5 deletions durabletask/entities/entity_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,22 @@ def __init__(self,

@staticmethod
def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool):
return EntityMetadata.from_entity_metadata(entity_response.entity, includes_state)

@staticmethod
def from_entity_metadata(entity: pb.EntityMetadata, includes_state: bool):
try:
entity_id = EntityInstanceId.parse(entity_response.entity.instanceId)
entity_id = EntityInstanceId.parse(entity.instanceId)
except ValueError:
raise ValueError("Invalid entity instance ID in entity response.")
entity_state = None
if includes_state:
entity_state = entity_response.entity.serializedState.value
entity_state = entity.serializedState.value
return EntityMetadata(
id=entity_id,
last_modified=entity_response.entity.lastModifiedTime.ToDatetime(timezone.utc),
backlog_queue_size=entity_response.entity.backlogQueueSize,
locked_by=entity_response.entity.lockedBy.value,
last_modified=entity.lastModifiedTime.ToDatetime(timezone.utc),
backlog_queue_size=entity.backlogQueueSize,
locked_by=entity.lockedBy.value,
includes_state=includes_state,
state=entity_state
)
Expand Down
7 changes: 7 additions & 0 deletions durabletask/internal/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ def get_string_value(val: Optional[str]) -> Optional[wrappers_pb2.StringValue]:
return wrappers_pb2.StringValue(value=val)


def get_int_value(val: Optional[int]) -> Optional[wrappers_pb2.Int32Value]:
if val is None:
return None
else:
return wrappers_pb2.Int32Value(value=val)


def get_string_value_or_empty(val: Optional[str]) -> wrappers_pb2.StringValue:
if val is None:
return wrappers_pb2.StringValue(value="")
Expand Down
Loading
Loading