diff --git a/temporalio/activity.py b/temporalio/activity.py index ff46bdea8..6926161a5 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -104,16 +104,25 @@ class Info: heartbeat_details: Sequence[Any] heartbeat_timeout: timedelta | None is_local: bool + namespace: str schedule_to_close_timeout: timedelta | None scheduled_time: datetime start_to_close_timeout: timedelta | None started_time: datetime task_queue: str task_token: bytes - workflow_id: str - workflow_namespace: str - workflow_run_id: str - workflow_type: str + workflow_id: str | None + """ID of the workflow. None if the activity was not started by a workflow.""" + workflow_namespace: str | None + """Namespace of the workflow. None if the activity was not started by a workflow. + + .. deprecated:: + Use :py:attr:`namespace` instead. + """ + workflow_run_id: str | None + """Run ID of the workflow. None if the activity was not started by a workflow.""" + workflow_type: str | None + """Type of the workflow. None if the activity was not started by a workflow.""" priority: temporalio.common.Priority retry_policy: temporalio.common.RetryPolicy | None """The retry policy of this activity. @@ -122,6 +131,14 @@ class Info: If the value is None, it means the server didn't send information about retry policy (e.g. due to old server version), but it may still be defined server-side.""" + activity_run_id: str | None = None + """Run ID of this activity. None for workflow activities.""" + + @property + def in_workflow(self) -> bool: + """Was this activity started by a workflow?""" + return self.workflow_id is not None + # TODO(cretz): Consider putting identity on here for "worker_id" for logger? def _logger_details(self) -> Mapping[str, Any]: @@ -129,7 +146,7 @@ def _logger_details(self) -> Mapping[str, Any]: "activity_id": self.activity_id, "activity_type": self.activity_type, "attempt": self.attempt, - "namespace": self.workflow_namespace, + "namespace": self.namespace, "task_queue": self.task_queue, "workflow_id": self.workflow_id, "workflow_run_id": self.workflow_run_id, @@ -238,7 +255,7 @@ def metric_meter(self) -> temporalio.common.MetricMeter: info = self.info() self._metric_meter = self.runtime_metric_meter.with_additional_attributes( { - "namespace": info.workflow_namespace, + "namespace": info.namespace, "task_queue": info.task_queue, "activity_type": info.activity_type, } @@ -577,6 +594,20 @@ def must_from_callable(fn: Callable) -> _Definition: f"Activity {fn_name} missing attributes, was it decorated with @activity.defn?" ) + @classmethod + def get_name_and_result_type( + cls, name_or_run_fn: str | Callable[..., Any] + ) -> tuple[str, type | None]: + if isinstance(name_or_run_fn, str): + return name_or_run_fn, None + elif callable(name_or_run_fn): + defn = cls.must_from_callable(name_or_run_fn) + if not defn.name: + raise ValueError(f"Activity {name_or_run_fn} definition has no name") + return defn.name, defn.ret_type + else: + raise TypeError("Activity must be a string or callable") # type:ignore[reportUnreachable] + @staticmethod def _apply_to_callable( fn: Callable, diff --git a/temporalio/client.py b/temporalio/client.py index b4d5af0fa..8dcd34b62 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -39,6 +39,8 @@ from google.protobuf.internal.containers import MessageMap from typing_extensions import Required, Self, TypedDict +import temporalio.activity +import temporalio.api.activity.v1 import temporalio.api.common.v1 import temporalio.api.enums.v1 import temporalio.api.errordetails.v1 @@ -60,6 +62,7 @@ import temporalio.workflow from temporalio.activity import ActivityCancellationDetails from temporalio.converter import ( + ActivitySerializationContext, DataConverter, SerializationContext, WithSerializationContext, @@ -79,6 +82,10 @@ from .common import HeaderCodecBehavior from .types import ( AnyType, + CallableAsyncNoParam, + CallableAsyncSingleParam, + CallableSyncNoParam, + CallableSyncSingleParam, LocalReturnType, MethodAsyncNoParam, MethodAsyncSingleParam, @@ -1266,1436 +1273,3180 @@ async def count_workflows( ) ) + # async no-param @overload - def get_async_activity_handle( - self, *, workflow_id: str, run_id: str | None, activity_id: str - ) -> AsyncActivityHandle: - pass + async def start_activity( + self, + activity: CallableAsyncNoParam[ReturnType], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + # sync no-param @overload - def get_async_activity_handle(self, *, task_token: bytes) -> AsyncActivityHandle: - pass + async def start_activity( + self, + activity: CallableSyncNoParam[ReturnType], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... - def get_async_activity_handle( + # async single-param + @overload + async def start_activity( self, + activity: CallableAsyncSingleParam[ParamType, ReturnType], + arg: ParamType, *, - workflow_id: str | None = None, - run_id: str | None = None, - activity_id: str | None = None, - task_token: bytes | None = None, - ) -> AsyncActivityHandle: - """Get an async activity handle. + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... - Either the workflow_id, run_id, and activity_id can be provided, or a - singular task_token can be provided. + # sync single-param + @overload + async def start_activity( + self, + activity: CallableSyncSingleParam[ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... - Args: - workflow_id: Workflow ID for the activity. Cannot be set if - task_token is set. - run_id: Run ID for the activity. Cannot be set if task_token is set. - activity_id: ID for the activity. Cannot be set if task_token is - set. - task_token: Task token for the activity. Cannot be set if any of the - id parameters are set. + # async multi-param + @overload + async def start_activity( + self, + activity: Callable[..., Awaitable[ReturnType]], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... - Returns: - A handle that can be used for completion or heartbeat. - """ - if task_token is not None: - if workflow_id is not None or run_id is not None or activity_id is not None: - raise ValueError("Task token cannot be present with other IDs") - return AsyncActivityHandle(self, task_token) - elif workflow_id is not None: - if activity_id is None: - raise ValueError( - "Workflow ID, run ID, and activity ID must all be given together" - ) - return AsyncActivityHandle( - self, - AsyncActivityIDReference( - workflow_id=workflow_id, run_id=run_id, activity_id=activity_id - ), - ) - raise ValueError("Task token or workflow/run/activity ID must be present") + # sync multi-param + @overload + async def start_activity( + self, + activity: Callable[..., ReturnType], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... - async def create_schedule( + # string name + @overload + async def start_activity( self, + activity: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], id: str, - schedule: Schedule, + task_queue: str, + result_type: type | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[Any]: ... + + async def start_activity( + self, + activity: ( + str | Callable[..., Awaitable[ReturnType]] | Callable[..., ReturnType] + ), + arg: Any = temporalio.common._arg_unset, *, - trigger_immediately: bool = False, - backfill: Sequence[ScheduleBackfill] = [], - memo: Mapping[str, Any] | None = None, - search_attributes: None - | ( - temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes - ) = None, + args: Sequence[Any] = [], + id: str, + task_queue: str, + result_type: type | None = None, + # Either schedule_to_close_timeout or start_to_close_timeout must be present + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> ScheduleHandle: - """Create a schedule and return its handle. + ) -> ActivityHandle[ReturnType]: + """Start an activity and return its handle. + + .. warning:: + This API is experimental. Args: - id: Unique identifier of the schedule. - schedule: Schedule to create. - trigger_immediately: If true, trigger one action immediately when - creating the schedule. - backfill: Set of time periods to take actions on as if that time - passed right now. - memo: Memo for the schedule. Memo for a scheduled workflow is part - of the schedule action. - search_attributes: Search attributes for the schedule. Search - attributes for a scheduled workflow are part of the scheduled - action. The dictionary form of this is DEPRECATED, use - :py:class:`temporalio.common.TypedSearchAttributes`. - rpc_metadata: Headers used on the RPC call. Keys here override - client-level RPC metadata keys. + activity: String name or callable activity function to execute. + arg: Single argument to the activity. + args: Multiple arguments to the activity. Cannot be set if arg is. + id: Unique identifier for the activity. Required. + task_queue: Task queue to send the activity to. + result_type: For string name activities, optional type to deserialize result into. + schedule_to_close_timeout: Total time allowed for the activity from schedule to completion. + schedule_to_start_timeout: Time allowed for the activity to sit in the task queue. + start_to_close_timeout: Time allowed for a single execution attempt. + heartbeat_timeout: Time between heartbeats before the activity is considered failed. + id_reuse_policy: How to handle reusing activity IDs from closed activities. + Default is ALLOW_DUPLICATE. + id_conflict_policy: How to handle activity ID conflicts with running activities. + Default is FAIL. + retry_policy: Retry policy for the activity. + search_attributes: Search attributes for the activity. + summary: A single-line fixed summary for this activity that may appear + in the UI/CLI. This can be in single-line Temporal markdown format. + priority: Priority of the activity execution. + rpc_metadata: Headers used on the RPC call. rpc_timeout: Optional RPC deadline to set for the RPC call. Returns: - A handle to the created schedule. - - Raises: - ScheduleAlreadyRunningError: If a schedule with this ID is already - running. + A handle to the started activity. """ - temporalio.common._warn_on_deprecated_search_attributes(search_attributes) - return await self._impl.create_schedule( - CreateScheduleInput( + name, result_type_from_type_annotation = ( + temporalio.activity._Definition.get_name_and_result_type(activity) + ) + return await self._impl.start_activity( + StartActivityInput( + activity_type=name, + args=temporalio.common._arg_or_args(arg, args), id=id, - schedule=schedule, - trigger_immediately=trigger_immediately, - backfill=backfill, - memo=memo, + task_queue=task_queue, + result_type=result_type or result_type_from_type_annotation, + schedule_to_close_timeout=schedule_to_close_timeout, + schedule_to_start_timeout=schedule_to_start_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, search_attributes=search_attributes, + summary=summary, + headers={}, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, + priority=priority, ) ) - def get_schedule_handle(self, id: str) -> ScheduleHandle: - """Get a schedule handle for the given ID.""" - return ScheduleHandle(self, id) - - async def list_schedules( + # async no-param + @overload + async def execute_activity( self, - query: str | None = None, + activity: CallableAsyncNoParam[ReturnType], *, - page_size: int = 1000, - next_page_token: bytes | None = None, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> ScheduleAsyncIterator: - """List schedules. - - This does not make a request until the first iteration is attempted. - Therefore any errors will not occur until then. - - Note, this list is eventually consistent. Therefore if a schedule is - added or deleted, it may not be available in the list immediately. - - Args: - page_size: Maximum number of results for each page. - query: A Temporal visibility list filter. See Temporal documentation - concerning visibility list filters including behavior when left - unset. - next_page_token: A previously obtained next page token if doing - pagination. Usually not needed as the iterator automatically - starts from the beginning. - rpc_metadata: Headers used on each RPC call. Keys here override - client-level RPC metadata keys. - rpc_timeout: Optional RPC deadline to set for each RPC call. - - Returns: - An async iterator that can be used with ``async for``. - """ - return self._impl.list_schedules( - ListSchedulesInput( - page_size=page_size, - next_page_token=next_page_token, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - query=query, - ) - ) + ) -> ReturnType: ... - async def update_worker_build_id_compatibility( + # sync no-param + @overload + async def execute_activity( self, + activity: CallableSyncNoParam[ReturnType], + *, + id: str, task_queue: str, - operation: BuildIdOp, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> None: - """Used to add new Build IDs or otherwise update the relative compatibility of Build Ids as - defined on a specific task queue for the Worker Versioning feature. - - For more on this feature, see https://docs.temporal.io/workers#worker-versioning - - .. warning:: - This API is experimental + ) -> ReturnType: ... - Args: - task_queue: The task queue to target. - operation: The operation to perform. - rpc_metadata: Headers used on each RPC call. Keys here override - client-level RPC metadata keys. - rpc_timeout: Optional RPC deadline to set for each RPC call. - """ - return await self._impl.update_worker_build_id_compatibility( - UpdateWorkerBuildIdCompatibilityInput( - task_queue, - operation, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - ) - ) + # async single-param + @overload + async def execute_activity( + self, + activity: CallableAsyncSingleParam[ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... - async def get_worker_build_id_compatibility( + # sync single-param + @overload + async def execute_activity( self, + activity: CallableSyncSingleParam[ParamType, ReturnType], + arg: ParamType, + *, + id: str, task_queue: str, - max_sets: int | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> WorkerBuildIdVersionSets: - """Get the Build ID compatibility sets for a specific task queue. + ) -> ReturnType: ... - For more on this feature, see https://docs.temporal.io/workers#worker-versioning - - .. warning:: - This API is experimental + # async multi-param + @overload + async def execute_activity( + self, + activity: Callable[..., Awaitable[ReturnType]], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... - Args: - task_queue: The task queue to target. - max_sets: The maximum number of sets to return. If not specified, all sets will be - returned. - rpc_metadata: Headers used on each RPC call. Keys here override - client-level RPC metadata keys. - rpc_timeout: Optional RPC deadline to set for each RPC call. - """ - return await self._impl.get_worker_build_id_compatibility( - GetWorkerBuildIdCompatibilityInput( - task_queue, - max_sets, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - ) - ) + # sync multi-param + @overload + async def execute_activity( + self, + activity: Callable[..., ReturnType], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... - async def get_worker_task_reachability( + # string name + @overload + async def execute_activity( self, - build_ids: Sequence[str], - task_queues: Sequence[str] = [], - reachability_type: TaskReachabilityType | None = None, + activity: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + result_type: type | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> WorkerTaskReachability: - """Determine if some Build IDs for certain Task Queues could have tasks dispatched to them. + ) -> Any: ... - For more on this feature, see https://docs.temporal.io/workers#worker-versioning + async def execute_activity( + self, + activity: ( + str | Callable[..., Awaitable[ReturnType]] | Callable[..., ReturnType] + ), + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + result_type: type | None = None, + # Either schedule_to_close_timeout or start_to_close_timeout must be present + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: + """Start an activity, wait for it to complete, and return its result. .. warning:: - This API is experimental + This API is experimental. - Args: - build_ids: The Build IDs to query the reachability of. At least one must be specified. - task_queues: Task Queues to restrict the query to. If not specified, all Task Queues - will be searched. When requesting a large number of task queues or all task queues - associated with the given Build IDs in a namespace, all Task Queues will be listed - in the response but some of them may not contain reachability information due to a - server enforced limit. When reaching the limit, task queues that reachability - information could not be retrieved for will be marked with a ``NotFetched`` entry in - {@link BuildIdReachability.taskQueueReachability}. The caller may issue another call - to get the reachability for those task queues. - reachability_type: The kind of reachability this request is concerned with. - rpc_metadata: Headers used on each RPC call. Keys here override - client-level RPC metadata keys. - rpc_timeout: Optional RPC deadline to set for each RPC call. + This is a convenience method that combines :py:meth:`start_activity` and + :py:meth:`ActivityHandle.result`. + + Returns: + The result of the activity. + + Raises: + ActivityFailureError: If the activity completed with a failure. """ - return await self._impl.get_worker_task_reachability( - GetWorkerTaskReachabilityInput( - build_ids, - task_queues, - reachability_type, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - ) + handle: ActivityHandle[ReturnType] = await self.start_activity( + cast(Any, activity), + arg, + args=args, + id=id, + task_queue=task_queue, + result_type=result_type, + schedule_to_close_timeout=schedule_to_close_timeout, + schedule_to_start_timeout=schedule_to_start_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + search_attributes=search_attributes, + summary=summary, + priority=priority, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, ) + return await handle.result() + # async no-param + @overload + async def start_activity_class( + self, + activity: type[CallableAsyncNoParam[ReturnType]], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... -class ClientConfig(TypedDict, total=False): - """TypedDict of config originally passed to :py:meth:`Client`.""" + # sync no-param + @overload + async def start_activity_class( + self, + activity: type[CallableSyncNoParam[ReturnType]], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... - service_client: Required[temporalio.service.ServiceClient] - namespace: Required[str] - data_converter: Required[temporalio.converter.DataConverter] - plugins: Required[Sequence[Plugin]] - interceptors: Required[Sequence[Interceptor]] - default_workflow_query_reject_condition: Required[ - temporalio.common.QueryRejectCondition | None - ] - header_codec_behavior: Required[HeaderCodecBehavior] + # async single-param + @overload + async def start_activity_class( + self, + activity: type[CallableAsyncSingleParam[ParamType, ReturnType]], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + # sync single-param + @overload + async def start_activity_class( + self, + activity: type[CallableSyncSingleParam[ParamType, ReturnType]], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... -class WorkflowHistoryEventFilterType(IntEnum): - """Type of history events to get for a workflow. + # async multi-param + @overload + async def start_activity_class( + self, + activity: type[Callable[..., Awaitable[ReturnType]]], # type: ignore[reportInvalidTypeForm] + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... - See :py:class:`temporalio.api.enums.v1.HistoryEventFilterType`. - """ + # sync multi-param + @overload + async def start_activity_class( + self, + activity: type[Callable[..., ReturnType]], # type: ignore[reportInvalidTypeForm] + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... - ALL_EVENT = int( - temporalio.api.enums.v1.HistoryEventFilterType.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT - ) - CLOSE_EVENT = int( - temporalio.api.enums.v1.HistoryEventFilterType.HISTORY_EVENT_FILTER_TYPE_CLOSE_EVENT - ) + async def start_activity_class( + self, + activity: type[Callable], # type: ignore[reportInvalidTypeForm] + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + result_type: type | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[Any]: + """Start an activity from a callable class. + .. warning:: + This API is experimental. -class WorkflowHandle(Generic[SelfType, ReturnType]): - """Handle for interacting with a workflow. + See :py:meth:`start_activity` for parameter and return details. + """ + return await self.start_activity( + cast(Any, activity), + arg, + args=args, + id=id, + task_queue=task_queue, + result_type=result_type, + schedule_to_close_timeout=schedule_to_close_timeout, + schedule_to_start_timeout=schedule_to_start_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + search_attributes=search_attributes, + summary=summary, + priority=priority, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) - This is usually created via :py:meth:`Client.get_workflow_handle` or - returned from :py:meth:`Client.start_workflow`. - """ + # async no-param + @overload + async def execute_activity_class( + self, + activity: type[CallableAsyncNoParam[ReturnType]], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... - def __init__( + # sync no-param + @overload + async def execute_activity_class( self, - client: Client, + activity: type[CallableSyncNoParam[ReturnType]], + *, id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + # async single-param + @overload + async def execute_activity_class( + self, + activity: type[CallableAsyncSingleParam[ParamType, ReturnType]], + arg: ParamType, *, - run_id: str | None = None, - result_run_id: str | None = None, - first_execution_run_id: str | None = None, - result_type: type | None = None, - start_workflow_response: None - | ( - temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse - | temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse - ) = None, - ) -> None: - """Create workflow handle.""" - self._client = client - self._id = id - self._run_id = run_id - self._result_run_id = result_run_id - self._first_execution_run_id = first_execution_run_id - self._result_type = result_type - self._start_workflow_response = start_workflow_response - self.__temporal_eagerly_started = False + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... - @functools.cached_property - def _data_converter(self) -> temporalio.converter.DataConverter: - return self._client.data_converter.with_context( - temporalio.converter.WorkflowSerializationContext( - namespace=self._client.namespace, workflow_id=self._id - ) - ) + # sync single-param + @overload + async def execute_activity_class( + self, + activity: type[CallableSyncSingleParam[ParamType, ReturnType]], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... - @property - def id(self) -> str: - """ID for the workflow.""" - return self._id + # async multi-param + @overload + async def execute_activity_class( + self, + activity: type[Callable[..., Awaitable[ReturnType]]], # type: ignore[reportInvalidTypeForm] + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... - @property - def run_id(self) -> str | None: - """If present, run ID used to ensure that requested operations apply - to this exact run. + # sync multi-param + @overload + async def execute_activity_class( + self, + activity: type[Callable[..., ReturnType]], # type: ignore[reportInvalidTypeForm] + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... - This is only created via :py:meth:`Client.get_workflow_handle`. - :py:meth:`Client.start_workflow` will not set this value. + async def execute_activity_class( + self, + activity: type[Callable], # type: ignore[reportInvalidTypeForm] + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + result_type: type | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> Any: + """Start an activity from a callable class and wait for completion. - This cannot be mutated. If a different run ID is needed, - :py:meth:`Client.get_workflow_handle` must be used instead. + .. warning:: + This API is experimental. + + This is a shortcut for ``await`` :py:meth:`start_activity_class`. """ - return self._run_id + return await self.execute_activity( + cast(Any, activity), + arg, + args=args, + id=id, + task_queue=task_queue, + result_type=result_type, + schedule_to_close_timeout=schedule_to_close_timeout, + schedule_to_start_timeout=schedule_to_start_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + search_attributes=search_attributes, + summary=summary, + priority=priority, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) - @property - def result_run_id(self) -> str | None: - """Run ID used for :py:meth:`result` calls if present to ensure result - is for a workflow starting from this run. + # async no-param + @overload + async def start_activity_method( + self, + activity: MethodAsyncNoParam[SelfType, ReturnType], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... - When this handle is created via :py:meth:`Client.get_workflow_handle`, - this is the same as run_id. When this handle is created via - :py:meth:`Client.start_workflow`, this value will be the resulting run - ID. + # async single-param + @overload + async def start_activity_method( + self, + activity: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... - This cannot be mutated. If a different run ID is needed, - :py:meth:`Client.get_workflow_handle` must be used instead. - """ - return self._result_run_id + # async multi-param + @overload + async def start_activity_method( + self, + activity: Callable[ + Concatenate[SelfType, MultiParamSpec], Awaitable[ReturnType] + ], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... - @property - def first_execution_run_id(self) -> str | None: - """Run ID used to ensure requested operations apply to a workflow ID - started with this run ID. + # sync multi-param + @overload + async def start_activity_method( + self, + activity: Callable[Concatenate[SelfType, MultiParamSpec], ReturnType], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... - This can be set when using :py:meth:`Client.get_workflow_handle`. When - :py:meth:`Client.start_workflow` is called without a start signal, this - is set to the resulting run. + async def start_activity_method( + self, + activity: Callable, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + result_type: type | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[Any]: + """Start an activity from a method. - This cannot be mutated. If a different first execution run ID is needed, - :py:meth:`Client.get_workflow_handle` must be used instead. + .. warning:: + This API is experimental. + + See :py:meth:`start_activity` for parameter and return details. """ - return self._first_execution_run_id + return await self.start_activity( + cast(Any, activity), + arg, + args=args, + id=id, + task_queue=task_queue, + result_type=result_type, + schedule_to_close_timeout=schedule_to_close_timeout, + schedule_to_start_timeout=schedule_to_start_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + search_attributes=search_attributes, + summary=summary, + priority=priority, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) - async def result( + # async no-param + @overload + async def execute_activity_method( self, + activity: MethodAsyncNoParam[SelfType, ReturnType], *, - follow_runs: bool = True, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> ReturnType: - """Wait for result of the workflow. + ) -> ReturnType: ... - This will use :py:attr:`result_run_id` if present to base the result on. - To use another run ID, a new handle must be created via - :py:meth:`Client.get_workflow_handle`. + # async single-param + @overload + async def execute_activity_method( + self, + activity: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... - Args: - follow_runs: If true (default), workflow runs will be continually - fetched, until the most recent one is found. If false, return - the result from the first run targeted by the request if that run - ends in a result, otherwise raise an exception. - rpc_metadata: Headers used on the RPC call. Keys here override - client-level RPC metadata keys. - rpc_timeout: Optional RPC deadline to set for each RPC call. Note, - this is the timeout for each history RPC call not this overall - function. + # async multi-param + @overload + async def execute_activity_method( + self, + activity: Callable[ + Concatenate[SelfType, MultiParamSpec], Awaitable[ReturnType] + ], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... - Returns: - Result of the workflow after being converted by the data converter. + # sync multi-param + @overload + async def execute_activity_method( + self, + activity: Callable[Concatenate[SelfType, MultiParamSpec], ReturnType], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... - Raises: - WorkflowFailureError: Workflow failed, was cancelled, was - terminated, or timed out. Use the - :py:attr:`WorkflowFailureError.cause` to see the underlying - reason. - Exception: Other possible failures during result fetching. - """ - # We have to maintain our own run ID because it can change if we follow - # executions - hist_run_id = self._result_run_id - while True: - async for event in self._fetch_history_events_for_run( - hist_run_id, - wait_new_event=True, - event_filter_type=WorkflowHistoryEventFilterType.CLOSE_EVENT, - skip_archival=True, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - ): - if event.HasField("workflow_execution_completed_event_attributes"): - complete_attr = event.workflow_execution_completed_event_attributes - # Follow execution - if follow_runs and complete_attr.new_execution_run_id: - hist_run_id = complete_attr.new_execution_run_id - break - # Ignoring anything after the first response like TypeScript - type_hints = [self._result_type] if self._result_type else None - results = await self._data_converter.decode_wrapper( - complete_attr.result, - type_hints, - ) - if not results: - return cast(ReturnType, None) - elif len(results) > 1: - warnings.warn(f"Expected single result, got {len(results)}") - return cast(ReturnType, results[0]) - elif event.HasField("workflow_execution_failed_event_attributes"): - fail_attr = event.workflow_execution_failed_event_attributes - # Follow execution - if follow_runs and fail_attr.new_execution_run_id: - hist_run_id = fail_attr.new_execution_run_id - break - raise WorkflowFailureError( - cause=await self._data_converter.decode_failure( - fail_attr.failure - ), - ) - elif event.HasField("workflow_execution_canceled_event_attributes"): - cancel_attr = event.workflow_execution_canceled_event_attributes - raise WorkflowFailureError( - cause=temporalio.exceptions.CancelledError( - "Workflow cancelled", - *( - await self._data_converter.decode_wrapper( - cancel_attr.details - ) - ), - ) - ) - elif event.HasField("workflow_execution_terminated_event_attributes"): - term_attr = event.workflow_execution_terminated_event_attributes - raise WorkflowFailureError( - cause=temporalio.exceptions.TerminatedError( - term_attr.reason or "Workflow terminated", - *( - await self._data_converter.decode_wrapper( - term_attr.details - ) - ), - ), - ) - elif event.HasField("workflow_execution_timed_out_event_attributes"): - time_attr = event.workflow_execution_timed_out_event_attributes - # Follow execution - if follow_runs and time_attr.new_execution_run_id: - hist_run_id = time_attr.new_execution_run_id - break - raise WorkflowFailureError( - cause=temporalio.exceptions.TimeoutError( - "Workflow timed out", - type=temporalio.exceptions.TimeoutType.START_TO_CLOSE, - last_heartbeat_details=[], - ), - ) - elif event.HasField( - "workflow_execution_continued_as_new_event_attributes" - ): - cont_attr = ( - event.workflow_execution_continued_as_new_event_attributes - ) - if not cont_attr.new_execution_run_id: - raise RuntimeError( - "Unexpectedly missing new run ID from continue as new" - ) - # Follow execution - if follow_runs: - hist_run_id = cont_attr.new_execution_run_id - break - raise WorkflowContinuedAsNewError(cont_attr.new_execution_run_id) - # This is reached on break which means that there's a different run - # ID if we're following. If there's not, it's an error because no - # event was given (should never happen). - if hist_run_id is None: - raise RuntimeError("No completion event found") - - async def cancel( + async def execute_activity_method( self, + activity: Callable, + arg: Any = temporalio.common._arg_unset, *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + result_type: type | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> None: - """Cancel the workflow. + ) -> Any: + """Start an activity from a method and wait for completion. - This will issue a cancellation for :py:attr:`run_id` if present. This - call will make sure to use the run chain starting from - :py:attr:`first_execution_run_id` if present. To create handles with - these values, use :py:meth:`Client.get_workflow_handle`. + .. warning:: + This API is experimental. + + This is a shortcut for ``await`` :py:meth:`start_activity_method`. + """ + return await self.execute_activity( + cast(Any, activity), + arg, + args=args, + id=id, + task_queue=task_queue, + result_type=result_type, + schedule_to_close_timeout=schedule_to_close_timeout, + schedule_to_start_timeout=schedule_to_start_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + search_attributes=search_attributes, + summary=summary, + priority=priority, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + + def list_activities( + self, + query: str, + *, + limit: int | None = None, + page_size: int = 1000, + next_page_token: bytes | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityExecutionAsyncIterator: + """List activities not started by a workflow. .. warning:: - Handles created as a result of :py:meth:`Client.start_workflow` with - a start signal will cancel the latest workflow with the same - workflow ID even if it is unrelated to the started workflow. + This API is experimental. + + This does not make a request until the first iteration is attempted. + Therefore any errors will not occur until then. Args: - rpc_metadata: Headers used on the RPC call. Keys here override + query: A Temporal visibility list filter for activities. Required. + limit: Maximum number of activities to return. If unset, all + activities are returned. Only applies if using the + returned :py:class:`ActivityExecutionAsyncIterator` + as an async iterator. + page_size: Maximum number of results for each page. + next_page_token: A previously obtained next page token if doing + pagination. Usually not needed as the iterator automatically + starts from the beginning. + rpc_metadata: Headers used on each RPC call. Keys here override client-level RPC metadata keys. - rpc_timeout: Optional RPC deadline to set for the RPC call. + rpc_timeout: Optional RPC deadline to set for each RPC call. - Raises: - RPCError: Workflow could not be cancelled. + Returns: + An async iterator that can be used with ``async for``. """ - await self._client._impl.cancel_workflow( - CancelWorkflowInput( - id=self._id, - run_id=self._run_id, - first_execution_run_id=self._first_execution_run_id, + return self._impl.list_activities( + ListActivitiesInput( + query=query, + page_size=page_size, + next_page_token=next_page_token, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, + limit=limit, ) ) - async def describe( + async def count_activities( self, + query: str | None = None, *, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> WorkflowExecutionDescription: - """Get workflow details. - - This will get details for :py:attr:`run_id` if present. To use a - different run ID, create a new handle with via - :py:meth:`Client.get_workflow_handle`. + ) -> ActivityExecutionCount: + """Count activities not started by a workflow. .. warning:: - Handles created as a result of :py:meth:`Client.start_workflow` will - describe the latest workflow with the same workflow ID even if it is - unrelated to the started workflow. + This API is experimental. Args: + query: A Temporal visibility filter for activities. rpc_metadata: Headers used on the RPC call. Keys here override client-level RPC metadata keys. rpc_timeout: Optional RPC deadline to set for the RPC call. Returns: - Workflow details. - - Raises: - RPCError: Workflow details could not be fetched. + Count of activities. """ - return await self._client._impl.describe_workflow( - DescribeWorkflowInput( - id=self._id, - run_id=self._run_id, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, + return await self._impl.count_activities( + CountActivitiesInput( + query=query, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout ) ) - async def fetch_history( + @overload + def get_activity_handle( self, + activity_id: str, *, - event_filter_type: WorkflowHistoryEventFilterType = WorkflowHistoryEventFilterType.ALL_EVENT, - skip_archival: bool = False, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> WorkflowHistory: - """Get workflow history. + run_id: str | None = None, + ) -> ActivityHandle[Any]: ... - This is a shortcut for :py:meth:`fetch_history_events` that just fetches - all events. - """ - return WorkflowHistory( - workflow_id=self.id, - events=[ - v - async for v in self.fetch_history_events( - event_filter_type=event_filter_type, - skip_archival=skip_archival, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - ) - ], - ) + @overload + def get_activity_handle( + self, + activity_id: str, + *, + run_id: str | None = None, + result_type: type[ReturnType], + ) -> ActivityHandle[ReturnType]: ... - def fetch_history_events( + def get_activity_handle( self, + activity_id: str, *, - page_size: int | None = None, - next_page_token: bytes | None = None, - wait_new_event: bool = False, - event_filter_type: WorkflowHistoryEventFilterType = WorkflowHistoryEventFilterType.ALL_EVENT, - skip_archival: bool = False, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> WorkflowHistoryEventAsyncIterator: - """Get workflow history events as an async iterator. + run_id: str | None = None, + result_type: type | None = None, + ) -> ActivityHandle[Any]: + """Get a handle to an existing activity, as the caller of that activity. - This does not make a request until the first iteration is attempted. - Therefore any errors will not occur until then. + The activity must not have been started by a workflow. + + .. warning:: + This API is experimental. + + To get a handle to an activity execution that you control for manual completion and + heartbeating, see :py:meth:`Client.get_async_activity_handle`. Args: - page_size: Maximum amount to fetch per request if any maximum. - next_page_token: A specific page token to fetch. - wait_new_event: Whether the event fetching request will wait for new - events or just return right away. - event_filter_type: Which events to obtain. - skip_archival: Whether to skip archival. - rpc_metadata: Headers used on each RPC call. Keys here override - client-level RPC metadata keys. - rpc_timeout: Optional RPC deadline to set for each RPC call. + activity_id: The activity ID. + run_id: The activity run ID. If not provided, targets the latest run. + result_type: The result type to deserialize into. Returns: - An async iterator that doesn't begin fetching until iterated on. + A handle to the activity. """ - return self._fetch_history_events_for_run( - self._run_id, - page_size=page_size, - next_page_token=next_page_token, - wait_new_event=wait_new_event, - event_filter_type=event_filter_type, - skip_archival=skip_archival, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, + return ActivityHandle( + self, + activity_id, + run_id=run_id, + result_type=result_type, ) - def _fetch_history_events_for_run( - self, - run_id: str | None, - *, - page_size: int | None = None, - next_page_token: bytes | None = None, - wait_new_event: bool = False, - event_filter_type: WorkflowHistoryEventFilterType = WorkflowHistoryEventFilterType.ALL_EVENT, - skip_archival: bool = False, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> WorkflowHistoryEventAsyncIterator: - return self._client._impl.fetch_workflow_history_events( - FetchWorkflowHistoryEventsInput( - id=self._id, - run_id=run_id, - page_size=page_size, - next_page_token=next_page_token, - wait_new_event=wait_new_event, - event_filter_type=event_filter_type, - skip_archival=skip_archival, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - ) - ) - - # Overload for no-param query @overload - async def query( - self, - query: MethodSyncOrAsyncNoParam[SelfType, LocalReturnType], - *, - reject_condition: temporalio.common.QueryRejectCondition | None = None, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> LocalReturnType: ... + def get_async_activity_handle( + self, *, activity_id: str, run_id: str | None = None + ) -> AsyncActivityHandle: + pass - # Overload for single-param query @overload - async def query( - self, - query: MethodSyncOrAsyncSingleParam[SelfType, ParamType, LocalReturnType], - arg: ParamType, - *, - reject_condition: temporalio.common.QueryRejectCondition | None = None, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> LocalReturnType: ... + def get_async_activity_handle( + self, *, workflow_id: str, run_id: str | None, activity_id: str + ) -> AsyncActivityHandle: + pass - # Overload for multi-param query @overload - async def query( - self, - query: Callable[ - Concatenate[SelfType, MultiParamSpec], - Awaitable[LocalReturnType] | LocalReturnType, - ], - *, - args: Sequence[Any], - reject_condition: temporalio.common.QueryRejectCondition | None = None, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> LocalReturnType: ... + def get_async_activity_handle(self, *, task_token: bytes) -> AsyncActivityHandle: + pass - # Overload for string-name query - @overload - async def query( + def get_async_activity_handle( self, - query: str, - arg: Any = temporalio.common._arg_unset, *, - args: Sequence[Any] = [], - result_type: type | None = None, - reject_condition: temporalio.common.QueryRejectCondition | None = None, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> Any: ... + workflow_id: str | None = None, + run_id: str | None = None, + activity_id: str | None = None, + task_token: bytes | None = None, + ) -> AsyncActivityHandle: + """Get a handle to an activity execution that you control, for manual + completion and heartbeating. - async def query( + To get a handle to an activity execution as the caller of that activity, + see :py:meth:`Client.get_activity_handle`. + + This function may be used to get a handle to an activity started by a + client, or an activity started by a workflow. + + To get a handle to an activity started by a workflow, use one of the + following two calls: + - Supply ``workflow_id``, ``run_id``, and ``activity_id`` + - Supply the activity ``task_token`` alone + + To get a handle to an activity not started by a workflow, supply + ``activity_id`` and ``run_id`` + + Args: + workflow_id: Workflow ID for the activity, or None if not a workflow + activity. Cannot be set if task_token is set. + run_id: Run ID for the activity or workflow. Cannot be set if + task_token is set. + activity_id: ID for the activity. Cannot be set if task_token is + set. + task_token: Task token for the activity. Cannot be set with other + fields. + + Returns: + A handle that can be used for completion or heartbeating. + """ + if task_token is not None: + if workflow_id is not None or run_id is not None or activity_id is not None: + raise ValueError("Task token cannot be present with other IDs") + return AsyncActivityHandle(self, task_token) + elif workflow_id is not None: + if activity_id is None: + raise ValueError( + "Workflow ID, run ID, and activity ID must all be given together" + ) + return AsyncActivityHandle( + self, + AsyncActivityIDReference( + workflow_id=workflow_id, run_id=run_id, activity_id=activity_id + ), + ) + elif activity_id is not None: + return AsyncActivityHandle( + self, + AsyncActivityIDReference( + activity_id=activity_id, + run_id=run_id, + workflow_id=None, + ), + ) + raise ValueError( + "Require task token, or workflow_id & run_id & activity_id, or activity_id & run_id" + ) + + async def create_schedule( self, - query: str | Callable, - arg: Any = temporalio.common._arg_unset, + id: str, + schedule: Schedule, *, - args: Sequence[Any] = [], - result_type: type | None = None, - reject_condition: temporalio.common.QueryRejectCondition | None = None, + trigger_immediately: bool = False, + backfill: Sequence[ScheduleBackfill] = [], + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> Any: - """Query the workflow. - - This will query for :py:attr:`run_id` if present. To use a different - run ID, create a new handle with - :py:meth:`Client.get_workflow_handle`. - - .. warning:: - Handles created as a result of :py:meth:`Client.start_workflow` will - query the latest workflow with the same workflow ID even if it is - unrelated to the started workflow. + ) -> ScheduleHandle: + """Create a schedule and return its handle. Args: - query: Query function or name on the workflow. - arg: Single argument to the query. - args: Multiple arguments to the query. Cannot be set if arg is. - result_type: For string queries, this can set the specific result - type hint to deserialize into. - reject_condition: Condition for rejecting the query. If unset/None, - defaults to the client's default (which is defaulted to None). + id: Unique identifier of the schedule. + schedule: Schedule to create. + trigger_immediately: If true, trigger one action immediately when + creating the schedule. + backfill: Set of time periods to take actions on as if that time + passed right now. + memo: Memo for the schedule. Memo for a scheduled workflow is part + of the schedule action. + search_attributes: Search attributes for the schedule. Search + attributes for a scheduled workflow are part of the scheduled + action. The dictionary form of this is DEPRECATED, use + :py:class:`temporalio.common.TypedSearchAttributes`. rpc_metadata: Headers used on the RPC call. Keys here override client-level RPC metadata keys. rpc_timeout: Optional RPC deadline to set for the RPC call. Returns: - Result of the query. + A handle to the created schedule. Raises: - WorkflowQueryRejectedError: A query reject condition was satisfied. - RPCError: Workflow details could not be fetched. + ScheduleAlreadyRunningError: If a schedule with this ID is already + running. """ - query_name: str - ret_type = result_type - if callable(query): - defn = temporalio.workflow._QueryDefinition.from_fn(query) - if not defn: - raise RuntimeError( - f"Query definition not found on {query.__qualname__}, " - "is it decorated with @workflow.query?" - ) - elif not defn.name: - raise RuntimeError("Cannot invoke dynamic query definition") - # TODO(cretz): Check count/type of args at runtime? - query_name = defn.name - ret_type = defn.ret_type - else: - query_name = str(query) - - return await self._client._impl.query_workflow( - QueryWorkflowInput( - id=self._id, - run_id=self._run_id, - query=query_name, - args=temporalio.common._arg_or_args(arg, args), - reject_condition=reject_condition - or self._client._config["default_workflow_query_reject_condition"], - headers={}, - ret_type=ret_type, + temporalio.common._warn_on_deprecated_search_attributes(search_attributes) + return await self._impl.create_schedule( + CreateScheduleInput( + id=id, + schedule=schedule, + trigger_immediately=trigger_immediately, + backfill=backfill, + memo=memo, + search_attributes=search_attributes, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, ) ) - # Overload for no-param signal - @overload - async def signal( - self, - signal: MethodSyncOrAsyncNoParam[SelfType, None], - *, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> None: ... + def get_schedule_handle(self, id: str) -> ScheduleHandle: + """Get a schedule handle for the given ID.""" + return ScheduleHandle(self, id) - # Overload for single-param signal - @overload - async def signal( + async def list_schedules( self, - signal: MethodSyncOrAsyncSingleParam[SelfType, ParamType, None], - arg: ParamType, + query: str | None = None, *, + page_size: int = 1000, + next_page_token: bytes | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> None: ... + ) -> ScheduleAsyncIterator: + """List schedules. - # Overload for multi-param signal - @overload - async def signal( - self, - signal: Callable[Concatenate[SelfType, MultiParamSpec], Awaitable[None] | None], - *, - args: Sequence[Any], - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> None: ... + This does not make a request until the first iteration is attempted. + Therefore any errors will not occur until then. - # Overload for string-name signal - @overload - async def signal( - self, - signal: str, - arg: Any = temporalio.common._arg_unset, - *, - args: Sequence[Any] = [], - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> None: ... - - async def signal( - self, - signal: str | Callable, - arg: Any = temporalio.common._arg_unset, - *, - args: Sequence[Any] = [], - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> None: - """Send a signal to the workflow. - - This will signal for :py:attr:`run_id` if present. To use a different - run ID, create a new handle with via - :py:meth:`Client.get_workflow_handle`. - - .. warning:: - Handles created as a result of :py:meth:`Client.start_workflow` will - signal the latest workflow with the same workflow ID even if it is - unrelated to the started workflow. + Note, this list is eventually consistent. Therefore if a schedule is + added or deleted, it may not be available in the list immediately. Args: - signal: Signal function or name on the workflow. - arg: Single argument to the signal. - args: Multiple arguments to the signal. Cannot be set if arg is. - rpc_metadata: Headers used on the RPC call. Keys here override + page_size: Maximum number of results for each page. + query: A Temporal visibility list filter. See Temporal documentation + concerning visibility list filters including behavior when left + unset. + next_page_token: A previously obtained next page token if doing + pagination. Usually not needed as the iterator automatically + starts from the beginning. + rpc_metadata: Headers used on each RPC call. Keys here override client-level RPC metadata keys. - rpc_timeout: Optional RPC deadline to set for the RPC call. + rpc_timeout: Optional RPC deadline to set for each RPC call. - Raises: - RPCError: Workflow could not be signalled. + Returns: + An async iterator that can be used with ``async for``. """ - await self._client._impl.signal_workflow( - SignalWorkflowInput( - id=self._id, - run_id=self._run_id, - signal=temporalio.workflow._SignalDefinition.must_name_from_fn_or_str( - signal - ), - args=temporalio.common._arg_or_args(arg, args), - headers={}, + return self._impl.list_schedules( + ListSchedulesInput( + page_size=page_size, + next_page_token=next_page_token, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, + query=query, ) ) - async def terminate( + async def update_worker_build_id_compatibility( self, - *args: Any, - reason: str | None = None, + task_queue: str, + operation: BuildIdOp, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, ) -> None: - """Terminate the workflow. + """Used to add new Build IDs or otherwise update the relative compatibility of Build Ids as + defined on a specific task queue for the Worker Versioning feature. - This will issue a termination for :py:attr:`run_id` if present. This - call will make sure to use the run chain starting from - :py:attr:`first_execution_run_id` if present. To create handles with - these values, use :py:meth:`Client.get_workflow_handle`. + For more on this feature, see https://docs.temporal.io/workers#worker-versioning .. warning:: - Handles created as a result of :py:meth:`Client.start_workflow` with - a start signal will terminate the latest workflow with the same - workflow ID even if it is unrelated to the started workflow. + This API is experimental Args: - args: Details to store on the termination. - reason: Reason for the termination. - rpc_metadata: Headers used on the RPC call. Keys here override + task_queue: The task queue to target. + operation: The operation to perform. + rpc_metadata: Headers used on each RPC call. Keys here override client-level RPC metadata keys. - rpc_timeout: Optional RPC deadline to set for the RPC call. - - Raises: - RPCError: Workflow could not be terminated. + rpc_timeout: Optional RPC deadline to set for each RPC call. """ - await self._client._impl.terminate_workflow( - TerminateWorkflowInput( - id=self._id, - run_id=self._run_id, - args=args, - reason=reason, - first_execution_run_id=self._first_execution_run_id, + return await self._impl.update_worker_build_id_compatibility( + UpdateWorkerBuildIdCompatibilityInput( + task_queue, + operation, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, ) ) - # Overload for no-param update - @overload - async def execute_update( + async def get_worker_build_id_compatibility( self, - update: temporalio.workflow.UpdateMethodMultiParam[[SelfType], LocalReturnType], - *, - id: str | None = None, + task_queue: str, + max_sets: int | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> LocalReturnType: ... + ) -> WorkerBuildIdVersionSets: + """Get the Build ID compatibility sets for a specific task queue. - # Overload for single-param update - @overload - async def execute_update( - self, - update: temporalio.workflow.UpdateMethodMultiParam[ - [SelfType, ParamType], LocalReturnType - ], - arg: ParamType, - *, - id: str | None = None, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> LocalReturnType: ... + For more on this feature, see https://docs.temporal.io/workers#worker-versioning - # Overload for multi-param update - @overload - async def execute_update( - self, - update: temporalio.workflow.UpdateMethodMultiParam[ - MultiParamSpec, LocalReturnType - ], - *, - args: MultiParamSpec.args, # type: ignore - id: str | None = None, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> LocalReturnType: ... + .. warning:: + This API is experimental - # Overload for string-name update - @overload - async def execute_update( - self, - update: str, - arg: Any = temporalio.common._arg_unset, - *, - args: Sequence[Any] = [], - id: str | None = None, - result_type: type | None = None, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> Any: ... + Args: + task_queue: The task queue to target. + max_sets: The maximum number of sets to return. If not specified, all sets will be + returned. + rpc_metadata: Headers used on each RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for each RPC call. + """ + return await self._impl.get_worker_build_id_compatibility( + GetWorkerBuildIdCompatibilityInput( + task_queue, + max_sets, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + ) - async def execute_update( + async def get_worker_task_reachability( self, - update: str | Callable, - arg: Any = temporalio.common._arg_unset, - *, - args: Sequence[Any] = [], - id: str | None = None, - result_type: type | None = None, + build_ids: Sequence[str], + task_queues: Sequence[str] = [], + reachability_type: TaskReachabilityType | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> Any: - """Send an update request to the workflow and wait for it to complete. + ) -> WorkerTaskReachability: + """Determine if some Build IDs for certain Task Queues could have tasks dispatched to them. - This will target the workflow with :py:attr:`run_id` if present. To use a - different run ID, create a new handle with via :py:meth:`Client.get_workflow_handle`. + For more on this feature, see https://docs.temporal.io/workers#worker-versioning + + .. warning:: + This API is experimental Args: - update: Update function or name on the workflow. - arg: Single argument to the update. - args: Multiple arguments to the update. Cannot be set if arg is. - id: ID of the update. If not set, the default is a new UUID. - result_type: For string updates, this can set the specific result - type hint to deserialize into. - rpc_metadata: Headers used on the RPC call. Keys here override + build_ids: The Build IDs to query the reachability of. At least one must be specified. + task_queues: Task Queues to restrict the query to. If not specified, all Task Queues + will be searched. When requesting a large number of task queues or all task queues + associated with the given Build IDs in a namespace, all Task Queues will be listed + in the response but some of them may not contain reachability information due to a + server enforced limit. When reaching the limit, task queues that reachability + information could not be retrieved for will be marked with a ``NotFetched`` entry in + {@link BuildIdReachability.taskQueueReachability}. The caller may issue another call + to get the reachability for those task queues. + reachability_type: The kind of reachability this request is concerned with. + rpc_metadata: Headers used on each RPC call. Keys here override client-level RPC metadata keys. - rpc_timeout: Optional RPC deadline to set for the RPC call. - - Raises: - WorkflowUpdateFailedError: If the update failed. - WorkflowUpdateRPCTimeoutOrCancelledError: This update call timed out - or was cancelled. This doesn't mean the update itself was timed - out or cancelled. - RPCError: There was some issue sending the update to the workflow. + rpc_timeout: Optional RPC deadline to set for each RPC call. """ - handle = await self._start_update( - update, - arg, - args=args, - wait_for_stage=WorkflowUpdateStage.COMPLETED, - id=id, - result_type=result_type, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, + return await self._impl.get_worker_task_reachability( + GetWorkerTaskReachabilityInput( + build_ids, + task_queues, + reachability_type, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) ) - return await handle.result() - # Overload for no-param start update - @overload - async def start_update( - self, - update: temporalio.workflow.UpdateMethodMultiParam[[SelfType], LocalReturnType], - *, - wait_for_stage: WorkflowUpdateStage, - id: str | None = None, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> WorkflowUpdateHandle[LocalReturnType]: ... - # Overload for single-param start update - @overload - async def start_update( - self, - update: temporalio.workflow.UpdateMethodMultiParam[ - [SelfType, ParamType], LocalReturnType - ], - arg: ParamType, - *, - wait_for_stage: WorkflowUpdateStage, - id: str | None = None, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> WorkflowUpdateHandle[LocalReturnType]: ... +class ClientConfig(TypedDict, total=False): + """TypedDict of config originally passed to :py:meth:`Client`.""" - # Overload for multi-param start update - @overload - async def start_update( - self, - update: temporalio.workflow.UpdateMethodMultiParam[ - MultiParamSpec, LocalReturnType - ], - *, - args: MultiParamSpec.args, # type: ignore - wait_for_stage: WorkflowUpdateStage, - id: str | None = None, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> WorkflowUpdateHandle[LocalReturnType]: ... + service_client: Required[temporalio.service.ServiceClient] + namespace: Required[str] + data_converter: Required[temporalio.converter.DataConverter] + plugins: Required[Sequence[Plugin]] + interceptors: Required[Sequence[Interceptor]] + default_workflow_query_reject_condition: Required[ + temporalio.common.QueryRejectCondition | None + ] + header_codec_behavior: Required[HeaderCodecBehavior] - # Overload for string-name start update - @overload - async def start_update( - self, - update: str, - arg: Any = temporalio.common._arg_unset, - *, - wait_for_stage: WorkflowUpdateStage, - args: Sequence[Any] = [], - id: str | None = None, - result_type: type | None = None, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> WorkflowUpdateHandle[Any]: ... - async def start_update( - self, - update: str | Callable, - arg: Any = temporalio.common._arg_unset, - *, - wait_for_stage: WorkflowUpdateStage, - args: Sequence[Any] = [], - id: str | None = None, - result_type: type | None = None, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> WorkflowUpdateHandle[Any]: - """Send an update request to the workflow and return a handle to it. +class WorkflowHistoryEventFilterType(IntEnum): + """Type of history events to get for a workflow. - This will target the workflow with :py:attr:`run_id` if present. To use a - different run ID, create a new handle with via :py:meth:`Client.get_workflow_handle`. + See :py:class:`temporalio.api.enums.v1.HistoryEventFilterType`. + """ - Args: - update: Update function or name on the workflow. arg: Single argument to the - update. - wait_for_stage: Required stage to wait until returning: either ACCEPTED or - COMPLETED. ADMITTED is not currently supported. See - https://docs.temporal.io/workflows#update for more details. - args: Multiple arguments to the update. Cannot be set if arg is. - id: ID of the update. If not set, the default is a new UUID. - result_type: For string updates, this can set the specific result - type hint to deserialize into. - rpc_metadata: Headers used on the RPC call. Keys here override - client-level RPC metadata keys. - rpc_timeout: Optional RPC deadline to set for the RPC call. + ALL_EVENT = int( + temporalio.api.enums.v1.HistoryEventFilterType.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT + ) + CLOSE_EVENT = int( + temporalio.api.enums.v1.HistoryEventFilterType.HISTORY_EVENT_FILTER_TYPE_CLOSE_EVENT + ) - Raises: - WorkflowUpdateRPCTimeoutOrCancelledError: This update call timed out - or was cancelled. This doesn't mean the update itself was timed out or - cancelled. - RPCError: There was some issue sending the update to the workflow. - """ - return await self._start_update( - update, - arg, - wait_for_stage=wait_for_stage, - args=args, - id=id, - result_type=result_type, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - ) - async def _start_update( +class WorkflowHandle(Generic[SelfType, ReturnType]): + """Handle for interacting with a workflow. + + This is usually created via :py:meth:`Client.get_workflow_handle` or + returned from :py:meth:`Client.start_workflow`. + """ + + def __init__( self, - update: str | Callable, - arg: Any = temporalio.common._arg_unset, + client: Client, + id: str, *, - wait_for_stage: WorkflowUpdateStage, - args: Sequence[Any] = [], - id: str | None = None, + run_id: str | None = None, + result_run_id: str | None = None, + first_execution_run_id: str | None = None, result_type: type | None = None, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - ) -> WorkflowUpdateHandle[Any]: - if wait_for_stage == WorkflowUpdateStage.ADMITTED: - raise ValueError("ADMITTED wait stage not supported") - - update_name, result_type_from_type_hint = ( - temporalio.workflow._UpdateDefinition.get_name_and_result_type(update) - ) + start_workflow_response: None + | ( + temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse + | temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse + ) = None, + ) -> None: + """Create workflow handle.""" + self._client = client + self._id = id + self._run_id = run_id + self._result_run_id = result_run_id + self._first_execution_run_id = first_execution_run_id + self._result_type = result_type + self._start_workflow_response = start_workflow_response + self.__temporal_eagerly_started = False - return await self._client._impl.start_workflow_update( - StartWorkflowUpdateInput( - id=self._id, - run_id=self._run_id, - first_execution_run_id=self.first_execution_run_id, - update_id=id, - update=update_name, - args=temporalio.common._arg_or_args(arg, args), - headers={}, - ret_type=result_type or result_type_from_type_hint, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - wait_for_stage=wait_for_stage, + @functools.cached_property + def _data_converter(self) -> temporalio.converter.DataConverter: + return self._client.data_converter.with_context( + temporalio.converter.WorkflowSerializationContext( + namespace=self._client.namespace, workflow_id=self._id ) ) - def get_update_handle( - self, - id: str, - *, - workflow_run_id: str | None = None, - result_type: type | None = None, - ) -> WorkflowUpdateHandle[Any]: - """Get a handle for an update. The handle can be used to wait on the - update result. + @property + def id(self) -> str: + """ID of the workflow.""" + return self._id - Users may prefer the more typesafe :py:meth:`get_update_handle_for` - which accepts an update definition. + @property + def run_id(self) -> str | None: + """If present, run ID used to ensure that requested operations apply + to this exact run. - Args: - id: Update ID to get a handle to. - workflow_run_id: Run ID to tie the handle to. If this is not set, - the :py:attr:`run_id` will be used. - result_type: The result type to deserialize into if known. + This is only created via :py:meth:`Client.get_workflow_handle`. + :py:meth:`Client.start_workflow` will not set this value. - Returns: - The update handle. + This cannot be mutated. If a different run ID is needed, + :py:meth:`Client.get_workflow_handle` must be used instead. """ - return WorkflowUpdateHandle( - self._client, - id, - self._id, - workflow_run_id=workflow_run_id or self._run_id, - result_type=result_type, - ) - - def get_update_handle_for( - self, - update: temporalio.workflow.UpdateMethodMultiParam[Any, LocalReturnType], - id: str, - *, - workflow_run_id: str | None = None, - ) -> WorkflowUpdateHandle[LocalReturnType]: - """Get a typed handle for an update. The handle can be used to wait on - the update result. + return self._run_id - This is the same as :py:meth:`get_update_handle` but typed. + @property + def result_run_id(self) -> str | None: + """Run ID used for :py:meth:`result` calls if present to ensure result + is for a workflow starting from this run. - Args: - update: The update method to use for typing the handle. - id: Update ID to get a handle to. - workflow_run_id: Run ID to tie the handle to. If this is not set, - the :py:attr:`run_id` will be used. + When this handle is created via :py:meth:`Client.get_workflow_handle`, + this is the same as run_id. When this handle is created via + :py:meth:`Client.start_workflow`, this value will be the resulting run + ID. - Returns: - The update handle. + This cannot be mutated. If a different run ID is needed, + :py:meth:`Client.get_workflow_handle` must be used instead. """ - return self.get_update_handle( - id, workflow_run_id=workflow_run_id, result_type=update._defn.ret_type - ) + return self._result_run_id + @property + def first_execution_run_id(self) -> str | None: + """Run ID used to ensure requested operations apply to a workflow ID + started with this run ID. -class WithStartWorkflowOperation(Generic[SelfType, ReturnType]): - """Defines a start-workflow operation used by update-with-start requests. + This can be set when using :py:meth:`Client.get_workflow_handle`. When + :py:meth:`Client.start_workflow` is called without a start signal, this + is set to the resulting run. - Update-With-Start allows you to send an update to a workflow, while starting the - workflow if necessary. - """ + This cannot be mutated. If a different first execution run ID is needed, + :py:meth:`Client.get_workflow_handle` must be used instead. + """ + return self._first_execution_run_id - # Overload for no-param workflow, with_start - @overload - def __init__( + async def result( self, - workflow: MethodAsyncNoParam[SelfType, ReturnType], *, - id: str, - task_queue: str, - id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, - execution_timeout: timedelta | None = None, - run_timeout: timedelta | None = None, - task_timeout: timedelta | None = None, - id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, - retry_policy: temporalio.common.RetryPolicy | None = None, - cron_schedule: str = "", - memo: Mapping[str, Any] | None = None, - search_attributes: None - | ( - temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes - ) = None, - static_summary: str | None = None, - static_details: str | None = None, - start_delay: timedelta | None = None, + follow_runs: bool = True, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - priority: temporalio.common.Priority = temporalio.common.Priority.default, - versioning_override: temporalio.common.VersioningOverride | None = None, - ) -> None: ... + ) -> ReturnType: + """Wait for result of the workflow. - # Overload for single-param workflow, with_start - @overload - def __init__( + This will use :py:attr:`result_run_id` if present to base the result on. + To use another run ID, a new handle must be created via + :py:meth:`Client.get_workflow_handle`. + + Args: + follow_runs: If true (default), workflow runs will be continually + fetched, until the most recent one is found. If false, return + the result from the first run targeted by the request if that run + ends in a result, otherwise raise an exception. + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for each RPC call. Note, + this is the timeout for each history RPC call not this overall + function. + + Returns: + Result of the workflow after being converted by the data converter. + + Raises: + WorkflowFailureError: Workflow failed, was cancelled, was + terminated, or timed out. Use the + :py:attr:`WorkflowFailureError.cause` to see the underlying + reason. + Exception: Other possible failures during result fetching. + """ + # We have to maintain our own run ID because it can change if we follow + # executions + hist_run_id = self._result_run_id + while True: + async for event in self._fetch_history_events_for_run( + hist_run_id, + wait_new_event=True, + event_filter_type=WorkflowHistoryEventFilterType.CLOSE_EVENT, + skip_archival=True, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ): + if event.HasField("workflow_execution_completed_event_attributes"): + complete_attr = event.workflow_execution_completed_event_attributes + # Follow execution + if follow_runs and complete_attr.new_execution_run_id: + hist_run_id = complete_attr.new_execution_run_id + break + # Ignoring anything after the first response like TypeScript + type_hints = [self._result_type] if self._result_type else None + results = await self._data_converter.decode_wrapper( + complete_attr.result, + type_hints, + ) + if not results: + return cast(ReturnType, None) + elif len(results) > 1: + warnings.warn(f"Expected single result, got {len(results)}") + return cast(ReturnType, results[0]) + elif event.HasField("workflow_execution_failed_event_attributes"): + fail_attr = event.workflow_execution_failed_event_attributes + # Follow execution + if follow_runs and fail_attr.new_execution_run_id: + hist_run_id = fail_attr.new_execution_run_id + break + raise WorkflowFailureError( + cause=await self._data_converter.decode_failure( + fail_attr.failure + ), + ) + elif event.HasField("workflow_execution_canceled_event_attributes"): + cancel_attr = event.workflow_execution_canceled_event_attributes + raise WorkflowFailureError( + cause=temporalio.exceptions.CancelledError( + "Workflow cancelled", + *( + await self._data_converter.decode_wrapper( + cancel_attr.details + ) + ), + ) + ) + elif event.HasField("workflow_execution_terminated_event_attributes"): + term_attr = event.workflow_execution_terminated_event_attributes + raise WorkflowFailureError( + cause=temporalio.exceptions.TerminatedError( + term_attr.reason or "Workflow terminated", + *( + await self._data_converter.decode_wrapper( + term_attr.details + ) + ), + ), + ) + elif event.HasField("workflow_execution_timed_out_event_attributes"): + time_attr = event.workflow_execution_timed_out_event_attributes + # Follow execution + if follow_runs and time_attr.new_execution_run_id: + hist_run_id = time_attr.new_execution_run_id + break + raise WorkflowFailureError( + cause=temporalio.exceptions.TimeoutError( + "Workflow timed out", + type=temporalio.exceptions.TimeoutType.START_TO_CLOSE, + last_heartbeat_details=[], + ), + ) + elif event.HasField( + "workflow_execution_continued_as_new_event_attributes" + ): + cont_attr = ( + event.workflow_execution_continued_as_new_event_attributes + ) + if not cont_attr.new_execution_run_id: + raise RuntimeError( + "Unexpectedly missing new run ID from continue as new" + ) + # Follow execution + if follow_runs: + hist_run_id = cont_attr.new_execution_run_id + break + raise WorkflowContinuedAsNewError(cont_attr.new_execution_run_id) + # This is reached on break which means that there's a different run + # ID if we're following. If there's not, it's an error because no + # event was given (should never happen). + if hist_run_id is None: + raise RuntimeError("No completion event found") + + async def cancel( self, - workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], - arg: ParamType, *, - id: str, - task_queue: str, - id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, - execution_timeout: timedelta | None = None, - run_timeout: timedelta | None = None, - task_timeout: timedelta | None = None, - id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, - retry_policy: temporalio.common.RetryPolicy | None = None, - cron_schedule: str = "", - memo: Mapping[str, Any] | None = None, - search_attributes: None - | ( - temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes - ) = None, - static_summary: str | None = None, - static_details: str | None = None, - start_delay: timedelta | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - priority: temporalio.common.Priority = temporalio.common.Priority.default, - versioning_override: temporalio.common.VersioningOverride | None = None, - ) -> None: ... + ) -> None: + """Cancel the workflow. - # Overload for multi-param workflow, with_start - @overload - def __init__( + This will issue a cancellation for :py:attr:`run_id` if present. This + call will make sure to use the run chain starting from + :py:attr:`first_execution_run_id` if present. To create handles with + these values, use :py:meth:`Client.get_workflow_handle`. + + .. warning:: + Handles created as a result of :py:meth:`Client.start_workflow` with + a start signal will cancel the latest workflow with the same + workflow ID even if it is unrelated to the started workflow. + + Args: + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for the RPC call. + + Raises: + RPCError: Workflow could not be cancelled. + """ + await self._client._impl.cancel_workflow( + CancelWorkflowInput( + id=self._id, + run_id=self._run_id, + first_execution_run_id=self._first_execution_run_id, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + ) + + async def describe( self, - workflow: Callable[ - Concatenate[SelfType, MultiParamSpec], Awaitable[ReturnType] - ], *, - args: Sequence[Any], - id: str, - task_queue: str, - id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, - execution_timeout: timedelta | None = None, - run_timeout: timedelta | None = None, - task_timeout: timedelta | None = None, - id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, - retry_policy: temporalio.common.RetryPolicy | None = None, - cron_schedule: str = "", - memo: Mapping[str, Any] | None = None, - search_attributes: None - | ( - temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes - ) = None, - static_summary: str | None = None, - static_details: str | None = None, - start_delay: timedelta | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - priority: temporalio.common.Priority = temporalio.common.Priority.default, - versioning_override: temporalio.common.VersioningOverride | None = None, - ) -> None: ... + ) -> WorkflowExecutionDescription: + """Get workflow details. - # Overload for string-name workflow, with_start - @overload - def __init__( + This will get details for :py:attr:`run_id` if present. To use a + different run ID, create a new handle with via + :py:meth:`Client.get_workflow_handle`. + + .. warning:: + Handles created as a result of :py:meth:`Client.start_workflow` will + describe the latest workflow with the same workflow ID even if it is + unrelated to the started workflow. + + Args: + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for the RPC call. + + Returns: + Workflow details. + + Raises: + RPCError: Workflow details could not be fetched. + """ + return await self._client._impl.describe_workflow( + DescribeWorkflowInput( + id=self._id, + run_id=self._run_id, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + ) + + async def fetch_history( self, - workflow: str, - arg: Any = temporalio.common._arg_unset, *, - args: Sequence[Any] = [], - id: str, - task_queue: str, - id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, - result_type: type | None = None, - execution_timeout: timedelta | None = None, - run_timeout: timedelta | None = None, - task_timeout: timedelta | None = None, - id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, - retry_policy: temporalio.common.RetryPolicy | None = None, - cron_schedule: str = "", - memo: Mapping[str, Any] | None = None, - search_attributes: None - | ( - temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes - ) = None, - static_summary: str | None = None, - static_details: str | None = None, - start_delay: timedelta | None = None, + event_filter_type: WorkflowHistoryEventFilterType = WorkflowHistoryEventFilterType.ALL_EVENT, + skip_archival: bool = False, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> WorkflowHistory: + """Get workflow history. + + This is a shortcut for :py:meth:`fetch_history_events` that just fetches + all events. + """ + return WorkflowHistory( + workflow_id=self.id, + events=[ + v + async for v in self.fetch_history_events( + event_filter_type=event_filter_type, + skip_archival=skip_archival, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + ], + ) + + def fetch_history_events( + self, + *, + page_size: int | None = None, + next_page_token: bytes | None = None, + wait_new_event: bool = False, + event_filter_type: WorkflowHistoryEventFilterType = WorkflowHistoryEventFilterType.ALL_EVENT, + skip_archival: bool = False, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> WorkflowHistoryEventAsyncIterator: + """Get workflow history events as an async iterator. + + This does not make a request until the first iteration is attempted. + Therefore any errors will not occur until then. + + Args: + page_size: Maximum amount to fetch per request if any maximum. + next_page_token: A specific page token to fetch. + wait_new_event: Whether the event fetching request will wait for new + events or just return right away. + event_filter_type: Which events to obtain. + skip_archival: Whether to skip archival. + rpc_metadata: Headers used on each RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for each RPC call. + + Returns: + An async iterator that doesn't begin fetching until iterated on. + """ + return self._fetch_history_events_for_run( + self._run_id, + page_size=page_size, + next_page_token=next_page_token, + wait_new_event=wait_new_event, + event_filter_type=event_filter_type, + skip_archival=skip_archival, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + + def _fetch_history_events_for_run( + self, + run_id: str | None, + *, + page_size: int | None = None, + next_page_token: bytes | None = None, + wait_new_event: bool = False, + event_filter_type: WorkflowHistoryEventFilterType = WorkflowHistoryEventFilterType.ALL_EVENT, + skip_archival: bool = False, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> WorkflowHistoryEventAsyncIterator: + return self._client._impl.fetch_workflow_history_events( + FetchWorkflowHistoryEventsInput( + id=self._id, + run_id=run_id, + page_size=page_size, + next_page_token=next_page_token, + wait_new_event=wait_new_event, + event_filter_type=event_filter_type, + skip_archival=skip_archival, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + ) + + # Overload for no-param query + @overload + async def query( + self, + query: MethodSyncOrAsyncNoParam[SelfType, LocalReturnType], + *, + reject_condition: temporalio.common.QueryRejectCondition | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> LocalReturnType: ... + + # Overload for single-param query + @overload + async def query( + self, + query: MethodSyncOrAsyncSingleParam[SelfType, ParamType, LocalReturnType], + arg: ParamType, + *, + reject_condition: temporalio.common.QueryRejectCondition | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> LocalReturnType: ... + + # Overload for multi-param query + @overload + async def query( + self, + query: Callable[ + Concatenate[SelfType, MultiParamSpec], + Awaitable[LocalReturnType] | LocalReturnType, + ], + *, + args: Sequence[Any], + reject_condition: temporalio.common.QueryRejectCondition | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> LocalReturnType: ... + + # Overload for string-name query + @overload + async def query( + self, + query: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + result_type: type | None = None, + reject_condition: temporalio.common.QueryRejectCondition | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> Any: ... + + async def query( + self, + query: str | Callable, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + result_type: type | None = None, + reject_condition: temporalio.common.QueryRejectCondition | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> Any: + """Query the workflow. + + This will query for :py:attr:`run_id` if present. To use a different + run ID, create a new handle with + :py:meth:`Client.get_workflow_handle`. + + .. warning:: + Handles created as a result of :py:meth:`Client.start_workflow` will + query the latest workflow with the same workflow ID even if it is + unrelated to the started workflow. + + Args: + query: Query function or name on the workflow. + arg: Single argument to the query. + args: Multiple arguments to the query. Cannot be set if arg is. + result_type: For string queries, this can set the specific result + type hint to deserialize into. + reject_condition: Condition for rejecting the query. If unset/None, + defaults to the client's default (which is defaulted to None). + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for the RPC call. + + Returns: + Result of the query. + + Raises: + WorkflowQueryRejectedError: A query reject condition was satisfied. + RPCError: Workflow details could not be fetched. + """ + query_name: str + ret_type = result_type + if callable(query): + defn = temporalio.workflow._QueryDefinition.from_fn(query) + if not defn: + raise RuntimeError( + f"Query definition not found on {query.__qualname__}, " + "is it decorated with @workflow.query?" + ) + elif not defn.name: + raise RuntimeError("Cannot invoke dynamic query definition") + # TODO(cretz): Check count/type of args at runtime? + query_name = defn.name + ret_type = defn.ret_type + else: + query_name = str(query) + + return await self._client._impl.query_workflow( + QueryWorkflowInput( + id=self._id, + run_id=self._run_id, + query=query_name, + args=temporalio.common._arg_or_args(arg, args), + reject_condition=reject_condition + or self._client._config["default_workflow_query_reject_condition"], + headers={}, + ret_type=ret_type, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + ) + + # Overload for no-param signal + @overload + async def signal( + self, + signal: MethodSyncOrAsyncNoParam[SelfType, None], + *, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> None: ... + + # Overload for single-param signal + @overload + async def signal( + self, + signal: MethodSyncOrAsyncSingleParam[SelfType, ParamType, None], + arg: ParamType, + *, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> None: ... + + # Overload for multi-param signal + @overload + async def signal( + self, + signal: Callable[Concatenate[SelfType, MultiParamSpec], Awaitable[None] | None], + *, + args: Sequence[Any], rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - priority: temporalio.common.Priority = temporalio.common.Priority.default, - versioning_override: temporalio.common.VersioningOverride | None = None, ) -> None: ... - def __init__( - self, - workflow: str | Callable[..., Awaitable[Any]], - arg: Any = temporalio.common._arg_unset, - *, - args: Sequence[Any] = [], - id: str, - task_queue: str, - id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, - result_type: type | None = None, - execution_timeout: timedelta | None = None, - run_timeout: timedelta | None = None, - task_timeout: timedelta | None = None, - id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, - retry_policy: temporalio.common.RetryPolicy | None = None, - cron_schedule: str = "", - memo: Mapping[str, Any] | None = None, - search_attributes: None - | ( - temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes - ) = None, - static_summary: str | None = None, - static_details: str | None = None, - start_delay: timedelta | None = None, - rpc_metadata: Mapping[str, str | bytes] = {}, - rpc_timeout: timedelta | None = None, - priority: temporalio.common.Priority = temporalio.common.Priority.default, - versioning_override: temporalio.common.VersioningOverride | None = None, - stack_level: int = 2, - ) -> None: - """Create a WithStartWorkflowOperation. + # Overload for string-name signal + @overload + async def signal( + self, + signal: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> None: ... + + async def signal( + self, + signal: str | Callable, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> None: + """Send a signal to the workflow. + + This will signal for :py:attr:`run_id` if present. To use a different + run ID, create a new handle with via + :py:meth:`Client.get_workflow_handle`. + + .. warning:: + Handles created as a result of :py:meth:`Client.start_workflow` will + signal the latest workflow with the same workflow ID even if it is + unrelated to the started workflow. + + Args: + signal: Signal function or name on the workflow. + arg: Single argument to the signal. + args: Multiple arguments to the signal. Cannot be set if arg is. + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for the RPC call. + + Raises: + RPCError: Workflow could not be signalled. + """ + await self._client._impl.signal_workflow( + SignalWorkflowInput( + id=self._id, + run_id=self._run_id, + signal=temporalio.workflow._SignalDefinition.must_name_from_fn_or_str( + signal + ), + args=temporalio.common._arg_or_args(arg, args), + headers={}, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + ) + + async def terminate( + self, + *args: Any, + reason: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> None: + """Terminate the workflow. + + This will issue a termination for :py:attr:`run_id` if present. This + call will make sure to use the run chain starting from + :py:attr:`first_execution_run_id` if present. To create handles with + these values, use :py:meth:`Client.get_workflow_handle`. + + .. warning:: + Handles created as a result of :py:meth:`Client.start_workflow` with + a start signal will terminate the latest workflow with the same + workflow ID even if it is unrelated to the started workflow. + + Args: + args: Details to store on the termination. + reason: Reason for the termination. + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for the RPC call. + + Raises: + RPCError: Workflow could not be terminated. + """ + await self._client._impl.terminate_workflow( + TerminateWorkflowInput( + id=self._id, + run_id=self._run_id, + args=args, + reason=reason, + first_execution_run_id=self._first_execution_run_id, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + ) + + # Overload for no-param update + @overload + async def execute_update( + self, + update: temporalio.workflow.UpdateMethodMultiParam[[SelfType], LocalReturnType], + *, + id: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> LocalReturnType: ... + + # Overload for single-param update + @overload + async def execute_update( + self, + update: temporalio.workflow.UpdateMethodMultiParam[ + [SelfType, ParamType], LocalReturnType + ], + arg: ParamType, + *, + id: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> LocalReturnType: ... + + # Overload for multi-param update + @overload + async def execute_update( + self, + update: temporalio.workflow.UpdateMethodMultiParam[ + MultiParamSpec, LocalReturnType + ], + *, + args: MultiParamSpec.args, # type: ignore + id: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> LocalReturnType: ... + + # Overload for string-name update + @overload + async def execute_update( + self, + update: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str | None = None, + result_type: type | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> Any: ... + + async def execute_update( + self, + update: str | Callable, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str | None = None, + result_type: type | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> Any: + """Send an update request to the workflow and wait for it to complete. + + This will target the workflow with :py:attr:`run_id` if present. To use a + different run ID, create a new handle with via :py:meth:`Client.get_workflow_handle`. + + Args: + update: Update function or name on the workflow. + arg: Single argument to the update. + args: Multiple arguments to the update. Cannot be set if arg is. + id: ID of the update. If not set, the default is a new UUID. + result_type: For string updates, this can set the specific result + type hint to deserialize into. + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for the RPC call. + + Raises: + WorkflowUpdateFailedError: If the update failed. + WorkflowUpdateRPCTimeoutOrCancelledError: This update call timed out + or was cancelled. This doesn't mean the update itself was timed + out or cancelled. + RPCError: There was some issue sending the update to the workflow. + """ + handle = await self._start_update( + update, + arg, + args=args, + wait_for_stage=WorkflowUpdateStage.COMPLETED, + id=id, + result_type=result_type, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + return await handle.result() + + # Overload for no-param start update + @overload + async def start_update( + self, + update: temporalio.workflow.UpdateMethodMultiParam[[SelfType], LocalReturnType], + *, + wait_for_stage: WorkflowUpdateStage, + id: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> WorkflowUpdateHandle[LocalReturnType]: ... + + # Overload for single-param start update + @overload + async def start_update( + self, + update: temporalio.workflow.UpdateMethodMultiParam[ + [SelfType, ParamType], LocalReturnType + ], + arg: ParamType, + *, + wait_for_stage: WorkflowUpdateStage, + id: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> WorkflowUpdateHandle[LocalReturnType]: ... + + # Overload for multi-param start update + @overload + async def start_update( + self, + update: temporalio.workflow.UpdateMethodMultiParam[ + MultiParamSpec, LocalReturnType + ], + *, + args: MultiParamSpec.args, # type: ignore + wait_for_stage: WorkflowUpdateStage, + id: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> WorkflowUpdateHandle[LocalReturnType]: ... + + # Overload for string-name start update + @overload + async def start_update( + self, + update: str, + arg: Any = temporalio.common._arg_unset, + *, + wait_for_stage: WorkflowUpdateStage, + args: Sequence[Any] = [], + id: str | None = None, + result_type: type | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> WorkflowUpdateHandle[Any]: ... + + async def start_update( + self, + update: str | Callable, + arg: Any = temporalio.common._arg_unset, + *, + wait_for_stage: WorkflowUpdateStage, + args: Sequence[Any] = [], + id: str | None = None, + result_type: type | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> WorkflowUpdateHandle[Any]: + """Send an update request to the workflow and return a handle to it. + + This will target the workflow with :py:attr:`run_id` if present. To use a + different run ID, create a new handle with via :py:meth:`Client.get_workflow_handle`. + + Args: + update: Update function or name on the workflow. arg: Single argument to the + update. + wait_for_stage: Required stage to wait until returning: either ACCEPTED or + COMPLETED. ADMITTED is not currently supported. See + https://docs.temporal.io/workflows#update for more details. + args: Multiple arguments to the update. Cannot be set if arg is. + id: ID of the update. If not set, the default is a new UUID. + result_type: For string updates, this can set the specific result + type hint to deserialize into. + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for the RPC call. + + Raises: + WorkflowUpdateRPCTimeoutOrCancelledError: This update call timed out + or was cancelled. This doesn't mean the update itself was timed out or + cancelled. + RPCError: There was some issue sending the update to the workflow. + """ + return await self._start_update( + update, + arg, + wait_for_stage=wait_for_stage, + args=args, + id=id, + result_type=result_type, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + + async def _start_update( + self, + update: str | Callable, + arg: Any = temporalio.common._arg_unset, + *, + wait_for_stage: WorkflowUpdateStage, + args: Sequence[Any] = [], + id: str | None = None, + result_type: type | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> WorkflowUpdateHandle[Any]: + if wait_for_stage == WorkflowUpdateStage.ADMITTED: + raise ValueError("ADMITTED wait stage not supported") + + update_name, result_type_from_type_hint = ( + temporalio.workflow._UpdateDefinition.get_name_and_result_type(update) + ) + + return await self._client._impl.start_workflow_update( + StartWorkflowUpdateInput( + id=self._id, + run_id=self._run_id, + first_execution_run_id=self.first_execution_run_id, + update_id=id, + update=update_name, + args=temporalio.common._arg_or_args(arg, args), + headers={}, + ret_type=result_type or result_type_from_type_hint, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + wait_for_stage=wait_for_stage, + ) + ) + + def get_update_handle( + self, + id: str, + *, + workflow_run_id: str | None = None, + result_type: type | None = None, + ) -> WorkflowUpdateHandle[Any]: + """Get a handle for an update. The handle can be used to wait on the + update result. + + Users may prefer the more typesafe :py:meth:`get_update_handle_for` + which accepts an update definition. + + Args: + id: Update ID to get a handle to. + workflow_run_id: Run ID to tie the handle to. If this is not set, + the :py:attr:`run_id` will be used. + result_type: The result type to deserialize into if known. + + Returns: + The update handle. + """ + return WorkflowUpdateHandle( + self._client, + id, + self._id, + workflow_run_id=workflow_run_id or self._run_id, + result_type=result_type, + ) + + def get_update_handle_for( + self, + update: temporalio.workflow.UpdateMethodMultiParam[Any, LocalReturnType], + id: str, + *, + workflow_run_id: str | None = None, + ) -> WorkflowUpdateHandle[LocalReturnType]: + """Get a typed handle for an update. The handle can be used to wait on + the update result. + + This is the same as :py:meth:`get_update_handle` but typed. + + Args: + update: The update method to use for typing the handle. + id: Update ID to get a handle to. + workflow_run_id: Run ID to tie the handle to. If this is not set, + the :py:attr:`run_id` will be used. + + Returns: + The update handle. + """ + return self.get_update_handle( + id, workflow_run_id=workflow_run_id, result_type=update._defn.ret_type + ) + + +class WithStartWorkflowOperation(Generic[SelfType, ReturnType]): + """Defines a start-workflow operation used by update-with-start requests. + + Update-With-Start allows you to send an update to a workflow, while starting the + workflow if necessary. + """ + + # Overload for no-param workflow, with_start + @overload + def __init__( + self, + workflow: MethodAsyncNoParam[SelfType, ReturnType], + *, + id: str, + task_queue: str, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> None: ... + + # Overload for single-param workflow, with_start + @overload + def __init__( + self, + workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> None: ... + + # Overload for multi-param workflow, with_start + @overload + def __init__( + self, + workflow: Callable[ + Concatenate[SelfType, MultiParamSpec], Awaitable[ReturnType] + ], + *, + args: Sequence[Any], + id: str, + task_queue: str, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> None: ... + + # Overload for string-name workflow, with_start + @overload + def __init__( + self, + workflow: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, + result_type: type | None = None, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> None: ... + + def __init__( + self, + workflow: str | Callable[..., Awaitable[Any]], + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, + result_type: type | None = None, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + stack_level: int = 2, + ) -> None: + """Create a WithStartWorkflowOperation. + + See :py:meth:`temporalio.client.Client.start_workflow` for documentation of the + arguments. + """ + temporalio.common._warn_on_deprecated_search_attributes( + search_attributes, stack_level=stack_level + ) + name, result_type_from_run_fn = ( + temporalio.workflow._Definition.get_name_and_result_type(workflow) + ) + if id_conflict_policy == temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED: + raise ValueError("WorkflowIDConflictPolicy is required") + + self._start_workflow_input = UpdateWithStartStartWorkflowInput( + workflow=name, + args=temporalio.common._arg_or_args(arg, args), + id=id, + task_queue=task_queue, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + headers={}, + ret_type=result_type or result_type_from_run_fn, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + priority=priority, + versioning_override=versioning_override, + ) + self._workflow_handle: Future[WorkflowHandle[SelfType, ReturnType]] = Future() + self._used = False + + async def workflow_handle(self) -> WorkflowHandle[SelfType, ReturnType]: + """Wait until workflow is running and return a WorkflowHandle.""" + return await self._workflow_handle + + +class ActivityExecutionAsyncIterator: + """Asynchronous iterator for activity execution values. + + You should typically use ``async for`` on this iterator and not call any of its methods. + + .. warning:: + This API is experimental. + """ + + def __init__( + self, + client: Client, + input: ListActivitiesInput, + ) -> None: + """Create an asynchronous iterator for the given input. + + Users should not create this directly, but rather use + :py:meth:`Client.list_activities`. + """ + self._client = client + self._input = input + self._next_page_token = input.next_page_token + self._current_page: Sequence[ActivityExecution] | None = None + self._current_page_index = 0 + self._limit = input.limit + self._yielded = 0 + + @property + def current_page_index(self) -> int: + """Index of the entry in the current page that will be returned from + the next :py:meth:`__anext__` call. + """ + return self._current_page_index + + @property + def current_page(self) -> Sequence[ActivityExecution] | None: + """Current page, if it has been fetched yet.""" + return self._current_page + + @property + def next_page_token(self) -> bytes | None: + """Token for the next page request if any.""" + return self._next_page_token + + async def fetch_next_page(self, *, page_size: int | None = None) -> None: + """Fetch the next page of results. + + Args: + page_size: Override the page size this iterator was originally + created with. + """ + page_size = page_size or self._input.page_size + if self._limit is not None and self._limit - self._yielded < page_size: + page_size = self._limit - self._yielded + + resp = await self._client.workflow_service.list_activity_executions( + temporalio.api.workflowservice.v1.ListActivityExecutionsRequest( + namespace=self._client.namespace, + page_size=page_size, + next_page_token=self._next_page_token or b"", + query=self._input.query or "", + ), + retry=True, + metadata=self._input.rpc_metadata, + timeout=self._input.rpc_timeout, + ) + + self._current_page = [ + ActivityExecution._from_raw_info(v, self._client.namespace) + for v in resp.executions + ] + self._current_page_index = 0 + self._next_page_token = resp.next_page_token or None + + def __aiter__(self) -> ActivityExecutionAsyncIterator: + """Return self as the iterator.""" + return self + + async def __anext__(self) -> ActivityExecution: + """Get the next execution on this iterator, fetching next page if + necessary. + """ + if self._limit is not None and self._yielded >= self._limit: + raise StopAsyncIteration + while True: + # No page? fetch and continue + if self._current_page is None: + await self.fetch_next_page() + continue + # No more left in page? + if self._current_page_index >= len(self._current_page): + # If there is a next page token, try to get another page and try + # again + if self._next_page_token is not None: + await self.fetch_next_page() + continue + # No more pages means we're done + raise StopAsyncIteration + # Get current, increment page index, and return + ret = self._current_page[self._current_page_index] + self._current_page_index += 1 + self._yielded += 1 + return ret + + +@dataclass(frozen=True) +class ActivityExecution: + """Info for an activity execution not started by a workflow, from list response. + + .. warning:: + This API is experimental. + """ + + activity_id: str + """Activity ID.""" + + activity_run_id: str | None + """Run ID of the activity.""" + + activity_type: str + """Type name of the activity.""" + + close_time: datetime | None + """Time the activity reached a terminal status, if closed.""" + + execution_duration: timedelta | None + """Duration from scheduled to close time, only populated if closed.""" + + namespace: str + """Namespace of the activity (copied from calling client).""" + + raw_info: ( + temporalio.api.activity.v1.ActivityExecutionListInfo + | temporalio.api.activity.v1.ActivityExecutionInfo + ) + """Underlying protobuf info.""" + + scheduled_time: datetime + """Time the activity was originally scheduled.""" + + state_transition_count: int | None + """Number of state transitions, if available.""" + + status: ActivityExecutionStatus + """Current status of the activity.""" + + task_queue: str + """Task queue the activity was scheduled on.""" + + typed_search_attributes: temporalio.common.TypedSearchAttributes + """Current set of search attributes if any.""" + + @classmethod + def _from_raw_info( + cls, info: temporalio.api.activity.v1.ActivityExecutionListInfo, namespace: str + ) -> Self: + """Create from raw proto activity list info.""" + return cls( + activity_id=info.activity_id, + activity_run_id=info.run_id or None, + activity_type=( + info.activity_type.name if info.HasField("activity_type") else "" + ), + close_time=( + info.close_time.ToDatetime().replace(tzinfo=timezone.utc) + if info.HasField("close_time") + else None + ), + execution_duration=( + info.execution_duration.ToTimedelta() + if info.HasField("execution_duration") + else None + ), + namespace=namespace, + raw_info=info, + scheduled_time=( + info.schedule_time.ToDatetime().replace(tzinfo=timezone.utc) + if info.HasField("schedule_time") + else datetime.min + ), + state_transition_count=( + info.state_transition_count if info.state_transition_count else None + ), + status=( + ActivityExecutionStatus(info.status) + if info.status + else ActivityExecutionStatus.UNSPECIFIED + ), + task_queue=info.task_queue, + typed_search_attributes=temporalio.converter.decode_typed_search_attributes( + info.search_attributes + ), + ) + + +@dataclass(frozen=True) +class ActivityExecutionDescription(ActivityExecution): + """Detailed information about an activity execution not started by a workflow. + + .. warning:: + This API is experimental. + """ + + attempt: int + """Current attempt number.""" + + canceled_reason: str | None + """Reason for cancellation, if cancel was requested.""" + + current_retry_interval: timedelta | None + """Time until the next retry, if applicable.""" + + eager_execution_requested: bool + """Whether eager execution was requested for this activity.""" + + expiration_time: datetime + """Scheduled time plus schedule_to_close_timeout.""" + + last_attempt_complete_time: datetime | None + """Time when the last attempt completed.""" + + last_failure: Exception | None + """Failure from the last failed attempt, if any.""" + + last_heartbeat_time: datetime | None + """Time of the last heartbeat.""" + + last_started_time: datetime | None + """Time the last attempt was started.""" + + last_worker_identity: str + """Identity of the last worker that processed the activity.""" + + next_attempt_schedule_time: datetime | None + """Time when the next attempt will be scheduled.""" + + paused: bool + """Whether the activity is paused.""" + + raw_heartbeat_details: Sequence[temporalio.api.common.v1.Payload] + """Details from the last heartbeat.""" + + retry_policy: temporalio.common.RetryPolicy | None + """Retry policy for the activity.""" + + run_state: PendingActivityState | None + """More detailed breakdown if status is RUNNING.""" + + long_poll_token: bytes | None + """Token for follow-on long-poll requests. None if the activity is complete.""" + + @classmethod + async def _from_execution_info( + cls, + info: temporalio.api.activity.v1.ActivityExecutionInfo, + long_poll_token: bytes | None, + namespace: str, + data_converter: temporalio.converter.DataConverter, + ) -> Self: + """Create from raw proto activity execution info.""" + # Decode heartbeat details if present + decoded_heartbeat_details: Sequence[temporalio.api.common.v1.Payload] = ( + info.heartbeat_details.payloads + ) + if decoded_heartbeat_details and data_converter.payload_codec: + decoded_heartbeat_details = await data_converter.payload_codec.decode( + decoded_heartbeat_details + ) + + return cls( + activity_id=info.activity_id, + activity_run_id=info.run_id or None, + activity_type=( + info.activity_type.name if info.HasField("activity_type") else "" + ), + attempt=info.attempt, + canceled_reason=info.canceled_reason or None, + close_time=( + info.close_time.ToDatetime(tzinfo=timezone.utc) + if info.HasField("close_time") + else None + ), + current_retry_interval=( + info.current_retry_interval.ToTimedelta() + if info.HasField("current_retry_interval") + else None + ), + eager_execution_requested=getattr(info, "eager_execution_requested", False), + execution_duration=( + info.execution_duration.ToTimedelta() + if info.HasField("execution_duration") + else None + ), + expiration_time=( + info.expiration_time.ToDatetime(tzinfo=timezone.utc) + if info.HasField("expiration_time") + else datetime.min + ), + last_attempt_complete_time=( + info.last_attempt_complete_time.ToDatetime(tzinfo=timezone.utc) + if info.HasField("last_attempt_complete_time") + else None + ), + last_failure=( + cast( + Exception | None, + await data_converter.decode_failure(info.last_failure), + ) + if info.HasField("last_failure") + else None + ), + last_heartbeat_time=( + info.last_heartbeat_time.ToDatetime(tzinfo=timezone.utc) + if info.HasField("last_heartbeat_time") + else None + ), + last_started_time=( + info.last_started_time.ToDatetime(tzinfo=timezone.utc) + if info.HasField("last_started_time") + else None + ), + last_worker_identity=info.last_worker_identity, + long_poll_token=long_poll_token or None, + namespace=namespace, + next_attempt_schedule_time=( + info.next_attempt_schedule_time.ToDatetime(tzinfo=timezone.utc) + if info.HasField("next_attempt_schedule_time") + else None + ), + paused=getattr(info, "paused", False), + raw_heartbeat_details=decoded_heartbeat_details, + raw_info=info, + retry_policy=temporalio.common.RetryPolicy.from_proto(info.retry_policy) + if info.HasField("retry_policy") + else None, + run_state=( + PendingActivityState(info.run_state) if info.run_state else None + ), + scheduled_time=(info.schedule_time.ToDatetime(tzinfo=timezone.utc)), + state_transition_count=( + info.state_transition_count if info.state_transition_count else None + ), + status=( + ActivityExecutionStatus(info.status) + if info.status + else ActivityExecutionStatus.UNSPECIFIED + ), + task_queue=info.task_queue, + typed_search_attributes=temporalio.converter.decode_typed_search_attributes( + info.search_attributes + ), + ) + + +class ActivityExecutionStatus(IntEnum): + """Status of an activity execution. + + .. warning:: + This API is experimental. + + See :py:class:`temporalio.api.enums.v1.ActivityExecutionStatus`. + """ + + UNSPECIFIED = int( + temporalio.api.enums.v1.ActivityExecutionStatus.ACTIVITY_EXECUTION_STATUS_UNSPECIFIED + ) + RUNNING = int( + temporalio.api.enums.v1.ActivityExecutionStatus.ACTIVITY_EXECUTION_STATUS_RUNNING + ) + COMPLETED = int( + temporalio.api.enums.v1.ActivityExecutionStatus.ACTIVITY_EXECUTION_STATUS_COMPLETED + ) + FAILED = int( + temporalio.api.enums.v1.ActivityExecutionStatus.ACTIVITY_EXECUTION_STATUS_FAILED + ) + CANCELED = int( + temporalio.api.enums.v1.ActivityExecutionStatus.ACTIVITY_EXECUTION_STATUS_CANCELED + ) + TERMINATED = int( + temporalio.api.enums.v1.ActivityExecutionStatus.ACTIVITY_EXECUTION_STATUS_TERMINATED + ) + TIMED_OUT = int( + temporalio.api.enums.v1.ActivityExecutionStatus.ACTIVITY_EXECUTION_STATUS_TIMED_OUT + ) + + +class PendingActivityState(IntEnum): + """Detailed state of an activity execution that is in ACTIVITY_EXECUTION_STATUS_RUNNING. + + .. warning:: + This API is experimental. + + See :py:class:`temporalio.api.enums.v1.PendingActivityState`. + """ + + UNSPECIFIED = int( + temporalio.api.enums.v1.PendingActivityState.PENDING_ACTIVITY_STATE_UNSPECIFIED + ) + SCHEDULED = int( + temporalio.api.enums.v1.PendingActivityState.PENDING_ACTIVITY_STATE_SCHEDULED + ) + STARTED = int( + temporalio.api.enums.v1.PendingActivityState.PENDING_ACTIVITY_STATE_STARTED + ) + CANCEL_REQUESTED = int( + temporalio.api.enums.v1.PendingActivityState.PENDING_ACTIVITY_STATE_CANCEL_REQUESTED + ) + PAUSED = int( + temporalio.api.enums.v1.PendingActivityState.PENDING_ACTIVITY_STATE_PAUSED + ) + PAUSE_REQUESTED = int( + temporalio.api.enums.v1.PendingActivityState.PENDING_ACTIVITY_STATE_PAUSE_REQUESTED + ) + + +@dataclass(frozen=True) +class ActivityExecutionCount: + """Representation of a count from a count activities call. + + .. warning:: + This API is experimental. + """ + + count: int + """Total count matching the filter, if any.""" - See :py:meth:`temporalio.client.Client.start_workflow` for documentation of the - arguments. - """ - temporalio.common._warn_on_deprecated_search_attributes( - search_attributes, stack_level=stack_level - ) - name, result_type_from_run_fn = ( - temporalio.workflow._Definition.get_name_and_result_type(workflow) - ) - if id_conflict_policy == temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED: - raise ValueError("WorkflowIDConflictPolicy is required") + groups: Sequence[ActivityExecutionCountAggregationGroup] + """Aggregation groups if requested.""" - self._start_workflow_input = UpdateWithStartStartWorkflowInput( - workflow=name, - args=temporalio.common._arg_or_args(arg, args), - id=id, - task_queue=task_queue, - execution_timeout=execution_timeout, - run_timeout=run_timeout, - task_timeout=task_timeout, - id_reuse_policy=id_reuse_policy, - id_conflict_policy=id_conflict_policy, - retry_policy=retry_policy, - cron_schedule=cron_schedule, - memo=memo, - search_attributes=search_attributes, - static_summary=static_summary, - static_details=static_details, - start_delay=start_delay, - headers={}, - ret_type=result_type or result_type_from_run_fn, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - priority=priority, - versioning_override=versioning_override, + @staticmethod + def _from_raw( + resp: temporalio.api.workflowservice.v1.CountActivityExecutionsResponse, + ) -> ActivityExecutionCount: + """Create from raw proto response.""" + return ActivityExecutionCount( + count=resp.count, + groups=[ + ActivityExecutionCountAggregationGroup._from_raw(g) for g in resp.groups + ], ) - self._workflow_handle: Future[WorkflowHandle[SelfType, ReturnType]] = Future() - self._used = False - async def workflow_handle(self) -> WorkflowHandle[SelfType, ReturnType]: - """Wait until workflow is running and return a WorkflowHandle.""" - return await self._workflow_handle + +@dataclass(frozen=True) +class ActivityExecutionCountAggregationGroup: + """A single aggregation group from a count activities call. + + .. warning:: + This API is experimental. + """ + + count: int + """Count for this group.""" + + group_values: Sequence[temporalio.common.SearchAttributeValue] + """Values that define this group.""" + + @staticmethod + def _from_raw( + raw: temporalio.api.workflowservice.v1.CountActivityExecutionsResponse.AggregationGroup, + ) -> ActivityExecutionCountAggregationGroup: + return ActivityExecutionCountAggregationGroup( + count=raw.count, + group_values=[ + temporalio.converter._decode_search_attribute_value(v) + for v in raw.group_values + ], + ) @dataclass(frozen=True) class AsyncActivityIDReference: """Reference to an async activity by its qualified ID.""" - workflow_id: str + workflow_id: str | None run_id: str | None activity_id: str @@ -2787,55 +4538,291 @@ async def fail( last_heartbeat_details=last_heartbeat_details, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, - data_converter_override=self._data_converter_override, - ), + data_converter_override=self._data_converter_override, + ), + ) + + async def report_cancellation( + self, + *details: Any, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> None: + """Report the activity as cancelled. + + Args: + details: Cancellation details. + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for the RPC call. + """ + await self._client._impl.report_cancellation_async_activity( + ReportCancellationAsyncActivityInput( + id_or_token=self._id_or_token, + details=details, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + data_converter_override=self._data_converter_override, + ), + ) + + def with_context(self, context: SerializationContext) -> Self: + """Create a new AsyncActivityHandle with a different serialization context. + + Payloads received by the activity will be decoded and deserialized using a data converter + with :py:class:`ActivitySerializationContext` set as context. If you are using a custom data + converter that makes use of this context then you can use this method to supply matching + context data to the data converter used to serialize and encode the outbound payloads. + """ + data_converter = self._client.data_converter.with_context(context) + if data_converter is self._client.data_converter: + return self + cls = type(self) + if cls.__init__ is not AsyncActivityHandle.__init__: + raise TypeError( + "If you have subclassed AsyncActivityHandle and overridden the __init__ method " + "then you must override with_context to return an instance of your class." + ) + return cls( + self._client, + self._id_or_token, + data_converter, + ) + + +class ActivityHandle(Generic[ReturnType]): + """Handle representing an activity execution not started by a workflow. + + .. warning:: + This API is experimental. + """ + + def __init__( + self, + client: Client, + id: str, + *, + run_id: str | None = None, + result_type: type | None = None, + ) -> None: + """Create activity handle.""" + self._client = client + self._id = id + self._run_id = run_id + self._result_type = result_type + self._known_outcome: ( + temporalio.api.activity.v1.ActivityExecutionOutcome | None + ) = None + + @functools.cached_property + def _data_converter(self) -> temporalio.converter.DataConverter: + return self._client.data_converter.with_context( + ActivitySerializationContext( + namespace=self._client.namespace, + activity_id=self._id, + activity_type=None, + activity_task_queue=None, + is_local=False, + workflow_id=None, + workflow_type=None, + ) + ) + + @property + def id(self) -> str: + """ID of the activity.""" + return self._id + + @property + def run_id(self) -> str | None: + """Run ID of the activity.""" + return self._run_id + + async def result( + self, + *, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: + """Wait for result of the activity. + + .. warning:: + This API is experimental. + + The result may already be known if this method has been called before, + in which case no network call is made. Otherwise the result will be + polled for until it is available. + + Args: + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for each RPC call. Note: + this is the timeout for each RPC call while polling, not a + timeout for the function as a whole. If an individual RPC + times out, it will be retried until the result is available. + + Returns: + The result of the activity. + + Raises: + ActivityFailureError: If the activity completed with a failure. + RPCError: Activity result could not be fetched for some reason. + """ + await self._poll_until_outcome( + rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout + ) + + # Convert outcome to failure or value + assert self._known_outcome + if self._known_outcome.HasField("failure"): + raise ActivityFailureError( + cause=await self._data_converter.decode_failure( + self._known_outcome.failure + ), + ) + if not self._known_outcome.result.payloads: + return None # type: ignore + type_hints = [self._result_type] if self._result_type else None + results = await self._data_converter.decode( + self._known_outcome.result.payloads, type_hints + ) + if not results: + return None # type: ignore + elif len(results) > 1: + warnings.warn(f"Expected single activity result, got {len(results)}") + return results[0] + + async def _poll_until_outcome( + self, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> None: + """Poll for activity result until it's available.""" + if self._known_outcome: + return + + req = temporalio.api.workflowservice.v1.PollActivityExecutionRequest( + namespace=self._client.namespace, + activity_id=self._id, + run_id=self._run_id or "", + ) + + # Continue polling as long as we have no outcome + while True: + try: + res = await self._client.workflow_service.poll_activity_execution( + req, + retry=True, + metadata=rpc_metadata, + timeout=rpc_timeout, + ) + if res.HasField("outcome"): + self._known_outcome = res.outcome + return + except RPCError as err: + if err.status == RPCStatusCode.DEADLINE_EXCEEDED: + # Deadline exceeded is expected with long polling; retry + continue + elif err.status == RPCStatusCode.CANCELLED: + raise asyncio.CancelledError() from err + else: + raise + except asyncio.CancelledError: + raise + + async def cancel( + self, + *, + reason: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> None: + """Request cancellation of the activity. + + .. warning:: + This API is experimental. + + Requesting cancellation of an activity does not automatically transition the activity to + canceled status. If the activity is heartbeating, a :py:class:`exceptions.CancelledError` + exception will be raised when receiving the heartbeat response; if the activity allows this + exception to bubble out, the activity will transition to canceled status. If the activity it + is not heartbeating, this method will have no effect on activity status. + + Args: + reason: Reason for the cancellation. Recorded and available via describe. + rpc_metadata: Headers used on the RPC call. + rpc_timeout: Optional RPC deadline to set for the RPC call. + """ + await self._client._impl.cancel_activity( + CancelActivityInput( + activity_id=self._id, + activity_run_id=self._run_id, + reason=reason, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) ) - async def report_cancellation( + async def terminate( self, - *details: Any, + *, + reason: str | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, ) -> None: - """Report the activity as cancelled. + """Terminate the activity execution immediately. + + .. warning:: + This API is experimental. + + Termination does not reach the worker and the activity code cannot react to it. + A terminated activity may have a running attempt and will be requested to be + canceled by the server when it heartbeats. Args: - details: Cancellation details. - rpc_metadata: Headers used on the RPC call. Keys here override - client-level RPC metadata keys. + reason: Reason for the termination. + rpc_metadata: Headers used on the RPC call. rpc_timeout: Optional RPC deadline to set for the RPC call. """ - await self._client._impl.report_cancellation_async_activity( - ReportCancellationAsyncActivityInput( - id_or_token=self._id_or_token, - details=details, + await self._client._impl.terminate_activity( + TerminateActivityInput( + activity_id=self._id, + activity_run_id=self._run_id, + reason=reason, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, - data_converter_override=self._data_converter_override, - ), + ) ) - def with_context(self, context: SerializationContext) -> Self: - """Create a new AsyncActivityHandle with a different serialization context. + async def describe( + self, + *, + long_poll_token: bytes | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityExecutionDescription: + """Describe the activity execution. - Payloads received by the activity will be decoded and deserialized using a data converter - with :py:class:`ActivitySerializationContext` set as context. If you are using a custom data - converter that makes use of this context then you can use this method to supply matching - context data to the data converter used to serialize and encode the outbound payloads. + .. warning:: + This API is experimental. + + Args: + long_poll_token: Token from a previous describe response. If provided, + the request will long-poll until the activity state changes. + rpc_metadata: Headers used on the RPC call. + rpc_timeout: Optional RPC deadline to set for the RPC call. + + Returns: + Activity execution description. """ - data_converter = self._client.data_converter.with_context(context) - if data_converter is self._client.data_converter: - return self - cls = type(self) - if cls.__init__ is not AsyncActivityHandle.__init__: - raise TypeError( - "If you have subclassed AsyncActivityHandle and overridden the __init__ method " - "then you must override with_context to return an instance of your class." + return await self._client._impl.describe_activity( + DescribeActivityInput( + activity_id=self._id, + activity_run_id=self._run_id, + long_poll_token=long_poll_token, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, ) - return cls( - self._client, - self._id_or_token, - data_converter, ) @@ -3011,7 +4998,7 @@ async def memo_value( key: Key to get memo value for. default: Default to use if key is not present. If unset, a :py:class:`KeyError` is raised when the key does not exist. - type_hint: Type hint to use when converting. + type_hint: type hint to use when converting. Returns: Memo value, converted with the type hint if present. @@ -4555,7 +6542,7 @@ async def memo_value( key: Key to get memo value for. default: Default to use if key is not present. If unset, a :py:class:`KeyError` is raised when the key does not exist. - type_hint: Type hint to use when converting. + type_hint: type hint to use when converting. Returns: Memo value, converted with the type hint if present. @@ -4804,7 +6791,7 @@ async def memo_value( key: Key to get memo value for. default: Default to use if key is not present. If unset, a :py:class:`KeyError` is raised when the key does not exist. - type_hint: Type hint to use when converting. + type_hint: type hint to use when converting. Returns: Memo value, converted with the type hint if present. @@ -5263,6 +7250,25 @@ def __init__(self) -> None: super().__init__("Timeout or cancellation waiting for update") +class ActivityFailureError(temporalio.exceptions.TemporalError): + """Error that occurs when an activity is unsuccessful. + + .. warning:: + This API is experimental. + """ + + def __init__(self, *, cause: BaseException) -> None: + """Create activity failure error.""" + super().__init__("Activity execution failed") + self.__cause__ = cause + + @property + def cause(self) -> BaseException: + """Cause of the activity failure.""" + assert self.__cause__ + return self.__cause__ + + class AsyncActivityCancelledError(temporalio.exceptions.TemporalError): """Error that occurs when async activity attempted heartbeat but was cancelled.""" @@ -5417,6 +7423,108 @@ class TerminateWorkflowInput: rpc_timeout: timedelta | None +@dataclass +class StartActivityInput: + """Input for :py:meth:`OutboundInterceptor.start_activity`. + + .. warning:: + This API is experimental. + """ + + activity_type: str + args: Sequence[Any] + id: str + task_queue: str + result_type: type | None + schedule_to_close_timeout: timedelta | None + start_to_close_timeout: timedelta | None + schedule_to_start_timeout: timedelta | None + heartbeat_timeout: timedelta | None + id_reuse_policy: temporalio.common.ActivityIDReusePolicy + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy + retry_policy: temporalio.common.RetryPolicy | None + priority: temporalio.common.Priority + search_attributes: temporalio.common.TypedSearchAttributes | None + summary: str | None + headers: Mapping[str, temporalio.api.common.v1.Payload] + rpc_metadata: Mapping[str, str | bytes] + rpc_timeout: timedelta | None + + +@dataclass +class CancelActivityInput: + """Input for :py:meth:`OutboundInterceptor.cancel_activity`. + + .. warning:: + This API is experimental. + """ + + activity_id: str + activity_run_id: str | None + reason: str | None + rpc_metadata: Mapping[str, str | bytes] + rpc_timeout: timedelta | None + + +@dataclass +class TerminateActivityInput: + """Input for :py:meth:`OutboundInterceptor.terminate_activity`. + + .. warning:: + This API is experimental. + """ + + activity_id: str + activity_run_id: str | None + reason: str | None + rpc_metadata: Mapping[str, str | bytes] + rpc_timeout: timedelta | None + + +@dataclass +class DescribeActivityInput: + """Input for :py:meth:`OutboundInterceptor.describe_activity`. + + .. warning:: + This API is experimental. + """ + + activity_id: str + activity_run_id: str | None + long_poll_token: bytes | None + rpc_metadata: Mapping[str, str | bytes] + rpc_timeout: timedelta | None + + +@dataclass +class ListActivitiesInput: + """Input for :py:meth:`OutboundInterceptor.list_activities`. + + .. warning:: + This API is experimental. + """ + + query: str | None + page_size: int + next_page_token: bytes | None + rpc_metadata: Mapping[str, str | bytes] + rpc_timeout: timedelta | None + limit: int | None + + +@dataclass +class CountActivitiesInput: + """Input for :py:meth:`OutboundInterceptor.count_activities`. + + .. warning:: + This API is experimental. + """ + + query: str | None + rpc_metadata: Mapping[str, str | bytes] + rpc_timeout: timedelta | None + + @dataclass class StartWorkflowUpdateInput: """Input for :py:meth:`OutboundInterceptor.start_workflow_update`.""" @@ -5751,6 +7859,62 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: """Called for every :py:meth:`WorkflowHandle.terminate` call.""" await self.next.terminate_workflow(input) + ### Activity calls + + async def start_activity(self, input: StartActivityInput) -> ActivityHandle[Any]: + """Called for every :py:meth:`Client.start_activity` call. + + .. warning:: + This API is experimental. + """ + return await self.next.start_activity(input) + + async def cancel_activity(self, input: CancelActivityInput) -> None: + """Called for every :py:meth:`ActivityHandle.cancel` call. + + .. warning:: + This API is experimental. + """ + await self.next.cancel_activity(input) + + async def terminate_activity(self, input: TerminateActivityInput) -> None: + """Called for every :py:meth:`ActivityHandle.terminate` call. + + .. warning:: + This API is experimental. + """ + await self.next.terminate_activity(input) + + async def describe_activity( + self, input: DescribeActivityInput + ) -> ActivityExecutionDescription: + """Called for every :py:meth:`ActivityHandle.describe` call. + + .. warning:: + This API is experimental. + """ + return await self.next.describe_activity(input) + + def list_activities( + self, input: ListActivitiesInput + ) -> ActivityExecutionAsyncIterator: + """Called for every :py:meth:`Client.list_activities` call. + + .. warning:: + This API is experimental. + """ + return self.next.list_activities(input) + + async def count_activities( + self, input: CountActivitiesInput + ) -> ActivityExecutionCount: + """Called for every :py:meth:`Client.count_activities` call. + + .. warning:: + This API is experimental. + """ + return await self.next.count_activities(input) + async def start_workflow_update( self, input: StartWorkflowUpdateInput ) -> WorkflowUpdateHandle[Any]: @@ -6202,6 +8366,186 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout ) + async def start_activity(self, input: StartActivityInput) -> ActivityHandle[Any]: + """Start an activity and return a handle to it.""" + if not (input.start_to_close_timeout or input.schedule_to_close_timeout): + raise ValueError( + "Activity must have start_to_close_timeout or schedule_to_close_timeout" + ) + req = await self._build_start_activity_execution_request(input) + + resp: temporalio.api.workflowservice.v1.StartActivityExecutionResponse + try: + resp = await self._client.workflow_service.start_activity_execution( + req, + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + except RPCError as err: + # If the status is ALREADY_EXISTS and the details can be extracted + # as already started, use a different exception + if err.status == RPCStatusCode.ALREADY_EXISTS and err.grpc_status.details: + details = temporalio.api.errordetails.v1.ActivityExecutionAlreadyStartedFailure() + if err.grpc_status.details[0].Unpack(details): + raise temporalio.exceptions.ActivityAlreadyStartedError( + input.id, input.activity_type, run_id=details.run_id + ) + raise + return ActivityHandle( + self._client, + input.id, + run_id=resp.run_id, + result_type=input.result_type, + ) + + async def _build_start_activity_execution_request( + self, input: StartActivityInput + ) -> temporalio.api.workflowservice.v1.StartActivityExecutionRequest: + """Build StartActivityExecutionRequest from input.""" + data_converter = self._client.data_converter.with_context( + ActivitySerializationContext( + namespace=self._client.namespace, + activity_id=input.id, + activity_type=input.activity_type, + activity_task_queue=input.task_queue, + is_local=False, + workflow_id=None, + workflow_type=None, + ) + ) + + req = temporalio.api.workflowservice.v1.StartActivityExecutionRequest( + namespace=self._client.namespace, + identity=self._client.identity, + activity_id=input.id, + activity_type=temporalio.api.common.v1.ActivityType( + name=input.activity_type + ), + task_queue=temporalio.api.taskqueue.v1.TaskQueue(name=input.task_queue), + id_reuse_policy=cast( + "temporalio.api.enums.v1.ActivityIdReusePolicy.ValueType", + int(input.id_reuse_policy), + ), + id_conflict_policy=cast( + "temporalio.api.enums.v1.ActivityIdConflictPolicy.ValueType", + int(input.id_conflict_policy), + ), + ) + + if input.schedule_to_close_timeout is not None: + req.schedule_to_close_timeout.FromTimedelta(input.schedule_to_close_timeout) + if input.start_to_close_timeout is not None: + req.start_to_close_timeout.FromTimedelta(input.start_to_close_timeout) + if input.schedule_to_start_timeout is not None: + req.schedule_to_start_timeout.FromTimedelta(input.schedule_to_start_timeout) + if input.heartbeat_timeout is not None: + req.heartbeat_timeout.FromTimedelta(input.heartbeat_timeout) + if input.retry_policy is not None: + input.retry_policy.apply_to_proto(req.retry_policy) + + # Set input payloads + if input.args: + req.input.payloads.extend(await data_converter.encode(input.args)) + + # Set search attributes + if input.search_attributes is not None: + temporalio.converter.encode_search_attributes( + input.search_attributes, req.search_attributes + ) + + # Set user metadata + metadata = await _encode_user_metadata(data_converter, input.summary, None) + if metadata is not None: + req.user_metadata.CopyFrom(metadata) + + # Set headers + if input.headers: + await self._apply_headers(input.headers, req.header.fields) + + # Set priority + req.priority.CopyFrom(input.priority._to_proto()) + + return req + + async def cancel_activity(self, input: CancelActivityInput) -> None: + """Cancel an activity.""" + await self._client.workflow_service.request_cancel_activity_execution( + temporalio.api.workflowservice.v1.RequestCancelActivityExecutionRequest( + namespace=self._client.namespace, + activity_id=input.activity_id, + run_id=input.activity_run_id or "", + identity=self._client.identity, + request_id=str(uuid.uuid4()), + reason=input.reason or "", + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + + async def terminate_activity(self, input: TerminateActivityInput) -> None: + """Terminate an activity.""" + await self._client.workflow_service.terminate_activity_execution( + temporalio.api.workflowservice.v1.TerminateActivityExecutionRequest( + namespace=self._client.namespace, + activity_id=input.activity_id, + run_id=input.activity_run_id or "", + reason=input.reason or "", + identity=self._client.identity, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + + async def describe_activity( + self, input: DescribeActivityInput + ) -> ActivityExecutionDescription: + """Describe an activity.""" + resp = await self._client.workflow_service.describe_activity_execution( + temporalio.api.workflowservice.v1.DescribeActivityExecutionRequest( + namespace=self._client.namespace, + activity_id=input.activity_id, + run_id=input.activity_run_id or "", + long_poll_token=input.long_poll_token or b"", + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + return await ActivityExecutionDescription._from_execution_info( + info=resp.info, + long_poll_token=resp.long_poll_token or None, + namespace=self._client.namespace, + data_converter=self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.activity_id, # Using activity_id as workflow_id for activities not started by a workflow + ) + ), + ) + + def list_activities( + self, input: ListActivitiesInput + ) -> ActivityExecutionAsyncIterator: + return ActivityExecutionAsyncIterator(self._client, input) + + async def count_activities( + self, input: CountActivitiesInput + ) -> ActivityExecutionCount: + return ActivityExecutionCount._from_raw( + await self._client.workflow_service.count_activity_executions( + temporalio.api.workflowservice.v1.CountActivityExecutionsRequest( + namespace=self._client.namespace, + query=input.query or "", + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + ) + async def start_workflow_update( self, input: StartWorkflowUpdateInput ) -> WorkflowUpdateHandle[Any]: @@ -6445,7 +8789,7 @@ async def heartbeat_async_activity( if isinstance(input.id_or_token, AsyncActivityIDReference): resp_by_id = await self._client.workflow_service.record_activity_task_heartbeat_by_id( temporalio.api.workflowservice.v1.RecordActivityTaskHeartbeatByIdRequest( - workflow_id=input.id_or_token.workflow_id, + workflow_id=input.id_or_token.workflow_id or "", run_id=input.id_or_token.run_id or "", activity_id=input.id_or_token.activity_id, namespace=self._client.namespace, @@ -6500,7 +8844,7 @@ async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> No if isinstance(input.id_or_token, AsyncActivityIDReference): await self._client.workflow_service.respond_activity_task_completed_by_id( temporalio.api.workflowservice.v1.RespondActivityTaskCompletedByIdRequest( - workflow_id=input.id_or_token.workflow_id, + workflow_id=input.id_or_token.workflow_id or "", run_id=input.id_or_token.run_id or "", activity_id=input.id_or_token.activity_id, namespace=self._client.namespace, @@ -6537,7 +8881,7 @@ async def fail_async_activity(self, input: FailAsyncActivityInput) -> None: if isinstance(input.id_or_token, AsyncActivityIDReference): await self._client.workflow_service.respond_activity_task_failed_by_id( temporalio.api.workflowservice.v1.RespondActivityTaskFailedByIdRequest( - workflow_id=input.id_or_token.workflow_id, + workflow_id=input.id_or_token.workflow_id or "", run_id=input.id_or_token.run_id or "", activity_id=input.id_or_token.activity_id, namespace=self._client.namespace, @@ -6575,7 +8919,7 @@ async def report_cancellation_async_activity( if isinstance(input.id_or_token, AsyncActivityIDReference): await self._client.workflow_service.respond_activity_task_canceled_by_id( temporalio.api.workflowservice.v1.RespondActivityTaskCanceledByIdRequest( - workflow_id=input.id_or_token.workflow_id, + workflow_id=input.id_or_token.workflow_id or "", run_id=input.id_or_token.run_id or "", activity_id=input.id_or_token.activity_id, namespace=self._client.namespace, diff --git a/temporalio/common.py b/temporalio/common.py index b6dd67a4e..1b3cf0afe 100644 --- a/temporalio/common.py +++ b/temporalio/common.py @@ -146,6 +146,49 @@ class WorkflowIDConflictPolicy(IntEnum): ) +class ActivityIDReusePolicy(IntEnum): + """How already-closed activity IDs are handled on start. + + .. warning:: + This API is experimental. + + See :py:class:`temporalio.api.enums.v1.ActivityIdReusePolicy`. + """ + + UNSPECIFIED = int( + temporalio.api.enums.v1.ActivityIdReusePolicy.ACTIVITY_ID_REUSE_POLICY_UNSPECIFIED + ) + ALLOW_DUPLICATE = int( + temporalio.api.enums.v1.ActivityIdReusePolicy.ACTIVITY_ID_REUSE_POLICY_ALLOW_DUPLICATE + ) + ALLOW_DUPLICATE_FAILED_ONLY = int( + temporalio.api.enums.v1.ActivityIdReusePolicy.ACTIVITY_ID_REUSE_POLICY_ALLOW_DUPLICATE_FAILED_ONLY + ) + REJECT_DUPLICATE = int( + temporalio.api.enums.v1.ActivityIdReusePolicy.ACTIVITY_ID_REUSE_POLICY_REJECT_DUPLICATE + ) + + +class ActivityIDConflictPolicy(IntEnum): + """How already-running activity IDs are handled on start. + + .. warning:: + This API is experimental. + + See :py:class:`temporalio.api.enums.v1.ActivityIdConflictPolicy`. + """ + + UNSPECIFIED = int( + temporalio.api.enums.v1.ActivityIdConflictPolicy.ACTIVITY_ID_CONFLICT_POLICY_UNSPECIFIED + ) + FAIL = int( + temporalio.api.enums.v1.ActivityIdConflictPolicy.ACTIVITY_ID_CONFLICT_POLICY_FAIL + ) + USE_EXISTING = int( + temporalio.api.enums.v1.ActivityIdConflictPolicy.ACTIVITY_ID_CONFLICT_POLICY_USE_EXISTING + ) + + class QueryRejectCondition(IntEnum): """Whether a query should be rejected in certain conditions. diff --git a/temporalio/contrib/openai_agents/_mcp.py b/temporalio/contrib/openai_agents/_mcp.py index c9d1f87ea..8d6a9464a 100644 --- a/temporalio/contrib/openai_agents/_mcp.py +++ b/temporalio/contrib/openai_agents/_mcp.py @@ -445,7 +445,7 @@ def name(self) -> str: def _get_activities(self) -> Sequence[Callable]: def _server_id(): - return self.name + "@" + activity.info().workflow_run_id + return self.name + "@" + (activity.info().workflow_run_id or "") @activity.defn(name=self.name + "-list-tools") async def list_tools() -> list[MCPTool]: @@ -491,7 +491,7 @@ async def connect( ) -> None: heartbeat_task = asyncio.create_task(heartbeat_every(30)) - server_id = self.name + "@" + activity.info().workflow_run_id + server_id = self.name + "@" + (activity.info().workflow_run_id or "") if server_id in self._servers: raise ApplicationError( "Cannot connect to an already running server. Use a distinct name if running multiple servers in one workflow." diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index ef1e52bb2..3e08ea68d 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -355,14 +355,15 @@ async def execute_activity( self, input: temporalio.worker.ExecuteActivityInput ) -> Any: info = temporalio.activity.info() + attributes: dict[str, str] = {"temporalActivityID": info.activity_id} + if info.workflow_id: + attributes["temporalWorkflowID"] = info.workflow_id + if info.workflow_run_id: + attributes["temporalRunID"] = info.workflow_run_id with self.root._start_as_current_span( f"RunActivity:{info.activity_type}", context=self.root._context_from_headers(input.headers), - attributes={ - "temporalWorkflowID": info.workflow_id, - "temporalRunID": info.workflow_run_id, - "temporalActivityID": info.activity_id, - }, + attributes=attributes, kind=opentelemetry.trace.SpanKind.SERVER, ): return await super().execute_activity(input) diff --git a/temporalio/converter.py b/temporalio/converter.py index 3849a47f4..c066dcc7b 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -82,15 +82,7 @@ class SerializationContext(ABC): @dataclass(frozen=True) -class BaseWorkflowSerializationContext(SerializationContext): - """Base serialization context shared by workflow and activity serialization contexts.""" - - namespace: str - workflow_id: str - - -@dataclass(frozen=True) -class WorkflowSerializationContext(BaseWorkflowSerializationContext): +class WorkflowSerializationContext(SerializationContext): """Serialization context for workflows. See :py:class:`SerializationContext` for more details. @@ -103,30 +95,61 @@ class WorkflowSerializationContext(BaseWorkflowSerializationContext): when the workflow is created by the schedule. """ - pass + namespace: str + """Namespace.""" + + workflow_id: str | None + """Workflow ID.""" @dataclass(frozen=True) -class ActivitySerializationContext(BaseWorkflowSerializationContext): +class ActivitySerializationContext(SerializationContext): """Serialization context for activities. See :py:class:`SerializationContext` for more details. Attributes: namespace: Workflow/activity namespace. - workflow_id: Workflow ID. Note, when creating/describing schedules, + activity_id: Activity ID. Optional if this is an activity started from a workflow. + activity_type: Activity type. + activity_task_queue: Activity task queue. + workflow_id: Workflow ID. Only set if this is an activity started from a workflow. Note, when creating/describing schedules, this may be the workflow ID prefix as configured, not the final workflow ID when the workflow is created by the schedule. - workflow_type: Workflow Type. - activity_type: Activity Type. - activity_task_queue: Activity task queue. - is_local: Whether the activity is a local activity. + workflow_type: Workflow Type. Only set if this is an activity started from a workflow. + is_local: Whether the activity is a local activity. False if the activity was not started by a workflow. """ - workflow_type: str - activity_type: str - activity_task_queue: str + namespace: str + """Namespace.""" + + activity_id: str | None + """Activity ID. Optional if this is an activity started from a workflow.""" + + activity_type: str | None + """Activity type. + + .. deprecated:: + This value may not be set in some bidirectional situations, it should + not be relied on. + """ + + activity_task_queue: str | None + """Activity task queue. + + .. deprecated:: + This value may not be set in some bidirectional situations, it should + not be relied on. + """ + + workflow_id: str | None + """Workflow ID if this is an activity started from a workflow.""" + + workflow_type: str | None + """Workflow type if this is an activity started from a workflow.""" + is_local: bool + """Whether the activity is a local activity started from a workflow.""" class WithSerializationContext(ABC): diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index f8f8ca20c..8f0423153 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -70,6 +70,26 @@ def __init__( self.run_id = run_id +class ActivityAlreadyStartedError(FailureError): + """Thrown by a client when an activity execution has already started. + + Attributes: + activity_id: ID of the already-started activity. + activity_type: Activity type name of the already-started activity. + run_id: Run ID of the already-started activity if this was raised by the + client. + """ + + def __init__( + self, activity_id: str, activity_type: str, *, run_id: str | None = None + ) -> None: + """Initialize a workflow already started error.""" + super().__init__("Workflow execution already started") + self.activity_id = activity_id + self.activity_type = activity_type + self.run_id = run_id + + class ApplicationErrorCategory(IntEnum): """Severity category for your application error. Maps to corresponding client-side logging/metrics behaviors""" diff --git a/temporalio/testing/_activity.py b/temporalio/testing/_activity.py index 0098a91e1..ae7d2a38b 100644 --- a/temporalio/testing/_activity.py +++ b/temporalio/testing/_activity.py @@ -31,6 +31,7 @@ heartbeat_details=[], heartbeat_timeout=None, is_local=False, + namespace="default", schedule_to_close_timeout=timedelta(seconds=1), scheduled_time=_utc_zero, start_to_close_timeout=timedelta(seconds=1), @@ -43,6 +44,7 @@ workflow_type="test", priority=temporalio.common.Priority.default, retry_policy=None, + activity_run_id=None, ) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 93249fad5..6368999ec 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -250,10 +250,11 @@ async def _heartbeat_async( data_converter = self._data_converter if activity.info: context = temporalio.converter.ActivitySerializationContext( - namespace=activity.info.workflow_namespace, + namespace=activity.info.namespace, workflow_id=activity.info.workflow_id, workflow_type=activity.info.workflow_type, activity_type=activity.info.activity_type, + activity_id=activity.info.activity_id, activity_task_queue=self._task_queue, is_local=activity.info.is_local, ) @@ -302,10 +303,11 @@ async def _handle_start_activity_task( ) # Create serialization context for the activity context = temporalio.converter.ActivitySerializationContext( - namespace=start.workflow_namespace, + namespace=start.workflow_namespace or self._client.namespace, workflow_id=start.workflow_execution.workflow_id, workflow_type=start.workflow_type, activity_type=start.activity_type, + activity_id=start.activity_id, activity_task_queue=self._task_queue, is_local=start.is_local, ) @@ -545,6 +547,7 @@ async def _execute_activity( ) from err # Build info + started_by_workflow = bool(start.workflow_execution.workflow_id) info = temporalio.activity.Info( activity_id=start.activity_id, activity_type=start.activity_type, @@ -557,6 +560,7 @@ async def _execute_activity( if start.HasField("heartbeat_timeout") else None, is_local=start.is_local, + namespace=start.workflow_namespace or self._client.namespace, schedule_to_close_timeout=_proto_to_non_zero_timedelta( start.schedule_to_close_timeout ) @@ -571,14 +575,17 @@ async def _execute_activity( started_time=_proto_to_datetime(start.started_time), task_queue=self._task_queue, task_token=task_token, - workflow_id=start.workflow_execution.workflow_id, - workflow_namespace=start.workflow_namespace, - workflow_run_id=start.workflow_execution.run_id, - workflow_type=start.workflow_type, + workflow_id=start.workflow_execution.workflow_id or None, + workflow_namespace=start.workflow_namespace or None, + workflow_run_id=start.workflow_execution.run_id or None, + workflow_type=start.workflow_type or None, priority=temporalio.common.Priority._from_proto(start.priority), retry_policy=temporalio.common.RetryPolicy.from_proto(start.retry_policy) if start.HasField("retry_policy") else None, + activity_run_id=getattr(start, "run_id", None) + if not started_by_workflow + else None, ) if self._encode_headers and data_converter.payload_codec is not None: diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 10fd594fd..a8785d6fb 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -790,6 +790,7 @@ def _apply_resolve_activity( workflow_id=self._info.workflow_id, workflow_type=self._info.workflow_type, activity_type=handle._input.activity, + activity_id=handle._input.activity_id, activity_task_queue=( handle._input.task_queue or self._info.task_queue if isinstance(handle._input, StartActivityInput) @@ -2123,6 +2124,7 @@ def get_serialization_context( workflow_id=self._info.workflow_id, workflow_type=self._info.workflow_type, activity_type=activity_handle._input.activity, + activity_id=activity_handle._input.activity_id, activity_task_queue=( activity_handle._input.task_queue if isinstance(activity_handle._input, StartActivityInput) @@ -2918,6 +2920,7 @@ def __init__( workflow_id=self._instance._info.workflow_id, workflow_type=self._instance._info.workflow_type, activity_type=self._input.activity, + activity_id=self._input.activity_id, activity_task_queue=( self._input.task_queue or self._instance._info.task_queue if isinstance(self._input, StartActivityInput) diff --git a/tests/conftest.py b/tests/conftest.py index f5935613e..f30a66cd6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,10 +119,17 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: "frontend.activityAPIsEnabled=true", "--dynamic-config-value", "component.nexusoperations.recordCancelRequestCompletionEvents=true", + "--dynamic-config-value", + "activity.enableStandalone=true", + "--dynamic-config-value", + "history.enableChasm=true", + "--dynamic-config-value", + "history.enableTransitionHistory=true", "--http-port", str(http_port), ], dev_server_download_version=DEV_SERVER_DOWNLOAD_VERSION, + dev_server_existing_path="../next-server-cli/cli/temporal.exe", ) # TODO(nexus-preview): expose this in a more principled way env._http_port = http_port # type: ignore diff --git a/tests/test_activity.py b/tests/test_activity.py new file mode 100644 index 000000000..c353d0c40 --- /dev/null +++ b/tests/test_activity.py @@ -0,0 +1,1121 @@ +import asyncio +import uuid +from dataclasses import dataclass +from datetime import timedelta + +import pytest + +from temporalio import activity, workflow +from temporalio.client import ( + ActivityExecutionCount, + ActivityExecutionCountAggregationGroup, + ActivityExecutionDescription, + ActivityExecutionStatus, + ActivityFailureError, + ActivityHandle, + CancelActivityInput, + Client, + CountActivitiesInput, + DescribeActivityInput, + Interceptor, + ListActivitiesInput, + OutboundInterceptor, + PendingActivityState, + StartActivityInput, + TerminateActivityInput, +) +from temporalio.exceptions import ApplicationError, CancelledError +from temporalio.service import RPCError, RPCStatusCode +from temporalio.worker import Worker +from tests.helpers import assert_eq_eventually + + +@activity.defn +async def increment(input: int) -> int: + return input + 1 + + +# Activity classes for testing start_activity_class / execute_activity_class +@activity.defn +class IncrementClass: + """Async callable class activity with a parameter.""" + + async def __call__(self, x: int) -> int: + return x + 1 + + +@activity.defn +class NoParamClass: + """Async callable class activity with no parameters.""" + + async def __call__(self) -> str: + return "no-param-result" + + +@activity.defn +class SyncIncrementClass: + """Sync callable class activity with a parameter.""" + + def __call__(self, x: int) -> int: + return x + 1 + + +# Activity holder for testing start_activity_method / execute_activity_method +class ActivityHolder: + """Class holding activity methods.""" + + @activity.defn + async def async_increment(self, x: int) -> int: + return x + 1 + + @activity.defn + async def async_no_param(self) -> str: + return "async-method-result" + + @activity.defn + def sync_increment(self, x: int) -> int: + return x + 1 + + +class TestDescribe: + @pytest.fixture + async def activity_handle(self, client: Client): + id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + yield await client.start_activity( + increment, + args=(42,), + id=id, + task_queue=task_queue, + schedule_to_close_timeout=timedelta(hours=1), + ) + + async def test_describe(self, client: Client, activity_handle: ActivityHandle): + desc = await activity_handle.describe() + # From ActivityExecution (base class) + assert desc.activity_id == activity_handle.id + assert desc.activity_run_id == activity_handle.run_id + assert desc.activity_type == "increment" + assert desc.close_time is None # not closed yet + assert desc.execution_duration is None # not closed yet + assert desc.namespace == client.namespace + assert desc.raw_info is not None + assert desc.scheduled_time is not None + assert len(desc.typed_search_attributes) == 0 + assert desc.state_transition_count is not None + assert desc.status == ActivityExecutionStatus.RUNNING + assert desc.task_queue + # From ActivityExecutionDescription + assert desc.attempt == 1 + assert desc.canceled_reason is None + assert desc.current_retry_interval is None + assert desc.eager_execution_requested is False + assert desc.expiration_time is not None + assert desc.raw_heartbeat_details == [] + assert desc.run_state == PendingActivityState.SCHEDULED + assert desc.last_attempt_complete_time is None + assert desc.last_failure is None + assert desc.last_heartbeat_time is None + assert desc.last_started_time is None + assert desc.last_worker_identity == "" + assert desc.long_poll_token is not None + assert desc.next_attempt_schedule_time is None + assert desc.paused is False + assert desc.retry_policy is not None + + async def test_describe_long_poll(self, activity_handle: ActivityHandle): + desc1 = await activity_handle.describe() + assert desc1.long_poll_token + desc2_task = asyncio.create_task( + activity_handle.describe(long_poll_token=desc1.long_poll_token) + ) + # Worker poll causes a transition to Started which notifies the waiting long-poll. + async with Worker( + activity_handle._client, + task_queue=desc1.task_queue, + activities=[increment], + ): + desc2 = await desc2_task + assert desc2.state_transition_count and desc1.state_transition_count + assert desc2.state_transition_count > desc1.state_transition_count + + +class ActivityTracingInterceptor(Interceptor): + """Test interceptor that tracks all activity interceptor calls.""" + + def __init__(self) -> None: + super().__init__() + self.start_activity_calls: list[StartActivityInput] = [] + self.describe_activity_calls: list[DescribeActivityInput] = [] + self.cancel_activity_calls: list[CancelActivityInput] = [] + self.terminate_activity_calls: list[TerminateActivityInput] = [] + self.list_activities_calls: list[ListActivitiesInput] = [] + self.count_activities_calls: list[CountActivitiesInput] = [] + + def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor: + return ActivityTracingOutboundInterceptor(self, next) + + +class ActivityTracingOutboundInterceptor(OutboundInterceptor): + def __init__( + self, + parent: ActivityTracingInterceptor, + next: OutboundInterceptor, + ) -> None: + super().__init__(next) + self._parent = parent + + async def start_activity(self, input: StartActivityInput): + assert isinstance(input, StartActivityInput) + self._parent.start_activity_calls.append(input) + return await super().start_activity(input) + + async def describe_activity(self, input: DescribeActivityInput): + assert isinstance(input, DescribeActivityInput) + self._parent.describe_activity_calls.append(input) + return await super().describe_activity(input) + + async def cancel_activity(self, input: CancelActivityInput): + assert isinstance(input, CancelActivityInput) + self._parent.cancel_activity_calls.append(input) + return await super().cancel_activity(input) + + async def terminate_activity(self, input: TerminateActivityInput): + assert isinstance(input, TerminateActivityInput) + self._parent.terminate_activity_calls.append(input) + return await super().terminate_activity(input) + + def list_activities(self, input: ListActivitiesInput): + assert isinstance(input, ListActivitiesInput) + self._parent.list_activities_calls.append(input) + return super().list_activities(input) + + async def count_activities(self, input: CountActivitiesInput): + assert isinstance(input, CountActivitiesInput) + self._parent.count_activities_calls.append(input) + return await super().count_activities(input) + + +async def test_start_activity_calls_interceptor(client: Client): + """Client.start_activity() should call the start_activity interceptor.""" + interceptor = ActivityTracingInterceptor() + intercepted_client = Client( + service_client=client.service_client, + namespace=client.namespace, + interceptors=[interceptor], + ) + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + await intercepted_client.start_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + assert len(interceptor.start_activity_calls) == 1 + call = interceptor.start_activity_calls[0] + assert call.id == activity_id + assert call.task_queue == task_queue + assert call.activity_type == "increment" + + +async def test_describe_activity_calls_interceptor(client: Client): + """ActivityHandle.describe() should call the describe_activity interceptor.""" + interceptor = ActivityTracingInterceptor() + intercepted_client = Client( + service_client=client.service_client, + namespace=client.namespace, + interceptors=[interceptor], + ) + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + activity_handle = await intercepted_client.start_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + desc = await activity_handle.describe() + assert isinstance(desc, ActivityExecutionDescription) + + assert len(interceptor.describe_activity_calls) == 1 + call = interceptor.describe_activity_calls[0] + assert call.activity_id == activity_id + + +async def test_cancel_activity_calls_interceptor(client: Client): + """ActivityHandle.cancel() should call the cancel_activity interceptor.""" + interceptor = ActivityTracingInterceptor() + intercepted_client = Client( + service_client=client.service_client, + namespace=client.namespace, + interceptors=[interceptor], + ) + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + activity_handle = await intercepted_client.start_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + await activity_handle.cancel(reason="test cancellation") + + assert len(interceptor.cancel_activity_calls) == 1 + call = interceptor.cancel_activity_calls[0] + assert call.activity_id == activity_id + assert call.reason == "test cancellation" + + +async def test_terminate_activity_calls_interceptor(client: Client): + """ActivityHandle.terminate() should call the terminate_activity interceptor.""" + interceptor = ActivityTracingInterceptor() + intercepted_client = Client( + service_client=client.service_client, + namespace=client.namespace, + interceptors=[interceptor], + ) + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + activity_handle = await intercepted_client.start_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + await activity_handle.terminate(reason="test termination") + + assert len(interceptor.terminate_activity_calls) == 1 + call = interceptor.terminate_activity_calls[0] + assert call.activity_id == activity_id + assert call.reason == "test termination" + + +async def test_list_activities_calls_interceptor(client: Client): + """Client.list_activities() should call the list_activities interceptor.""" + interceptor = ActivityTracingInterceptor() + intercepted_client = Client( + service_client=client.service_client, + namespace=client.namespace, + interceptors=[interceptor], + ) + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + await intercepted_client.start_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + query = f'ActivityId = "{activity_id}"' + async for _ in intercepted_client.list_activities(query): + pass + + assert len(interceptor.list_activities_calls) >= 1 + call = interceptor.list_activities_calls[0] + assert call.query == query + + +async def test_count_activities_calls_interceptor(client: Client): + """Client.count_activities() should call the count_activities interceptor.""" + interceptor = ActivityTracingInterceptor() + intercepted_client = Client( + service_client=client.service_client, + namespace=client.namespace, + interceptors=[interceptor], + ) + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + await intercepted_client.start_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + query = f'ActivityId = "{activity_id}"' + count = await intercepted_client.count_activities(query) + assert isinstance(count, ActivityExecutionCount) + + assert len(interceptor.count_activities_calls) == 1 + call = interceptor.count_activities_calls[0] + assert call.query == query + + +async def test_get_result(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + activity_handle = await client.start_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + result_via_execute_activity = client.execute_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[increment], + ): + assert await activity_handle.result() == 2 + assert await result_via_execute_activity == 2 + + +async def test_get_activity_handle(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + activity_handle = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + handle_by_id = client.get_activity_handle(activity_id) + assert handle_by_id.id == activity_id + assert handle_by_id.run_id is None + + handle_by_id_and_run_id = client.get_activity_handle( + activity_id, + run_id=activity_handle.run_id, + ) + assert handle_by_id_and_run_id.id == activity_id + assert handle_by_id_and_run_id.run_id == activity_handle.run_id + + handle_with_result_type = client.get_activity_handle( + activity_id, + run_id=activity_handle.run_id, + result_type=int, + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[increment], + ): + assert await handle_by_id.result() == 2 + assert await handle_by_id_and_run_id.result() == 2 + assert await handle_with_result_type.result() == 2 + + +async def test_list_activities(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + executions = [ + e async for e in client.list_activities(f'ActivityId = "{activity_id}"') + ] + assert len(executions) == 1 + execution = executions[0] + assert execution.activity_id == activity_id + assert execution.activity_type == "increment" + assert execution.task_queue == task_queue + assert execution.status == ActivityExecutionStatus.RUNNING + assert execution.state_transition_count is None # Not set until activity completes + + +async def test_count_activities(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async def fetch_count(): + return await client.count_activities(f'ActivityId = "{activity_id}"') + + await assert_eq_eventually( + ActivityExecutionCount(count=1, groups=[]), + fetch_count, + ) + + +async def test_count_activities_group_by(client: Client): + from temporalio.client import ActivityExecutionCount + + task_queue = str(uuid.uuid4()) + activity_ids = [] + + for _ in range(3): + activity_id = str(uuid.uuid4()) + activity_ids.append(activity_id) + await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + schedule_to_close_timeout=timedelta(seconds=60), + ) + + ids_filter = " OR ".join([f'ActivityId = "{aid}"' for aid in activity_ids]) + + async def fetch_count() -> ActivityExecutionCount: + return await client.count_activities(f"({ids_filter}) GROUP BY ExecutionStatus") + + await assert_eq_eventually( + ActivityExecutionCount( + count=3, + groups=[ + ActivityExecutionCountAggregationGroup( + count=3, group_values=["Running"] + ), + ], + ), + fetch_count, + ) + + +@dataclass +class ActivityInput: + event_workflow_id: str + wait_for_activity_start_workflow_id: str | None = None + + +@activity.defn +async def async_activity(input: ActivityInput) -> int: + # Notify test that the activity has started and is ready to be completed manually + await ( + activity.client() + .get_workflow_handle(input.event_workflow_id) + .signal(EventWorkflow.set) + ) + activity.raise_complete_async() + + +async def test_manual_completion(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + event_workflow_id = str(uuid.uuid4()) + + activity_handle = await client.start_activity( + async_activity, + ActivityInput(event_workflow_id=event_workflow_id), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[async_activity], + workflows=[EventWorkflow], + ): + # Wait for activity to start + await client.execute_workflow( + EventWorkflow.wait, + id=event_workflow_id, + task_queue=task_queue, + ) + # Complete activity manually + async_activity_handle = client.get_async_activity_handle( + activity_id=activity_id, + run_id=activity_handle.run_id, + ) + await async_activity_handle.complete(7) + assert await activity_handle.result() == 7 + + desc = await activity_handle.describe() + assert desc.status == ActivityExecutionStatus.COMPLETED + + +async def test_manual_cancellation(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + event_workflow_id = str(uuid.uuid4()) + + activity_handle = await client.start_activity( + async_activity, + ActivityInput(event_workflow_id=event_workflow_id), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[async_activity], + workflows=[EventWorkflow], + ): + # Wait for activity to start + await client.execute_workflow( + EventWorkflow.wait, + id=event_workflow_id, + task_queue=task_queue, + ) + async_activity_handle = client.get_async_activity_handle( + activity_id=activity_id, + run_id=activity_handle.run_id, + ) + + # report_cancellation fails if activity is not in CANCELLATION_REQUESTED state + with pytest.raises(RPCError) as err: + await async_activity_handle.report_cancellation("Test cancellation") + assert err.value.status == RPCStatusCode.FAILED_PRECONDITION + assert "invalid transition from Started" in str(err.value) + + # Request cancellation to transition activity to CANCELLATION_REQUESTED state + await activity_handle.cancel() + + # Now report_cancellation succeeds + await async_activity_handle.report_cancellation("Test cancellation") + + with pytest.raises(ActivityFailureError) as exc_info: + await activity_handle.result() + assert isinstance(exc_info.value.cause, CancelledError) + assert list(exc_info.value.cause.details) == ["Test cancellation"] + + desc = await activity_handle.describe() + assert desc.status == ActivityExecutionStatus.CANCELED + + +async def test_manual_failure(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + event_workflow_id = str(uuid.uuid4()) + + activity_handle = await client.start_activity( + async_activity, + ActivityInput(event_workflow_id=event_workflow_id), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + async with Worker( + client, + task_queue=task_queue, + activities=[async_activity], + workflows=[EventWorkflow], + ): + await client.execute_workflow( + EventWorkflow.wait, + id=event_workflow_id, + task_queue=task_queue, + ) + async_activity_handle = client.get_async_activity_handle( + activity_id=activity_id, + run_id=activity_handle.run_id, + ) + await async_activity_handle.fail( + ApplicationError("Test failure", non_retryable=True) + ) + with pytest.raises(ActivityFailureError) as err: + await activity_handle.result() + assert isinstance(err.value.cause, ApplicationError) + assert str(err.value.cause) == "Test failure" + + desc = await activity_handle.describe() + assert desc.status == ActivityExecutionStatus.FAILED + + +@activity.defn +async def activity_for_testing_heartbeat(input: ActivityInput) -> str: + info = activity.info() + if info.attempt == 1: + # Signal that activity has started (only on first attempt) + if input.wait_for_activity_start_workflow_id: + await ( + activity.client() + .get_workflow_handle( + workflow_id=input.wait_for_activity_start_workflow_id, + ) + .signal(EventWorkflow.set) + ) + wait_for_heartbeat_wf_handle = await activity.client().start_workflow( + EventWorkflow.wait, + id=input.event_workflow_id, + task_queue=activity.info().task_queue, + ) + # Wait for test to notify that it has sent heartbeat + await wait_for_heartbeat_wf_handle.result() + raise Exception("Intentional error to force retry") + elif info.attempt == 2: + [heartbeat_data] = info.heartbeat_details + assert isinstance(heartbeat_data, str) + return heartbeat_data + else: + raise AssertionError(f"Unexpected attempt number: {info.attempt}") + + +async def test_manual_heartbeat(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + event_workflow_id = str(uuid.uuid4()) + wait_for_activity_start_workflow_id = str(uuid.uuid4()) + + activity_handle = await client.start_activity( + activity_for_testing_heartbeat, + ActivityInput( + event_workflow_id=event_workflow_id, + wait_for_activity_start_workflow_id=wait_for_activity_start_workflow_id, + ), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + wait_for_activity_start_wf_handle = await client.start_workflow( + EventWorkflow.wait, + id=wait_for_activity_start_workflow_id, + task_queue=task_queue, + ) + async with Worker( + client, + task_queue=task_queue, + activities=[activity_for_testing_heartbeat], + workflows=[EventWorkflow], + ): + async_activity_handle = client.get_async_activity_handle( + activity_id=activity_id, + run_id=activity_handle.run_id, + ) + await wait_for_activity_start_wf_handle.result() + await async_activity_handle.heartbeat("Test heartbeat details") + await client.get_workflow_handle( + workflow_id=event_workflow_id, + ).signal(EventWorkflow.set) + assert await activity_handle.result() == "Test heartbeat details" + + +async def test_id_conflict_policy_fail(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + from temporalio.common import ActivityIDConflictPolicy + from temporalio.exceptions import ActivityAlreadyStartedError + + await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + schedule_to_close_timeout=timedelta(seconds=60), + id_conflict_policy=ActivityIDConflictPolicy.FAIL, + ) + + with pytest.raises(ActivityAlreadyStartedError) as err: + await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + schedule_to_close_timeout=timedelta(seconds=60), + id_conflict_policy=ActivityIDConflictPolicy.FAIL, + ) + assert err.value.activity_id == activity_id + + +async def test_id_conflict_policy_use_existing(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + from temporalio.common import ActivityIDConflictPolicy + + handle1 = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + schedule_to_close_timeout=timedelta(seconds=60), + id_conflict_policy=ActivityIDConflictPolicy.USE_EXISTING, + ) + + handle2 = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + schedule_to_close_timeout=timedelta(seconds=60), + id_conflict_policy=ActivityIDConflictPolicy.USE_EXISTING, + ) + + assert handle1.id == handle2.id + assert handle1.run_id == handle2.run_id + + +async def test_id_reuse_policy_reject_duplicate(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + from temporalio.common import ActivityIDReusePolicy + from temporalio.exceptions import ActivityAlreadyStartedError + + handle = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + id_reuse_policy=ActivityIDReusePolicy.REJECT_DUPLICATE, + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[increment], + ): + await handle.result() + + with pytest.raises(ActivityAlreadyStartedError) as err: + await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + id_reuse_policy=ActivityIDReusePolicy.REJECT_DUPLICATE, + ) + assert err.value.activity_id == activity_id + + +async def test_id_reuse_policy_allow_duplicate(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + from temporalio.common import ActivityIDReusePolicy + + handle1 = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + id_reuse_policy=ActivityIDReusePolicy.ALLOW_DUPLICATE, + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[increment], + ): + await handle1.result() + + handle2 = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + id_reuse_policy=ActivityIDReusePolicy.ALLOW_DUPLICATE, + ) + + assert handle1.id == handle2.id + assert handle1.run_id != handle2.run_id + + +async def test_search_attributes(client: Client): + from temporalio.common import ( + SearchAttributeKey, + SearchAttributePair, + TypedSearchAttributes, + ) + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + temporal_change_version_key = SearchAttributeKey.for_keyword_list( + "TemporalChangeVersion" + ) + + handle = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + schedule_to_close_timeout=timedelta(seconds=60), + search_attributes=TypedSearchAttributes( + [SearchAttributePair(temporal_change_version_key, ["test-1", "test-2"])] + ), + ) + + desc = await handle.describe() + assert desc.typed_search_attributes[temporal_change_version_key] == [ + "test-1", + "test-2", + ] + + +async def test_retry_policy(client: Client): + from temporalio.common import RetryPolicy + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + handle = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + retry_policy=RetryPolicy( + initial_interval=timedelta(seconds=1), + maximum_interval=timedelta(seconds=10), + backoff_coefficient=2.0, + maximum_attempts=3, + ), + ) + + desc = await handle.describe() + assert desc.retry_policy is not None + assert desc.retry_policy.initial_interval == timedelta(seconds=1) + assert desc.retry_policy.maximum_interval == timedelta(seconds=10) + assert desc.retry_policy.backoff_coefficient == 2.0 + assert desc.retry_policy.maximum_attempts == 3 + + +async def test_terminate(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + event_workflow_id = str(uuid.uuid4()) + + activity_handle = await client.start_activity( + async_activity, + args=(ActivityInput(event_workflow_id=event_workflow_id),), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[async_activity], + workflows=[EventWorkflow], + ): + await client.execute_workflow( + EventWorkflow.wait, + id=event_workflow_id, + task_queue=task_queue, + ) + + await activity_handle.terminate(reason="Test termination") + + with pytest.raises(ActivityFailureError): + await activity_handle.result() + + desc = await activity_handle.describe() + assert desc.status == ActivityExecutionStatus.TERMINATED + + +# Tests for start_activity_class / execute_activity_class + + +async def test_start_activity_class_async(client: Client): + """Test start_activity_class with an async callable class.""" + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + handle = await client.start_activity_class( + IncrementClass, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[IncrementClass()], + ): + result = await handle.result() + assert result == 2 + + +async def test_execute_activity_class_async(client: Client): + """Test execute_activity_class with an async callable class.""" + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + async with Worker( + client, + task_queue=task_queue, + activities=[IncrementClass()], + ): + result = await client.execute_activity_class( + IncrementClass, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + assert result == 2 + + +async def test_start_activity_class_no_param(client: Client): + """Test start_activity_class with a no-param callable class.""" + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + handle = await client.start_activity_class( + NoParamClass, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[NoParamClass()], + ): + result = await handle.result() + assert result == "no-param-result" + + +async def test_start_activity_class_sync(client: Client): + """Test start_activity_class with a sync callable class.""" + import concurrent.futures + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + handle = await client.start_activity_class( + SyncIncrementClass, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + with concurrent.futures.ThreadPoolExecutor() as executor: + async with Worker( + client, + task_queue=task_queue, + activities=[SyncIncrementClass()], + activity_executor=executor, + ): + result = await handle.result() + assert result == 2 + + +# Tests for start_activity_method / execute_activity_method + + +async def test_start_activity_method_async(client: Client): + """Test start_activity_method with an async method.""" + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + holder = ActivityHolder() + handle = await client.start_activity_method( + ActivityHolder.async_increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[holder.async_increment], + ): + result = await handle.result() + assert result == 2 + + +async def test_execute_activity_method_async(client: Client): + """Test execute_activity_method with an async method.""" + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + holder = ActivityHolder() + async with Worker( + client, + task_queue=task_queue, + activities=[holder.async_increment], + ): + result = await client.execute_activity_method( + ActivityHolder.async_increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + assert result == 2 + + +async def test_start_activity_method_no_param(client: Client): + """Test start_activity_method with a no-param method.""" + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + holder = ActivityHolder() + handle = await client.start_activity_method( + ActivityHolder.async_no_param, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[holder.async_no_param], + ): + result = await handle.result() + assert result == "async-method-result" + + +# Utilities + + +@workflow.defn +class EventWorkflow: + """ + A workflow version of asyncio.Event() + """ + + def __init__(self) -> None: + self.signal_received = asyncio.Event() + + @workflow.run + async def wait(self) -> None: + await self.signal_received.wait() + + @workflow.signal + def set(self) -> None: + self.signal_received.set() diff --git a/tests/test_activity_type_errors.py b/tests/test_activity_type_errors.py new file mode 100644 index 000000000..fadf7d14b --- /dev/null +++ b/tests/test_activity_type_errors.py @@ -0,0 +1,491 @@ +""" +This file exists to test for type-checker false positives and false negatives +for the activity client API. + +It doesn't contain any test functions - it uses the machinery in test_type_errors.py +to verify that pyright produces the expected errors. +""" + +from datetime import timedelta +from unittest.mock import Mock + +from temporalio import activity +from temporalio.client import ActivityHandle, Client +from temporalio.service import ServiceClient + + +@activity.defn +async def increment(x: int) -> int: + return x + 1 + + +@activity.defn +async def greet(name: str) -> str: + return f"Hello, {name}" + + +@activity.defn +async def no_return(_: int) -> None: + pass + + +@activity.defn +async def no_param_async() -> str: + return "done" + + +@activity.defn +def increment_sync(x: int) -> int: + return x + 1 + + +@activity.defn +def no_param_sync() -> str: + return "done" + + +@activity.defn +class IncrementClass: + """Async activity defined as a callable class.""" + + async def __call__(self, x: int) -> int: + return x + 1 + + +@activity.defn +class NoParamClass: + """Async activity class with no parameters.""" + + async def __call__(self) -> str: + return "done" + + +@activity.defn +class SyncIncrementClass: + """Sync activity defined as a callable class.""" + + def __call__(self, x: int) -> int: + return x + 1 + + +@activity.defn +class SyncNoParamClass: + """Sync activity class with no parameters.""" + + def __call__(self) -> str: + return "done" + + +class ActivityHolder: + """Class holding activity methods.""" + + @activity.defn + async def increment_method(self, x: int) -> int: + return x + 1 + + @activity.defn + async def no_param_method(self) -> str: + return "done" + + +async def _test_start_activity_typed_callable_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[int] = await client.start_activity( + increment, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + _result: int = await _handle.result() + + +async def _test_execute_activity_typed_callable_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: int = await client.execute_activity( + increment, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_positional_arg_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[int] = await client.start_activity( + increment, + 1, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_positional_arg_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: int = await client.execute_activity( + increment, + 1, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_string_name_with_result_type() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle = await client.start_activity( + "increment", + args=[1], + id="activity-id", + task_queue="tq", + result_type=int, + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_no_param_async_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[str] = await client.start_activity( + no_param_async, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_no_param_async_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: str = await client.execute_activity( + no_param_async, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_no_param_sync_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[str] = await client.start_activity( + no_param_sync, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_no_param_sync_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: str = await client.execute_activity( + no_param_sync, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_wrong_arg_type() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[int] = await client.start_activity( + increment, + # assert-type-error-pyright: 'cannot be assigned to parameter' + "wrong type", # type: ignore + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_wrong_arg_type() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: int = await client.execute_activity( + increment, + # assert-type-error-pyright: 'cannot be assigned to parameter' + "wrong type", # type: ignore + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_wrong_result_type_assignment() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + handle = await client.start_activity( + increment, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + # assert-type-error-pyright: 'Type "int" is not assignable to declared type "str"' + _wrong: str = await handle.result() # type: ignore + + +async def _test_execute_activity_wrong_result_type_assignment() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + # assert-type-error-pyright: 'Type "int" is not assignable to declared type "str"' + _wrong: str = await client.execute_activity( # type: ignore + increment, # type: ignore[arg-type] + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_missing_required_params() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + # assert-type-error-pyright: 'No overloads for "start_activity" match' + await client.start_activity( # type: ignore + increment, + args=[1], + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + # assert-type-error-pyright: 'No overloads for "start_activity" match' + await client.start_activity( # type: ignore + increment, + args=[1], + id="activity-id", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_activity_handle_typed_correctly() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + handle_int: ActivityHandle[int] = await client.start_activity( + increment, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + _int_result: int = await handle_int.result() + + handle_str: ActivityHandle[str] = await client.start_activity( + greet, + args=["world"], + id="activity-id-2", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + _str_result: str = await handle_str.result() + + handle_none: ActivityHandle[None] = await client.start_activity( + no_return, + args=[1], + id="activity-id-3", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + _none_result: None = await handle_none.result() # type: ignore[func-returns-value] + + +async def _test_activity_handle_wrong_type_parameter() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + # assert-type-error-pyright: 'Type "ActivityHandle\[int\]" is not assignable to declared type "ActivityHandle\[str\]"' + _handle: ActivityHandle[str] = await client.start_activity( # type: ignore + increment, # type: ignore[arg-type] + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_sync_activity() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[int] = await client.start_activity( + increment_sync, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_sync_activity() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: int = await client.execute_activity( + increment_sync, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_sync_no_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[str] = await client.start_activity( + no_param_sync, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +# Tests for start_activity_class and execute_activity_class +# Note: Type inference for callable classes is limited; use args= form + + +async def _test_start_activity_class_single_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[int] = await client.start_activity_class( + IncrementClass, + 1, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_class_single_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: int = await client.execute_activity_class( + IncrementClass, + 1, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_class_no_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[str] = await client.start_activity_class( + NoParamClass, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_class_no_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: str = await client.execute_activity_class( + NoParamClass, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +# Tests for sync callable classes + + +async def _test_start_activity_class_sync_single_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[int] = await client.start_activity_class( + SyncIncrementClass, + 1, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_class_sync_single_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: int = await client.execute_activity_class( + SyncIncrementClass, + 1, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_class_sync_no_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[str] = await client.start_activity_class( + SyncNoParamClass, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +# Tests for start_activity_method and execute_activity_method +# Note: The _method variants work best with unbound methods (class references). +# For bound methods accessed via instance, use start_activity directly. + + +async def _test_start_activity_method_unbound() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + # Using unbound method reference + _handle: ActivityHandle[int] = await client.start_activity_method( + ActivityHolder.increment_method, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_method_unbound() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + # Using unbound method reference + _result: int = await client.execute_activity_method( + ActivityHolder.increment_method, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_method_no_param_unbound() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + # Using unbound method reference + _handle: ActivityHandle[str] = await client.start_activity_method( + ActivityHolder.no_param_method, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_method_no_param_unbound() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + # Using unbound method reference + _result: str = await client.execute_activity_method( + ActivityHolder.no_param_method, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) diff --git a/tests/test_serialization_context.py b/tests/test_serialization_context.py index 4e217861b..c346c46bb 100644 --- a/tests/test_serialization_context.py +++ b/tests/test_serialization_context.py @@ -180,6 +180,7 @@ async def run(self, data: TraceData) -> TraceData: data, start_to_close_timeout=timedelta(seconds=10), heartbeat_timeout=timedelta(seconds=2), + activity_id="activity-id", ) data = await workflow.execute_child_workflow( EchoWorkflow.run, data, id=f"{workflow.info().workflow_id}_child" @@ -232,6 +233,7 @@ async def test_payload_conversion_calls_follow_expected_sequence_and_contexts( workflow_id=workflow_id, workflow_type=PayloadConversionWorkflow.__name__, activity_type=passthrough_activity.__name__, + activity_id="activity-id", activity_task_queue=task_queue, is_local=False, ) @@ -329,6 +331,7 @@ async def run(self) -> TraceData: initial_interval=timedelta(milliseconds=100), maximum_attempts=2, ), + activity_id="activity-id", ) @@ -371,6 +374,7 @@ async def test_heartbeat_details_payload_conversion(client: Client): workflow_id=workflow_id, workflow_type=HeartbeatDetailsSerializationContextTestWorkflow.__name__, activity_type=activity_with_heartbeat_details.__name__, + activity_id="activity-id", activity_task_queue=task_queue, is_local=False, ) @@ -420,6 +424,7 @@ async def run(self, data: TraceData) -> TraceData: local_activity, data, start_to_close_timeout=timedelta(seconds=10), + activity_id="activity-id", ) @@ -460,6 +465,7 @@ async def test_local_activity_payload_conversion(client: Client): workflow_id=workflow_id, workflow_type=LocalActivityWorkflow.__name__, activity_type=local_activity.__name__, + activity_id="activity-id", activity_task_queue=task_queue, is_local=True, ) @@ -505,7 +511,7 @@ async def test_local_activity_payload_conversion(client: Client): @workflow.defn -class EventWorkflow: +class WaitForSignalWorkflow: # Like a global asyncio.Event() def __init__(self) -> None: @@ -522,10 +528,11 @@ def signal(self) -> None: @activity.defn async def async_activity() -> TraceData: + # Notify test that the activity has started and is ready to be completed manually await ( activity.client() .get_workflow_handle("activity-started-wf-id") - .signal(EventWorkflow.signal) + .signal(WaitForSignalWorkflow.signal) ) activity.raise_complete_async() @@ -559,7 +566,7 @@ async def test_async_activity_completion_payload_conversion( task_queue=task_queue, workflows=[ AsyncActivityCompletionSerializationContextTestWorkflow, - EventWorkflow, + WaitForSignalWorkflow, ], activities=[async_activity], workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance @@ -573,12 +580,13 @@ async def test_async_activity_completion_payload_conversion( workflow_id=workflow_id, workflow_type=AsyncActivityCompletionSerializationContextTestWorkflow.__name__, activity_type=async_activity.__name__, + activity_id="async-activity-id", activity_task_queue=task_queue, is_local=False, ) act_started_wf_handle = await client.start_workflow( - EventWorkflow.run, + WaitForSignalWorkflow.run, id="activity-started-wf-id", task_queue=task_queue, ) @@ -645,6 +653,7 @@ def test_subclassed_async_activity_handle(client: Client): workflow_id="workflow-id", workflow_type="workflow-type", activity_type="activity-type", + activity_id="activity-id", activity_task_queue="activity-task-queue", is_local=False, ) @@ -1059,11 +1068,12 @@ async def run(self) -> Never: failing_activity, start_to_close_timeout=timedelta(seconds=10), retry_policy=RetryPolicy(maximum_attempts=1), + activity_id="activity-id", ) raise Exception("Unreachable") -test_traces: dict[str, list[TraceItem]] = defaultdict(list) +test_traces: dict[str | None, list[TraceItem]] = defaultdict(list) class FailureConverterWithContext(DefaultFailureConverter, WithSerializationContext): @@ -1155,6 +1165,7 @@ async def test_failure_converter_with_context(client: Client): workflow_id=workflow_id, workflow_type=FailureConverterTestWorkflow.__name__, activity_type=failing_activity.__name__, + activity_id="activity-id", activity_task_queue=task_queue, is_local=False, ) @@ -1323,6 +1334,7 @@ async def run(self, data: str) -> str: codec_test_local_activity, data, start_to_close_timeout=timedelta(seconds=10), + activity_id="activity-id", ) @@ -1361,6 +1373,7 @@ async def test_local_activity_codec_with_context(client: Client): workflow_id=workflow_id, workflow_type=LocalActivityCodecTestWorkflow.__name__, activity_type=codec_test_local_activity.__name__, + activity_id="activity-id", activity_task_queue=task_queue, is_local=True, ) @@ -1594,6 +1607,7 @@ async def run(self, _data: str) -> str: payload_encryption_activity, "outbound", start_to_close_timeout=timedelta(seconds=10), + activity_id="activity-id", ), workflow.execute_child_workflow( PayloadEncryptionChildWorkflow.run, diff --git a/tests/worker/test_activity.py b/tests/worker/test_activity.py index e66a42dc0..f811fd1b5 100644 --- a/tests/worker/test_activity.py +++ b/tests/worker/test_activity.py @@ -1259,6 +1259,9 @@ def async_handle(self, client: Client, use_task_token: bool) -> AsyncActivityHan assert self._info if use_task_token: return client.get_async_activity_handle(task_token=self._info.task_token) + assert ( + self._info.workflow_id + ) # These tests are for workflow-triggered activities return client.get_async_activity_handle( workflow_id=self._info.workflow_id, run_id=self._info.workflow_run_id, @@ -1739,8 +1742,8 @@ async def wait_cancel() -> str: req = temporalio.api.workflowservice.v1.ResetActivityRequest( namespace=client.namespace, execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=activity.info().workflow_id, - run_id=activity.info().workflow_run_id, + workflow_id=activity.info().workflow_id or "", + run_id=activity.info().workflow_run_id or "", ), id=activity.info().activity_id, ) @@ -1759,8 +1762,8 @@ def sync_wait_cancel() -> str: req = temporalio.api.workflowservice.v1.ResetActivityRequest( namespace=client.namespace, execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=activity.info().workflow_id, - run_id=activity.info().workflow_run_id, + workflow_id=activity.info().workflow_id or "", + run_id=activity.info().workflow_run_id or "", ), id=activity.info().activity_id, ) @@ -1811,8 +1814,8 @@ async def wait_cancel() -> str: req = temporalio.api.workflowservice.v1.ResetActivityRequest( namespace=client.namespace, execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=activity.info().workflow_id, - run_id=activity.info().workflow_run_id, + workflow_id=activity.info().workflow_id or "", + run_id=activity.info().workflow_run_id or "", ), id=activity.info().activity_id, )