From 2dc20a550a927a39f4d447dd07e976839c9b25b2 Mon Sep 17 00:00:00 2001 From: Michal Pryc Date: Wed, 17 Dec 2025 14:25:09 +0100 Subject: [PATCH 1/3] Add timeout option and env var to the OADP CLI tool Prevents OADP CLI operation hanging indefinitely. Introduced default timeout of 10min which can be controlled via OADP_CLI_HTTP_TIMEOUT variable. Co-Authored-By: Claude Signed-off-by: Michal Pryc --- cmd/shared/download.go | 51 ++++- cmd/shared/download_test.go | 426 ++++++++++++++++++++++++++++++++++++ 2 files changed, 475 insertions(+), 2 deletions(-) create mode 100644 cmd/shared/download_test.go diff --git a/cmd/shared/download.go b/cmd/shared/download.go index 22a1c8c..32c709c 100644 --- a/cmd/shared/download.go +++ b/cmd/shared/download.go @@ -21,7 +21,9 @@ import ( "context" "fmt" "io" + "log" "net/http" + "os" "strings" "time" @@ -31,6 +33,35 @@ import ( kbclient "sigs.k8s.io/controller-runtime/pkg/client" ) +// DefaultHTTPTimeout is the default timeout for HTTP requests when downloading content from object storage. +// This prevents the CLI from hanging indefinitely if the connection stalls. +const DefaultHTTPTimeout = 10 * time.Minute + +// HTTPTimeoutEnvVar is the environment variable name that can be used to override the default HTTP timeout. +// Example: OADP_CLI_HTTP_TIMEOUT=30m kubectl oadp nonadmin backup logs my-backup +const HTTPTimeoutEnvVar = "OADP_CLI_HTTP_TIMEOUT" + +// getHTTPTimeout returns the HTTP timeout to use for download operations. +// It checks for an environment variable override first, then falls back to the default. +func getHTTPTimeout() time.Duration { + if envTimeout := os.Getenv(HTTPTimeoutEnvVar); envTimeout != "" { + if parsed, err := time.ParseDuration(envTimeout); err == nil { + log.Printf("Using custom HTTP timeout from %s: %v", HTTPTimeoutEnvVar, parsed) + return parsed + } + log.Printf("Warning: Invalid duration in %s=%q, using default %v", HTTPTimeoutEnvVar, envTimeout, DefaultHTTPTimeout) + } + return DefaultHTTPTimeout +} + +// httpClientWithTimeout returns an HTTP client with a configured timeout. +// Using a custom client instead of http.DefaultClient ensures downloads don't hang indefinitely. +func httpClientWithTimeout(timeout time.Duration) *http.Client { + return &http.Client{ + Timeout: timeout, + } +} + // DownloadRequestOptions holds configuration for creating and processing NonAdminDownloadRequests type DownloadRequestOptions struct { // BackupName is the name of the backup to download data for @@ -134,8 +165,16 @@ func waitForDownloadURL(ctx context.Context, kbClient kbclient.Client, req *nacv // DownloadContent fetches content from a signed URL and returns it as a string. // It handles both gzipped and non-gzipped content automatically. +// Uses DefaultHTTPTimeout (or OADP_CLI_HTTP_TIMEOUT env var) to prevent hanging indefinitely. func DownloadContent(url string) (string, error) { - resp, err := http.Get(url) + return DownloadContentWithTimeout(url, getHTTPTimeout()) +} + +// DownloadContentWithTimeout fetches content from a signed URL with a specified timeout. +// It handles both gzipped and non-gzipped content automatically. +func DownloadContentWithTimeout(url string, timeout time.Duration) (string, error) { + client := httpClientWithTimeout(timeout) + resp, err := client.Get(url) if err != nil { return "", fmt.Errorf("failed to download content from URL %q: %w", url, err) } @@ -168,8 +207,16 @@ func DownloadContent(url string) (string, error) { // StreamDownloadContent fetches content from a signed URL and streams it to the provided writer. // This is useful for large files like logs that should be streamed rather than loaded into memory. +// Uses DefaultHTTPTimeout (or OADP_CLI_HTTP_TIMEOUT env var) to prevent hanging indefinitely. func StreamDownloadContent(url string, writer io.Writer) error { - resp, err := http.Get(url) + return StreamDownloadContentWithTimeout(url, writer, getHTTPTimeout()) +} + +// StreamDownloadContentWithTimeout fetches content from a signed URL with a specified timeout +// and streams it to the provided writer. +func StreamDownloadContentWithTimeout(url string, writer io.Writer, timeout time.Duration) error { + client := httpClientWithTimeout(timeout) + resp, err := client.Get(url) if err != nil { return fmt.Errorf("failed to download content from URL %q: %w", url, err) } diff --git a/cmd/shared/download_test.go b/cmd/shared/download_test.go new file mode 100644 index 0000000..dd41acf --- /dev/null +++ b/cmd/shared/download_test.go @@ -0,0 +1,426 @@ +/* +Copyright 2025 The OADP CLI Contributors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package shared + +import ( + "bytes" + "compress/gzip" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" +) + +// TestDefaultHTTPTimeout verifies the default timeout constant +func TestDefaultHTTPTimeout(t *testing.T) { + expected := 10 * time.Minute + if DefaultHTTPTimeout != expected { + t.Errorf("DefaultHTTPTimeout = %v, want %v", DefaultHTTPTimeout, expected) + } +} + +// TestHTTPTimeoutEnvVar verifies the environment variable name constant +func TestHTTPTimeoutEnvVar(t *testing.T) { + expected := "OADP_CLI_HTTP_TIMEOUT" + if HTTPTimeoutEnvVar != expected { + t.Errorf("HTTPTimeoutEnvVar = %q, want %q", HTTPTimeoutEnvVar, expected) + } +} + +// TestGetHTTPTimeout tests the getHTTPTimeout function +func TestGetHTTPTimeout(t *testing.T) { + tests := []struct { + name string + envValue string + want time.Duration + }{ + { + name: "no env var set returns default", + envValue: "", + want: DefaultHTTPTimeout, + }, + { + name: "valid duration in minutes", + envValue: "30m", + want: 30 * time.Minute, + }, + { + name: "valid duration in seconds", + envValue: "120s", + want: 120 * time.Second, + }, + { + name: "valid duration in hours", + envValue: "1h", + want: 1 * time.Hour, + }, + { + name: "valid complex duration", + envValue: "1h30m", + want: 90 * time.Minute, + }, + { + name: "invalid duration falls back to default", + envValue: "invalid", + want: DefaultHTTPTimeout, + }, + { + name: "empty string returns default", + envValue: "", + want: DefaultHTTPTimeout, + }, + { + name: "numeric only (no unit) falls back to default", + envValue: "30", + want: DefaultHTTPTimeout, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save and restore original env var + originalValue := os.Getenv(HTTPTimeoutEnvVar) + defer os.Setenv(HTTPTimeoutEnvVar, originalValue) + + if tt.envValue != "" { + os.Setenv(HTTPTimeoutEnvVar, tt.envValue) + } else { + os.Unsetenv(HTTPTimeoutEnvVar) + } + + got := getHTTPTimeout() + if got != tt.want { + t.Errorf("getHTTPTimeout() = %v, want %v", got, tt.want) + } + }) + } +} + +// TestHttpClientWithTimeout verifies that the HTTP client is created with the correct timeout +func TestHttpClientWithTimeout(t *testing.T) { + tests := []struct { + name string + timeout time.Duration + }{ + { + name: "1 minute timeout", + timeout: 1 * time.Minute, + }, + { + name: "30 second timeout", + timeout: 30 * time.Second, + }, + { + name: "default timeout", + timeout: DefaultHTTPTimeout, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := httpClientWithTimeout(tt.timeout) + if client == nil { + t.Fatal("httpClientWithTimeout returned nil") + } + if client.Timeout != tt.timeout { + t.Errorf("client.Timeout = %v, want %v", client.Timeout, tt.timeout) + } + }) + } +} + +// TestDownloadContentWithTimeout tests downloading content with explicit timeout +func TestDownloadContentWithTimeout(t *testing.T) { + tests := []struct { + name string + serverResponse string + serverStatus int + contentType string + gzipped bool + timeout time.Duration + wantContent string + wantErr bool + errContains string + }{ + { + name: "successful plain text download", + serverResponse: "Hello, World!", + serverStatus: http.StatusOK, + contentType: "text/plain", + gzipped: false, + timeout: 5 * time.Second, + wantContent: "Hello, World!", + wantErr: false, + }, + { + name: "successful gzipped download", + serverResponse: "Gzipped content here", + serverStatus: http.StatusOK, + contentType: "application/gzip", + gzipped: true, + timeout: 5 * time.Second, + wantContent: "Gzipped content here", + wantErr: false, + }, + { + name: "server returns 404", + serverResponse: "Not Found", + serverStatus: http.StatusNotFound, + contentType: "text/plain", + gzipped: false, + timeout: 5 * time.Second, + wantErr: true, + errContains: "404", + }, + { + name: "server returns 500", + serverResponse: "Internal Server Error", + serverStatus: http.StatusInternalServerError, + contentType: "text/plain", + gzipped: false, + timeout: 5 * time.Second, + wantErr: true, + errContains: "500", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", tt.contentType) + if tt.gzipped { + w.Header().Set("Content-Encoding", "gzip") + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write([]byte(tt.serverResponse)) + gz.Close() + w.WriteHeader(tt.serverStatus) + _, _ = w.Write(buf.Bytes()) + } else { + w.WriteHeader(tt.serverStatus) + _, _ = w.Write([]byte(tt.serverResponse)) + } + })) + defer server.Close() + + content, err := DownloadContentWithTimeout(server.URL, tt.timeout) + + if tt.wantErr { + if err == nil { + t.Errorf("DownloadContentWithTimeout() expected error, got nil") + } else if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("DownloadContentWithTimeout() error = %v, want error containing %q", err, tt.errContains) + } + return + } + + if err != nil { + t.Errorf("DownloadContentWithTimeout() unexpected error: %v", err) + return + } + + if content != tt.wantContent { + t.Errorf("DownloadContentWithTimeout() = %q, want %q", content, tt.wantContent) + } + }) + } +} + +// TestDownloadContent tests that DownloadContent uses the default timeout mechanism +func TestDownloadContent(t *testing.T) { + // Save and restore original env var + originalValue := os.Getenv(HTTPTimeoutEnvVar) + defer os.Setenv(HTTPTimeoutEnvVar, originalValue) + os.Unsetenv(HTTPTimeoutEnvVar) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("test content")) + })) + defer server.Close() + + content, err := DownloadContent(server.URL) + if err != nil { + t.Errorf("DownloadContent() unexpected error: %v", err) + return + } + + if content != "test content" { + t.Errorf("DownloadContent() = %q, want %q", content, "test content") + } +} + +// TestStreamDownloadContentWithTimeout tests streaming content with explicit timeout +func TestStreamDownloadContentWithTimeout(t *testing.T) { + tests := []struct { + name string + serverResponse string + serverStatus int + contentType string + gzipped bool + timeout time.Duration + wantContent string + wantErr bool + errContains string + }{ + { + name: "successful plain text stream", + serverResponse: "Streaming content", + serverStatus: http.StatusOK, + contentType: "text/plain", + gzipped: false, + timeout: 5 * time.Second, + wantContent: "Streaming content", + wantErr: false, + }, + { + name: "successful gzipped stream", + serverResponse: "Gzipped streaming content", + serverStatus: http.StatusOK, + contentType: "application/gzip", + gzipped: true, + timeout: 5 * time.Second, + wantContent: "Gzipped streaming content", + wantErr: false, + }, + { + name: "server returns 403", + serverResponse: "Forbidden", + serverStatus: http.StatusForbidden, + contentType: "text/plain", + gzipped: false, + timeout: 5 * time.Second, + wantErr: true, + errContains: "403", + }, + { + name: "large content stream", + serverResponse: strings.Repeat("Large content block. ", 1000), + serverStatus: http.StatusOK, + contentType: "text/plain", + gzipped: false, + timeout: 5 * time.Second, + wantContent: strings.Repeat("Large content block. ", 1000), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", tt.contentType) + if tt.gzipped { + w.Header().Set("Content-Encoding", "gzip") + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write([]byte(tt.serverResponse)) + gz.Close() + w.WriteHeader(tt.serverStatus) + _, _ = w.Write(buf.Bytes()) + } else { + w.WriteHeader(tt.serverStatus) + _, _ = w.Write([]byte(tt.serverResponse)) + } + })) + defer server.Close() + + var buf bytes.Buffer + err := StreamDownloadContentWithTimeout(server.URL, &buf, tt.timeout) + + if tt.wantErr { + if err == nil { + t.Errorf("StreamDownloadContentWithTimeout() expected error, got nil") + } else if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("StreamDownloadContentWithTimeout() error = %v, want error containing %q", err, tt.errContains) + } + return + } + + if err != nil { + t.Errorf("StreamDownloadContentWithTimeout() unexpected error: %v", err) + return + } + + if buf.String() != tt.wantContent { + t.Errorf("StreamDownloadContentWithTimeout() = %q, want %q", buf.String(), tt.wantContent) + } + }) + } +} + +// TestStreamDownloadContent tests that StreamDownloadContent uses the default timeout mechanism +func TestStreamDownloadContent(t *testing.T) { + // Save and restore original env var + originalValue := os.Getenv(HTTPTimeoutEnvVar) + defer os.Setenv(HTTPTimeoutEnvVar, originalValue) + os.Unsetenv(HTTPTimeoutEnvVar) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("streamed test content")) + })) + defer server.Close() + + var buf bytes.Buffer + err := StreamDownloadContent(server.URL, &buf) + if err != nil { + t.Errorf("StreamDownloadContent() unexpected error: %v", err) + return + } + + if buf.String() != "streamed test content" { + t.Errorf("StreamDownloadContent() = %q, want %q", buf.String(), "streamed test content") + } +} + +// TestDownloadContentWithTimeout_InvalidURL tests handling of invalid URLs +func TestDownloadContentWithTimeout_InvalidURL(t *testing.T) { + _, err := DownloadContentWithTimeout("http://invalid-url-that-does-not-exist.local:12345", 1*time.Second) + if err == nil { + t.Error("DownloadContentWithTimeout() expected error for invalid URL, got nil") + } +} + +// TestStreamDownloadContentWithTimeout_InvalidURL tests handling of invalid URLs in streaming +func TestStreamDownloadContentWithTimeout_InvalidURL(t *testing.T) { + var buf bytes.Buffer + err := StreamDownloadContentWithTimeout("http://invalid-url-that-does-not-exist.local:12345", &buf, 1*time.Second) + if err == nil { + t.Error("StreamDownloadContentWithTimeout() expected error for invalid URL, got nil") + } +} + +// TestGetHTTPTimeoutWithEnvVar tests that the env var override works correctly +func TestGetHTTPTimeoutWithEnvVar(t *testing.T) { + // Save and restore original env var + originalValue := os.Getenv(HTTPTimeoutEnvVar) + defer os.Setenv(HTTPTimeoutEnvVar, originalValue) + + // Set custom timeout + os.Setenv(HTTPTimeoutEnvVar, "5m") + + timeout := getHTTPTimeout() + expected := 5 * time.Minute + + if timeout != expected { + t.Errorf("getHTTPTimeout() with env var = %v, want %v", timeout, expected) + } +} From 1bf72146c28f779a643471765ee2a78b7c8d0df0 Mon Sep 17 00:00:00 2001 From: Michal Pryc Date: Wed, 14 Jan 2026 15:34:56 +0100 Subject: [PATCH 2/3] Add --request-timeout flag with TCP dial support for all commands This commit adds a --request-timeout flag (following kubectl conventions) to all OADP CLI commands, with full support for cluster unreachability: - Add --request-timeout flag to nonadmin backup logs and describe commands (takes precedence over OADP_CLI_REQUEST_TIMEOUT env var) - Add timeoutFactory wrapper that applies dial timeout to all Velero commands by overriding all client-creating methods (KubeClient, DynamicClient, DiscoveryClient, KubebuilderClient, KubebuilderWatchClient) - Add renameTimeoutFlag to rename Velero's --timeout to --request-timeout for consistent kubectl-style CLI experience - Set custom net.Dialer with timeout to handle TCP dial timeouts - Add context-based timeout handling with proper cancellation detection - Add FormatDownloadRequestTimeoutError for helpful timeout diagnostics - Add tests for timeout functionality (global timeout, config application, dialer timeout behavior) The timeout now applies to both HTTP requests and TCP connection attempts across all commands, ensuring the CLI times out quickly when the cluster is unreachable instead of waiting for the default ~30s TCP timeout. Co-Authored-By: Claude Opus 4.5 Signed-off-by: Michal Pryc --- cmd/non-admin/backup/describe.go | 67 ++++++++--- cmd/non-admin/backup/logs.go | 57 +++++++-- cmd/root.go | 195 ++++++++++++++++++++++++++++++- cmd/root_test.go | 143 +++++++++++++++++++++++ cmd/shared/client.go | 44 +++++++ cmd/shared/download.go | 75 ++++++++++-- cmd/shared/download_test.go | 162 ++++++++++++++++++++++--- 7 files changed, 684 insertions(+), 59 deletions(-) diff --git a/cmd/non-admin/backup/describe.go b/cmd/non-admin/backup/describe.go index 1215cd7..c772e4c 100644 --- a/cmd/non-admin/backup/describe.go +++ b/cmd/non-admin/backup/describe.go @@ -18,6 +18,8 @@ import ( ) func NewDescribeCommand(f client.Factory, use string) *cobra.Command { + var requestTimeout time.Duration + c := &cobra.Command{ Use: use + " NAME", Short: "Describe a non-admin backup", @@ -25,17 +27,25 @@ func NewDescribeCommand(f client.Factory, use string) *cobra.Command { RunE: func(cmd *cobra.Command, args []string) error { backupName := args[0] + // Get effective timeout (flag takes precedence over env var) + effectiveTimeout := shared.GetHTTPTimeoutWithOverride(requestTimeout) + + // Create context with the effective timeout + ctx, cancel := context.WithTimeout(context.Background(), effectiveTimeout) + defer cancel() + // Get the current namespace from kubectl context userNamespace, err := shared.GetCurrentNamespace() if err != nil { return fmt.Errorf("failed to determine current namespace: %w", err) } - // Create client with required scheme types + // Create client with required scheme types and timeout kbClient, err := shared.NewClientWithScheme(f, shared.ClientOptions{ IncludeNonAdminTypes: true, IncludeVeleroTypes: true, IncludeCoreTypes: true, + Timeout: effectiveTimeout, }) if err != nil { return err @@ -43,10 +53,17 @@ func NewDescribeCommand(f client.Factory, use string) *cobra.Command { // Get the specific backup var nab nacv1alpha1.NonAdminBackup - if err := kbClient.Get(context.Background(), kbclient.ObjectKey{ + if err := kbClient.Get(ctx, kbclient.ObjectKey{ Namespace: userNamespace, Name: backupName, }, &nab); err != nil { + // Check for context cancellation + if ctx.Err() == context.DeadlineExceeded { + return fmt.Errorf("timed out after %v getting NonAdminBackup %q", effectiveTimeout, backupName) + } + if ctx.Err() == context.Canceled { + return fmt.Errorf("operation cancelled: %w", ctx.Err()) + } return fmt.Errorf("NonAdminBackup %q not found in namespace %q: %w", backupName, userNamespace, err) } @@ -55,9 +72,12 @@ func NewDescribeCommand(f client.Factory, use string) *cobra.Command { return nil }, - Example: ` kubectl oadp nonadmin backup describe my-backup`, + Example: ` kubectl oadp nonadmin backup describe my-backup + kubectl oadp nonadmin backup describe my-backup --request-timeout=30m`, } + c.Flags().DurationVar(&requestTimeout, "request-timeout", 0, fmt.Sprintf("The length of time to wait before giving up on a single server request (e.g., 30s, 5m, 1h). Overrides %s env var. Default: %v", shared.TimeoutEnvVar, shared.DefaultHTTPTimeout)) + output.BindFlags(c.Flags()) output.ClearOutputFlagDefault(c) @@ -349,9 +369,16 @@ func colorizePhase(phase string) string { } // NonAdminDescribeBackup mirrors Velero's output.DescribeBackup functionality -// but works within non-admin RBAC boundaries using NonAdminDownloadRequest -func NonAdminDescribeBackup(cmd *cobra.Command, kbClient kbclient.Client, nab *nacv1alpha1.NonAdminBackup, userNamespace string) error { - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) +// but works within non-admin RBAC boundaries using NonAdminDownloadRequest. +// The timeout parameter controls how long to wait for download requests to complete. +// If timeout is 0, DefaultOperationTimeout is used. +func NonAdminDescribeBackup(cmd *cobra.Command, kbClient kbclient.Client, nab *nacv1alpha1.NonAdminBackup, userNamespace string, timeout time.Duration) error { + // Use provided timeout or fall back to default + if timeout == 0 { + timeout = shared.DefaultOperationTimeout + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() // Print basic backup information @@ -401,9 +428,10 @@ func NonAdminDescribeBackup(cmd *cobra.Command, kbClient kbclient.Client, nab *n // Get backup results using NonAdminDownloadRequest (most important data) if results, err := shared.ProcessDownloadRequest(ctx, kbClient, shared.DownloadRequestOptions{ - BackupName: veleroBackupName, - DataType: "BackupResults", - Namespace: userNamespace, + BackupName: veleroBackupName, + DataType: "BackupResults", + Namespace: userNamespace, + HTTPTimeout: timeout, }); err == nil { fmt.Fprintf(cmd.OutOrStdout(), "\nBackup Results:\n") fmt.Fprintf(cmd.OutOrStdout(), "%s", indent(results, " ")) @@ -411,9 +439,10 @@ func NonAdminDescribeBackup(cmd *cobra.Command, kbClient kbclient.Client, nab *n // Get backup details using NonAdminDownloadRequest for BackupResourceList if resourceList, err := shared.ProcessDownloadRequest(ctx, kbClient, shared.DownloadRequestOptions{ - BackupName: veleroBackupName, - DataType: "BackupResourceList", - Namespace: userNamespace, + BackupName: veleroBackupName, + DataType: "BackupResourceList", + Namespace: userNamespace, + HTTPTimeout: timeout, }); err == nil { fmt.Fprintf(cmd.OutOrStdout(), "\nBackup Resource List:\n") fmt.Fprintf(cmd.OutOrStdout(), "%s", indent(resourceList, " ")) @@ -421,9 +450,10 @@ func NonAdminDescribeBackup(cmd *cobra.Command, kbClient kbclient.Client, nab *n // Get backup volume info using NonAdminDownloadRequest if volumeInfo, err := shared.ProcessDownloadRequest(ctx, kbClient, shared.DownloadRequestOptions{ - BackupName: veleroBackupName, - DataType: "BackupVolumeInfos", - Namespace: userNamespace, + BackupName: veleroBackupName, + DataType: "BackupVolumeInfos", + Namespace: userNamespace, + HTTPTimeout: timeout, }); err == nil { fmt.Fprintf(cmd.OutOrStdout(), "\nBackup Volume Info:\n") fmt.Fprintf(cmd.OutOrStdout(), "%s", indent(volumeInfo, " ")) @@ -431,9 +461,10 @@ func NonAdminDescribeBackup(cmd *cobra.Command, kbClient kbclient.Client, nab *n // Get backup item operations using NonAdminDownloadRequest if itemOps, err := shared.ProcessDownloadRequest(ctx, kbClient, shared.DownloadRequestOptions{ - BackupName: veleroBackupName, - DataType: "BackupItemOperations", - Namespace: userNamespace, + BackupName: veleroBackupName, + DataType: "BackupItemOperations", + Namespace: userNamespace, + HTTPTimeout: timeout, }); err == nil { fmt.Fprintf(cmd.OutOrStdout(), "\nBackup Item Operations:\n") fmt.Fprintf(cmd.OutOrStdout(), "%s", indent(itemOps, " ")) diff --git a/cmd/non-admin/backup/logs.go b/cmd/non-admin/backup/logs.go index ee8f555..850699e 100644 --- a/cmd/non-admin/backup/logs.go +++ b/cmd/non-admin/backup/logs.go @@ -19,6 +19,7 @@ limitations under the License. import ( "context" "fmt" + "net" "time" "github.com/migtools/oadp-cli/cmd/shared" @@ -31,12 +32,18 @@ import ( ) func NewLogsCommand(f client.Factory, use string) *cobra.Command { - return &cobra.Command{ + var requestTimeout time.Duration + + c := &cobra.Command{ Use: use + " NAME", Short: "Show logs for a non-admin backup", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + // Get effective timeout (flag takes precedence over env var) + effectiveTimeout := shared.GetHTTPTimeoutWithOverride(requestTimeout) + + // Create context with the effective timeout for the entire operation + ctx, cancel := context.WithTimeout(context.Background(), effectiveTimeout) defer cancel() // Get the current namespace from kubectl context @@ -59,6 +66,18 @@ func NewLogsCommand(f client.Factory, use string) *cobra.Command { if err != nil { return fmt.Errorf("failed to get rest config: %w", err) } + // Set timeout on REST config to prevent hanging when cluster is unreachable + restConfig.Timeout = effectiveTimeout + + // Set a custom dial function with timeout to ensure TCP connection attempts + // also respect the timeout (the default TCP dial timeout is ~30s) + dialer := &net.Dialer{ + Timeout: effectiveTimeout, + } + restConfig.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { + return dialer.DialContext(ctx, network, address) + } + kbClient, err := kbclient.New(restConfig, kbclient.Options{Scheme: scheme}) if err != nil { return fmt.Errorf("failed to create controller-runtime client: %w", err) @@ -97,26 +116,34 @@ func NewLogsCommand(f client.Factory, use string) *cobra.Command { _ = kbClient.Delete(deleteCtx, req) }() - fmt.Fprintf(cmd.OutOrStdout(), "Waiting for backup logs to be processed...\n") + fmt.Fprintf(cmd.OutOrStdout(), "Waiting for backup logs to be processed (timeout: %v)...\n", effectiveTimeout) - // Wait for the download request to be processed using shared utility - // Note: We create a custom waiting implementation here to provide user feedback - timeout := time.After(120 * time.Second) - tick := time.Tick(2 * time.Second) + // Wait for the download request to be processed + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() var signedURL string Loop: for { select { - case <-timeout: - return fmt.Errorf("timed out waiting for NonAdminDownloadRequest to be processed") - case <-tick: + case <-ctx.Done(): + // Check if context was cancelled due to timeout or other reason + if ctx.Err() == context.DeadlineExceeded { + return shared.FormatDownloadRequestTimeoutError(kbClient, req, effectiveTimeout) + } + // Context cancelled for other reason (e.g., user interruption) + return fmt.Errorf("operation cancelled: %w", ctx.Err()) + case <-ticker.C: fmt.Fprintf(cmd.OutOrStdout(), ".") var updated nacv1alpha1.NonAdminDownloadRequest if err := kbClient.Get(ctx, kbclient.ObjectKey{ Namespace: req.Namespace, Name: req.Name, }, &updated); err != nil { + // If context expired during Get, handle it in next iteration + if ctx.Err() != nil { + continue + } return fmt.Errorf("failed to get NonAdminDownloadRequest: %w", err) } @@ -141,12 +168,18 @@ func NewLogsCommand(f client.Factory, use string) *cobra.Command { } // Use the shared StreamDownloadContent function to download and stream logs - if err := shared.StreamDownloadContent(signedURL, cmd.OutOrStdout()); err != nil { + // Note: We use the same effective timeout for the HTTP download + if err := shared.StreamDownloadContentWithTimeout(signedURL, cmd.OutOrStdout(), effectiveTimeout); err != nil { return fmt.Errorf("failed to download and stream logs: %w", err) } return nil }, - Example: ` kubectl oadp nonadmin backup logs my-backup`, + Example: ` kubectl oadp nonadmin backup logs my-backup + kubectl oadp nonadmin backup logs my-backup --request-timeout=30m`, } + + c.Flags().DurationVar(&requestTimeout, "request-timeout", 0, fmt.Sprintf("The length of time to wait before giving up on a single server request (e.g., 30s, 5m, 1h). Overrides %s env var. Default: %v", shared.TimeoutEnvVar, shared.DefaultHTTPTimeout)) + + return c } diff --git a/cmd/root.go b/cmd/root.go index 7acc04b..c639384 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -17,18 +17,29 @@ limitations under the License. package cmd import ( + "context" "flag" "fmt" "io" + "net" "os" "regexp" "strings" + "sync" + "time" "github.com/fatih/color" "github.com/migtools/oadp-cli/cmd/nabsl-request" nonadmin "github.com/migtools/oadp-cli/cmd/non-admin" "github.com/spf13/cobra" + velerov1 "github.com/vmware-tanzu/velero/pkg/apis/velero/v1" clientcmd "github.com/vmware-tanzu/velero/pkg/client" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/discovery" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + kbclient "sigs.k8s.io/controller-runtime/pkg/client" "github.com/vmware-tanzu/velero/pkg/cmd/cli/backup" "github.com/vmware-tanzu/velero/pkg/cmd/cli/backuplocation" @@ -53,6 +64,111 @@ import ( "sigs.k8s.io/kustomize/cmd/config/completion" ) +// globalRequestTimeout holds the request timeout value set by --request-timeout flag. +// This is used by the timeoutFactory wrapper to apply dial timeout to all clients. +var ( + globalRequestTimeout time.Duration + globalTimeoutMu sync.RWMutex +) + +// setGlobalRequestTimeout sets the global request timeout value. +func setGlobalRequestTimeout(timeout time.Duration) { + globalTimeoutMu.Lock() + defer globalTimeoutMu.Unlock() + globalRequestTimeout = timeout +} + +// getGlobalRequestTimeout gets the global request timeout value. +func getGlobalRequestTimeout() time.Duration { + globalTimeoutMu.RLock() + defer globalTimeoutMu.RUnlock() + return globalRequestTimeout +} + +// timeoutFactory wraps a Velero client.Factory to apply dial timeout to REST configs. +type timeoutFactory struct { + clientcmd.Factory +} + +// applyTimeoutToConfig applies the global request timeout to a REST config. +func applyTimeoutToConfig(config *rest.Config) { + timeout := getGlobalRequestTimeout() + if timeout > 0 { + config.Timeout = timeout + + // Set custom dial function with timeout for TCP connections + dialer := &net.Dialer{ + Timeout: timeout, + } + config.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { + return dialer.DialContext(ctx, network, address) + } + } +} + +// ClientConfig returns a REST config with dial timeout applied. +func (f *timeoutFactory) ClientConfig() (*rest.Config, error) { + config, err := f.Factory.ClientConfig() + if err != nil { + return nil, err + } + applyTimeoutToConfig(config) + return config, nil +} + +// KubeClient returns a Kubernetes client with dial timeout applied. +func (f *timeoutFactory) KubeClient() (kubernetes.Interface, error) { + config, err := f.ClientConfig() + if err != nil { + return nil, err + } + return kubernetes.NewForConfig(config) +} + +// DynamicClient returns a Kubernetes dynamic client with dial timeout applied. +func (f *timeoutFactory) DynamicClient() (dynamic.Interface, error) { + config, err := f.ClientConfig() + if err != nil { + return nil, err + } + return dynamic.NewForConfig(config) +} + +// DiscoveryClient returns a Kubernetes discovery client with dial timeout applied. +func (f *timeoutFactory) DiscoveryClient() (discovery.AggregatedDiscoveryInterface, error) { + config, err := f.ClientConfig() + if err != nil { + return nil, err + } + return discovery.NewDiscoveryClientForConfig(config) +} + +// KubebuilderClient returns a controller-runtime client with dial timeout applied. +func (f *timeoutFactory) KubebuilderClient() (kbclient.Client, error) { + config, err := f.ClientConfig() + if err != nil { + return nil, err + } + scheme := runtime.NewScheme() + if err := velerov1.AddToScheme(scheme); err != nil { + return nil, err + } + return kbclient.New(config, kbclient.Options{Scheme: scheme}) +} + +// KubebuilderWatchClient returns a controller-runtime client with watch capability and dial timeout applied. +func (f *timeoutFactory) KubebuilderWatchClient() (kbclient.WithWatch, error) { + config, err := f.ClientConfig() + if err != nil { + return nil, err + } + scheme := runtime.NewScheme() + if err := velerov1.AddToScheme(scheme); err != nil { + return nil, err + } + return kbclient.NewWithWatch(config, kbclient.Options{Scheme: scheme}) +} + // veleroCommandPattern matches "velero" when used as a CLI command. // It matches "velero" followed by common command patterns, including two-word commands // like "backup create", "restore get", etc. @@ -123,6 +239,76 @@ func replaceVeleroWithOADP(cmd *cobra.Command) *cobra.Command { return cmd } +// renameTimeoutFlag renames --timeout flag to --request-timeout for kubectl consistency. +// This applies to all commands recursively to ensure a consistent CLI experience. +func renameTimeoutFlag(cmd *cobra.Command) { + // Check if this command has a --timeout flag + timeoutFlag := cmd.Flags().Lookup("timeout") + if timeoutFlag != nil { + // Get the current value and usage + usage := timeoutFlag.Usage + defValue := timeoutFlag.DefValue + + // Parse the default value as duration + var defaultDuration time.Duration + if defValue != "" && defValue != "0s" { + if parsed, err := time.ParseDuration(defValue); err == nil { + defaultDuration = parsed + } + } + + // Create a variable to hold the value + var requestTimeout time.Duration + + // If there's a shorthand, we need to handle it + shorthand := timeoutFlag.Shorthand + + // Hide the old flag instead of removing it (to avoid breaking existing scripts) + timeoutFlag.Hidden = true + + // Add the new --request-timeout flag + if shorthand != "" { + cmd.Flags().DurationVarP(&requestTimeout, "request-timeout", shorthand, defaultDuration, usage) + } else { + cmd.Flags().DurationVar(&requestTimeout, "request-timeout", defaultDuration, usage) + } + + // Link the flags so setting one affects the other and set global timeout + cmd.PreRunE = wrapPreRunE(cmd.PreRunE, func(c *cobra.Command, args []string) error { + // If request-timeout was set, copy its value to the timeout flag and set global + if c.Flags().Changed("request-timeout") { + rtFlag := c.Flags().Lookup("request-timeout") + if rtFlag != nil { + // Set the global timeout for the timeoutFactory wrapper + if parsed, err := time.ParseDuration(rtFlag.Value.String()); err == nil { + setGlobalRequestTimeout(parsed) + } + return c.Flags().Set("timeout", rtFlag.Value.String()) + } + } + return nil + }) + } + + // Recursively process all child commands + for _, child := range cmd.Commands() { + renameTimeoutFlag(child) + } +} + +// wrapPreRunE wraps an existing PreRunE function with additional logic +func wrapPreRunE(existing func(*cobra.Command, []string) error, additional func(*cobra.Command, []string) error) func(*cobra.Command, []string) error { + return func(cmd *cobra.Command, args []string) error { + if err := additional(cmd, args); err != nil { + return err + } + if existing != nil { + return existing(cmd, args) + } + return nil + } +} + // NewVeleroRootCommand returns a root command with all Velero CLI subcommands attached. func NewVeleroRootCommand(baseName string) *cobra.Command { @@ -158,7 +344,9 @@ func NewVeleroRootCommand(baseName string) *cobra.Command { // Create Velero client factory for regular Velero commands // This factory is used to create clients for interacting with Velero resources. - f := clientcmd.NewFactory(baseName, config) + // We wrap it with timeoutFactory to apply dial timeout from --request-timeout flag. + baseFactory := clientcmd.NewFactory(baseName, config) + f := &timeoutFactory{Factory: baseFactory} c.AddCommand( backup.NewCommand(f), @@ -191,6 +379,11 @@ func NewVeleroRootCommand(baseName string) *cobra.Command { replaceVeleroWithOADP(cmd) } + // Rename --timeout flags to --request-timeout for kubectl consistency + for _, cmd := range c.Commands() { + renameTimeoutFlag(cmd) + } + klog.InitFlags(flag.CommandLine) c.PersistentFlags().AddGoFlagSet(flag.CommandLine) return c diff --git a/cmd/root_test.go b/cmd/root_test.go index 0a72483..b2778fc 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -18,14 +18,18 @@ package cmd import ( "bytes" + "context" "fmt" "io" + "net" "os" "strings" "testing" + "time" "github.com/migtools/oadp-cli/internal/testutil" "github.com/spf13/cobra" + "k8s.io/client-go/rest" ) // TestRootCommand tests the root command functionality @@ -491,3 +495,142 @@ func TestReplaceVeleroWithOADP_RunOutputPreservesProperNouns(t *testing.T) { }) } } + +// TestGlobalRequestTimeout tests the thread-safe global timeout get/set functions +func TestGlobalRequestTimeout(t *testing.T) { + // Reset to zero at start + setGlobalRequestTimeout(0) + + // Test initial value is zero + if got := getGlobalRequestTimeout(); got != 0 { + t.Errorf("Expected initial timeout to be 0, got %v", got) + } + + // Test setting a value + expected := 5 * time.Second + setGlobalRequestTimeout(expected) + if got := getGlobalRequestTimeout(); got != expected { + t.Errorf("Expected timeout to be %v, got %v", expected, got) + } + + // Test setting another value + expected = 30 * time.Second + setGlobalRequestTimeout(expected) + if got := getGlobalRequestTimeout(); got != expected { + t.Errorf("Expected timeout to be %v, got %v", expected, got) + } + + // Reset after test + setGlobalRequestTimeout(0) +} + +// TestApplyTimeoutToConfig tests that applyTimeoutToConfig correctly sets timeout on REST config +func TestApplyTimeoutToConfig(t *testing.T) { + tests := []struct { + name string + globalTimeout time.Duration + expectTimeout bool + expectDialer bool + }{ + { + name: "zero timeout does not modify config", + globalTimeout: 0, + expectTimeout: false, + expectDialer: false, + }, + { + name: "positive timeout sets config timeout and dialer", + globalTimeout: 10 * time.Second, + expectTimeout: true, + expectDialer: true, + }, + { + name: "1 second timeout", + globalTimeout: 1 * time.Second, + expectTimeout: true, + expectDialer: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set global timeout + setGlobalRequestTimeout(tt.globalTimeout) + defer setGlobalRequestTimeout(0) + + // Create a config + config := &rest.Config{ + Host: "https://test-cluster:6443", + } + + // Apply timeout + applyTimeoutToConfig(config) + + // Check timeout + if tt.expectTimeout { + if config.Timeout != tt.globalTimeout { + t.Errorf("Expected config.Timeout to be %v, got %v", tt.globalTimeout, config.Timeout) + } + } else { + if config.Timeout != 0 { + t.Errorf("Expected config.Timeout to be 0, got %v", config.Timeout) + } + } + + // Check dialer + if tt.expectDialer { + if config.Dial == nil { + t.Error("Expected config.Dial to be set, but it was nil") + } + } else { + if config.Dial != nil { + t.Error("Expected config.Dial to be nil, but it was set") + } + } + }) + } +} + +// TestApplyTimeoutToConfig_DialerTimeout tests that the custom dialer respects the timeout +func TestApplyTimeoutToConfig_DialerTimeout(t *testing.T) { + // Set a very short timeout + timeout := 100 * time.Millisecond + setGlobalRequestTimeout(timeout) + defer setGlobalRequestTimeout(0) + + config := &rest.Config{ + Host: "https://test-cluster:6443", + } + + applyTimeoutToConfig(config) + + if config.Dial == nil { + t.Fatal("Expected config.Dial to be set") + } + + // Test that the dialer times out quickly when connecting to a non-routable address + // 10.255.255.1 is a non-routable IP that should cause a timeout + ctx := context.Background() + start := time.Now() + _, err := config.Dial(ctx, "tcp", "10.255.255.1:6443") + elapsed := time.Since(start) + + // Should get a timeout error + if err == nil { + t.Error("Expected dial to fail with timeout, but it succeeded") + } + + // Check it's a timeout error + if netErr, ok := err.(net.Error); ok { + if !netErr.Timeout() { + t.Errorf("Expected timeout error, got: %v", err) + } + } + + // Should complete within a reasonable time of the timeout + // Allow some margin for test execution overhead + maxExpected := timeout + 500*time.Millisecond + if elapsed > maxExpected { + t.Errorf("Dial took too long: %v (expected ~%v)", elapsed, timeout) + } +} diff --git a/cmd/shared/client.go b/cmd/shared/client.go index f0f7990..1de83a6 100644 --- a/cmd/shared/client.go +++ b/cmd/shared/client.go @@ -17,7 +17,10 @@ limitations under the License. package shared import ( + "context" "fmt" + "net" + "time" nacv1alpha1 "github.com/migtools/oadp-non-admin/api/v1alpha1" velerov1 "github.com/vmware-tanzu/velero/pkg/apis/velero/v1" @@ -36,10 +39,51 @@ type ClientOptions struct { IncludeVeleroTypes bool // IncludeCoreTypes adds Kubernetes core types to the scheme IncludeCoreTypes bool + // Timeout sets a timeout on the REST client configuration. + // This prevents the client from hanging indefinitely when the cluster is unreachable. + // If zero, no timeout is set. + Timeout time.Duration } // NewClientWithScheme creates a controller-runtime client with the specified scheme types func NewClientWithScheme(f client.Factory, opts ClientOptions) (kbclient.WithWatch, error) { + // If a timeout is specified, we need to create the client manually with the timeout + // applied to the REST config. Otherwise, use the factory's default method. + if opts.Timeout > 0 { + // Get REST config from factory + restConfig, err := f.ClientConfig() + if err != nil { + return nil, fmt.Errorf("failed to get rest config: %w", err) + } + + // Set timeout on REST config to prevent hanging when cluster is unreachable + restConfig.Timeout = opts.Timeout + + // Set a custom dial function with timeout to ensure TCP connection attempts + // also respect the timeout (the default TCP dial timeout is ~30s) + dialer := &net.Dialer{ + Timeout: opts.Timeout, + } + restConfig.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { + return dialer.DialContext(ctx, network, address) + } + + // Create scheme with required types + scheme, err := NewSchemeWithTypes(opts) + if err != nil { + return nil, err + } + + // Create client with the timeout-configured REST config + kbClient, err := kbclient.NewWithWatch(restConfig, kbclient.Options{Scheme: scheme}) + if err != nil { + return nil, fmt.Errorf("failed to create controller-runtime client: %w", err) + } + + return kbClient, nil + } + + // No timeout specified, use factory's default method kbClient, err := f.KubebuilderWatchClient() if err != nil { return nil, fmt.Errorf("failed to create controller-runtime client: %w", err) diff --git a/cmd/shared/download.go b/cmd/shared/download.go index 32c709c..af6b8b5 100644 --- a/cmd/shared/download.go +++ b/cmd/shared/download.go @@ -37,20 +37,35 @@ import ( // This prevents the CLI from hanging indefinitely if the connection stalls. const DefaultHTTPTimeout = 10 * time.Minute -// HTTPTimeoutEnvVar is the environment variable name that can be used to override the default HTTP timeout. -// Example: OADP_CLI_HTTP_TIMEOUT=30m kubectl oadp nonadmin backup logs my-backup -const HTTPTimeoutEnvVar = "OADP_CLI_HTTP_TIMEOUT" +// TimeoutEnvVar is the environment variable name that can be used to override the default timeout. +// Example: OADP_CLI_REQUEST_TIMEOUT=30m kubectl oadp nonadmin backup logs my-backup +const TimeoutEnvVar = "OADP_CLI_REQUEST_TIMEOUT" // getHTTPTimeout returns the HTTP timeout to use for download operations. // It checks for an environment variable override first, then falls back to the default. func getHTTPTimeout() time.Duration { - if envTimeout := os.Getenv(HTTPTimeoutEnvVar); envTimeout != "" { + return GetHTTPTimeoutWithOverride(0) +} + +// GetHTTPTimeoutWithOverride returns the HTTP timeout to use for download operations. +// Priority order: override parameter (if > 0) > environment variable > default. +// This allows CLI flags to take precedence over environment variables. +func GetHTTPTimeoutWithOverride(override time.Duration) time.Duration { + // If an explicit override is provided (e.g., from --timeout flag), use it + if override > 0 { + log.Printf("Using HTTP timeout from command-line flag: %v", override) + return override + } + + // Check for environment variable + if envTimeout := os.Getenv(TimeoutEnvVar); envTimeout != "" { if parsed, err := time.ParseDuration(envTimeout); err == nil { - log.Printf("Using custom HTTP timeout from %s: %v", HTTPTimeoutEnvVar, parsed) + log.Printf("Using custom HTTP timeout from %s: %v", TimeoutEnvVar, parsed) return parsed } - log.Printf("Warning: Invalid duration in %s=%q, using default %v", HTTPTimeoutEnvVar, envTimeout, DefaultHTTPTimeout) + log.Printf("Warning: Invalid duration in %s=%q, using default %v", TimeoutEnvVar, envTimeout, DefaultHTTPTimeout) } + return DefaultHTTPTimeout } @@ -74,6 +89,9 @@ type DownloadRequestOptions struct { Timeout time.Duration // PollInterval is how often to check the status of the download request PollInterval time.Duration + // HTTPTimeout is the timeout for downloading content from the signed URL. + // If zero, uses the default timeout (env var or DefaultHTTPTimeout). + HTTPTimeout time.Duration } // ProcessDownloadRequest creates a NonAdminDownloadRequest, waits for it to be processed, @@ -119,8 +137,9 @@ func ProcessDownloadRequest(ctx context.Context, kbClient kbclient.Client, opts return "", err } - // Download and return the content - return DownloadContent(signedURL) + // Download and return the content using the specified HTTP timeout + httpTimeout := GetHTTPTimeoutWithOverride(opts.HTTPTimeout) + return DownloadContentWithTimeout(signedURL, httpTimeout) } // waitForDownloadURL waits for a NonAdminDownloadRequest to be processed and returns the signed URL @@ -165,7 +184,7 @@ func waitForDownloadURL(ctx context.Context, kbClient kbclient.Client, req *nacv // DownloadContent fetches content from a signed URL and returns it as a string. // It handles both gzipped and non-gzipped content automatically. -// Uses DefaultHTTPTimeout (or OADP_CLI_HTTP_TIMEOUT env var) to prevent hanging indefinitely. +// Uses DefaultHTTPTimeout (or OADP_CLI_REQUEST_TIMEOUT env var) to prevent hanging indefinitely. func DownloadContent(url string) (string, error) { return DownloadContentWithTimeout(url, getHTTPTimeout()) } @@ -207,7 +226,7 @@ func DownloadContentWithTimeout(url string, timeout time.Duration) (string, erro // StreamDownloadContent fetches content from a signed URL and streams it to the provided writer. // This is useful for large files like logs that should be streamed rather than loaded into memory. -// Uses DefaultHTTPTimeout (or OADP_CLI_HTTP_TIMEOUT env var) to prevent hanging indefinitely. +// Uses DefaultHTTPTimeout (or OADP_CLI_REQUEST_TIMEOUT env var) to prevent hanging indefinitely. func StreamDownloadContent(url string, writer io.Writer) error { return StreamDownloadContentWithTimeout(url, writer, getHTTPTimeout()) } @@ -245,3 +264,39 @@ func StreamDownloadContentWithTimeout(url string, writer io.Writer, timeout time return nil } + +// DefaultOperationTimeout is the default timeout for waiting for download requests to be processed. +const DefaultOperationTimeout = 5 * time.Minute + +// defaultStatusCheckTimeout is the timeout for checking status when formatting timeout errors. +const defaultStatusCheckTimeout = 5 * time.Second + +// FormatDownloadRequestTimeoutError creates a helpful error message when a download request times out. +// It attempts to fetch the current status of the request to provide diagnostic information. +func FormatDownloadRequestTimeoutError(kbClient kbclient.Client, req *nacv1alpha1.NonAdminDownloadRequest, timeout time.Duration) error { + // If client is available, try to get the current status for better diagnostics + if kbClient != nil { + // Use a fresh context to check final status since the original context is expired + statusCtx, cancel := context.WithTimeout(context.Background(), defaultStatusCheckTimeout) + defer cancel() + + var updated nacv1alpha1.NonAdminDownloadRequest + if err := kbClient.Get(statusCtx, kbclient.ObjectKey{ + Namespace: req.Namespace, + Name: req.Name, + }, &updated); err == nil { + // Format status conditions for helpful error message + var statusInfo string + if len(updated.Status.Conditions) > 0 { + var conditions []string + for _, c := range updated.Status.Conditions { + conditions = append(conditions, fmt.Sprintf("%s=%s (reason: %s)", c.Type, c.Status, c.Reason)) + } + statusInfo = fmt.Sprintf(" Current status: %s.", strings.Join(conditions, ", ")) + } + return fmt.Errorf("timed out after %v waiting for NonAdminDownloadRequest %q to be processed.%s", timeout, req.Name, statusInfo) + } + } + + return fmt.Errorf("timed out after %v waiting for NonAdminDownloadRequest %q to be processed", timeout, req.Name) +} diff --git a/cmd/shared/download_test.go b/cmd/shared/download_test.go index dd41acf..01c825a 100644 --- a/cmd/shared/download_test.go +++ b/cmd/shared/download_test.go @@ -25,6 +25,8 @@ import ( "strings" "testing" "time" + + nacv1alpha1 "github.com/migtools/oadp-non-admin/api/v1alpha1" ) // TestDefaultHTTPTimeout verifies the default timeout constant @@ -35,11 +37,11 @@ func TestDefaultHTTPTimeout(t *testing.T) { } } -// TestHTTPTimeoutEnvVar verifies the environment variable name constant -func TestHTTPTimeoutEnvVar(t *testing.T) { - expected := "OADP_CLI_HTTP_TIMEOUT" - if HTTPTimeoutEnvVar != expected { - t.Errorf("HTTPTimeoutEnvVar = %q, want %q", HTTPTimeoutEnvVar, expected) +// TestTimeoutEnvVar verifies the environment variable name constant +func TestTimeoutEnvVar(t *testing.T) { + expected := "OADP_CLI_REQUEST_TIMEOUT" + if TimeoutEnvVar != expected { + t.Errorf("TimeoutEnvVar = %q, want %q", TimeoutEnvVar, expected) } } @@ -95,13 +97,13 @@ func TestGetHTTPTimeout(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Save and restore original env var - originalValue := os.Getenv(HTTPTimeoutEnvVar) - defer os.Setenv(HTTPTimeoutEnvVar, originalValue) + originalValue := os.Getenv(TimeoutEnvVar) + defer os.Setenv(TimeoutEnvVar, originalValue) if tt.envValue != "" { - os.Setenv(HTTPTimeoutEnvVar, tt.envValue) + os.Setenv(TimeoutEnvVar, tt.envValue) } else { - os.Unsetenv(HTTPTimeoutEnvVar) + os.Unsetenv(TimeoutEnvVar) } got := getHTTPTimeout() @@ -246,9 +248,9 @@ func TestDownloadContentWithTimeout(t *testing.T) { // TestDownloadContent tests that DownloadContent uses the default timeout mechanism func TestDownloadContent(t *testing.T) { // Save and restore original env var - originalValue := os.Getenv(HTTPTimeoutEnvVar) - defer os.Setenv(HTTPTimeoutEnvVar, originalValue) - os.Unsetenv(HTTPTimeoutEnvVar) + originalValue := os.Getenv(TimeoutEnvVar) + defer os.Setenv(TimeoutEnvVar, originalValue) + os.Unsetenv(TimeoutEnvVar) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -369,9 +371,9 @@ func TestStreamDownloadContentWithTimeout(t *testing.T) { // TestStreamDownloadContent tests that StreamDownloadContent uses the default timeout mechanism func TestStreamDownloadContent(t *testing.T) { // Save and restore original env var - originalValue := os.Getenv(HTTPTimeoutEnvVar) - defer os.Setenv(HTTPTimeoutEnvVar, originalValue) - os.Unsetenv(HTTPTimeoutEnvVar) + originalValue := os.Getenv(TimeoutEnvVar) + defer os.Setenv(TimeoutEnvVar, originalValue) + os.Unsetenv(TimeoutEnvVar) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -411,11 +413,11 @@ func TestStreamDownloadContentWithTimeout_InvalidURL(t *testing.T) { // TestGetHTTPTimeoutWithEnvVar tests that the env var override works correctly func TestGetHTTPTimeoutWithEnvVar(t *testing.T) { // Save and restore original env var - originalValue := os.Getenv(HTTPTimeoutEnvVar) - defer os.Setenv(HTTPTimeoutEnvVar, originalValue) + originalValue := os.Getenv(TimeoutEnvVar) + defer os.Setenv(TimeoutEnvVar, originalValue) // Set custom timeout - os.Setenv(HTTPTimeoutEnvVar, "5m") + os.Setenv(TimeoutEnvVar, "5m") timeout := getHTTPTimeout() expected := 5 * time.Minute @@ -424,3 +426,127 @@ func TestGetHTTPTimeoutWithEnvVar(t *testing.T) { t.Errorf("getHTTPTimeout() with env var = %v, want %v", timeout, expected) } } + +// TestGetHTTPTimeoutWithOverride tests the priority order: override > env var > default +func TestGetHTTPTimeoutWithOverride(t *testing.T) { + tests := []struct { + name string + override time.Duration + envValue string + want time.Duration + }{ + { + name: "override takes precedence over env var", + override: 15 * time.Minute, + envValue: "30m", + want: 15 * time.Minute, + }, + { + name: "override takes precedence over default when no env var", + override: 20 * time.Minute, + envValue: "", + want: 20 * time.Minute, + }, + { + name: "zero override falls back to env var", + override: 0, + envValue: "25m", + want: 25 * time.Minute, + }, + { + name: "zero override and no env var falls back to default", + override: 0, + envValue: "", + want: DefaultHTTPTimeout, + }, + { + name: "zero override with invalid env var falls back to default", + override: 0, + envValue: "invalid", + want: DefaultHTTPTimeout, + }, + { + name: "small override value is respected", + override: 30 * time.Second, + envValue: "10m", + want: 30 * time.Second, + }, + { + name: "large override value is respected", + override: 2 * time.Hour, + envValue: "5m", + want: 2 * time.Hour, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save and restore original env var + originalValue := os.Getenv(TimeoutEnvVar) + defer os.Setenv(TimeoutEnvVar, originalValue) + + if tt.envValue != "" { + os.Setenv(TimeoutEnvVar, tt.envValue) + } else { + os.Unsetenv(TimeoutEnvVar) + } + + got := GetHTTPTimeoutWithOverride(tt.override) + if got != tt.want { + t.Errorf("GetHTTPTimeoutWithOverride(%v) = %v, want %v", tt.override, got, tt.want) + } + }) + } +} + +// TestDefaultOperationTimeout verifies the default operation timeout constant +func TestDefaultOperationTimeout(t *testing.T) { + expected := 5 * time.Minute + if DefaultOperationTimeout != expected { + t.Errorf("DefaultOperationTimeout = %v, want %v", DefaultOperationTimeout, expected) + } +} + +// TestFormatDownloadRequestTimeoutError_NilClient tests error formatting when client is nil or request fails +func TestFormatDownloadRequestTimeoutError_BasicMessage(t *testing.T) { + // Test that the function returns a properly formatted error message + // even when we can't fetch the status (simulated by passing nil client) + timeout := 5 * time.Minute + + // Create a mock request + req := &nacv1alpha1.NonAdminDownloadRequest{} + req.Name = "test-backup-logs-abc123" + req.Namespace = "test-namespace" + + // With a nil client, the Get will fail, so we'll get the basic error message + err := FormatDownloadRequestTimeoutError(nil, req, timeout) + + // Should contain timeout duration and request name + if err == nil { + t.Fatal("expected error, got nil") + } + + errStr := err.Error() + if !strings.Contains(errStr, "5m0s") { + t.Errorf("error should contain timeout duration '5m0s', got: %s", errStr) + } + if !strings.Contains(errStr, "test-backup-logs-abc123") { + t.Errorf("error should contain request name, got: %s", errStr) + } + if !strings.Contains(errStr, "timed out") { + t.Errorf("error should contain 'timed out', got: %s", errStr) + } +} + +// TestGetHTTPTimeoutWithOverride_ZeroReturnsDefault verifies that zero override with no env var returns default +func TestGetHTTPTimeoutWithOverride_ZeroReturnsDefault(t *testing.T) { + // Save and restore original env var + originalValue := os.Getenv(TimeoutEnvVar) + defer os.Setenv(TimeoutEnvVar, originalValue) + os.Unsetenv(TimeoutEnvVar) + + got := GetHTTPTimeoutWithOverride(0) + if got != DefaultHTTPTimeout { + t.Errorf("GetHTTPTimeoutWithOverride(0) without env var = %v, want %v", got, DefaultHTTPTimeout) + } +} From 24f50bab2dd505cc2bd66d0faa3f5ca0a169f1de Mon Sep 17 00:00:00 2001 From: Michal Pryc Date: Tue, 27 Jan 2026 15:58:56 +0100 Subject: [PATCH 3/3] Apply suggestion from @kaovilai Co-authored-by: Tiger Kaovilai --- cmd/shared/download.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/shared/download.go b/cmd/shared/download.go index af6b8b5..c17de81 100644 --- a/cmd/shared/download.go +++ b/cmd/shared/download.go @@ -294,7 +294,7 @@ func FormatDownloadRequestTimeoutError(kbClient kbclient.Client, req *nacv1alpha } statusInfo = fmt.Sprintf(" Current status: %s.", strings.Join(conditions, ", ")) } - return fmt.Errorf("timed out after %v waiting for NonAdminDownloadRequest %q to be processed.%s", timeout, req.Name, statusInfo) + return fmt.Errorf("timed out after %v waiting for NonAdminDownloadRequest %q to be processed. statusInfo: %s", timeout, req.Name, statusInfo) } }