diff --git a/CHANGELOG.md b/CHANGELOG.md index 3cc5e03..d22b052 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ADDED - Added `durabletask.testing` module with `InMemoryOrchestrationBackend` for testing orchestrations without a sidecar process +- Improved distributed tracing support with full span coverage for orchestrations, activities, sub-orchestrations, timers, and events FIXED: diff --git a/durabletask/client.py b/durabletask/client.py index e00ba99..f830f02 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -16,6 +16,7 @@ import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared +import durabletask.internal.tracing as tracing from durabletask import task from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl @@ -169,20 +170,26 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu version: Optional[str] = None) -> str: name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator) + resolved_instance_id = instance_id if instance_id else uuid.uuid4().hex + resolved_version = version if version else self.default_version + + with tracing.start_create_orchestration_span( + name, resolved_instance_id, version=resolved_version, + ): + req = pb.CreateInstanceRequest( + name=name, + instanceId=resolved_instance_id, + input=helpers.get_string_value(shared.to_json(input) if input is not None else None), + scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, + version=helpers.get_string_value(resolved_version), + orchestrationIdReusePolicy=reuse_id_policy, + tags=tags, + parentTraceContext=tracing.get_current_trace_context(), + ) - req = pb.CreateInstanceRequest( - name=name, - instanceId=instance_id if instance_id else uuid.uuid4().hex, - input=helpers.get_string_value(shared.to_json(input) if input is not None else None), - scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, - version=helpers.get_string_value(version if version else self.default_version), - orchestrationIdReusePolicy=reuse_id_policy, - tags=tags - ) - - self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.") - res: pb.CreateInstanceResponse = self._stub.StartInstance(req) - return res.instanceId + self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.") + res: pb.CreateInstanceResponse = self._stub.StartInstance(req) + return res.instanceId def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) @@ -269,14 +276,15 @@ def wait_for_orchestration_completion(self, instance_id: str, *, def raise_orchestration_event(self, instance_id: str, event_name: str, *, data: Optional[Any] = None): - req = pb.RaiseEventRequest( - instanceId=instance_id, - name=event_name, - input=helpers.get_string_value(shared.to_json(data) if data is not None else None) - ) + with tracing.start_raise_event_span(event_name, instance_id): + req = pb.RaiseEventRequest( + instanceId=instance_id, + name=event_name, + input=helpers.get_string_value(shared.to_json(data) if data is not None else None) + ) - self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") - self._stub.RaiseEvent(req) + self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") + self._stub.RaiseEvent(req) def terminate_orchestration(self, instance_id: str, *, output: Optional[Any] = None, @@ -355,7 +363,7 @@ def signal_entity(self, input=helpers.get_string_value(shared.to_json(input) if input is not None else None), requestId=str(uuid.uuid4()), scheduledTime=None, - parentTraceContext=None, + parentTraceContext=tracing.get_current_trace_context(), requestTime=helpers.new_timestamp(datetime.now(timezone.utc)) ) self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.") diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 4720046..7b31095 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -223,11 +223,13 @@ def new_create_timer_action(id: int, fire_at: datetime) -> pb.OrchestratorAction def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str], - tags: Optional[dict[str, str]]) -> pb.OrchestratorAction: + tags: Optional[dict[str, str]], + parent_trace_context: Optional[pb.TraceContext] = None) -> pb.OrchestratorAction: return pb.OrchestratorAction(id=id, scheduleTask=pb.ScheduleTaskAction( name=name, input=get_string_value(encoded_input), - tags=tags + tags=tags, + parentTraceContext=parent_trace_context, )) @@ -302,12 +304,14 @@ def new_create_sub_orchestration_action( name: str, instance_id: Optional[str], encoded_input: Optional[str], - version: Optional[str]) -> pb.OrchestratorAction: + version: Optional[str], + parent_trace_context: Optional[pb.TraceContext] = None) -> pb.OrchestratorAction: return pb.OrchestratorAction(id=id, createSubOrchestration=pb.CreateSubOrchestrationAction( name=name, instanceId=instance_id, input=get_string_value(encoded_input), - version=get_string_value(version) + version=get_string_value(version), + parentTraceContext=parent_trace_context, )) diff --git a/durabletask/internal/tracing.py b/durabletask/internal/tracing.py new file mode 100644 index 0000000..6e7ea19 --- /dev/null +++ b/durabletask/internal/tracing.py @@ -0,0 +1,573 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""OpenTelemetry distributed tracing utilities for the Durable Task SDK. + +This module provides helpers for propagating W3C Trace Context between +orchestrations, activities, sub-orchestrations, and entities via the +``TraceContext`` protobuf message carried over gRPC. + +OpenTelemetry is an **optional** dependency. When the ``opentelemetry-api`` +package is not installed every helper gracefully degrades to a no-op so +that the rest of the SDK continues to work without any tracing overhead. +""" + +from __future__ import annotations + +import logging +import time +from contextlib import contextmanager +from datetime import datetime +from typing import Any, Optional + +from google.protobuf import timestamp_pb2, wrappers_pb2 + +import durabletask.internal.orchestrator_service_pb2 as pb + +logger = logging.getLogger("durabletask-tracing") + +# --------------------------------------------------------------------------- +# Lazy / optional OpenTelemetry imports +# --------------------------------------------------------------------------- +try: + from opentelemetry import context as otel_context + from opentelemetry import trace + from opentelemetry.trace import ( + SpanKind, # type: ignore[no-redef] + StatusCode, # type: ignore[no-redef] + ) + from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator, + ) + + _OTEL_AVAILABLE = True +except ImportError: # pragma: no cover + _OTEL_AVAILABLE = False + # Provide stub for SpanKind so callers can reference tracing.SpanKind + # without guarding every reference with OTEL_AVAILABLE checks. + + class SpanKind: # type: ignore[no-redef] + INTERNAL: Any = None + CLIENT: Any = None + SERVER: Any = None + PRODUCER: Any = None + CONSUMER: Any = None + + class StatusCode: # type: ignore[no-redef] + OK: Any = None + ERROR: Any = None + UNSET: Any = None + +# Re-export so callers can check without importing opentelemetry themselves. +OTEL_AVAILABLE = _OTEL_AVAILABLE + +# The instrumentation scope name used when creating spans. +_TRACER_NAME = "durabletask" + + +# --------------------------------------------------------------------------- +# Span attribute keys (mirrors Schema.cs from .NET SDK) +# --------------------------------------------------------------------------- + +ATTR_TASK_TYPE = "durabletask.type" +ATTR_TASK_NAME = "durabletask.task.name" +ATTR_TASK_VERSION = "durabletask.task.version" +ATTR_TASK_INSTANCE_ID = "durabletask.task.instance_id" +ATTR_TASK_EXECUTION_ID = "durabletask.task.execution_id" +ATTR_TASK_STATUS = "durabletask.task.status" +ATTR_TASK_TASK_ID = "durabletask.task.task_id" +ATTR_EVENT_TARGET_INSTANCE_ID = "durabletask.event.target_instance_id" +ATTR_FIRE_AT = "durabletask.fire_at" + + +# --------------------------------------------------------------------------- +# Span name helpers (mirrors TraceActivityConstants / TraceHelper naming) +# --------------------------------------------------------------------------- + +def create_span_name( + span_type: str, task_name: str, version: Optional[str] = None, +) -> str: + """Build a span name with optional version suffix. + + Examples:: + + create_span_name("orchestration", "MyOrch") -> "orchestration:MyOrch" + create_span_name("activity", "Say", "1.0") -> "activity:Say@(1.0)" + """ + if version: + return f"{span_type}:{task_name}@({version})" + return f"{span_type}:{task_name}" + + +def create_timer_span_name(orchestration_name: str) -> str: + """Build a timer span name: ``orchestration::timer``.""" + return f"orchestration:{orchestration_name}:timer" + + +# --------------------------------------------------------------------------- +# Public helpers – extracting / injecting trace context +# --------------------------------------------------------------------------- + + +def _trace_context_from_carrier(carrier: dict[str, str]) -> Optional[pb.TraceContext]: + """Build a ``TraceContext`` protobuf from a W3C propagation carrier. + + Returns ``None`` when the carrier does not contain a valid + ``traceparent`` header. + """ + traceparent = carrier.get("traceparent") + if not traceparent: + return None + + tracestate = carrier.get("tracestate") + # Format: 00--- + parts = traceparent.split("-") + span_id = parts[2] if len(parts) >= 4 else "" + + return pb.TraceContext( + traceParent=traceparent, + spanID=span_id, + traceState=wrappers_pb2.StringValue(value=tracestate) + if tracestate else None, + ) + + +def get_current_trace_context() -> Optional[pb.TraceContext]: + """Capture the current OpenTelemetry span context as a protobuf ``TraceContext``. + + Returns ``None`` when OpenTelemetry is not installed or there is no + active span. + """ + if not _OTEL_AVAILABLE: + return None + + propagator = TraceContextTextMapPropagator() + carrier: dict[str, str] = {} + propagator.inject(carrier) + return _trace_context_from_carrier(carrier) + + +def extract_trace_context(proto_ctx: Optional[pb.TraceContext]) -> Optional[Any]: + """Convert a protobuf ``TraceContext`` into an OpenTelemetry ``Context``. + + Returns ``None`` when OpenTelemetry is not installed or the supplied + context is empty / ``None``. + """ + if not _OTEL_AVAILABLE or proto_ctx is None: + return None + + traceparent = proto_ctx.traceParent + if not traceparent: + return None + + carrier: dict[str, str] = {"traceparent": traceparent} + if proto_ctx.HasField("traceState") and proto_ctx.traceState.value: + carrier["tracestate"] = proto_ctx.traceState.value + + propagator = TraceContextTextMapPropagator() + ctx = propagator.extract(carrier) + return ctx + + +@contextmanager +def start_span( + name: str, + trace_context: Optional[pb.TraceContext] = None, + kind: Any = None, + attributes: Optional[dict[str, str]] = None, +): + """Context manager that starts an OpenTelemetry span linked to a parent trace context. + + If OpenTelemetry is not installed, the block executes without tracing. + + Parameters + ---------- + name: + Human-readable span name (e.g. ``"activity:say_hello"``). + trace_context: + The protobuf ``TraceContext`` received from the sidecar. When + provided the new span will be created as a **child** of this + context. + kind: + The ``SpanKind`` for the new span. Defaults to ``SpanKind.INTERNAL``. + attributes: + Optional dictionary of span attributes. + """ + if not _OTEL_AVAILABLE: + yield None + return + + parent_ctx = extract_trace_context(trace_context) + + if kind is None: + kind = SpanKind.INTERNAL + + tracer = trace.get_tracer(_TRACER_NAME) + + if parent_ctx is not None: + token = otel_context.attach(parent_ctx) + try: + with tracer.start_as_current_span( + name, kind=kind, attributes=attributes + ) as span: + yield span + finally: + otel_context.detach(token) + else: + with tracer.start_as_current_span( + name, kind=kind, attributes=attributes + ) as span: + yield span + + +def set_span_error(span: Any, ex: Exception) -> None: + """Record an exception on the given span (if tracing is available).""" + if not _OTEL_AVAILABLE or span is None: + return + span.set_status(StatusCode.ERROR, str(ex)) + span.record_exception(ex) + + +# --------------------------------------------------------------------------- +# Orchestration-level span helpers +# --------------------------------------------------------------------------- + +def start_orchestration_span( + name: str, + instance_id: str, + parent_trace_context: Optional[pb.TraceContext] = None, + orchestration_trace_context: Optional[pb.OrchestrationTraceContext] = None, + version: Optional[str] = None, +) -> tuple[Any, Any, Optional[str], Optional[int]]: + """Start a Server span for an orchestration execution. + + Returns a tuple ``(span, token, span_id, start_time_ns)`` where + *token* is the OTel context token(s) that must be detached later, and + *span_id* / *start_time_ns* are the values to feed back to the sidecar + on the first execution. + + If OpenTelemetry is not available every element of the tuple is ``None``. + """ + if not _OTEL_AVAILABLE: + return None, None, None, None + + span_name = create_span_name("orchestration", name, version) + + attrs: dict[str, str] = { + ATTR_TASK_TYPE: "orchestration", + ATTR_TASK_NAME: name, + ATTR_TASK_INSTANCE_ID: instance_id, + } + if version: + attrs[ATTR_TASK_VERSION] = version + + tracer = trace.get_tracer(_TRACER_NAME) + parent_ctx = extract_trace_context(parent_trace_context) + + # Determine start time: prefer the value persisted in the + # OrchestrationTraceContext (replay / cross-worker), otherwise + # capture "now" so the value can be fed back to the sidecar. + start_time_ns: Optional[int] = None + if orchestration_trace_context is not None and orchestration_trace_context.HasField("spanStartTime"): + start_time_ns = orchestration_trace_context.spanStartTime.ToNanoseconds() + else: + start_time_ns = time.time_ns() + + token = None + if parent_ctx is not None: + token = otel_context.attach(parent_ctx) + + span = tracer.start_span( + span_name, + kind=SpanKind.SERVER, + attributes=attrs, + start_time=start_time_ns, + ) + + # Make this span the current span + ctx_with_span = trace.set_span_in_context(span) + span_token = otel_context.attach(ctx_with_span) + + # Extract the span ID and start time to return to sidecar + span_ctx = span.get_span_context() + span_id_hex = format(span_ctx.span_id, '016x') + + return span, (token, span_token), span_id_hex, start_time_ns + + +def reattach_orchestration_span(span: Any) -> Any: + """Re-attach a saved orchestration span as the current span. + + Returns the context token that must be detached later. + Returns ``None`` when OTel is not available or *span* is ``None``. + """ + if not _OTEL_AVAILABLE or span is None: + return None + + ctx_with_span = trace.set_span_in_context(span) + return otel_context.attach(ctx_with_span) + + +def detach_orchestration_tokens(tokens: Any) -> None: + """Detach context tokens without ending the span. + + Use this on intermediate dispatches where the orchestration is not + yet complete so the span is kept alive for subsequent dispatches. + """ + if tokens is None: + return + parent_token, span_token = tokens + if span_token is not None: + otel_context.detach(span_token) + if parent_token is not None: + otel_context.detach(parent_token) + + +def end_orchestration_span( + span: Any, + tokens: Any, + is_complete: bool, + is_failed: bool, + failure_details: Any = None, +) -> None: + """End the orchestration Server span, setting status and detaching context.""" + if not _OTEL_AVAILABLE or span is None: + return + + if is_complete: + if is_failed: + msg = "" + if failure_details is not None: + msg = ( + str(failure_details.errorMessage) + if hasattr(failure_details, 'errorMessage') + else str(failure_details) + ) + span.set_status(StatusCode.ERROR, msg) + span.set_attribute(ATTR_TASK_STATUS, "Failed") + else: + span.set_attribute(ATTR_TASK_STATUS, "Completed") + + span.end() + + detach_orchestration_tokens(tokens) + + +# --------------------------------------------------------------------------- +# CLIENT span helpers (create / end) +# --------------------------------------------------------------------------- + + +def create_client_span_context( + task_type: str, + name: str, + instance_id: str, + task_id: Optional[int] = None, + version: Optional[str] = None, +) -> Optional[tuple[pb.TraceContext, Any]]: + """Create a CLIENT span and return its trace context for propagation. + + The span is **not** ended here — the caller must keep a reference + and call :func:`end_client_span` when the downstream task completes + so the CLIENT span captures the full scheduling-to-completion duration. + + Returns a ``(TraceContext, span)`` tuple, or ``None`` when + OpenTelemetry is not installed. + """ + if not _OTEL_AVAILABLE: + return None + + span_name = create_span_name(task_type, name, version) + attrs: dict[str, str] = { + ATTR_TASK_TYPE: task_type, + ATTR_TASK_NAME: name, + ATTR_TASK_INSTANCE_ID: instance_id, + } + if task_id is not None: + attrs[ATTR_TASK_TASK_ID] = str(task_id) + if version: + attrs[ATTR_TASK_VERSION] = version + + tracer = trace.get_tracer(_TRACER_NAME) + span = tracer.start_span( + span_name, + kind=SpanKind.CLIENT, + attributes=attrs, + ) + + # Capture the trace context with this CLIENT span as the current span, + # so that the downstream SERVER span is parented by this CLIENT span. + ctx = trace.set_span_in_context(span) + propagator = TraceContextTextMapPropagator() + carrier: dict[str, str] = {} + propagator.inject(carrier, context=ctx) + + trace_ctx = _trace_context_from_carrier(carrier) + if trace_ctx is None: + span.end() + return None + + return trace_ctx, span + + +def end_client_span( + span, + is_error: bool = False, + error_message: Optional[str] = None, +) -> None: + """End a CLIENT span previously created by :func:`create_client_span_context`. + + If *is_error* is ``True`` the span status is set to ERROR before closing. + """ + if span is None or not _OTEL_AVAILABLE: + return + if is_error: + span.set_status(StatusCode.ERROR, error_message or "") + span.end() + + +def emit_timer_span( + orchestration_name: str, + instance_id: str, + timer_id: int, + fire_at: datetime, + scheduled_time_ns: Optional[int] = None, +) -> None: + """Emit an Internal span for a timer (emit-and-close pattern). + + When *scheduled_time_ns* is provided the span start time is backdated + to when the timer was originally created, so the span duration covers + the full wait period. + """ + if not _OTEL_AVAILABLE: + return + + span_name = create_timer_span_name(orchestration_name) + attrs: dict[str, str] = { + ATTR_TASK_TYPE: "timer", + ATTR_TASK_NAME: orchestration_name, + ATTR_TASK_INSTANCE_ID: instance_id, + ATTR_TASK_TASK_ID: str(timer_id), + ATTR_FIRE_AT: fire_at.isoformat(), + } + + tracer = trace.get_tracer(_TRACER_NAME) + span = tracer.start_span( + span_name, + kind=SpanKind.INTERNAL, + attributes=attrs, + start_time=scheduled_time_ns, + ) + span.end() + + +def emit_event_raised_span( + event_name: str, + instance_id: str, + target_instance_id: Optional[str] = None, +) -> None: + """Emit a Producer span for an event raised from the orchestration.""" + if not _OTEL_AVAILABLE: + return + + span_name = create_span_name("orchestration_event", event_name) + attrs: dict[str, str] = { + ATTR_TASK_TYPE: "event", + ATTR_TASK_NAME: event_name, + ATTR_TASK_INSTANCE_ID: instance_id, + } + if target_instance_id: + attrs[ATTR_EVENT_TARGET_INSTANCE_ID] = target_instance_id + + tracer = trace.get_tracer(_TRACER_NAME) + span = tracer.start_span( + span_name, + kind=SpanKind.PRODUCER, + attributes=attrs, + ) + span.end() + + +# --------------------------------------------------------------------------- +# Client-side Producer span helpers +# --------------------------------------------------------------------------- + +@contextmanager +def start_create_orchestration_span( + name: str, + instance_id: str, + version: Optional[str] = None, +): + """Context manager for a Producer span when scheduling a new orchestration. + + Yields the span; caller should capture the trace context after entering + the span context so it can be injected into the gRPC request. + """ + if not _OTEL_AVAILABLE: + yield None + return + + span_name = create_span_name("create_orchestration", name, version) + attrs: dict[str, str] = { + ATTR_TASK_TYPE: "orchestration", + ATTR_TASK_NAME: name, + ATTR_TASK_INSTANCE_ID: instance_id, + } + if version: + attrs[ATTR_TASK_VERSION] = version + + tracer = trace.get_tracer(_TRACER_NAME) + with tracer.start_as_current_span( + span_name, + kind=SpanKind.PRODUCER, + attributes=attrs, + ) as span: + yield span + + +@contextmanager +def start_raise_event_span( + event_name: str, + target_instance_id: str, +): + """Context manager for a Producer span when raising an event from the client.""" + if not _OTEL_AVAILABLE: + yield None + return + + span_name = create_span_name("orchestration_event", event_name) + attrs: dict[str, str] = { + ATTR_TASK_TYPE: "event", + ATTR_TASK_NAME: event_name, + ATTR_EVENT_TARGET_INSTANCE_ID: target_instance_id, + } + + tracer = trace.get_tracer(_TRACER_NAME) + with tracer.start_as_current_span( + span_name, + kind=SpanKind.PRODUCER, + attributes=attrs, + ) as span: + yield span + + +def build_orchestration_trace_context( + span_id: Optional[str], + start_time_ns: Optional[int], +) -> Optional[pb.OrchestrationTraceContext]: + """Build an ``OrchestrationTraceContext`` protobuf to return to the sidecar. + + This preserves the span ID and start time across replays. + """ + if span_id is None: + return None + + ctx = pb.OrchestrationTraceContext() + ctx.spanID.CopyFrom(wrappers_pb2.StringValue(value=span_id)) + + if start_time_ns is not None: + ts = timestamp_pb2.Timestamp() + ts.FromNanoseconds(start_time_ns) + ctx.spanStartTime.CopyFrom(ts) + + return ctx diff --git a/durabletask/worker.py b/durabletask/worker.py index 442165d..e7ab69e 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -32,6 +32,7 @@ import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared +import durabletask.internal.tracing as tracing from durabletask import task from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl @@ -329,6 +330,14 @@ def __init__( self._shutdown = Event() self._is_running = False self._secure_channel = secure_channel + # Persist orchestration spans across replay dispatches so only one + # span is exported per orchestration (on completion). + # Key: instance_id, Value: (span, orchestrationTraceContext) + self._orchestration_spans: dict[str, tuple] = {} + # Persist activity / sub-orchestration CLIENT spans across dispatches + # so the span covers scheduling-to-completion duration. + # Key: instance_id -> (task_id -> span) + self._pending_client_spans: dict[str, dict[int, Any]] = {} # Use provided concurrency options or create default ones self._concurrency_options = ( @@ -627,24 +636,108 @@ def stop(self): self._logger.info("Worker shutdown completed") self._is_running = False + def _end_remaining_client_spans(self, instance_id: str) -> None: + """End and discard any CLIENT spans still pending for *instance_id*. + + Called when the orchestration completes, fails, or is abandoned so + that in-flight CLIENT spans are properly closed and exported. + """ + spans = self._pending_client_spans.pop(instance_id, None) + if spans: + for span in spans.values(): + tracing.end_client_span(span) + def _execute_orchestrator( self, req: pb.OrchestratorRequest, stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub], completionToken, ): + instance_id = req.instanceId + + # Extract parent trace context from executionStarted event + parent_trace_ctx = None + orchestration_name = "" + for e in list(req.pastEvents) + list(req.newEvents): + if e.HasField("executionStarted"): + orchestration_name = e.executionStarted.name + if e.executionStarted.HasField("parentTraceContext"): + parent_trace_ctx = e.executionStarted.parentTraceContext + break + + # Reuse the orchestration span from a previous dispatch if available, + # so a single span covers the entire orchestration lifetime. + saved = self._orchestration_spans.get(instance_id) + if saved is not None: + span, orch_trace_ctx = saved + reattach_token = tracing.reattach_orchestration_span(span) + tokens = (None, reattach_token) + span_id = None # already captured on first dispatch + else: + # First dispatch for this instance — create a new span + span, tokens, span_id, start_time_ns = tracing.start_orchestration_span( + orchestration_name, + instance_id, + parent_trace_context=parent_trace_ctx, + orchestration_trace_context=( + req.orchestrationTraceContext + if req.HasField("orchestrationTraceContext") else None + ), + ) + orch_trace_ctx = tracing.build_orchestration_trace_context( + span_id, start_time_ns, + ) + try: - executor = _OrchestrationExecutor(self._registry, self._logger) - result = executor.execute(req.instanceId, req.pastEvents, req.newEvents) + instance_client_spans = self._pending_client_spans.setdefault( + instance_id, {}) + executor = _OrchestrationExecutor( + self._registry, self._logger, + pending_client_spans=instance_client_spans) + result = executor.execute(instance_id, req.pastEvents, req.newEvents) + + # Determine completion status for span + is_complete = False + is_failed = False + failure_details = None + for action in result.actions: + if action.HasField("completeOrchestration"): + is_complete = True + orch_status = action.completeOrchestration.orchestrationStatus + if orch_status == pb.ORCHESTRATION_STATUS_FAILED: + is_failed = True + failure_details = action.completeOrchestration.failureDetails + + if is_complete: + # Orchestration finished — end and export the span + tracing.end_orchestration_span( + span, tokens, True, is_failed, failure_details, + ) + self._orchestration_spans.pop(instance_id, None) + self._end_remaining_client_spans(instance_id) + else: + # Intermediate dispatch — keep the span alive for later, + # but detach context tokens for this call. + if span is not None: + self._orchestration_spans[instance_id] = (span, orch_trace_ctx) + tracing.detach_orchestration_tokens(tokens) + res = pb.OrchestratorResponse( - instanceId=req.instanceId, + instanceId=instance_id, actions=result.actions, customStatus=ph.get_string_value(result.encoded_custom_status), completionToken=completionToken, + orchestrationTraceContext=( + orch_trace_ctx if orch_trace_ctx + else req.orchestrationTraceContext + ), ) except pe.AbandonOrchestrationError: + tracing.end_orchestration_span(span, tokens, False, False) + self._orchestration_spans.pop(instance_id, None) + self._end_remaining_client_spans(instance_id) self._logger.info( - f"Abandoning orchestration. InstanceId = '{req.instanceId}'. Completion token = '{completionToken}'" + f"Abandoning orchestration. InstanceId = '{instance_id}'. Completion token = '{completionToken}'" ) stub.AbandonTaskOrchestratorWorkItem( pb.AbandonOrchestrationTaskRequest( @@ -653,8 +746,12 @@ def _execute_orchestrator( ) return except Exception as ex: + tracing.set_span_error(span, ex) + tracing.end_orchestration_span(span, tokens, True, True, ex) + self._orchestration_spans.pop(instance_id, None) + self._end_remaining_client_spans(instance_id) self._logger.exception( - f"An error occurred while trying to execute instance '{req.instanceId}': {ex}" + f"An error occurred while trying to execute instance '{instance_id}': {ex}" ) failure_details = ph.new_failure_details(ex) actions = [ @@ -663,7 +760,7 @@ def _execute_orchestrator( ) ] res = pb.OrchestratorResponse( - instanceId=req.instanceId, + instanceId=instance_id, actions=actions, completionToken=completionToken, ) @@ -697,9 +794,24 @@ def _execute_activity( instance_id = req.orchestrationInstance.instanceId try: executor = _ActivityExecutor(self._registry, self._logger) - result = executor.execute( - instance_id, req.name, req.taskId, req.input.value - ) + with tracing.start_span( + tracing.create_span_name("activity", req.name), + trace_context=req.parentTraceContext, + kind=tracing.SpanKind.SERVER, + attributes={ + tracing.ATTR_TASK_TYPE: "activity", + tracing.ATTR_TASK_INSTANCE_ID: instance_id, + tracing.ATTR_TASK_NAME: req.name, + tracing.ATTR_TASK_TASK_ID: str(req.taskId), + }, + ) as span: + try: + result = executor.execute( + instance_id, req.name, req.taskId, req.input.value + ) + except Exception as ex: + tracing.set_span_error(span, ex) + raise res = pb.ActivityResponse( instanceId=instance_id, taskId=req.taskId, @@ -759,30 +871,45 @@ def _execute_entity_batch( operation_result = None - try: - entity_result = executor.execute( - instance_id, entity_instance_id, operation.operation, entity_state, operation.input.value - ) - - entity_result = ph.get_string_value_or_empty(entity_result) - operation_result = pb.OperationResult(success=pb.OperationResultSuccess( - result=entity_result, - startTimeUtc=new_timestamp(start_time), - endTimeUtc=new_timestamp(datetime.now(timezone.utc)) - )) - results.append(operation_result) - - entity_state.commit() - except Exception as ex: - self._logger.exception(ex) - operation_result = pb.OperationResult(failure=pb.OperationResultFailure( - failureDetails=ph.new_failure_details(ex), - startTimeUtc=new_timestamp(start_time), - endTimeUtc=new_timestamp(datetime.now(timezone.utc)) - )) - results.append(operation_result) + # Get the trace context for this operation, if available + op_trace_ctx = operation.traceContext if operation.HasField("traceContext") else None + + with tracing.start_span( + tracing.create_span_name("entity", f"{entity_instance_id.entity}:{operation.operation}"), + trace_context=op_trace_ctx, + kind=tracing.SpanKind.SERVER, + attributes={ + tracing.ATTR_TASK_TYPE: "entity", + tracing.ATTR_TASK_INSTANCE_ID: instance_id, + tracing.ATTR_TASK_NAME: entity_instance_id.entity, + "durabletask.entity.operation": operation.operation, + }, + ) as span: + try: + entity_result = executor.execute( + instance_id, entity_instance_id, operation.operation, entity_state, operation.input.value + ) - entity_state.rollback() + entity_result = ph.get_string_value_or_empty(entity_result) + operation_result = pb.OperationResult(success=pb.OperationResultSuccess( + result=entity_result, + startTimeUtc=new_timestamp(start_time), + endTimeUtc=new_timestamp(datetime.now(timezone.utc)) + )) + results.append(operation_result) + + entity_state.commit() + except Exception as ex: + tracing.set_span_error(span, ex) + self._logger.exception(ex) + operation_result = pb.OperationResult(failure=pb.OperationResultFailure( + failureDetails=ph.new_failure_details(ex), + startTimeUtc=new_timestamp(start_time), + endTimeUtc=new_timestamp(datetime.now(timezone.utc)) + )) + results.append(operation_result) + + entity_state.rollback() batch_result = pb.EntityBatchResult( results=results, @@ -847,6 +974,9 @@ def __init__(self, instance_id: str, registry: _Registry): self._new_input: Optional[Any] = None self._save_events = False self._encoded_custom_status: Optional[str] = None + self._parent_trace_context: Optional[pb.TraceContext] = None + # Shared dict for activity/sub-orch CLIENT span lifecycle + self._pending_client_spans: dict[int, Any] = {} def run(self, generator: Generator[task.Task, Any, Any]): self._generator = generator @@ -1136,15 +1266,37 @@ def call_activity_function_helper( if isinstance(activity_function, str) else task.get_name(activity_function) ) - action = ph.new_schedule_task_action(id, name, encoded_input, tags) + # Create a CLIENT span for the activity and propagate its trace + # context so the activity SERVER span is parented by it. + parent_ctx = self._parent_trace_context + if not self._is_replaying: + client_result = tracing.create_client_span_context( + "activity", name, self.instance_id, id) + if client_result: + parent_ctx, client_span = client_result + self._pending_client_spans[id] = client_span + action = ph.new_schedule_task_action( + id, name, encoded_input, tags, + parent_trace_context=parent_ctx) else: if instance_id is None: # Create a deteministic instance ID based on the parent instance ID instance_id = f"{self.instance_id}:{id:04x}" if not isinstance(activity_function, str): raise ValueError("Orchestrator function name must be a string") + # Create a CLIENT span for the sub-orchestration and propagate + # its trace context so the sub-orch SERVER span is parented by it. + parent_ctx = self._parent_trace_context + if not self._is_replaying: + client_result = tracing.create_client_span_context( + "orchestration", activity_function, instance_id, id, + version=version) + if client_result: + parent_ctx, client_span = client_result + self._pending_client_spans[id] = client_span action = ph.new_create_sub_orchestration_action( - id, activity_function, instance_id, encoded_input, version + id, activity_function, instance_id, encoded_input, version, + parent_trace_context=parent_ctx ) self._pending_actions[id] = action @@ -1288,11 +1440,54 @@ def __init__( class _OrchestrationExecutor: _generator: Optional[task.Orchestrator] = None - def __init__(self, registry: _Registry, logger: logging.Logger): + def __init__( + self, + registry: _Registry, + logger: logging.Logger, + pending_client_spans: Optional[dict[int, Any]] = None, + ): self._registry = registry self._logger = logger self._is_suspended = False self._suspended_events: list[pb.HistoryEvent] = [] + # Maps timer_id -> (fire_at, created_time_ns) + self._timer_fire_at: dict[int, tuple[datetime, Optional[int]]] = {} + # Shared dict for CLIENT span lifecycle (from the worker) + self._pending_client_spans: dict[int, Any] = ( + pending_client_spans if pending_client_spans is not None else {}) + # True when the worker provides a persistent span dict. + # Fallback CLIENT spans are only emitted in this mode + # (bare executors in tests don't have prior-dispatch state). + self._has_worker_span_context = pending_client_spans is not None + # Track task_id -> (task_type, name, instance_id) for fallback + # CLIENT span emission when the scheduling worker differs from + # the completion worker (distributed environment). + self._scheduled_task_info: dict[int, tuple[str, str, str]] = {} + + def _emit_fallback_client_span( + self, + task_id: int, + is_error: bool = False, + error_message: Optional[str] = None, + ) -> None: + """Emit an instant CLIENT span for a task whose scheduling dispatch + was handled by a different worker (distributed environment). + + This is a no-op when the executor was created without worker span + context or when no scheduling info is available for *task_id*. + """ + if not self._has_worker_span_context: + return + info = self._scheduled_task_info.get(task_id) + if not info: + return + task_type, task_name, inst_id = info + result = tracing.create_client_span_context( + task_type, task_name, inst_id, task_id) + if result: + _, span = result + tracing.end_client_span(span, is_error=is_error, + error_message=error_message) def execute( self, @@ -1304,6 +1499,7 @@ def execute( orchestration_started_events = [e for e in old_events if e.HasField("executionStarted")] if len(orchestration_started_events) >= 1: orchestration_name = orchestration_started_events[0].executionStarted.name + self._orchestration_name = orchestration_name self._logger.debug( f"{instance_id}: Beginning replay for orchestrator {orchestration_name}..." @@ -1315,6 +1511,7 @@ def execute( ) ctx = _RuntimeOrchestrationContext(instance_id, self._registry) + ctx._pending_client_spans = self._pending_client_spans try: # Rebuild local state by replaying old history into the orchestrator function self._logger.debug( @@ -1397,6 +1594,10 @@ def process_event( if event.executionStarted.version: ctx._version = event.executionStarted.version.value + # Store the parent trace context for propagation to child tasks + if event.executionStarted.HasField("parentTraceContext"): + ctx._parent_trace_context = event.executionStarted.parentTraceContext + if self._registry.versioning: version_failure = self.evaluate_orchestration_versioning( self._registry.versioning, @@ -1440,6 +1641,13 @@ def process_event( raise _get_wrong_action_type_error( timer_id, expected_method_name, action ) + # Track timer fire_at and creation timestamp for span emission + if action.createTimer.HasField("fireAt"): + created_ns = (event.timestamp.ToNanoseconds() + if event.HasField("timestamp") else None) + self._timer_fire_at[timer_id] = ( + action.createTimer.fireAt.ToDatetime(), created_ns, + ) elif event.HasField("timerFired"): timer_id = event.timerFired.timerId timer_task = ctx._pending_tasks.pop(timer_id, None) @@ -1450,6 +1658,16 @@ def process_event( f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}." ) return + # Emit timer span with backdated start time (skip during replay) + if not ctx.is_replaying: + timer_info = self._timer_fire_at.get(timer_id) + if timer_info is not None: + fire_at, created_ns = timer_info + tracing.emit_timer_span( + self._orchestration_name, ctx.instance_id, + timer_id, fire_at, + scheduled_time_ns=created_ns, + ) timer_task.complete(None) if timer_task._retryable_parent is not None: activity_action = timer_task._retryable_parent._action @@ -1493,9 +1711,19 @@ def process_event( expected_task_name=event.taskScheduled.name, actual_task_name=action.scheduleTask.name, ) + # Track task name for fallback CLIENT span in distributed case + self._scheduled_task_info[task_id] = ( + "activity", event.taskScheduled.name, ctx.instance_id) elif event.HasField("taskCompleted"): # This history event contains the result of a completed activity task. task_id = event.taskCompleted.taskScheduledId + # End the CLIENT span for this activity (spans the full duration) + if not ctx.is_replaying: + client_span = self._pending_client_spans.pop(task_id, None) + if client_span is not None: + tracing.end_client_span(client_span) + else: + self._emit_fallback_client_span(task_id) activity_task = ctx._pending_tasks.pop(task_id, None) if not activity_task: # TODO: Should this be an error? When would it ever happen? @@ -1511,6 +1739,20 @@ def process_event( ctx.resume() elif event.HasField("taskFailed"): task_id = event.taskFailed.taskScheduledId + # End the CLIENT span with error status + if not ctx.is_replaying: + client_span = self._pending_client_spans.pop(task_id, None) + err_msg = ( + event.taskFailed.failureDetails.errorMessage + if event.taskFailed.HasField("failureDetails") + else None + ) + if client_span is not None: + tracing.end_client_span( + client_span, is_error=True, error_message=err_msg) + else: + self._emit_fallback_client_span( + task_id, is_error=True, error_message=err_msg) activity_task = ctx._pending_tasks.pop(task_id, None) if not activity_task: # TODO: Should this be an error? When would it ever happen? @@ -1563,8 +1805,21 @@ def process_event( expected_task_name=event.subOrchestrationInstanceCreated.name, actual_task_name=action.createSubOrchestration.name, ) + # Track task name for fallback CLIENT span in distributed case + self._scheduled_task_info[task_id] = ( + "orchestration", + event.subOrchestrationInstanceCreated.name, + event.subOrchestrationInstanceCreated.instanceId, + ) elif event.HasField("subOrchestrationInstanceCompleted"): task_id = event.subOrchestrationInstanceCompleted.taskScheduledId + # End the CLIENT span for this sub-orchestration + if not ctx.is_replaying: + client_span = self._pending_client_spans.pop(task_id, None) + if client_span is not None: + tracing.end_client_span(client_span) + else: + self._emit_fallback_client_span(task_id) sub_orch_task = ctx._pending_tasks.pop(task_id, None) if not sub_orch_task: # TODO: Should this be an error? When would it ever happen? @@ -1583,6 +1838,20 @@ def process_event( elif event.HasField("subOrchestrationInstanceFailed"): failedEvent = event.subOrchestrationInstanceFailed task_id = failedEvent.taskScheduledId + # End the CLIENT span with error status + if not ctx.is_replaying: + client_span = self._pending_client_spans.pop(task_id, None) + err_msg = ( + failedEvent.failureDetails.errorMessage + if failedEvent.HasField("failureDetails") + else None + ) + if client_span is not None: + tracing.end_client_span( + client_span, is_error=True, error_message=err_msg) + else: + self._emit_fallback_client_span( + task_id, is_error=True, error_message=err_msg) sub_orch_task = ctx._pending_tasks.pop(task_id, None) if not sub_orch_task: # TODO: Should this be an error? When would it ever happen? diff --git a/pyproject.toml b/pyproject.toml index be5d8dd..a28c73c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,12 @@ dependencies = [ "packaging" ] +[project.optional-dependencies] +opentelemetry = [ + "opentelemetry-api>=1.0.0", + "opentelemetry-sdk>=1.0.0" +] + [project.urls] repository = "https://github.com/microsoft/durabletask-python" changelog = "https://github.com/microsoft/durabletask-python/blob/main/CHANGELOG.md" diff --git a/requirements.txt b/requirements.txt index f32d350..85ba9a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,6 @@ pytest pytest-cov azure-identity asyncio -packaging \ No newline at end of file +packaging +opentelemetry-api +opentelemetry-sdk \ No newline at end of file diff --git a/tests/durabletask/test_tracing.py b/tests/durabletask/test_tracing.py new file mode 100644 index 0000000..d17d240 --- /dev/null +++ b/tests/durabletask/test_tracing.py @@ -0,0 +1,1399 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for distributed tracing utilities and integration.""" + +import json +import logging +from datetime import datetime, timezone +from typing import Any +from unittest.mock import patch + +import pytest +from google.protobuf import wrappers_pb2 + +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import StatusCode + +import durabletask.internal.helpers as helpers +import durabletask.internal.orchestrator_service_pb2 as pb +import durabletask.internal.tracing as tracing +from durabletask import task, worker + +logging.basicConfig( + format='%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.DEBUG) +TEST_LOGGER = logging.getLogger("tests") +TEST_INSTANCE_ID = 'abc123' + + +# Module-level setup: create a TracerProvider with an InMemorySpanExporter once. +# Newer OpenTelemetry versions only allow set_tracer_provider to be called once. +_EXPORTER = InMemorySpanExporter() +_PROVIDER = TracerProvider() +_PROVIDER.add_span_processor(SimpleSpanProcessor(_EXPORTER)) +trace.set_tracer_provider(_PROVIDER) + + +@pytest.fixture(autouse=True) +def otel_setup(): + """Clear the in-memory exporter before each test.""" + _EXPORTER.clear() + yield _EXPORTER + + +# --------------------------------------------------------------------------- +# Tests for tracing utility functions +# --------------------------------------------------------------------------- + + +class TestGetCurrentTraceContext: + """Tests for tracing.get_current_trace_context().""" + + def test_returns_none_when_no_active_span(self, otel_setup): + """When there is no active span, should return None.""" + result = tracing.get_current_trace_context() + assert result is None + + def test_returns_trace_context_with_active_span(self, otel_setup): + """When there is an active span, should return a populated TraceContext.""" + tracer = trace.get_tracer("test") + with tracer.start_as_current_span("test-span"): + result = tracing.get_current_trace_context() + + assert result is not None + assert isinstance(result, pb.TraceContext) + assert result.traceParent != "" + assert result.spanID != "" + # traceparent format: 00--- + parts = result.traceParent.split("-") + assert len(parts) == 4 + assert parts[0] == "00" + assert len(parts[1]) == 32 # trace ID + assert len(parts[2]) == 16 # span ID + assert result.spanID == parts[2] + + +class TestExtractTraceContext: + """Tests for tracing.extract_trace_context().""" + + def test_returns_none_for_none_input(self): + result = tracing.extract_trace_context(None) + assert result is None + + def test_returns_none_for_empty_traceparent(self): + proto_ctx = pb.TraceContext(traceParent="", spanID="") + result = tracing.extract_trace_context(proto_ctx) + assert result is None + + def test_extracts_valid_context(self, otel_setup): + """Should extract a valid OTel context from a protobuf TraceContext.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + proto_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + otel_ctx = tracing.extract_trace_context(proto_ctx) + assert otel_ctx is not None + + def test_extracts_context_with_tracestate(self, otel_setup): + """Should extract context including tracestate.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + tracestate_val = "congo=t61rcWkgMzE" + proto_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + traceState=wrappers_pb2.StringValue(value=tracestate_val), + ) + otel_ctx = tracing.extract_trace_context(proto_ctx) + assert otel_ctx is not None + + +class TestStartSpan: + """Tests for tracing.start_span().""" + + def test_creates_span_without_parent(self, otel_setup: InMemorySpanExporter): + """Should create a span even without a parent trace context.""" + with tracing.start_span("test-span") as span: + assert span is not None + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == "test-span" + + def test_creates_span_with_attributes(self, otel_setup: InMemorySpanExporter): + """Should create a span with custom attributes.""" + attrs = {"key1": "value1", "key2": "value2"} + with tracing.start_span("test-span", attributes=attrs) as span: + assert span is not None + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].attributes is not None + assert spans[0].attributes["key1"] == "value1" + assert spans[0].attributes["key2"] == "value2" + + def test_creates_child_span_from_trace_context(self, otel_setup: InMemorySpanExporter): + """Should create a child span linked to the parent trace context.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + proto_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + with tracing.start_span("child-span", trace_context=proto_ctx) as span: + assert span is not None + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + child_span = spans[0] + assert child_span.name == "child-span" + # The child span's trace ID should match the parent's + assert child_span.context is not None + assert child_span.context.trace_id == int("0af7651916cd43dd8448eb211c80319c", 16) + + +class TestSetSpanError: + """Tests for tracing.set_span_error().""" + + def test_records_error_on_span(self, otel_setup: InMemorySpanExporter): + """Should record error status and exception on the span.""" + with tracing.start_span("error-span") as span: + ex = ValueError("something went wrong") + tracing.set_span_error(span, ex) + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.ERROR + assert spans[0].status.description is not None + assert "something went wrong" in spans[0].status.description + + def test_noop_with_none_span(self): + """Should not raise when span is None.""" + tracing.set_span_error(None, ValueError("test")) + + +# --------------------------------------------------------------------------- +# Tests for client-side trace context injection +# --------------------------------------------------------------------------- + + +class TestClientTraceContextInjection: + """Tests that the client methods inject trace context.""" + + def test_schedule_new_orchestration_includes_trace_context(self, otel_setup): + """schedule_new_orchestration should set parentTraceContext from current span.""" + tracer = trace.get_tracer("test") + with tracer.start_as_current_span("client-span"): + ctx = tracing.get_current_trace_context() + + assert ctx is not None + assert ctx.traceParent != "" + assert ctx.spanID != "" + + +# --------------------------------------------------------------------------- +# Tests for activity execution with tracing +# --------------------------------------------------------------------------- + + +class TestActivityExecutionTracing: + """Tests that activity execution creates spans from parent trace context.""" + + def test_activity_executes_within_span(self, otel_setup: InMemorySpanExporter): + """Activity execution should create a span when parentTraceContext is provided.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + + def test_activity(ctx: task.ActivityContext, input: Any): + return "hello" + + registry = worker._Registry() + name = registry.add_activity(test_activity) + executor = worker._ActivityExecutor(registry, TEST_LOGGER) + + with tracing.start_span( + f"activity:{name}", + trace_context=parent_ctx, + attributes={"durabletask.task.instance_id": TEST_INSTANCE_ID, + "durabletask.task.name": name, + "durabletask.task.task_id": "42"}, + ): + result = executor.execute(TEST_INSTANCE_ID, name, 42, None) + + assert result == json.dumps("hello") + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == f"activity:{name}" + assert spans[0].attributes is not None + assert spans[0].attributes["durabletask.task.instance_id"] == TEST_INSTANCE_ID + + def test_activity_error_sets_span_error(self, otel_setup: InMemorySpanExporter): + """Activity execution errors should be recorded on the span.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + + def failing_activity(ctx: task.ActivityContext, input: Any): + raise ValueError("Activity failed!") + + registry = worker._Registry() + name = registry.add_activity(failing_activity) + executor = worker._ActivityExecutor(registry, TEST_LOGGER) + + with pytest.raises(ValueError, match="Activity failed!"): + with tracing.start_span( + f"activity:{name}", + trace_context=parent_ctx, + ) as span: + try: + executor.execute(TEST_INSTANCE_ID, name, 42, None) + except Exception as ex: + tracing.set_span_error(span, ex) + raise + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.ERROR + + +# --------------------------------------------------------------------------- +# Tests for orchestration trace context propagation +# --------------------------------------------------------------------------- + + +class TestOrchestrationTraceContextPropagation: + """Tests that orchestration actions include trace context.""" + + def test_schedule_task_action_includes_trace_context(self): + """new_schedule_task_action should include parentTraceContext when provided.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + action = helpers.new_schedule_task_action( + 0, "my_activity", None, None, + parent_trace_context=parent_ctx + ) + assert action.scheduleTask.parentTraceContext.traceParent == traceparent + + def test_schedule_task_action_without_trace_context(self): + """new_schedule_task_action should work without trace context.""" + action = helpers.new_schedule_task_action(0, "my_activity", None, None) + # parentTraceContext should not be set (default empty) + assert action.scheduleTask.parentTraceContext.traceParent == "" + + def test_create_sub_orchestration_action_includes_trace_context(self): + """new_create_sub_orchestration_action should include parentTraceContext.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + action = helpers.new_create_sub_orchestration_action( + 0, "sub_orch", "inst1", None, None, + parent_trace_context=parent_ctx + ) + assert action.createSubOrchestration.parentTraceContext.traceParent == traceparent + + def test_create_sub_orchestration_action_without_trace_context(self): + """new_create_sub_orchestration_action should work without trace context.""" + action = helpers.new_create_sub_orchestration_action( + 0, "sub_orch", "inst1", None, None + ) + assert action.createSubOrchestration.parentTraceContext.traceParent == "" + + +class TestOrchestrationExecutorStoresTraceContext: + """Tests that the orchestration executor extracts and stores trace context from events.""" + + def test_execution_started_stores_parent_trace_context(self): + """process_event should store parentTraceContext from executionStarted.""" + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + + def simple_orchestrator(ctx: task.OrchestrationContext, _): + return "done" + + registry = worker._Registry() + registry.add_orchestrator(simple_orchestrator) + + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry) + assert ctx._parent_trace_context is None + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + + # Create an executionStarted event with parentTraceContext + event = pb.HistoryEvent( + eventId=-1, + executionStarted=pb.ExecutionStartedEvent( + name="simple_orchestrator", + orchestrationInstance=pb.OrchestrationInstance(instanceId=TEST_INSTANCE_ID), + parentTraceContext=parent_ctx, + ) + ) + + executor.process_event(ctx, event) + assert ctx._parent_trace_context is not None + assert ctx._parent_trace_context.traceParent == traceparent + + def test_execution_started_without_trace_context(self): + """process_event should leave parentTraceContext as None when not provided.""" + def simple_orchestrator(ctx: task.OrchestrationContext, _): + return "done" + + registry = worker._Registry() + registry.add_orchestrator(simple_orchestrator) + + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + + event = pb.HistoryEvent( + eventId=-1, + executionStarted=pb.ExecutionStartedEvent( + name="simple_orchestrator", + orchestrationInstance=pb.OrchestrationInstance(instanceId=TEST_INSTANCE_ID), + ) + ) + + executor.process_event(ctx, event) + assert ctx._parent_trace_context is None + + +class TestOtelNotAvailable: + """Tests that tracing functions gracefully degrade when OTel is unavailable.""" + + def test_get_current_trace_context_without_otel(self): + """get_current_trace_context returns None when OTel is not available.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + result = tracing.get_current_trace_context() + assert result is None + + def test_extract_trace_context_without_otel(self): + """extract_trace_context returns None when OTel is not available.""" + proto_ctx = pb.TraceContext(traceParent="00-abc-def-01", spanID="def") + with patch.object(tracing, '_OTEL_AVAILABLE', False): + result = tracing.extract_trace_context(proto_ctx) + assert result is None + + def test_start_span_without_otel(self): + """start_span should yield None when OTel is not available.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + with tracing.start_span("test") as span: + assert span is None + + def test_set_span_error_without_otel(self): + """set_span_error should be a no-op when OTel is not available.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + tracing.set_span_error(None, ValueError("test")) # should not raise + + def test_start_orchestration_span_without_otel(self): + """start_orchestration_span returns all-None tuple when OTel unavailable.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + span, tokens, span_id, start_time = tracing.start_orchestration_span( + "test_orch", "inst1", + ) + assert span is None + assert tokens is None + assert span_id is None + assert start_time is None + + def test_end_orchestration_span_without_otel(self): + """end_orchestration_span is a no-op when OTel is unavailable.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + tracing.end_orchestration_span(None, None, True, False) + + def test_emit_timer_span_without_otel(self): + """emit_timer_span is a no-op when OTel is unavailable.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + tracing.emit_timer_span("orch", "inst1", 1, datetime.now(timezone.utc)) + + def test_start_create_orchestration_span_without_otel(self): + """start_create_orchestration_span yields None when OTel unavailable.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + with tracing.start_create_orchestration_span("orch", "inst1") as span: + assert span is None + + def test_start_raise_event_span_without_otel(self): + """start_raise_event_span yields None when OTel unavailable.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + with tracing.start_raise_event_span("evt", "inst1") as span: + assert span is None + + def test_create_client_span_context_without_otel(self): + """create_client_span_context returns None when OTel is unavailable.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + result = tracing.create_client_span_context("activity", "Act", "inst1") + assert result is None + + def test_end_client_span_without_otel(self): + """end_client_span is a no-op when OTel is unavailable.""" + with patch.object(tracing, '_OTEL_AVAILABLE', False): + tracing.end_client_span(None) # should not raise + + +# --------------------------------------------------------------------------- +# Tests for span naming helpers +# --------------------------------------------------------------------------- + + +class TestSpanNaming: + """Tests for create_span_name and create_timer_span_name.""" + + def test_create_span_name_without_version(self): + assert tracing.create_span_name("orchestration", "MyOrch") == "orchestration:MyOrch" + + def test_create_span_name_with_version(self): + assert tracing.create_span_name("activity", "Say", "1.0") == "activity:Say@(1.0)" + + def test_create_timer_span_name(self): + assert tracing.create_timer_span_name("MyOrch") == "orchestration:MyOrch:timer" + + +# --------------------------------------------------------------------------- +# Tests for schema attribute constants +# --------------------------------------------------------------------------- + + +class TestSchemaConstants: + """Tests that schema constants match expected names.""" + + def test_attribute_keys_defined(self): + assert tracing.ATTR_TASK_TYPE == "durabletask.type" + assert tracing.ATTR_TASK_NAME == "durabletask.task.name" + assert tracing.ATTR_TASK_VERSION == "durabletask.task.version" + assert tracing.ATTR_TASK_INSTANCE_ID == "durabletask.task.instance_id" + assert tracing.ATTR_TASK_EXECUTION_ID == "durabletask.task.execution_id" + assert tracing.ATTR_TASK_STATUS == "durabletask.task.status" + assert tracing.ATTR_TASK_TASK_ID == "durabletask.task.task_id" + assert tracing.ATTR_EVENT_TARGET_INSTANCE_ID == "durabletask.event.target_instance_id" + assert tracing.ATTR_FIRE_AT == "durabletask.fire_at" + + +# --------------------------------------------------------------------------- +# Tests for Producer / Client / Server span creation +# --------------------------------------------------------------------------- + + +class TestCreateOrchestrationSpan: + """Tests for start_create_orchestration_span (Producer span).""" + + def test_creates_producer_span(self, otel_setup: InMemorySpanExporter): + """Should create a Producer span for create_orchestration.""" + with tracing.start_create_orchestration_span("MyOrch", "inst-123") as span: + assert span is not None + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.name == "create_orchestration:MyOrch" + assert s.kind == trace.SpanKind.PRODUCER + assert s.attributes is not None + assert s.attributes[tracing.ATTR_TASK_TYPE] == "orchestration" + assert s.attributes[tracing.ATTR_TASK_NAME] == "MyOrch" + assert s.attributes[tracing.ATTR_TASK_INSTANCE_ID] == "inst-123" + + def test_creates_producer_span_with_version(self, otel_setup: InMemorySpanExporter): + with tracing.start_create_orchestration_span("MyOrch", "inst-123", version="2.0"): + pass + + spans = otel_setup.get_finished_spans() + assert spans[0].name == "create_orchestration:MyOrch@(2.0)" + assert spans[0].attributes is not None + assert spans[0].attributes[tracing.ATTR_TASK_VERSION] == "2.0" + + def test_trace_context_injected_inside_producer_span(self, otel_setup: InMemorySpanExporter): + """Inside the producer span, get_current_trace_context should capture producer span ctx.""" + with tracing.start_create_orchestration_span("Orch", "inst"): + ctx = tracing.get_current_trace_context() + assert ctx is not None + assert ctx.traceParent != "" + + +class TestRaiseEventSpan: + """Tests for start_raise_event_span (Producer span).""" + + def test_creates_producer_span(self, otel_setup: InMemorySpanExporter): + with tracing.start_raise_event_span("MyEvent", "inst-456") as span: + assert span is not None + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.name == "orchestration_event:MyEvent" + assert s.kind == trace.SpanKind.PRODUCER + assert s.attributes is not None + assert s.attributes[tracing.ATTR_TASK_TYPE] == "event" + assert s.attributes[tracing.ATTR_TASK_NAME] == "MyEvent" + assert s.attributes[tracing.ATTR_EVENT_TARGET_INSTANCE_ID] == "inst-456" + + +class TestOrchestrationServerSpan: + """Tests for start_orchestration_span and end_orchestration_span.""" + + def test_creates_server_span(self, otel_setup: InMemorySpanExporter): + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + parent_ctx = pb.TraceContext( + traceParent=traceparent, + spanID="b7ad6b7169203331", + ) + span, tokens, span_id, start_time_ns = tracing.start_orchestration_span( + "MyOrch", "inst-100", parent_trace_context=parent_ctx, + ) + assert span is not None + assert span_id is not None + assert len(span_id) == 16 + + tracing.end_orchestration_span(span, tokens, True, False) + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.name == "orchestration:MyOrch" + assert s.kind == trace.SpanKind.SERVER + assert s.attributes is not None + assert s.attributes[tracing.ATTR_TASK_TYPE] == "orchestration" + assert s.attributes[tracing.ATTR_TASK_NAME] == "MyOrch" + assert s.attributes[tracing.ATTR_TASK_INSTANCE_ID] == "inst-100" + assert s.attributes[tracing.ATTR_TASK_STATUS] == "Completed" + + def test_start_time_always_captured(self, otel_setup: InMemorySpanExporter): + """On first execution (no orchestration_trace_context), start_time_ns + should still be non-None so it can be persisted for cross-worker replay.""" + span, tokens, span_id, start_time_ns = tracing.start_orchestration_span( + "MyOrch", "inst-first", + ) + assert start_time_ns is not None + assert start_time_ns > 0 + tracing.end_orchestration_span(span, tokens, True, False) + + def test_server_span_failure(self, otel_setup: InMemorySpanExporter): + span, tokens, span_id, _ = tracing.start_orchestration_span( + "FailOrch", "inst-200", + ) + tracing.end_orchestration_span(span, tokens, True, True, "boom") + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.ERROR + assert spans[0].attributes is not None + assert spans[0].attributes[tracing.ATTR_TASK_STATUS] == "Failed" + + def test_server_span_not_complete(self, otel_setup: InMemorySpanExporter): + """Span without completion should not set status attribute.""" + span, tokens, _, _ = tracing.start_orchestration_span("PendingOrch", "inst-300") + tracing.end_orchestration_span(span, tokens, False, False) + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].attributes is not None + assert tracing.ATTR_TASK_STATUS not in spans[0].attributes + + +class TestCreateClientSpanContext: + """Tests for create_client_span_context.""" + + def test_creates_client_span_with_trace_context(self, otel_setup: InMemorySpanExporter): + """Should return a (TraceContext, span) tuple with correct attributes.""" + result = tracing.create_client_span_context( + "activity", "SayHello", "inst-1", task_id=42) + assert result is not None + trace_ctx, span = result + + assert trace_ctx.traceParent != "" + assert trace_ctx.spanID != "" + # Span should NOT be finished yet + assert len(otel_setup.get_finished_spans()) == 0 + + # End it and verify attributes + span.end() + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.kind == trace.SpanKind.CLIENT + assert s.name == "activity:SayHello" + assert s.attributes is not None + assert s.attributes[tracing.ATTR_TASK_TYPE] == "activity" + assert s.attributes[tracing.ATTR_TASK_NAME] == "SayHello" + assert s.attributes[tracing.ATTR_TASK_INSTANCE_ID] == "inst-1" + assert s.attributes[tracing.ATTR_TASK_TASK_ID] == "42" + + def test_includes_version_attribute(self, otel_setup: InMemorySpanExporter): + result = tracing.create_client_span_context( + "activity", "Act", "inst-1", version="2.0") + assert result is not None + _, span = result + span.end() + spans = otel_setup.get_finished_spans() + assert spans[0].name == "activity:Act@(2.0)" + assert spans[0].attributes is not None + assert spans[0].attributes[tracing.ATTR_TASK_VERSION] == "2.0" + + def test_trace_context_span_id_matches_span(self, otel_setup: InMemorySpanExporter): + """The TraceContext spanID should match the CLIENT span's span ID.""" + result = tracing.create_client_span_context( + "orchestration", "SubOrch", "inst-1") + assert result is not None + trace_ctx, span = result + span.end() + spans = otel_setup.get_finished_spans() + span_ctx = spans[0].get_span_context() + assert span_ctx is not None + client_span_id = format(span_ctx.span_id, '016x') + assert trace_ctx.spanID == client_span_id + + +class TestEndClientSpan: + """Tests for end_client_span.""" + + def test_ends_span(self, otel_setup: InMemorySpanExporter): + """end_client_span should close the span and export it.""" + result = tracing.create_client_span_context( + "activity", "Act", "inst-1") + assert result is not None + _, span = result + tracing.end_client_span(span) + assert len(otel_setup.get_finished_spans()) == 1 + + def test_ends_span_with_error(self, otel_setup: InMemorySpanExporter): + result = tracing.create_client_span_context( + "activity", "Act", "inst-1") + assert result is not None + _, span = result + tracing.end_client_span(span, is_error=True, error_message="boom") + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.ERROR + assert spans[0].status.description is not None + assert "boom" in spans[0].status.description + + def test_noop_with_none_span(self): + """Should not raise when span is None.""" + tracing.end_client_span(None) # no-op + + +class TestEmitTimerSpan: + """Tests for emit_timer_span.""" + + def test_emits_internal_span(self, otel_setup: InMemorySpanExporter): + fire_at = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + tracing.emit_timer_span("MyOrch", "inst-1", 5, fire_at) + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.name == "orchestration:MyOrch:timer" + assert s.kind == trace.SpanKind.INTERNAL + assert s.attributes is not None + assert s.attributes[tracing.ATTR_TASK_TYPE] == "timer" + assert s.attributes[tracing.ATTR_FIRE_AT] == fire_at.isoformat() + assert s.attributes[tracing.ATTR_TASK_TASK_ID] == "5" + + def test_backdated_start_time(self, otel_setup: InMemorySpanExporter): + """Timer span should cover the full wait period when scheduled_time_ns is set.""" + fire_at = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + created_ns = 1704067200_000_000_000 # 2024-01-01T00:00:00Z + tracing.emit_timer_span( + "MyOrch", "inst-1", 5, fire_at, scheduled_time_ns=created_ns, + ) + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].start_time == created_ns + assert spans[0].end_time is not None + assert spans[0].start_time is not None + assert spans[0].end_time > spans[0].start_time + + +class TestEmitEventRaisedSpan: + """Tests for emit_event_raised_span.""" + + def test_emits_producer_span(self, otel_setup: InMemorySpanExporter): + tracing.emit_event_raised_span("approval", "inst-1", target_instance_id="inst-2") + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + s = spans[0] + assert s.name == "orchestration_event:approval" + assert s.kind == trace.SpanKind.PRODUCER + assert s.attributes is not None + assert s.attributes[tracing.ATTR_TASK_TYPE] == "event" + assert s.attributes[tracing.ATTR_EVENT_TARGET_INSTANCE_ID] == "inst-2" + + def test_emits_span_without_target(self, otel_setup: InMemorySpanExporter): + tracing.emit_event_raised_span("approval", "inst-1") + + spans = otel_setup.get_finished_spans() + assert len(spans) == 1 + assert spans[0].attributes is not None + assert tracing.ATTR_EVENT_TARGET_INSTANCE_ID not in spans[0].attributes + + +# --------------------------------------------------------------------------- +# Tests for build_orchestration_trace_context +# --------------------------------------------------------------------------- + + +class TestBuildOrchestrationTraceContext: + """Tests for build_orchestration_trace_context.""" + + def test_returns_none_when_span_id_none(self): + result = tracing.build_orchestration_trace_context(None, None) + assert result is None + + def test_builds_context_with_span_id(self): + result = tracing.build_orchestration_trace_context("abc123def456", None) + assert result is not None + assert result.spanID.value == "abc123def456" + + def test_builds_context_with_start_time(self): + start_time_ns = 1704067200000000000 # 2024-01-01T00:00:00Z + result = tracing.build_orchestration_trace_context("abc123", start_time_ns) + assert result is not None + assert result.spanStartTime.seconds == 1704067200 + assert result.spanStartTime.nanos == 0 + + +class TestReplayDoesNotEmitSpans: + """Tests that replayed (old) events do NOT re-emit client spans for + activities, sub-orchestrations, or timers. Client spans for activities + and sub-orchestrations are now emitted at action-creation time (inside + call_activity / call_sub_orchestrator). During a replay dispatch all + generator calls happen inside old_events processing (is_replaying=True), + so no CLIENT spans are produced — they were already emitted in prior + dispatches.""" + + def _get_client_spans(self, exporter): + """Return non-Server spans (Client/Internal schedule/timer spans).""" + return [ + s for s in exporter.get_finished_spans() + if s.kind != trace.SpanKind.SERVER + ] + + def test_replayed_activity_completion_no_span(self, otel_setup): + """During a replay dispatch, no CLIENT spans are emitted for + activities — both old and new completions. The CLIENT span for + activity 2 was emitted in a prior dispatch when call_activity() + was first called with is_replaying=False.""" + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + r1 = yield ctx.call_activity(dummy_activity, input=1) + r2 = yield ctx.call_activity(dummy_activity, input=2) + return [r1, r2] + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + registry.add_activity(dummy_activity) + + # First activity scheduled + completed in old_events (replay) + old_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, task.get_name(dummy_activity)), + helpers.new_task_completed_event(1, json.dumps(10)), + ] + # Second activity scheduled in replay, completed as new event + new_events = [ + helpers.new_task_scheduled_event(2, task.get_name(dummy_activity)), + helpers.new_task_completed_event(2, json.dumps(20)), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + client_spans = self._get_client_spans(otel_setup) + # No CLIENT spans during replay — they were emitted in prior dispatches + assert len(client_spans) == 0 + + def test_replayed_activity_failure_no_span(self, otel_setup): + """During a replay dispatch, no CLIENT spans are emitted for + failed activities.""" + def failing_activity(ctx, _): + raise ValueError("boom") + + def orchestrator(ctx: task.OrchestrationContext, _): + try: + yield ctx.call_activity(failing_activity, input=1) + except task.TaskFailedError: + pass + result = yield ctx.call_activity(failing_activity, input=2) + return result + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + registry.add_activity(failing_activity) + + ex = Exception("boom") + old_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, task.get_name(failing_activity)), + helpers.new_task_failed_event(1, ex), + ] + new_events = [ + helpers.new_task_scheduled_event(2, task.get_name(failing_activity)), + helpers.new_task_completed_event(2, json.dumps("ok")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + client_spans = self._get_client_spans(otel_setup) + # No CLIENT spans during replay + assert len(client_spans) == 0 + + def test_replayed_timer_no_span(self, otel_setup): + """A timer that fired during replay should not emit a timer span.""" + from datetime import timedelta + + def orchestrator(ctx: task.OrchestrationContext, _): + t1 = ctx.current_utc_datetime + timedelta(seconds=1) + yield ctx.create_timer(t1) + t2 = ctx.current_utc_datetime + timedelta(seconds=2) + yield ctx.create_timer(t2) + return "done" + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + fire_at_1 = start_time + timedelta(seconds=1) + fire_at_2 = start_time + timedelta(seconds=2) + + # First timer created, fired, and second timer created all in old events + old_events = [ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_timer_created_event(1, fire_at_1), + helpers.new_timer_fired_event(1, fire_at_1), + helpers.new_timer_created_event(2, fire_at_2), + ] + # Only the second timer firing is a new event + new_events = [ + helpers.new_timer_fired_event(2, fire_at_2), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + client_spans = self._get_client_spans(otel_setup) + # Only the second timer (new event) should produce a span + assert len(client_spans) == 1 + assert "timer" in client_spans[0].name.lower() + + def test_replayed_sub_orchestration_completion_no_span(self, otel_setup): + """During a replay dispatch, no CLIENT spans are emitted for + sub-orchestrations.""" + def sub_orch(ctx: task.OrchestrationContext, _): + return "sub_done" + + def orchestrator(ctx: task.OrchestrationContext, _): + r1 = yield ctx.call_sub_orchestrator(sub_orch) + r2 = yield ctx.call_sub_orchestrator(sub_orch) + return [r1, r2] + + registry = worker._Registry() + sub_name = registry.add_orchestrator(sub_orch) + orch_name = registry.add_orchestrator(orchestrator) + + old_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(orch_name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_sub_orchestration_created_event(1, sub_name, "sub-1", encoded_input=None), + helpers.new_sub_orchestration_completed_event(1, encoded_output=json.dumps("r1")), + helpers.new_sub_orchestration_created_event(2, sub_name, "sub-2", encoded_input=None), + ] + new_events = [ + helpers.new_sub_orchestration_completed_event(2, encoded_output=json.dumps("r2")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + client_spans = self._get_client_spans(otel_setup) + # No CLIENT spans during replay + assert len(client_spans) == 0 + + def test_replayed_sub_orchestration_failure_no_span(self, otel_setup): + """During a replay dispatch, no CLIENT spans are emitted for + failed sub-orchestrations.""" + def sub_orch(ctx: task.OrchestrationContext, _): + raise ValueError("sub failed") + + def orchestrator(ctx: task.OrchestrationContext, _): + try: + yield ctx.call_sub_orchestrator(sub_orch) + except task.TaskFailedError: + pass + result = yield ctx.call_sub_orchestrator(sub_orch) + return result + + registry = worker._Registry() + sub_name = registry.add_orchestrator(sub_orch) + orch_name = registry.add_orchestrator(orchestrator) + + ex = Exception("sub failed") + old_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(orch_name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_sub_orchestration_created_event(1, sub_name, "sub-1", encoded_input=None), + helpers.new_sub_orchestration_failed_event(1, ex), + helpers.new_sub_orchestration_created_event(2, sub_name, "sub-2", encoded_input=None), + ] + new_events = [ + helpers.new_sub_orchestration_completed_event(2, encoded_output=json.dumps("ok")), + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + client_spans = self._get_client_spans(otel_setup) + # No CLIENT spans during replay + assert len(client_spans) == 0 + + +class TestOrchestrationSpanLifecycle: + """Tests that the orchestration SERVER span is persisted across + intermediate dispatches and only exported on orchestration completion.""" + + def _get_orch_server_spans(self, exporter): + """Return orchestration SERVER spans from the exporter.""" + return [ + s for s in exporter.get_finished_spans() + if s.kind == trace.SpanKind.SERVER + ] + + def _make_worker_with_registry(self, registry): + """Create a TaskHubGrpcWorker with a pre-populated registry.""" + from unittest.mock import MagicMock + w = worker.TaskHubGrpcWorker(host_address="localhost:4001") + w._registry = registry + return w, MagicMock() + + def test_intermediate_dispatch_does_not_export_span(self, otel_setup): + """An intermediate dispatch (no completeOrchestration) should NOT + export an orchestration SERVER span.""" + from datetime import timedelta + + def orchestrator(ctx: task.OrchestrationContext, _): + due = ctx.current_utc_datetime + timedelta(seconds=1) + yield ctx.create_timer(due) + return "done" + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + w, stub = self._make_worker_with_registry(registry) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + req = pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + newEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + ], + ) + w._execute_orchestrator(req, stub, "token1") + + # Nothing exported yet — span is kept alive + assert len(self._get_orch_server_spans(otel_setup)) == 0 + assert TEST_INSTANCE_ID in w._orchestration_spans + + def test_final_dispatch_exports_single_span(self, otel_setup): + """Across multiple dispatches, only one orchestration span should + be exported, and only when the orchestration completes.""" + from datetime import timedelta + + def orchestrator(ctx: task.OrchestrationContext, _): + due = ctx.current_utc_datetime + timedelta(seconds=2) + yield ctx.create_timer(due) + results = yield task.when_all([ + ctx.call_activity(dummy_activity, input=i) + for i in range(3) + ]) + return results + + def dummy_activity(ctx, _): + pass + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + registry.add_activity(dummy_activity) + w, stub = self._make_worker_with_registry(registry) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + fire_at = start_time + timedelta(seconds=2) + activity_name = task.get_name(dummy_activity) + + # Dispatch 1: start + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + newEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + ], + ), stub, "t1") + assert len(self._get_orch_server_spans(otel_setup)) == 0 + + # Dispatch 2: timer fires + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + pastEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_timer_created_event(1, fire_at), + ], + newEvents=[ + helpers.new_timer_fired_event(1, fire_at), + ], + ), stub, "t2") + assert len(self._get_orch_server_spans(otel_setup)) == 0 + + # Dispatch 3: activities complete + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + pastEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_timer_created_event(1, fire_at), + helpers.new_timer_fired_event(1, fire_at), + helpers.new_task_scheduled_event(2, activity_name), + helpers.new_task_scheduled_event(3, activity_name), + helpers.new_task_scheduled_event(4, activity_name), + ], + newEvents=[ + helpers.new_task_completed_event(2, json.dumps("r1")), + helpers.new_task_completed_event(3, json.dumps("r2")), + helpers.new_task_completed_event(4, json.dumps("r3")), + ], + ), stub, "t3") + + # Exactly one orchestration span exported + orch_spans = self._get_orch_server_spans(otel_setup) + assert len(orch_spans) == 1 + assert "orchestration" in orch_spans[0].name + assert TEST_INSTANCE_ID not in w._orchestration_spans + + def test_span_id_consistent_across_dispatches(self, otel_setup): + """The same span object (same span_id) is reused across dispatches.""" + from datetime import timedelta + + def orchestrator(ctx: task.OrchestrationContext, _): + due = ctx.current_utc_datetime + timedelta(seconds=1) + yield ctx.create_timer(due) + return "done" + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + w, stub = self._make_worker_with_registry(registry) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + fire_at = start_time + timedelta(seconds=1) + + # Dispatch 1 + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + newEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + ], + ), stub, "t1") + span_id_1 = w._orchestration_spans[TEST_INSTANCE_ID][0] \ + .get_span_context().span_id + + # Dispatch 2 (final) + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + pastEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_timer_created_event(1, fire_at), + ], + newEvents=[ + helpers.new_timer_fired_event(1, fire_at), + ], + ), stub, "t2") + + orch_spans = self._get_orch_server_spans(otel_setup) + assert len(orch_spans) == 1 + assert orch_spans[0].get_span_context().span_id == span_id_1 + + def test_error_cleans_up_saved_span(self, otel_setup): + """When an orchestration raises an unhandled error, the span is + exported with ERROR status and cleaned up from the saved dict.""" + + def orchestrator(ctx: task.OrchestrationContext, _): + raise ValueError("orchestration error") + + registry = worker._Registry() + registry.add_orchestrator(orchestrator) + w, stub = self._make_worker_with_registry(registry) + + name = task.get_name(orchestrator) + req = pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + newEvents=[ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + ], + ) + w._execute_orchestrator(req, stub, "token1") + + orch_spans = self._get_orch_server_spans(otel_setup) + assert len(orch_spans) == 1 + assert orch_spans[0].status.status_code == StatusCode.ERROR + assert TEST_INSTANCE_ID not in w._orchestration_spans + + def test_separate_instances_get_separate_spans(self, otel_setup): + """Two different orchestration instances should get independent + spans that can be persisted and exported independently.""" + from datetime import timedelta + + def orchestrator(ctx: task.OrchestrationContext, _): + due = ctx.current_utc_datetime + timedelta(seconds=1) + yield ctx.create_timer(due) + return "done" + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + w, stub = self._make_worker_with_registry(registry) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + fire_at = start_time + timedelta(seconds=1) + instance_a = "inst-a" + instance_b = "inst-b" + + # Start both instances + for iid in (instance_a, instance_b): + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=iid, + newEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, iid, encoded_input=None), + ], + ), stub, f"t-{iid}") + + assert len(self._get_orch_server_spans(otel_setup)) == 0 + assert instance_a in w._orchestration_spans + assert instance_b in w._orchestration_spans + + # Complete only instance A + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=instance_a, + pastEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, instance_a, encoded_input=None), + helpers.new_timer_created_event(1, fire_at), + ], + newEvents=[ + helpers.new_timer_fired_event(1, fire_at), + ], + ), stub, "t-a-2") + + # Only instance A's span is exported + assert len(self._get_orch_server_spans(otel_setup)) == 1 + assert instance_a not in w._orchestration_spans + assert instance_b in w._orchestration_spans + + # Complete instance B + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=instance_b, + pastEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, instance_b, encoded_input=None), + helpers.new_timer_created_event(1, fire_at), + ], + newEvents=[ + helpers.new_timer_fired_event(1, fire_at), + ], + ), stub, "t-b-2") + + assert len(self._get_orch_server_spans(otel_setup)) == 2 + assert instance_b not in w._orchestration_spans + + def test_initial_dispatch_creates_activity_client_spans(self, otel_setup): + """On the first dispatch, a CLIENT span is created for the scheduled + activity but it is NOT yet finished — it stays open until the + activity completes in a subsequent dispatch.""" + + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity(dummy_activity, input="hello") + return "done" + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + registry.add_activity(dummy_activity) + w, stub = self._make_worker_with_registry(registry) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + + # First dispatch — generator runs with is_replaying=False + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + newEvents=[ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + ], + ), stub, "t1") + + # The CLIENT span should NOT be finished yet (it's still open) + client_spans = [ + s for s in otel_setup.get_finished_spans() + if s.kind == trace.SpanKind.CLIENT + ] + assert len(client_spans) == 0 + + # But it should be stored in the worker's pending dict + instance_spans = w._pending_client_spans.get(TEST_INSTANCE_ID, {}) + assert len(instance_spans) == 1 + + def test_activity_client_span_has_duration(self, otel_setup): + """The CLIENT span should cover the full scheduling-to-completion + duration. After a completion dispatch, the span is finished and + its parentTraceContext.spanID matches the exported CLIENT span.""" + + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity(dummy_activity, input="hello") + return "done" + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + registry.add_activity(dummy_activity) + w, stub = self._make_worker_with_registry(registry) + + schedule_time = datetime(2020, 1, 1, 12, 0, 0) + complete_time = datetime(2020, 1, 1, 12, 0, 5) + + # Dispatch 1: schedule the activity + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + newEvents=[ + helpers.new_orchestrator_started_event(schedule_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + ], + ), stub, "t1") + + # Capture the parentTraceContext from the action + call_args = stub.CompleteOrchestratorTask.call_args + res = call_args[0][0] + schedule_actions = [ + a for a in res.actions + if a.HasField("scheduleTask") + ] + assert len(schedule_actions) == 1 + ptc = schedule_actions[0].scheduleTask.parentTraceContext + assert ptc.traceParent != "" + + # Dispatch 2: activity completes + activity_name = task.get_name(dummy_activity) + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + pastEvents=[ + helpers.new_orchestrator_started_event(schedule_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, activity_name), + ], + newEvents=[ + helpers.new_orchestrator_started_event(complete_time), + helpers.new_task_completed_event(1, '"world"'), + ], + ), stub, "t2") + + # Now the CLIENT span should be finished and exported + client_spans = [ + s for s in otel_setup.get_finished_spans() + if s.kind == trace.SpanKind.CLIENT + ] + assert len(client_spans) == 1 + assert "activity" in client_spans[0].name + + # The parentTraceContext spanID should match the CLIENT span + client_span_id = format( + client_spans[0].get_span_context().span_id, '016x') + assert ptc.spanID == client_span_id + + # The span should have real duration (start != end) + assert client_spans[0].start_time < client_spans[0].end_time + + # Pending dict should be cleaned up + instance_spans = w._pending_client_spans.get( + TEST_INSTANCE_ID, {}) + assert len(instance_spans) == 0 + + def test_distributed_worker_fallback_client_span(self, otel_setup): + """When a different worker handles the completion dispatch (no + in-memory CLIENT span), a fallback instant CLIENT span is emitted + so the trace still contains the CLIENT->SERVER link.""" + + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity(dummy_activity, input="hello") + return "done" + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + registry.add_activity(dummy_activity) + + # Simulate a DIFFERENT worker handling the completion dispatch: + # The pending_client_spans dict is empty (no span from dispatch 1). + w, stub = self._make_worker_with_registry(registry) + activity_name = task.get_name(dummy_activity) + + schedule_time = datetime(2020, 1, 1, 12, 0, 0) + complete_time = datetime(2020, 1, 1, 12, 0, 5) + + # Completion dispatch with no prior in-memory state + w._execute_orchestrator(pb.OrchestratorRequest( + instanceId=TEST_INSTANCE_ID, + pastEvents=[ + helpers.new_orchestrator_started_event(schedule_time), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, activity_name), + ], + newEvents=[ + helpers.new_orchestrator_started_event(complete_time), + helpers.new_task_completed_event(1, json.dumps("world")), + ], + ), stub, "t1") + + # A fallback CLIENT span should have been emitted + client_spans = [ + s for s in otel_setup.get_finished_spans() + if s.kind == trace.SpanKind.CLIENT + ] + assert len(client_spans) == 1 + assert "activity" in client_spans[0].name