diff --git a/server/apiserver/argoserver.go b/server/apiserver/argoserver.go index bcd25186a8c5..5e36889627d4 100644 --- a/server/apiserver/argoserver.go +++ b/server/apiserver/argoserver.go @@ -3,11 +3,14 @@ package apiserver import ( "context" "crypto/tls" + "errors" "fmt" "net" "net/http" "os" + "slices" "strings" + "sync" "time" "github.com/gorilla/handlers" @@ -15,11 +18,15 @@ import ( grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "github.com/grpc-ecosystem/grpc-gateway/runtime" "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/sethvargo/go-limiter" + "github.com/sethvargo/go-limiter/httplimit" + "github.com/sethvargo/go-limiter/memorystore" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" @@ -60,16 +67,14 @@ import ( grpcutil "github.com/argoproj/argo-workflows/v3/util/grpc" "github.com/argoproj/argo-workflows/v3/util/instanceid" "github.com/argoproj/argo-workflows/v3/util/json" + k8sutil "github.com/argoproj/argo-workflows/v3/util/k8s" "github.com/argoproj/argo-workflows/v3/util/logging" rbacutil "github.com/argoproj/argo-workflows/v3/util/rbac" "github.com/argoproj/argo-workflows/v3/util/sqldb" "github.com/argoproj/argo-workflows/v3/workflow/artifactrepositories" + "github.com/argoproj/argo-workflows/v3/workflow/artifacts/plugin" "github.com/argoproj/argo-workflows/v3/workflow/events" "github.com/argoproj/argo-workflows/v3/workflow/hydrator" - - "github.com/sethvargo/go-limiter" - "github.com/sethvargo/go-limiter/httplimit" - "github.com/sethvargo/go-limiter/memorystore" ) var MaxGRPCMessageSize int @@ -209,7 +214,19 @@ func (as *argoServer) Run(ctx context.Context, port int, browserOpenFunc func(st if err != nil { log.WithFatal().Error(ctx, err.Error()) } + + // Validate artifact driver images against server pod images + if err := as.validateArtifactDriverImages(ctx, config); err != nil { + log.WithFatal().WithError(err).Error(ctx, "failed to validate artifact driver images") + } + + // Validate artifact driver connections + if err := as.validateArtifactDriverConnections(ctx, config); err != nil { + log.WithFatal().WithError(err).Error(ctx, "failed to validate artifact driver connections") + } + log.WithFields(argo.GetVersion().Fields()).WithField("instanceID", config.InstanceID).Info(ctx, "Starting Argo Server") + instanceIDService := instanceid.NewService(config.InstanceID) offloadRepo := persist.ExplosiveOffloadNodeStatusRepo wfArchive := persist.NullWorkflowArchive @@ -443,6 +460,111 @@ func (as *argoServer) newHTTPServer(ctx context.Context, port int, artifactServe return handler } +// validateArtifactDriverConnections validates that all configured artifact drivers can be connected to +func (as *argoServer) validateArtifactDriverConnections(ctx context.Context, cfg *config.Config) error { + log := logging.RequireLoggerFromContext(ctx) + if len(cfg.ArtifactDrivers) == 0 { + log.Info(ctx, "No artifact drivers configured, skipping connection validation") + return nil + } + + log.Info(ctx, "Validating artifact driver connections") + + var wg sync.WaitGroup + errorChannel := make(chan error, len(cfg.ArtifactDrivers)) + + // Validate each driver connection in parallel + for _, driver := range cfg.ArtifactDrivers { + wg.Add(1) + go func(driver config.ArtifactDriver) { + defer wg.Done() + + // Create a new driver connection + pluginDriver, err := plugin.NewDriver(ctx, driver.Name, driver.Name.SocketPath(), 5) // replace with driver.ConnectionTimeoutSeconds once we have it + if err != nil { + errorChannel <- fmt.Errorf("failed to connect to artifact driver %s: %w", driver.Name, err) + return + } + + // Close the connection after validation + defer func() { + if closeErr := pluginDriver.Close(); closeErr != nil { + log.WithError(closeErr).WithField("driver", driver.Name).Warn(ctx, "Failed to close connection to artifact driver") + } + }() + + log.WithField("driver", driver.Name).Info(ctx, "Successfully validated connection to artifact driver") + }(driver) + } + + // Wait for all validations to complete + wg.Wait() + close(errorChannel) + + // Collect any errors + var connectionErrors []string + for err := range errorChannel { + connectionErrors = append(connectionErrors, err.Error()) + } + + if len(connectionErrors) > 0 { + errorMsg := fmt.Sprintf("Artifact driver connection validation failed: %v", connectionErrors) + log.WithField("errors", connectionErrors).Error(ctx, errorMsg) + return errors.New(errorMsg) + } + + log.WithField("driverCount", len(cfg.ArtifactDrivers)).Info(ctx, "Artifact driver connection validation passed: All configured artifact drivers are accessible") + return nil +} + +// validateArtifactDriverImages validates that the artifact driver images are present in the server pod +func (as *argoServer) validateArtifactDriverImages(ctx context.Context, cfg *config.Config) error { + log := logging.RequireLoggerFromContext(ctx) + if len(cfg.ArtifactDrivers) == 0 { + log.Info(ctx, "No artifact drivers configured, skipping validation") + return nil + } + + log.Info(ctx, "Validating artifact driver images against server pod") + + // Get the current pod name using the standard Argo pattern + podName, err := k8sutil.GetCurrentPodName(ctx, as.clients.Kubernetes, as.namespace, "app=argo-server") + if err != nil { + log.WithError(err).Warn(ctx, "Failed to get current pod name, cannot validate artifact driver images") + return nil + } + + log.WithField("podName", podName).Debug(ctx, "Found argo-server pod for validation") + + // Get the current pod to check the available images + pod, err := as.clients.Kubernetes.CoreV1().Pods(as.namespace).Get(ctx, podName, metav1.GetOptions{}) + if err != nil { + log.WithError(err).WithField("podName", podName).Warn(ctx, "Failed to get current pod, cannot validate artifact driver images") + return nil + } + + // Get missing images + images := make([]string, 0, len(cfg.ArtifactDrivers)) + for _, driver := range cfg.ArtifactDrivers { + images = append(images, driver.Image) + } + + for _, container := range pod.Spec.Containers { + images = slices.DeleteFunc(images, func(image string) bool { + return image == container.Image + }) + } + + if len(images) > 0 { + errorMsg := fmt.Sprintf("Artifact driver validation failed: The following artifact driver images are not present in the server pod: %v. Please ensure all artifact driver images are included in the argo-server pod.", images) + log.Error(ctx, errorMsg) + return errors.New(errorMsg) + } + + log.WithField("driverCount", len(cfg.ArtifactDrivers)).Info(ctx, "Artifact driver validation passed: All configured artifact driver images are present in the server pod") + return nil +} + type registerFunc func(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) error // mustRegisterGWHandler is a convenience function to register a gateway handler diff --git a/server/apiserver/argoserver_test.go b/server/apiserver/argoserver_test.go new file mode 100644 index 000000000000..3a57c63ffa01 --- /dev/null +++ b/server/apiserver/argoserver_test.go @@ -0,0 +1,233 @@ +package apiserver + +import ( + "testing" + + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" + + "github.com/argoproj/argo-workflows/v3/config" + "github.com/argoproj/argo-workflows/v3/server/types" + "github.com/argoproj/argo-workflows/v3/util/logging" + "github.com/argoproj/argo-workflows/v3/workflow/common" +) + +func TestValidateArtifactDriverImages(t *testing.T) { + tests := []struct { + name string + config *config.Config + pod *corev1.Pod + expectedError bool + expectedErrMsg string + }{ + { + name: "No artifact drivers configured - should skip validation", + config: &config.Config{ + ArtifactDrivers: []config.ArtifactDriver{}, + }, + expectedError: false, + }, + { + name: "All artifact driver images present in pod - should pass", + config: &config.Config{ + ArtifactDrivers: []config.ArtifactDriver{ + { + Name: "my-driver", + Image: "my-driver:latest", + }, + { + Name: "another-driver", + Image: "another-driver:v1.0", + }, + }, + }, + pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod", + Namespace: "argo", + Labels: map[string]string{ + "app": "argo-server", + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "argo-server", + Image: "quay.io/argoproj/argocli:latest", + }, + { + Name: "my-driver", + Image: "my-driver:latest", + }, + { + Name: "another-driver", + Image: "another-driver:v1.0", + }, + }, + }, + Status: corev1.PodStatus{ + Phase: corev1.PodRunning, + }, + }, + expectedError: false, + }, + { + name: "Missing artifact driver image in pod - should fail", + config: &config.Config{ + ArtifactDrivers: []config.ArtifactDriver{ + { + Name: "my-driver", + Image: "my-driver:latest", + }, + { + Name: "missing-driver", + Image: "missing-driver:v1.0", + }, + }, + }, + pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod", + Namespace: "argo", + Labels: map[string]string{ + "app": "argo-server", + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "argo-server", + Image: "quay.io/argoproj/argocli:latest", + }, + { + Name: "my-driver", + Image: "my-driver:latest", + }, + }, + }, + Status: corev1.PodStatus{ + Phase: corev1.PodRunning, + }, + }, + expectedError: true, + expectedErrMsg: "Artifact driver validation failed: The following artifact driver images are not present in the server pod: [missing-driver:v1.0]", + }, + { + name: "Artifact driver image in regular container - should pass", + config: &config.Config{ + ArtifactDrivers: []config.ArtifactDriver{ + { + Name: "sidecar-driver", + Image: "sidecar-driver:latest", + }, + }, + }, + pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod", + Namespace: "argo", + Labels: map[string]string{ + "app": "argo-server", + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "argo-server", + Image: "quay.io/argoproj/argocli:latest", + }, + { + Name: "sidecar-driver", + Image: "sidecar-driver:latest", + }, + }, + }, + Status: corev1.PodStatus{ + Phase: corev1.PodRunning, + }, + }, + expectedError: false, + }, + { + name: "Test fallback to label selector when ARGO_POD_NAME not set", + config: &config.Config{ + ArtifactDrivers: []config.ArtifactDriver{ + { + Name: "fallback-driver", + Image: "fallback-driver:latest", + }, + }, + }, + pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "fallback-test-pod", + Namespace: "argo", + Labels: map[string]string{ + "app": "argo-server", + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "argo-server", + Image: "quay.io/argoproj/argocli:latest", + }, + { + Name: "fallback-driver", + Image: "fallback-driver:latest", + }, + }, + }, + Status: corev1.PodStatus{ + Phase: corev1.PodRunning, + }, + }, + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create fake Kubernetes client + fakeClient := fake.NewSimpleClientset() + + // Create the argoServer instance + as := &argoServer{ + clients: &types.Clients{ + Kubernetes: fakeClient, + }, + namespace: "argo", + } + + // Set up the test data + ctx := logging.TestContext(t.Context()) + if tt.pod != nil { + _, err := fakeClient.CoreV1().Pods("argo").Create(ctx, tt.pod, metav1.CreateOptions{}) + require.NoError(t, err) + } + + // Set ARGO_POD_NAME environment variable for most tests, except the fallback test + if tt.name != "Test fallback to label selector when ARGO_POD_NAME not set" { + t.Setenv(common.EnvVarPodName, "test-pod") + } else { + // For the fallback test, ensure the environment variable is not set + t.Setenv(common.EnvVarPodName, "") + } + + // Run the validation with proper logging context + err := as.validateArtifactDriverImages(ctx, tt.config) + + // Check results + if tt.expectedError { + require.Error(t, err) + if tt.expectedErrMsg != "" { + require.Contains(t, err.Error(), tt.expectedErrMsg) + } + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/util/k8s/pod.go b/util/k8s/pod.go new file mode 100644 index 000000000000..7ccce91a70c1 --- /dev/null +++ b/util/k8s/pod.go @@ -0,0 +1,46 @@ +package k8s + +import ( + "context" + "fmt" + "os" + + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + + "github.com/argoproj/argo-workflows/v3/workflow/common" +) + +// GetCurrentPodName returns the name of the current pod using the standard +// Argo Workflows pattern. It first tries to get the pod name from the +// ARGO_POD_NAME environment variable (set via Downward API), and falls back +// to using the Kubernetes client to find the pod by label selector. +func GetCurrentPodName(ctx context.Context, client kubernetes.Interface, namespace, labelSelector string) (string, error) { + // First try the standard Argo environment variable + if podName := os.Getenv(common.EnvVarPodName); podName != "" { + return podName, nil + } + + // Fallback: use Kubernetes client to find pod by label selector + podList, err := client.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{ + LabelSelector: labelSelector, + }) + if err != nil { + return "", fmt.Errorf("failed to list pods with selector %s: %w", labelSelector, err) + } + + if len(podList.Items) == 0 { + return "", fmt.Errorf("no pods found with selector: %s", labelSelector) + } + + // Find the first running pod + for _, pod := range podList.Items { + if pod.Status.Phase == v1.PodRunning { + return pod.Name, nil + } + } + + // If no running pods, return the first pod found + return podList.Items[0].Name, nil +} diff --git a/util/k8s/pod_test.go b/util/k8s/pod_test.go new file mode 100644 index 000000000000..1c4a34a66627 --- /dev/null +++ b/util/k8s/pod_test.go @@ -0,0 +1,87 @@ +package k8s + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" + + "github.com/argoproj/argo-workflows/v3/util/logging" + "github.com/argoproj/argo-workflows/v3/workflow/common" +) + +func TestGetCurrentPodName(t *testing.T) { + ctx := logging.TestContext(t.Context()) + + t.Run("Returns pod name from environment variable", func(t *testing.T) { + t.Setenv(common.EnvVarPodName, "test-pod-from-env") + + client := fake.NewSimpleClientset() + podName, err := GetCurrentPodName(ctx, client, "test-namespace", "app=test") + + require.NoError(t, err) + assert.Equal(t, "test-pod-from-env", podName) + }) + + t.Run("Falls back to Kubernetes client when env var not set", func(t *testing.T) { + // Ensure env var is not set + t.Setenv(common.EnvVarPodName, "") + + // Create a fake pod + pod := &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod-from-k8s", + Namespace: "test-namespace", + Labels: map[string]string{ + "app": "test", + }, + }, + Status: v1.PodStatus{ + Phase: v1.PodRunning, + }, + } + + client := fake.NewSimpleClientset(pod) + podName, err := GetCurrentPodName(ctx, client, "test-namespace", "app=test") + + require.NoError(t, err) + assert.Equal(t, "test-pod-from-k8s", podName) + }) + + t.Run("Returns first pod when no running pods found", func(t *testing.T) { + t.Setenv(common.EnvVarPodName, "") + + // Create a fake pod that's not running + pod := &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod-pending", + Namespace: "test-namespace", + Labels: map[string]string{ + "app": "test", + }, + }, + Status: v1.PodStatus{ + Phase: v1.PodPending, + }, + } + + client := fake.NewSimpleClientset(pod) + podName, err := GetCurrentPodName(ctx, client, "test-namespace", "app=test") + + require.NoError(t, err) + assert.Equal(t, "test-pod-pending", podName) + }) + + t.Run("Returns error when no pods found", func(t *testing.T) { + t.Setenv(common.EnvVarPodName, "") + + client := fake.NewSimpleClientset() + _, err := GetCurrentPodName(ctx, client, "test-namespace", "app=nonexistent") + + require.Error(t, err) + assert.Contains(t, err.Error(), "no pods found with selector") + }) +}