diff --git a/task/activity.go b/task/activity.go index 56bd2f9..47eba9a 100644 --- a/task/activity.go +++ b/task/activity.go @@ -107,12 +107,14 @@ type ActivityContext interface { GetTaskID() int32 GetTaskExecutionID() string Context() context.Context + GetTraceContext() *protos.TraceContext } type activityContext struct { TaskID int32 TaskExecutionID string Name string + TraceContext *protos.TraceContext rawInput []byte ctx context.Context @@ -122,10 +124,12 @@ type activityContext struct { type Activity func(ctx ActivityContext) (any, error) func newTaskActivityContext(ctx context.Context, taskID int32, ts *protos.TaskScheduledEvent) *activityContext { + return &activityContext{ TaskID: taskID, TaskExecutionID: ts.TaskExecutionId, Name: ts.Name, + TraceContext: ts.ParentTraceContext, rawInput: []byte(ts.Input.GetValue()), ctx: ctx, } @@ -147,3 +151,7 @@ func (actx *activityContext) GetTaskID() int32 { func (actx *activityContext) GetTaskExecutionID() string { return actx.TaskExecutionID } + +func (actx *activityContext) GetTraceContext() *protos.TraceContext { + return actx.TraceContext +} \ No newline at end of file diff --git a/tests/orchestrations_test.go b/tests/orchestrations_test.go index 2403f2e..eb544d4 100644 --- a/tests/orchestrations_test.go +++ b/tests/orchestrations_test.go @@ -14,6 +14,8 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" "github.com/dapr/durabletask-go/api" "github.com/dapr/durabletask-go/api/protos" @@ -21,7 +23,6 @@ import ( "github.com/dapr/durabletask-go/backend/sqlite" "github.com/dapr/durabletask-go/task" "github.com/dapr/durabletask-go/tests/utils" - "go.opentelemetry.io/otel" ) var tracer = otel.Tracer("orchestration-test") @@ -231,6 +232,7 @@ func Test_SingleActivity_TaskSpan(t *testing.T) { if err := ctx.GetInput(&name); err != nil { return nil, err } + ctx.GetTraceContext() _, childSpan := tracer.Start(ctx.Context(), "activityChild") childSpan.End() return fmt.Sprintf("Hello, %s!", name), nil @@ -1544,6 +1546,72 @@ func Test_TaskExecutionId(t *testing.T) { }) } +func Test_ActivityTraceContext(t *testing.T) { + t.Run("TraceContext is propagated in the activity context", func(t *testing.T) { + // Registration + r := task.NewTaskRegistry() + require.NoError(t, r.AddOrchestratorN("TraceContextOrchestration", func(ctx *task.OrchestrationContext) (any, error) { + if err := ctx.CallActivity("ActivityWithContext", task.WithActivityRetryPolicy(&task.RetryPolicy{ + MaxAttempts: 3, + InitialRetryInterval: 10 * time.Millisecond, + })).Await(nil); err != nil { + return nil, err + } + return nil, nil + })) + + traceParentMap := make(map[string]string) + var executionId string + require.NoError(t, r.AddActivityN("ActivityWithContext", func(ctx task.ActivityContext) (any, error) { + executionId = ctx.GetTaskExecutionID() + tp := ctx.GetTraceContext().GetTraceParent() + traceParentMap[executionId] = tp + + // Create a new context + newCtx := context.Background() + + // Create a TextMapCarrier with the traceparent + carrier := propagation.MapCarrier{} + carrier.Set("traceparent", tp) + + // Use the TraceContext propagator to extract the trace context + propagator := propagation.TraceContext{} + newCtx = propagator.Extract(newCtx, carrier) + + _, childSpan := tracer.Start(context.Background(), "ActivityWith1Context") + childSpan.End() + return nil, nil + })) + + // Initialization + ctx := context.Background() + exporter := utils.InitTracing() + + client, worker := initTaskHubWorker(ctx, r) + defer worker.Shutdown(ctx) + + // Run the orchestration + id, err := client.ScheduleNewOrchestration(ctx, "TraceContextOrchestration") + require.NoError(t, err) + + metadata, err := client.WaitForOrchestrationCompletion(ctx, id) + require.NoError(t, err) + + assert.Equal(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_COMPLETED, metadata.RuntimeStatus) + assert.NotEmpty(t, executionId) + assert.NotEmpty(t, traceParentMap[executionId]) + + // Validate the exported OTel traces include patch spans + spans := exporter.GetSpans().Snapshots() + utils.AssertSpanSequence(t, spans, + utils.AssertOrchestratorCreated("TraceContextOrchestration", id), + utils.AssertSpan("ActivityWith1Context"), + utils.AssertActivity("ActivityWithContext", id, 0), + utils.AssertOrchestratorExecuted("TraceContextOrchestration", id, "COMPLETED"), + ) + }) +} + func Test_OrchestrationPatching_DefaultToPatched(t *testing.T) { // Registration r := task.NewTaskRegistry()