diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e3f698a..3950250 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -82,7 +82,7 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v9 with: - version: v2.7.2 + # Version omitted to use action's default (tracks Go compatibility) args: --timeout=5m # Posts a sticky coverage comment to PRs (updates in place, details collapsed) diff --git a/.gitignore b/.gitignore index 0d0be21..e3f62a1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ bin/ +dist/ +completions/ .DS_Store *.log nohup.out diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 01fba21..4bf52d1 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -6,6 +6,11 @@ before: hooks: - go mod tidy - go mod download + - mkdir -p completions + - sh -c "go run ./cmd/armis-cli completion bash > completions/armis-cli.bash" + - sh -c "go run ./cmd/armis-cli completion zsh > completions/_armis-cli" + - sh -c "go run ./cmd/armis-cli completion fish > completions/armis-cli.fish" + - sh -c "go run ./cmd/armis-cli completion powershell > completions/armis-cli.ps1" builds: - id: armis-cli @@ -42,6 +47,7 @@ archives: - LICENSE* - README.md - docs/**/* + - completions/* checksum: name_template: "armis-cli-checksums.txt" @@ -89,29 +95,29 @@ release: mode: replace header: | ## Armis CLI {{ .Tag }} - + Enterprise-grade CLI tool for static application security scanning. - + ### Installation - + **Quick Install Script:** ```bash curl -sSL https://raw.githubusercontent.com/ArmisSecurity/armis-cli/main/scripts/install.sh | bash ``` - + **Windows (PowerShell):** ```powershell irm https://raw.githubusercontent.com/ArmisSecurity/armis-cli/main/scripts/install.ps1 | iex ``` - + **Go Install:** ```bash go install github.com/ArmisSecurity/armis-cli/cmd/armis-cli@latest ``` - + **Manual Download:** Download the appropriate binary for your platform below. - + ### Verification All binaries are signed with cosign. To verify: ```bash @@ -123,9 +129,9 @@ release: ``` footer: | --- - + **Full Changelog**: https://github.com/ArmisSecurity/armis-cli/compare/{{ .PreviousTag }}...{{ .Tag }} - + For issues or questions, visit: https://github.com/ArmisSecurity/armis-cli/issues snapshot: diff --git a/cmd/armis-cli/main.go b/cmd/armis-cli/main.go index e2a4c1a..675421c 100644 --- a/cmd/armis-cli/main.go +++ b/cmd/armis-cli/main.go @@ -2,8 +2,10 @@ package main import ( + "errors" "os" + "github.com/ArmisSecurity/armis-cli/internal/cli" "github.com/ArmisSecurity/armis-cli/internal/cmd" ) @@ -15,7 +17,16 @@ var ( func main() { cmd.SetVersion(version, commit, date) + // Initialize colors early with auto-detection as fallback. + // This handles cases where PersistentPreRunE doesn't fire (e.g., flag parsing errors). + // The actual --color flag value will override this in PersistentPreRunE. + cli.InitColors(cli.ColorModeAuto) if err := cmd.Execute(); err != nil { + // Handle user cancellation (Ctrl+C) cleanly without printing error + if errors.Is(err, cmd.ErrScanCancelled) { + os.Exit(cmd.ExitCodeCancelled) + } + cli.PrintError(err.Error()) os.Exit(1) } } diff --git a/go.mod b/go.mod index 829ea46..de9b5bb 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/mattn/go-runewidth v0.0.19 github.com/schollz/progressbar/v3 v3.19.0 github.com/spf13/cobra v1.10.2 + golang.org/x/term v0.38.0 ) require ( @@ -24,6 +25,5 @@ require ( github.com/spf13/pflag v1.0.10 // indirect golang.org/x/net v0.48.0 // indirect golang.org/x/sys v0.39.0 // indirect - golang.org/x/term v0.38.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect ) diff --git a/internal/api/client.go b/internal/api/client.go index 1f49ffd..f980a07 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -179,6 +179,10 @@ type IngestOptions struct { GenerateVEX bool } +// StatusCallback is called on each poll with the current scan status. +// It allows callers to react to status changes (e.g., updating a spinner). +type StatusCallback func(status model.IngestStatusData) + // StartIngest uploads an artifact for scanning and returns the scan ID. func (c *Client) StartIngest(ctx context.Context, opts IngestOptions) (string, error) { // Validate upload size for defense-in-depth @@ -264,8 +268,12 @@ func (c *Client) StartIngest(ctx context.Context, opts IngestOptions) (string, e elapsed, formatBytes(opts.Size), resp.Status, strings.TrimSpace(string(bodyBytes))) } + bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, MaxAPIResponseSize)) + if err != nil { + return "", fmt.Errorf("failed to read response body: %w", err) + } var result model.IngestUploadResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := json.Unmarshal(bodyBytes, &result); err != nil { return "", fmt.Errorf("failed to decode response: %w", err) } @@ -296,12 +304,16 @@ func (c *Client) GetIngestStatus(ctx context.Context, tenantID, scanID string) ( defer resp.Body.Close() //nolint:errcheck // response body read-only if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) + bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, MaxAPIResponseSize)) return nil, fmt.Errorf("get status failed with status %d: %s", resp.StatusCode, string(bodyBytes)) } + bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, MaxAPIResponseSize)) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } var result model.IngestStatusResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := json.Unmarshal(bodyBytes, &result); err != nil { return nil, fmt.Errorf("failed to decode response: %w", err) } @@ -309,7 +321,8 @@ func (c *Client) GetIngestStatus(ctx context.Context, tenantID, scanID string) ( } // WaitForIngest polls until the ingestion is complete or times out. -func (c *Client) WaitForIngest(ctx context.Context, tenantID, scanID string, pollInterval time.Duration, timeout time.Duration) (*model.IngestStatusData, error) { +// If onStatus is non-nil, it is called on each poll with the current status. +func (c *Client) WaitForIngest(ctx context.Context, tenantID, scanID string, pollInterval time.Duration, timeout time.Duration, onStatus StatusCallback) (*model.IngestStatusData, error) { if timeout <= 0 { timeout = 60 * time.Minute } @@ -338,6 +351,11 @@ func (c *Client) WaitForIngest(ctx context.Context, tenantID, scanID string, pol } status := statusResp.Data[0] + + if onStatus != nil { + onStatus(status) + } + statusUpper := strings.ToUpper(status.ScanStatus) if statusUpper == "COMPLETED" || statusUpper == "FAILED" { @@ -379,11 +397,11 @@ func (c *Client) FetchNormalizedResults(ctx context.Context, tenantID, scanID st defer resp.Body.Close() //nolint:errcheck // response body read-only if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) + bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, MaxAPIResponseSize)) return nil, fmt.Errorf("fetch results failed with status %d: %s", resp.StatusCode, string(bodyBytes)) } - bodyBytes, err := io.ReadAll(resp.Body) + bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, MaxAPIResponseSize)) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } @@ -444,12 +462,16 @@ func (c *Client) GetScanResult(ctx context.Context, scanID string) (*model.ScanR defer resp.Body.Close() //nolint:errcheck // response body read-only if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) + bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, MaxAPIResponseSize)) return nil, fmt.Errorf("get scan failed with status %d: %s", resp.StatusCode, string(bodyBytes)) } + bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, MaxAPIResponseSize)) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } var result model.ScanResult - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := json.Unmarshal(bodyBytes, &result); err != nil { return nil, fmt.Errorf("failed to decode response: %w", err) } diff --git a/internal/api/client_test.go b/internal/api/client_test.go index 38b6104..96f9655 100644 --- a/internal/api/client_test.go +++ b/internal/api/client_test.go @@ -934,7 +934,7 @@ func TestClient_WaitForIngest(t *testing.T) { t.Fatalf("NewClient failed: %v", err) } - result, err := client.WaitForIngest(context.Background(), "tenant-456", "scan-123", 10*time.Millisecond, 5*time.Second) + result, err := client.WaitForIngest(context.Background(), "tenant-456", "scan-123", 10*time.Millisecond, 5*time.Second, nil) if err != nil { t.Fatalf("WaitForIngest failed: %v", err) @@ -969,7 +969,7 @@ func TestClient_WaitForIngest(t *testing.T) { t.Fatalf("NewClient failed: %v", err) } - _, err = client.WaitForIngest(context.Background(), "tenant-456", "scan-123", 10*time.Millisecond, 5*time.Second) + _, err = client.WaitForIngest(context.Background(), "tenant-456", "scan-123", 10*time.Millisecond, 5*time.Second, nil) if err == nil { t.Fatal("Expected error for FAILED status") @@ -1000,7 +1000,7 @@ func TestClient_WaitForIngest(t *testing.T) { t.Fatalf("NewClient failed: %v", err) } - result, err := client.WaitForIngest(context.Background(), "tenant-456", "scan-123", 10*time.Millisecond, 5*time.Second) + result, err := client.WaitForIngest(context.Background(), "tenant-456", "scan-123", 10*time.Millisecond, 5*time.Second, nil) if err != nil { t.Fatalf("WaitForIngest failed: %v", err) @@ -1030,7 +1030,7 @@ func TestClient_WaitForIngest(t *testing.T) { t.Fatalf("NewClient failed: %v", err) } - _, err = client.WaitForIngest(context.Background(), "tenant-456", "scan-123", 10*time.Millisecond, 50*time.Millisecond) + _, err = client.WaitForIngest(context.Background(), "tenant-456", "scan-123", 10*time.Millisecond, 50*time.Millisecond, nil) if err == nil { t.Fatal("Expected timeout error") @@ -1067,7 +1067,7 @@ func TestClient_WaitForIngest(t *testing.T) { cancel() }() - _, err = client.WaitForIngest(ctx, "tenant-456", "scan-123", 10*time.Millisecond, 5*time.Second) + _, err = client.WaitForIngest(ctx, "tenant-456", "scan-123", 10*time.Millisecond, 5*time.Second, nil) if err == nil { t.Fatal("Expected context cancellation error") @@ -1088,7 +1088,7 @@ func TestClient_WaitForIngest(t *testing.T) { t.Fatalf("NewClient failed: %v", err) } - _, err = client.WaitForIngest(context.Background(), "tenant-456", "scan-123", 10*time.Millisecond, 5*time.Second) + _, err = client.WaitForIngest(context.Background(), "tenant-456", "scan-123", 10*time.Millisecond, 5*time.Second, nil) if err == nil { t.Fatal("Expected error for empty status data") @@ -1109,7 +1109,7 @@ func TestClient_WaitForIngest(t *testing.T) { t.Fatalf("NewClient failed: %v", err) } - _, err = client.WaitForIngest(context.Background(), "tenant-456", "scan-123", 10*time.Millisecond, 100*time.Millisecond) + _, err = client.WaitForIngest(context.Background(), "tenant-456", "scan-123", 10*time.Millisecond, 100*time.Millisecond, nil) if err == nil { t.Fatal("Expected error for failed status check") @@ -1136,7 +1136,7 @@ func TestClient_WaitForIngest(t *testing.T) { t.Fatalf("NewClient failed: %v", err) } - result, err := client.WaitForIngest(context.Background(), "tenant-456", "scan-123", 10*time.Millisecond, 5*time.Second) + result, err := client.WaitForIngest(context.Background(), "tenant-456", "scan-123", 10*time.Millisecond, 5*time.Second, nil) if err != nil { t.Fatalf("WaitForIngest failed: %v", err) @@ -1145,6 +1145,54 @@ func TestClient_WaitForIngest(t *testing.T) { t.Fatal("Expected non-nil result") } }) + + t.Run("invokes status callback on each poll", func(t *testing.T) { + callCount := 0 + var receivedStatuses []string + server := testutil.NewTestServer(t, func(w http.ResponseWriter, _ *http.Request) { + callCount++ + var status string + switch { + case callCount <= 1: + status = "QUEUED" + case callCount <= 2: + status = "PROCESSING" + default: + status = testStatusCompleted + } + response := model.IngestStatusResponse{ + Data: []model.IngestStatusData{ + {ScanID: "scan-123", ScanStatus: status, TenantID: "tenant-456"}, + }, + } + testutil.JSONResponse(t, w, http.StatusOK, response) + }) + + httpClient := httpclient.NewClient(httpclient.Config{Timeout: 5 * time.Second}) + client, err := NewClient(server.URL, testutil.NewTestAuthProvider("token123"), false, 0, WithHTTPClient(httpClient)) + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + result, err := client.WaitForIngest(context.Background(), "tenant-456", "scan-123", + 10*time.Millisecond, 5*time.Second, + func(status model.IngestStatusData) { + receivedStatuses = append(receivedStatuses, status.ScanStatus) + }) + + if err != nil { + t.Fatalf("WaitForIngest failed: %v", err) + } + if result.ScanStatus != testStatusCompleted { + t.Errorf("Expected status %s, got %s", testStatusCompleted, result.ScanStatus) + } + if len(receivedStatuses) < 3 { + t.Errorf("Expected at least 3 callback invocations, got %d", len(receivedStatuses)) + } + if len(receivedStatuses) > 0 && receivedStatuses[0] != "QUEUED" { + t.Errorf("Expected first status QUEUED, got %s", receivedStatuses[0]) + } + }) } func TestClient_WaitForScan(t *testing.T) { diff --git a/internal/cli/color.go b/internal/cli/color.go new file mode 100644 index 0000000..ec98657 --- /dev/null +++ b/internal/cli/color.go @@ -0,0 +1,105 @@ +// Package cli provides CLI utilities including colored output with TTY detection. +package cli + +import ( + "fmt" + "os" + "strings" + + "golang.org/x/term" +) + +// ColorMode represents the color output strategy. +type ColorMode string + +const ( + ColorModeAuto ColorMode = "auto" + ColorModeAlways ColorMode = "always" + ColorModeNever ColorMode = "never" +) + +// ANSI color codes +var ( + colorReset = "\033[0m" + colorRed = "\033[31m" + colorYellow = "\033[93m" + colorBold = "\033[1m" +) + +// colorsEnabled tracks whether colors are currently active. +var colorsEnabled = true + +// InitColors resolves the final color state based on the --color flag value, +// the NO_COLOR env var, and TTY detection. This should be called after flag parsing. +// +// Precedence: +// 1. --color=always -> colors ON (overrides everything, including NO_COLOR) +// 2. --color=never -> colors OFF +// 3. NO_COLOR env -> colors OFF (takes precedence over auto) +// 4. TERM=dumb -> colors OFF +// 5. --color=auto -> detect TTY on stderr +func InitColors(mode ColorMode) { + switch mode { + case ColorModeAlways: + enableColors() + case ColorModeNever: + disableColors() + case ColorModeAuto: + if os.Getenv("NO_COLOR") != "" { + disableColors() + return + } + if strings.Contains(strings.ToLower(os.Getenv("TERM")), "dumb") { + disableColors() + return + } + if !term.IsTerminal(int(os.Stderr.Fd())) { + disableColors() + return + } + enableColors() + } +} + +// ColorsEnabled returns whether colors are currently enabled. +func ColorsEnabled() bool { + return colorsEnabled +} + +func enableColors() { + colorsEnabled = true + colorReset = "\033[0m" + colorRed = "\033[31m" + colorYellow = "\033[93m" + colorBold = "\033[1m" +} + +func disableColors() { + colorsEnabled = false + colorReset = "" + colorRed = "" + colorYellow = "" + colorBold = "" +} + +// PrintError writes a colored error message to stderr. +// Format: "Error: \n" +func PrintError(msg string) { + fmt.Fprintf(os.Stderr, "%s%sError:%s %s\n", colorBold, colorRed, colorReset, msg) +} + +// PrintErrorf is like PrintError but with fmt.Sprintf formatting. +func PrintErrorf(format string, args ...interface{}) { + PrintError(fmt.Sprintf(format, args...)) +} + +// PrintWarning writes a colored warning message to stderr. +// Format: "Warning: \n" +func PrintWarning(msg string) { + fmt.Fprintf(os.Stderr, "%s%sWarning:%s %s\n", colorBold, colorYellow, colorReset, msg) +} + +// PrintWarningf is like PrintWarning but with fmt.Sprintf formatting. +func PrintWarningf(format string, args ...interface{}) { + PrintWarning(fmt.Sprintf(format, args...)) +} diff --git a/internal/cli/color_test.go b/internal/cli/color_test.go new file mode 100644 index 0000000..a527fe7 --- /dev/null +++ b/internal/cli/color_test.go @@ -0,0 +1,132 @@ +package cli + +import ( + "bytes" + "os" + "testing" +) + +func TestInitColors_Never(t *testing.T) { + InitColors(ColorModeNever) + if ColorsEnabled() { + t.Error("expected colors to be disabled with ColorModeNever") + } + if colorRed != "" || colorYellow != "" || colorBold != "" || colorReset != "" { + t.Error("expected all color codes to be empty strings") + } +} + +func TestInitColors_Always(t *testing.T) { + // Set NO_COLOR to verify that 'always' overrides it + t.Setenv("NO_COLOR", "1") + + InitColors(ColorModeAlways) + if !ColorsEnabled() { + t.Error("expected colors to be enabled with ColorModeAlways even when NO_COLOR is set") + } + if colorRed == "" { + t.Error("expected colorRed to have ANSI code") + } +} + +func TestInitColors_Auto_NoColor(t *testing.T) { + t.Setenv("NO_COLOR", "1") + + InitColors(ColorModeAuto) + if ColorsEnabled() { + t.Error("expected colors to be disabled when NO_COLOR is set in auto mode") + } +} + +func TestInitColors_Auto_DumbTerm(t *testing.T) { + t.Setenv("TERM", "dumb") + + InitColors(ColorModeAuto) + if ColorsEnabled() { + t.Error("expected colors to be disabled when TERM=dumb in auto mode") + } +} + +// captureStderr captures stderr output during function execution +func captureStderr(t *testing.T, fn func()) string { + t.Helper() + oldStderr := os.Stderr + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("failed to create pipe: %v", err) + } + os.Stderr = w + + fn() + + if err := w.Close(); err != nil { + t.Errorf("failed to close pipe writer: %v", err) + } + var buf bytes.Buffer + if _, err := buf.ReadFrom(r); err != nil { + t.Errorf("failed to read from pipe: %v", err) + } + if err := r.Close(); err != nil { + t.Errorf("failed to close pipe reader: %v", err) + } + os.Stderr = oldStderr + return buf.String() +} + +func TestPrintError_WithColors(t *testing.T) { + InitColors(ColorModeAlways) + + output := captureStderr(t, func() { + PrintError("test error message") + }) + + if !bytes.Contains([]byte(output), []byte("Error:")) { + t.Errorf("expected output to contain 'Error:', got: %s", output) + } + if !bytes.Contains([]byte(output), []byte("test error message")) { + t.Errorf("expected output to contain message, got: %s", output) + } + // Check for ANSI codes + if !bytes.Contains([]byte(output), []byte("\033[")) { + t.Error("expected output to contain ANSI escape codes when colors enabled") + } +} + +func TestPrintError_WithoutColors(t *testing.T) { + InitColors(ColorModeNever) + + output := captureStderr(t, func() { + PrintError("test error message") + }) + + if bytes.Contains([]byte(output), []byte("\033[")) { + t.Error("expected output to NOT contain ANSI escape codes when colors disabled") + } + if output != "Error: test error message\n" { + t.Errorf("expected plain 'Error: test error message\\n', got: %q", output) + } +} + +func TestPrintWarning_WithColors(t *testing.T) { + InitColors(ColorModeAlways) + + output := captureStderr(t, func() { + PrintWarning("test warning") + }) + + if !bytes.Contains([]byte(output), []byte("Warning:")) { + t.Errorf("expected output to contain 'Warning:', got: %s", output) + } +} + +func TestPrintWarningf(t *testing.T) { + InitColors(ColorModeNever) + + output := captureStderr(t, func() { + PrintWarningf("file %s not found", "test.txt") + }) + + if output != "Warning: file test.txt not found\n" { + t.Errorf("expected formatted warning, got: %q", output) + } +} diff --git a/internal/cmd/context.go b/internal/cmd/context.go index 1e57aed..75871d6 100644 --- a/internal/cmd/context.go +++ b/internal/cmd/context.go @@ -7,8 +7,18 @@ import ( "os" "os/signal" "syscall" + + "github.com/ArmisSecurity/armis-cli/internal/cli" ) +// ErrScanCancelled is a sentinel error indicating the scan was cancelled by the user. +// This is treated as a clean termination (warning is already printed, no additional +// error message needed) and should result in exit code 130 (128 + SIGINT). +var ErrScanCancelled = errors.New("scan cancelled by user") + +// ExitCodeCancelled is the standard Unix exit code for SIGINT (128 + 2). +const ExitCodeCancelled = 130 + // NewSignalContext creates a context that is cancelled when SIGINT or SIGTERM // is received. The returned cancel function should be called to release resources. func NewSignalContext() (context.Context, context.CancelFunc) { @@ -16,12 +26,16 @@ func NewSignalContext() (context.Context, context.CancelFunc) { } // handleScanError prints a cancellation message if the error indicates cancellation -// and returns a wrapped scan error. The ctx parameter is accepted for API consistency -// and future extensibility (e.g., logging or metrics). +// and returns an appropriate error. For context.Canceled, it prints a warning and +// returns ErrScanCancelled (which main.go handles specially without printing). +// For other errors, it returns a wrapped scan error. +// The ctx parameter is accepted for API consistency and future extensibility. func handleScanError(ctx context.Context, err error) error { _ = ctx // unused but kept for API consistency if errors.Is(err, context.Canceled) { - fmt.Fprintln(os.Stderr, "\nScan cancelled") + _, _ = fmt.Fprintln(os.Stderr, "") // newline before warning; ignore write errors + cli.PrintWarning("Scan cancelled") + return ErrScanCancelled } return fmt.Errorf("scan failed: %w", err) } diff --git a/internal/cmd/context_test.go b/internal/cmd/context_test.go index edf2d7f..5f8f56a 100644 --- a/internal/cmd/context_test.go +++ b/internal/cmd/context_test.go @@ -69,10 +69,13 @@ func TestHandleScanError(t *testing.T) { if _, err := io.Copy(&buf, r); err != nil { t.Fatalf("failed to copy stderr output: %v", err) } + if err := r.Close(); err != nil { + t.Fatalf("failed to close pipe reader: %v", err) + } return buf.String() } - t.Run("prints cancellation message when error contains context.Canceled", func(t *testing.T) { + t.Run("returns ErrScanCancelled for context.Canceled", func(t *testing.T) { ctx := context.Background() cancelErr := fmt.Errorf("operation failed: %w", context.Canceled) @@ -85,12 +88,8 @@ func TestHandleScanError(t *testing.T) { t.Errorf("expected stderr to contain 'Scan cancelled', got: %q", output) } - if !errors.Is(resultErr, context.Canceled) { - t.Errorf("expected wrapped error to contain context.Canceled") - } - - if !strings.Contains(resultErr.Error(), "scan failed") { - t.Errorf("expected error message to contain 'scan failed', got: %q", resultErr.Error()) + if !errors.Is(resultErr, ErrScanCancelled) { + t.Errorf("expected ErrScanCancelled, got: %v", resultErr) } }) diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 4758c2a..588e890 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -2,10 +2,15 @@ package cmd import ( + "context" "fmt" "os" "github.com/ArmisSecurity/armis-cli/internal/auth" + "github.com/ArmisSecurity/armis-cli/internal/cli" + "github.com/ArmisSecurity/armis-cli/internal/output" + "github.com/ArmisSecurity/armis-cli/internal/progress" + "github.com/ArmisSecurity/armis-cli/internal/update" "github.com/spf13/cobra" ) @@ -15,15 +20,17 @@ const ( ) var ( - token string - useDev bool - format string - noProgress bool - failOn []string - exitCode int - tenantID string - pageLimit int - debug bool + token string + useDev bool + format string + noProgress bool + failOn []string + exitCode int + tenantID string + pageLimit int + debug bool + noUpdateCheck bool + colorFlag string // JWT authentication clientID string @@ -33,14 +40,47 @@ var ( version = "dev" commit = "none" date = "unknown" + + // updateResultCh receives version check results from background goroutine. + updateResultCh <-chan *update.CheckResult ) var rootCmd = &cobra.Command{ - Use: "armis-cli", - Short: "Armis Security Scanner CLI", - Long: `Enterprise-grade CLI for static application security scanning integrated with Armis Cloud.`, - Version: version, - SilenceUsage: true, + Use: "armis-cli", + Short: "Armis Security Scanner CLI", + Long: `Enterprise-grade CLI for static application security scanning integrated with Armis Cloud.`, + Version: version, + SilenceUsage: true, + SilenceErrors: true, + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + // Initialize colors based on --color flag + mode := cli.ColorMode(colorFlag) + switch mode { + case cli.ColorModeAuto, cli.ColorModeAlways, cli.ColorModeNever: + // valid + default: + return fmt.Errorf("invalid --color value %q: must be auto, always, or never", colorFlag) + } + cli.InitColors(mode) + output.SyncColors() + + // Skip update check if: + // - explicitly disabled via flag or env var + // - running in CI + // - version is "dev" (development build) + // - running meta-commands (help, completion, shell completion) + isCompletionCmd := cmd.Name() == "completion" || + (cmd.Parent() != nil && cmd.Parent().Name() == "completion") + if noUpdateCheck || os.Getenv("ARMIS_NO_UPDATE_CHECK") != "" || + progress.IsCI() || version == "dev" || + cmd.Name() == "help" || cmd.Name() == "__complete" || isCompletionCmd { + return nil + } + + checker := update.NewChecker(version) + updateResultCh = checker.CheckInBackground(context.Background()) + return nil + }, } // SetVersion sets the version information for the CLI. @@ -74,6 +114,27 @@ func init() { rootCmd.PersistentFlags().IntVar(&exitCode, "exit-code", 1, "Exit code to return when build fails") rootCmd.PersistentFlags().IntVar(&pageLimit, "page-limit", getEnvOrDefaultInt("ARMIS_PAGE_LIMIT", 500), "Results page size for pagination (range: 1-1000)") rootCmd.PersistentFlags().BoolVar(&debug, "debug", false, "Enable debug mode to print detailed API responses") + rootCmd.PersistentFlags().BoolVar(&noUpdateCheck, "no-update-check", false, "Disable automatic update checking (env: ARMIS_NO_UPDATE_CHECK)") + rootCmd.PersistentFlags().StringVar(&colorFlag, "color", "auto", "Control colored output: auto, always, never") +} + +// PrintUpdateNotification prints a version update notification if one is available. +// This should be called before any os.Exit() call to ensure the notification is displayed. +func PrintUpdateNotification() { + if updateResultCh == nil { + return + } + + // Non-blocking read: only show if result is already available + select { + case result, ok := <-updateResultCh: + if ok && result != nil { + msg := update.FormatNotification(result.CurrentVersion, result.LatestVersion) + fmt.Fprint(os.Stderr, msg) + } + default: + // Check hasn't completed yet -- silently skip + } } func getEnvOrDefault(key, defaultValue string) string { diff --git a/internal/cmd/scan_image.go b/internal/cmd/scan_image.go index 618d9d0..92012ea 100644 --- a/internal/cmd/scan_image.go +++ b/internal/cmd/scan_image.go @@ -22,7 +22,11 @@ var scanImageCmd = &cobra.Command{ Use: "image [image-name]", Short: "Scan a container image", Long: `Scan a local or remote container image for security vulnerabilities.`, - Args: cobra.MaximumNArgs(1), + Example: ` $ armis-cli scan image nginx:latest + $ armis-cli scan image myapp:v1.0 --format json + $ armis-cli scan image --tarball ./image.tar + $ armis-cli scan image alpine:3.18 --sbom --vex --fail-on HIGH,CRITICAL`, + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { if tarballPath == "" && len(args) == 0 { return fmt.Errorf("either provide an image name or use --tarball flag") @@ -105,6 +109,7 @@ var scanImageCmd = &cobra.Command{ return fmt.Errorf("failed to format output: %w", err) } + PrintUpdateNotification() output.ExitIfNeeded(result, failOnSeverities, exitCode) return nil }, diff --git a/internal/cmd/scan_repo.go b/internal/cmd/scan_repo.go index 4e78aa0..06d674d 100644 --- a/internal/cmd/scan_repo.go +++ b/internal/cmd/scan_repo.go @@ -17,7 +17,11 @@ var scanRepoCmd = &cobra.Command{ Use: "repo [path]", Short: "Scan a local repository", Long: `Scan a local repository for security vulnerabilities, secrets, and license risks.`, - Args: cobra.ExactArgs(1), + Example: ` $ armis-cli scan repo . + $ armis-cli scan repo . --format json + $ armis-cli scan repo . --format sarif --fail-on HIGH,CRITICAL + $ armis-cli scan repo . --sbom --sbom-output sbom.json`, + Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { repoPath := args[0] @@ -100,6 +104,7 @@ var scanRepoCmd = &cobra.Command{ return fmt.Errorf("failed to format output: %w", err) } + PrintUpdateNotification() output.ExitIfNeeded(result, failOnSeverities, exitCode) return nil }, diff --git a/internal/output/human.go b/internal/output/human.go index 0e7b7b4..c625ea5 100644 --- a/internal/output/human.go +++ b/internal/output/human.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/ArmisSecurity/armis-cli/internal/cli" "github.com/ArmisSecurity/armis-cli/internal/model" "github.com/ArmisSecurity/armis-cli/internal/util" "github.com/mattn/go-runewidth" @@ -216,13 +217,29 @@ var ( colorUnderline = "\033[4m" //nolint:unused // reserved for future formatting options ) -func init() { - // Support NO_COLOR standard (https://no-color.org/) and dumb terminals - if os.Getenv("NO_COLOR") != "" || strings.Contains(strings.ToLower(os.Getenv("TERM")), "dumb") { +// SyncColors synchronizes the output package's color variables with the +// centralized color state from internal/cli. Must be called after cli.InitColors(). +func SyncColors() { + if cli.ColorsEnabled() { + enableColors() + } else { disableColors() } } +func enableColors() { + colorReset = "\033[0m" + colorRed = "\033[31m" + colorGreen = "\033[32m" + colorOrange = "\033[33m" + colorYellow = "\033[93m" + colorBlue = "\033[34m" + colorGray = "\033[90m" + colorBgRed = "\033[101m" + colorBold = "\033[1m" + colorUnderline = "\033[4m" +} + func disableColors() { colorReset = "" colorRed = "" diff --git a/internal/output/output.go b/internal/output/output.go index d5600c3..2e2a04e 100644 --- a/internal/output/output.go +++ b/internal/output/output.go @@ -71,6 +71,7 @@ func ExitIfNeeded(result *model.ScanResult, failOnSeverities []string, exitCode // Flush stdout to ensure all output is written before exit if err := stdoutSyncer(); err != nil { // Log flush failure to stderr (stdout may be broken) + // Using stderrWriter (instead of cli.PrintWarning) for testability _, _ = fmt.Fprintf(stderrWriter, "Warning: failed to flush stdout before exit: %v\n", err) } osExit(exitCode) diff --git a/internal/output/output_test.go b/internal/output/output_test.go index 42d0a84..356ef2c 100644 --- a/internal/output/output_test.go +++ b/internal/output/output_test.go @@ -5,6 +5,7 @@ import ( "errors" "testing" + "github.com/ArmisSecurity/armis-cli/internal/cli" "github.com/ArmisSecurity/armis-cli/internal/model" ) @@ -291,3 +292,91 @@ func TestExitIfNeeded_StdoutSyncError(t *testing.T) { t.Errorf("stderr output should contain error message, got: %s", stderrOutput) } } + +// TestSyncColors_Enabled verifies that SyncColors enables color codes when cli colors are enabled. +func TestSyncColors_Enabled(t *testing.T) { + // Ensure colors start disabled + disableColors() + + // Enable colors via cli package + cli.InitColors(cli.ColorModeAlways) + + // Sync should enable output package colors + SyncColors() + + // Verify color codes are set + if colorRed != "\033[31m" { + t.Errorf("expected colorRed to be '\\033[31m', got %q", colorRed) + } + if colorReset != "\033[0m" { + t.Errorf("expected colorReset to be '\\033[0m', got %q", colorReset) + } + if colorBold != "\033[1m" { + t.Errorf("expected colorBold to be '\\033[1m', got %q", colorBold) + } +} + +// TestSyncColors_Disabled verifies that SyncColors disables color codes when cli colors are disabled. +func TestSyncColors_Disabled(t *testing.T) { + // Ensure colors start enabled + enableColors() + + // Disable colors via cli package + cli.InitColors(cli.ColorModeNever) + + // Sync should disable output package colors + SyncColors() + + // Verify color codes are empty + if colorRed != "" { + t.Errorf("expected colorRed to be empty, got %q", colorRed) + } + if colorReset != "" { + t.Errorf("expected colorReset to be empty, got %q", colorReset) + } + if colorBold != "" { + t.Errorf("expected colorBold to be empty, got %q", colorBold) + } +} + +// TestEnableColors verifies that enableColors sets all color codes to their ANSI values. +func TestEnableColors(t *testing.T) { + // Start with colors disabled + disableColors() + + // Enable colors + enableColors() + + // Check all color codes + expectedColors := map[string]string{ + "colorReset": "\033[0m", + "colorRed": "\033[31m", + "colorGreen": "\033[32m", + "colorOrange": "\033[33m", + "colorYellow": "\033[93m", + "colorBlue": "\033[34m", + "colorGray": "\033[90m", + "colorBgRed": "\033[101m", + "colorBold": "\033[1m", + "colorUnderline": "\033[4m", + } + + actualColors := map[string]string{ + "colorReset": colorReset, + "colorRed": colorRed, + "colorGreen": colorGreen, + "colorOrange": colorOrange, + "colorYellow": colorYellow, + "colorBlue": colorBlue, + "colorGray": colorGray, + "colorBgRed": colorBgRed, + "colorBold": colorBold, + "colorUnderline": colorUnderline, + } + + for name, expected := range expectedColors { + if actual := actualColors[name]; actual != expected { + t.Errorf("%s: expected %q, got %q", name, expected, actual) + } + } +} diff --git a/internal/output/sarif.go b/internal/output/sarif.go index 760a34e..0b09b98 100644 --- a/internal/output/sarif.go +++ b/internal/output/sarif.go @@ -4,11 +4,11 @@ import ( "encoding/json" "fmt" "io" - "os" "regexp" "strconv" "strings" + "github.com/ArmisSecurity/armis-cli/internal/cli" "github.com/ArmisSecurity/armis-cli/internal/model" "github.com/ArmisSecurity/armis-cli/internal/util" ) @@ -319,7 +319,7 @@ func convertToSarifResults(findings []model.Finding, ruleIndexMap map[string]int // Sanitize file path to prevent path traversal in SARIF output sanitizedFile, err := util.SanitizePath(finding.File) if err != nil { - fmt.Fprintf(os.Stderr, "Warning: could not sanitize file path for finding %s: %v\n", finding.ID, err) + cli.PrintWarningf("could not sanitize file path for finding %s: %v", finding.ID, err) // Use finding ID to ensure unique placeholder paths in SARIF output sanitizedFile = fmt.Sprintf("unknown-%s", finding.ID) } diff --git a/internal/scan/image/helpers_test.go b/internal/scan/image/helpers_test.go index 5f5ac64..1239213 100644 --- a/internal/scan/image/helpers_test.go +++ b/internal/scan/image/helpers_test.go @@ -7,33 +7,8 @@ import ( "github.com/ArmisSecurity/armis-cli/internal/model" ) -func TestMapSeverity(t *testing.T) { - tests := []struct { - input string - expected model.Severity - }{ - {"CRITICAL", model.SeverityCritical}, - {"critical", model.SeverityCritical}, - {"HIGH", model.SeverityHigh}, - {"high", model.SeverityHigh}, - {"MEDIUM", model.SeverityMedium}, - {"medium", model.SeverityMedium}, - {"LOW", model.SeverityLow}, - {"low", model.SeverityLow}, - {"INFO", model.SeverityInfo}, - {"UNKNOWN", model.SeverityInfo}, - {"", model.SeverityInfo}, - } - - for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - result := mapSeverity(tt.input) - if result != tt.expected { - t.Errorf("mapSeverity(%q) = %v, want %v", tt.input, result, tt.expected) - } - }) - } -} +// Note: mapSeverity, formatElapsed, and formatScanStatus are now in the shared +// scan package and tested in internal/scan/status_test.go func TestCleanDescription(t *testing.T) { tests := []struct { @@ -263,54 +238,6 @@ func TestIsEmptyFinding(t *testing.T) { } } -func TestFormatElapsed(t *testing.T) { - tests := []struct { - name string - duration time.Duration - expected string - }{ - { - name: "zero duration", - duration: 0, - expected: "0s", - }, - { - name: "seconds only", - duration: 45 * time.Second, - expected: "45s", - }, - { - name: "one minute", - duration: 60 * time.Second, - expected: "1m 0s", - }, - { - name: "minutes and seconds", - duration: 125 * time.Second, - expected: "2m 5s", - }, - { - name: "many minutes", - duration: 10*time.Minute + 30*time.Second, - expected: "10m 30s", - }, - { - name: "rounds to nearest second", - duration: 45*time.Second + 600*time.Millisecond, - expected: "46s", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := formatElapsed(tt.duration) - if result != tt.expected { - t.Errorf("formatElapsed(%v) = %q, want %q", tt.duration, result, tt.expected) - } - }) - } -} - func TestNewScanner(t *testing.T) { scanner := NewScanner(nil, true, "tenant-123", 100, false, 5*time.Minute, true) diff --git a/internal/scan/image/image.go b/internal/scan/image/image.go index 483a32e..2d2ce3f 100644 --- a/internal/scan/image/image.go +++ b/internal/scan/image/image.go @@ -12,6 +12,7 @@ import ( "time" "github.com/ArmisSecurity/armis-cli/internal/api" + "github.com/ArmisSecurity/armis-cli/internal/cli" "github.com/ArmisSecurity/armis-cli/internal/model" "github.com/ArmisSecurity/armis-cli/internal/progress" "github.com/ArmisSecurity/armis-cli/internal/scan" @@ -144,18 +145,21 @@ func (s *Scanner) ScanTarball(ctx context.Context, tarballPath string) (*model.S uploadSpinner.Stop() fmt.Fprintf(os.Stderr, "Scan initiated with ID: %s\n\n", scanID) - spinner := progress.NewSpinnerWithContext(ctx, "Waiting for scan to complete...", s.noProgress) + spinner := progress.NewSpinnerWithContext(ctx, "Scanning image for vulnerabilities...", s.noProgress) spinner.Start() defer spinner.Stop() - _, err = s.client.WaitForIngest(ctx, s.tenantID, scanID, s.pollInterval, s.timeout) + _, err = s.client.WaitForIngest(ctx, s.tenantID, scanID, s.pollInterval, s.timeout, + func(status model.IngestStatusData) { + spinner.Update(scan.FormatScanStatus(status.ScanStatus, "Scanning image for vulnerabilities...")) + }) elapsed := spinner.GetElapsed() if err != nil { return nil, fmt.Errorf("failed to wait for scan: %w", err) } spinner.Stop() - fmt.Fprintf(os.Stderr, "Scan completed in %s. Fetching results...\n", formatElapsed(elapsed)) + fmt.Fprintf(os.Stderr, "Scan completed in %s. Fetching results...\n", scan.FormatElapsed(elapsed)) findings, err := s.client.FetchAllNormalizedResults(ctx, s.tenantID, scanID, s.pageLimit) if err != nil { @@ -172,7 +176,7 @@ func (s *Scanner) ScanTarball(ctx context.Context, tarballPath string) (*model.S downloader := scan.NewSBOMVEXDownloader(s.client, s.tenantID, s.sbomVEXOpts) if err := downloader.Download(ctx, scanID, artifactName); err != nil { // Log warning but don't fail the scan - fmt.Fprintf(os.Stderr, "Warning: %v\n", err) + cli.PrintWarningf("%v", err) } } @@ -313,7 +317,7 @@ func convertNormalizedFindings(normalizedFindings []model.NormalizedFinding, deb finding := model.Finding{ ID: nf.NormalizedTask.FindingID, - Severity: mapSeverity(nf.NormalizedRemediation.ToolSeverity), + Severity: scan.MapSeverity(nf.NormalizedRemediation.ToolSeverity), Description: nf.NormalizedRemediation.Description, CVEs: nf.NormalizedRemediation.VulnerabilityTypeMetadata.CVEs, CWEs: nf.NormalizedRemediation.VulnerabilityTypeMetadata.CWEs, @@ -444,28 +448,3 @@ func isEmptyFinding(nf model.NormalizedFinding) bool { return !hasDescription && !hasCVEsOrCWEs && !hasCategory } - -func mapSeverity(toolSeverity string) model.Severity { - switch strings.ToUpper(toolSeverity) { - case "CRITICAL": - return model.SeverityCritical - case "HIGH": - return model.SeverityHigh - case "MEDIUM": - return model.SeverityMedium - case "LOW": - return model.SeverityLow - default: - return model.SeverityInfo - } -} - -func formatElapsed(d time.Duration) string { - d = d.Round(time.Second) - minutes := int(d.Minutes()) - seconds := int(d.Seconds()) % 60 - if minutes > 0 { - return fmt.Sprintf("%dm %ds", minutes, seconds) - } - return fmt.Sprintf("%ds", seconds) -} diff --git a/internal/scan/repo/helpers_test.go b/internal/scan/repo/helpers_test.go index 57501eb..be266df 100644 --- a/internal/scan/repo/helpers_test.go +++ b/internal/scan/repo/helpers_test.go @@ -123,33 +123,7 @@ func TestIsTestFile(t *testing.T) { } } -func TestMapSeverity(t *testing.T) { - tests := []struct { - input string - expected model.Severity - }{ - {"CRITICAL", model.SeverityCritical}, - {"critical", model.SeverityCritical}, - {"HIGH", model.SeverityHigh}, - {"high", model.SeverityHigh}, - {"MEDIUM", model.SeverityMedium}, - {"medium", model.SeverityMedium}, - {"LOW", model.SeverityLow}, - {"low", model.SeverityLow}, - {"INFO", model.SeverityInfo}, - {"UNKNOWN", model.SeverityInfo}, - {"", model.SeverityInfo}, - } - - for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - result := mapSeverity(tt.input) - if result != tt.expected { - t.Errorf("mapSeverity(%q) = %v, want %v", tt.input, result, tt.expected) - } - }) - } -} +// Note: mapSeverity is now in the shared scan package and tested in internal/scan/status_test.go func TestCleanDescription(t *testing.T) { tests := []struct { diff --git a/internal/scan/repo/repo.go b/internal/scan/repo/repo.go index 463f688..a6b5cd4 100644 --- a/internal/scan/repo/repo.go +++ b/internal/scan/repo/repo.go @@ -14,6 +14,7 @@ import ( "time" "github.com/ArmisSecurity/armis-cli/internal/api" + "github.com/ArmisSecurity/armis-cli/internal/cli" "github.com/ArmisSecurity/armis-cli/internal/model" "github.com/ArmisSecurity/armis-cli/internal/progress" "github.com/ArmisSecurity/armis-cli/internal/scan" @@ -106,7 +107,7 @@ func (s *Scanner) Scan(ctx context.Context, path string) (*model.ScanResult, err // Targeted file scanning mode - scan only specified files existing, warnings := s.includeFiles.ValidateExistence() for _, w := range warnings { - fmt.Fprintf(os.Stderr, "Warning: %s\n", w) + cli.PrintWarning(w) } if len(existing) == 0 { @@ -195,14 +196,17 @@ func (s *Scanner) Scan(ctx context.Context, path string) (*model.ScanResult, err analysisSpinner.Start() defer analysisSpinner.Stop() - _, err = s.client.WaitForIngest(ctx, s.tenantID, scanID, s.pollInterval, s.timeout) + _, err = s.client.WaitForIngest(ctx, s.tenantID, scanID, s.pollInterval, s.timeout, + func(status model.IngestStatusData) { + analysisSpinner.Update(scan.FormatScanStatus(status.ScanStatus, "Analyzing code for vulnerabilities...")) + }) elapsed := analysisSpinner.GetElapsed() analysisSpinner.Stop() if err != nil { return nil, fmt.Errorf("failed to wait for scan: %w", err) } - fmt.Fprintf(os.Stderr, "Analysis completed in %s\n\n", formatElapsed(elapsed)) + fmt.Fprintf(os.Stderr, "Analysis completed in %s\n\n", scan.FormatElapsed(elapsed)) fetchSpinner := progress.NewSpinnerWithContext(ctx, "Fetching scan results...", s.noProgress) fetchSpinner.Start() @@ -220,7 +224,7 @@ func (s *Scanner) Scan(ctx context.Context, path string) (*model.ScanResult, err downloader := scan.NewSBOMVEXDownloader(s.client, s.tenantID, s.sbomVEXOpts) if err := downloader.Download(ctx, scanID, filepath.Base(absPath)); err != nil { // Log warning but don't fail the scan - fmt.Fprintf(os.Stderr, "Warning: %v\n", err) + cli.PrintWarningf("%v", err) } } @@ -270,7 +274,7 @@ func (s *Scanner) tarGzDirectory(sourcePath string, writer io.Writer, ignoreMatc // Skip symlinks to avoid security risks (symlinks pointing outside repo) // and potential issues (broken symlinks, loops) if info.Mode()&os.ModeSymlink != 0 { - fmt.Fprintf(os.Stderr, "Warning: skipping symlink %s\n", relPath) + cli.PrintWarningf("skipping symlink %s", relPath) return nil } @@ -347,14 +351,14 @@ func (s *Scanner) tarGzFiles(repoRoot string, files []string, writer io.Writer) // Defense-in-depth: verify path is within repo root if !isPathContained(repoRoot, absPath) { - fmt.Fprintf(os.Stderr, "Warning: skipping path outside repository: %s\n", relPath) + cli.PrintWarningf("skipping path outside repository: %s", relPath) continue } info, err := os.Stat(absPath) if err != nil { // Skip files that don't exist (may have been deleted) - fmt.Fprintf(os.Stderr, "Warning: skipping %s: %v\n", relPath, err) + cli.PrintWarningf("skipping %s: %v", relPath, err) continue } @@ -365,7 +369,7 @@ func (s *Scanner) tarGzFiles(repoRoot string, files []string, writer io.Writer) // Skip symlinks for security if info.Mode()&os.ModeSymlink != 0 { - fmt.Fprintf(os.Stderr, "Warning: skipping symlink %s\n", relPath) + cli.PrintWarningf("skipping symlink %s", relPath) continue } @@ -599,7 +603,7 @@ func convertNormalizedFindings(normalizedFindings []model.NormalizedFinding, deb finding := model.Finding{ ID: nf.NormalizedTask.FindingID, - Severity: mapSeverity(nf.NormalizedRemediation.ToolSeverity), + Severity: scan.MapSeverity(nf.NormalizedRemediation.ToolSeverity), Description: nf.NormalizedRemediation.Description, CVEs: nf.NormalizedRemediation.VulnerabilityTypeMetadata.CVEs, CWEs: nf.NormalizedRemediation.VulnerabilityTypeMetadata.CWEs, @@ -785,28 +789,3 @@ func isEmptyFinding(nf model.NormalizedFinding) bool { return !hasDescription && !hasCVEsOrCWEs && !hasCategory } - -func mapSeverity(toolSeverity string) model.Severity { - switch strings.ToUpper(toolSeverity) { - case "CRITICAL": - return model.SeverityCritical - case "HIGH": - return model.SeverityHigh - case "MEDIUM": - return model.SeverityMedium - case "LOW": - return model.SeverityLow - default: - return model.SeverityInfo - } -} - -func formatElapsed(d time.Duration) string { - d = d.Round(time.Second) - minutes := int(d.Minutes()) - seconds := int(d.Seconds()) % 60 - if minutes > 0 { - return fmt.Sprintf("%dm %ds", minutes, seconds) - } - return fmt.Sprintf("%ds", seconds) -} diff --git a/internal/scan/repo/repo_test.go b/internal/scan/repo/repo_test.go index 7aa414c..22bdd1e 100644 --- a/internal/scan/repo/repo_test.go +++ b/internal/scan/repo/repo_test.go @@ -503,53 +503,8 @@ func TestConvertNormalizedFindings(t *testing.T) { }) } -func TestFormatElapsed(t *testing.T) { - tests := []struct { - name string - duration time.Duration - expected string - }{ - { - name: "zero duration", - duration: 0, - expected: "0s", - }, - { - name: "seconds only", - duration: 45 * time.Second, - expected: "45s", - }, - { - name: "one minute", - duration: 60 * time.Second, - expected: "1m 0s", - }, - { - name: "minutes and seconds", - duration: 125 * time.Second, - expected: "2m 5s", - }, - { - name: "many minutes", - duration: 10*time.Minute + 30*time.Second, - expected: "10m 30s", - }, - { - name: "rounds to nearest second", - duration: 45*time.Second + 600*time.Millisecond, - expected: "46s", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := formatElapsed(tt.duration) - if result != tt.expected { - t.Errorf("formatElapsed(%v) = %q, want %q", tt.duration, result, tt.expected) - } - }) - } -} +// Note: formatElapsed and formatScanStatus are now in the shared scan package +// and tested in internal/scan/status_test.go // mockFileInfo implements os.FileInfo for testing type mockFileInfo struct { diff --git a/internal/scan/sbom_vex.go b/internal/scan/sbom_vex.go index 7730dd7..9a7c88c 100644 --- a/internal/scan/sbom_vex.go +++ b/internal/scan/sbom_vex.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/ArmisSecurity/armis-cli/internal/api" + "github.com/ArmisSecurity/armis-cli/internal/cli" "github.com/ArmisSecurity/armis-cli/internal/util" ) @@ -70,10 +71,10 @@ func (d *SBOMVEXDownloader) Download(ctx context.Context, scanID, artifactName s outputPath = filepath.Join(".armis", sanitizedName+"-sbom.json") } if err := d.downloadAndSave(ctx, sbomURL, outputPath, "SBOM"); err != nil { - _, _ = fmt.Fprintf(os.Stderr, "Warning: %v\n", err) + cli.PrintWarningf("%v", err) } } else { - _, _ = fmt.Fprintf(os.Stderr, "Warning: SBOM was requested but not available in results\n") + cli.PrintWarning("SBOM was requested but not available in results") } } @@ -86,10 +87,10 @@ func (d *SBOMVEXDownloader) Download(ctx context.Context, scanID, artifactName s outputPath = filepath.Join(".armis", sanitizedName+"-vex.json") } if err := d.downloadAndSave(ctx, vexURL, outputPath, "VEX"); err != nil { - _, _ = fmt.Fprintf(os.Stderr, "Warning: %v\n", err) + cli.PrintWarningf("%v", err) } } else { - _, _ = fmt.Fprintf(os.Stderr, "Warning: VEX was requested but not available in results\n") + cli.PrintWarning("VEX was requested but not available in results") } } diff --git a/internal/scan/status.go b/internal/scan/status.go new file mode 100644 index 0000000..8353936 --- /dev/null +++ b/internal/scan/status.go @@ -0,0 +1,61 @@ +// Package scan provides shared utilities for scanning operations. +package scan + +import ( + "fmt" + "strings" + "time" + + "github.com/ArmisSecurity/armis-cli/internal/model" +) + +// FormatScanStatus returns a human-readable message for the current scan phase. +// The inProgressMsg parameter customizes the message for the IN_PROGRESS state, +// allowing different scan types (repo, image) to show context-specific messages. +// Status values are from the ArtifactScanStatus enum in the Project-Moose API. +func FormatScanStatus(scanStatus, inProgressMsg string) string { + switch strings.ToUpper(scanStatus) { + case "INITIATED": + return "Scan initiated, preparing analysis..." + case "IN_PROGRESS": + return inProgressMsg + case "COMPLETED": + return "Scan completed, preparing results..." + case "FAILED": + return "Scan encountered an error" + case "STOPPED": + return "Scan was stopped" + default: + return fmt.Sprintf("Scanning... [%s]", strings.ToUpper(scanStatus)) + } +} + +// FormatElapsed formats a duration as a human-readable time string. +// Examples: "45s", "2m 30s" +func FormatElapsed(d time.Duration) string { + d = d.Round(time.Second) + minutes := int(d.Minutes()) + seconds := int(d.Seconds()) % 60 + if minutes > 0 { + return fmt.Sprintf("%dm %ds", minutes, seconds) + } + return fmt.Sprintf("%ds", seconds) +} + +// MapSeverity converts a string severity level to the model.Severity type. +// Recognized values (case-insensitive): CRITICAL, HIGH, MEDIUM, LOW. +// Unrecognized values default to Info severity. +func MapSeverity(toolSeverity string) model.Severity { + switch strings.ToUpper(toolSeverity) { + case "CRITICAL": + return model.SeverityCritical + case "HIGH": + return model.SeverityHigh + case "MEDIUM": + return model.SeverityMedium + case "LOW": + return model.SeverityLow + default: + return model.SeverityInfo + } +} diff --git a/internal/scan/status_test.go b/internal/scan/status_test.go new file mode 100644 index 0000000..fb827cd --- /dev/null +++ b/internal/scan/status_test.go @@ -0,0 +1,88 @@ +package scan + +import ( + "testing" + "time" + + "github.com/ArmisSecurity/armis-cli/internal/model" +) + +func TestFormatScanStatus(t *testing.T) { + tests := []struct { + name string + status string + inProgressMsg string + want string + }{ + {"initiated", "INITIATED", "Scanning...", "Scan initiated, preparing analysis..."}, + {"initiated lowercase", "initiated", "Scanning...", "Scan initiated, preparing analysis..."}, + {"in_progress", "IN_PROGRESS", "Analyzing code...", "Analyzing code..."}, + {"in_progress lowercase", "in_progress", "Scanning image...", "Scanning image..."}, + {"completed", "COMPLETED", "Scanning...", "Scan completed, preparing results..."}, + {"failed", "FAILED", "Scanning...", "Scan encountered an error"}, + {"stopped", "STOPPED", "Scanning...", "Scan was stopped"}, + {"unknown", "UNKNOWN", "Scanning...", "Scanning... [UNKNOWN]"}, + {"empty", "", "Scanning...", "Scanning... []"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FormatScanStatus(tt.status, tt.inProgressMsg) + if got != tt.want { + t.Errorf("FormatScanStatus(%q, %q) = %q, want %q", tt.status, tt.inProgressMsg, got, tt.want) + } + }) + } +} + +func TestFormatElapsed(t *testing.T) { + tests := []struct { + name string + duration time.Duration + want string + }{ + {"zero", 0, "0s"}, + {"seconds only", 45 * time.Second, "45s"}, + {"exactly one minute", 60 * time.Second, "1m 0s"}, + {"minutes and seconds", 2*time.Minute + 30*time.Second, "2m 30s"}, + {"rounds to nearest second", 45*time.Second + 500*time.Millisecond, "46s"}, + {"rounds down", 45*time.Second + 400*time.Millisecond, "45s"}, + {"large duration", 10*time.Minute + 5*time.Second, "10m 5s"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FormatElapsed(tt.duration) + if got != tt.want { + t.Errorf("FormatElapsed(%v) = %q, want %q", tt.duration, got, tt.want) + } + }) + } +} + +func TestMapSeverity(t *testing.T) { + tests := []struct { + name string + severity string + want model.Severity + }{ + {"critical", "CRITICAL", model.SeverityCritical}, + {"critical lowercase", "critical", model.SeverityCritical}, + {"high", "HIGH", model.SeverityHigh}, + {"high mixed case", "High", model.SeverityHigh}, + {"medium", "MEDIUM", model.SeverityMedium}, + {"low", "LOW", model.SeverityLow}, + {"unknown", "UNKNOWN", model.SeverityInfo}, + {"empty", "", model.SeverityInfo}, + {"info", "INFO", model.SeverityInfo}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MapSeverity(tt.severity) + if got != tt.want { + t.Errorf("MapSeverity(%q) = %v, want %v", tt.severity, got, tt.want) + } + }) + } +} diff --git a/internal/update/update.go b/internal/update/update.go new file mode 100644 index 0000000..39f4df0 --- /dev/null +++ b/internal/update/update.go @@ -0,0 +1,312 @@ +// Package update provides version update checking for the CLI. +package update + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "time" + + "github.com/ArmisSecurity/armis-cli/internal/util" +) + +const ( + // githubReleasesURL is the GitHub API endpoint for the latest release. + githubReleasesURL = "https://api.github.com/repos/ArmisSecurity/armis-cli/releases/latest" + + // cacheTTL is how long to cache the version check result. + cacheTTL = 24 * time.Hour + + // checkTimeout is the maximum time for a version check. + checkTimeout = 10 * time.Second + + // cacheFileName is the name of the cache file. + cacheFileName = "update-check.json" + + // cacheDirName is the subdirectory name for cache files. + cacheDirName = "armis-cli" +) + +// CheckResult holds the result of a version check. +type CheckResult struct { + LatestVersion string + CurrentVersion string +} + +// cacheFile is the on-disk JSON structure for persisting check results. +type cacheFile struct { + LatestVersion string `json:"latest_version"` + CheckedAt time.Time `json:"checked_at"` +} + +// githubRelease is the minimal structure from the GitHub releases API. +type githubRelease struct { + TagName string `json:"tag_name"` +} + +// Checker performs version update checks. +type Checker struct { + currentVersion string + githubAPIURL string + cacheTTL time.Duration + cacheDir string // for testing; empty means use os.UserCacheDir() + httpClient *http.Client +} + +// NewChecker creates a version update checker. +// currentVersion should be the semver version (e.g., "1.0.7"). +func NewChecker(currentVersion string) *Checker { + return &Checker{ + currentVersion: currentVersion, + githubAPIURL: githubReleasesURL, + cacheTTL: cacheTTL, + httpClient: &http.Client{ + Timeout: checkTimeout, + }, + } +} + +// CheckInBackground starts a non-blocking version check. +// Returns a channel that will receive at most one *CheckResult. +// The channel is closed when the check completes (or is skipped). +// If the result is nil, no update notification should be shown. +func (c *Checker) CheckInBackground(ctx context.Context) <-chan *CheckResult { + ch := make(chan *CheckResult, 1) + + // Use a short-lived context so the background check does not hold + // the process open. + checkCtx, cancel := context.WithTimeout(ctx, checkTimeout) + + go func() { + defer cancel() + defer close(ch) + result := c.check(checkCtx) + if result != nil { + ch <- result + } + }() + + return ch +} + +// check performs the actual version check (blocking). +func (c *Checker) check(ctx context.Context) *CheckResult { + // Try reading cache first + cached := c.readCache() + if cached != nil && time.Since(cached.CheckedAt) < c.cacheTTL { + // Cache is fresh -- use cached version + if IsNewer(c.currentVersion, cached.LatestVersion) { + return &CheckResult{ + LatestVersion: cached.LatestVersion, + CurrentVersion: c.currentVersion, + } + } + return nil // no update needed + } + + // Fetch from GitHub + latest, err := c.fetchLatestVersion(ctx) + if err != nil { + return nil // silently fail + } + + // Don't cache empty tags - retry on next check + if latest == "" { + return nil + } + + // Write to cache (best-effort) + c.writeCache(&cacheFile{ + LatestVersion: latest, + CheckedAt: time.Now(), + }) + + // Compare + if IsNewer(c.currentVersion, latest) { + return &CheckResult{ + LatestVersion: latest, + CurrentVersion: c.currentVersion, + } + } + + return nil +} + +// fetchLatestVersion queries the GitHub releases API. +func (c *Checker) fetchLatestVersion(ctx context.Context) (string, error) { + req, err := http.NewRequestWithContext(ctx, "GET", c.githubAPIURL, nil) + if err != nil { + return "", err + } + req.Header.Set("Accept", "application/vnd.github.v3+json") + req.Header.Set("User-Agent", "armis-cli-update-check") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("GitHub API returned %d", resp.StatusCode) + } + + // Limit body size to prevent memory issues + body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) + if err != nil { + return "", err + } + + var release githubRelease + if err := json.Unmarshal(body, &release); err != nil { + return "", err + } + + return release.TagName, nil +} + +// getCacheFilePath returns the path to the cache file. +func (c *Checker) getCacheFilePath() string { + if c.cacheDir != "" { + // Validate cacheDir to prevent path traversal (CWE-73) + sanitized, err := util.SanitizePath(c.cacheDir) + if err != nil { + return "" // invalid cacheDir, disable caching + } + return filepath.Join(sanitized, cacheFileName) + } + cacheDir, err := os.UserCacheDir() + if err != nil { + return "" // no caching possible + } + return filepath.Join(cacheDir, cacheDirName, cacheFileName) +} + +// readCache attempts to read a cached check result. +// Returns nil if cache is missing or corrupt. +func (c *Checker) readCache() *cacheFile { + path := c.getCacheFilePath() + if path == "" { + return nil + } + // Validate the final path before reading to prevent path traversal (CWE-73) + sanitizedPath, err := util.SanitizePath(path) + if err != nil { + return nil + } + data, err := os.ReadFile(sanitizedPath) //nolint:gosec // path validated by SanitizePath + if err != nil { + return nil + } + var cache cacheFile + if err := json.Unmarshal(data, &cache); err != nil { + return nil + } + return &cache +} + +// writeCache persists a check result to disk. +// Errors are silently ignored. +func (c *Checker) writeCache(result *cacheFile) { + path := c.getCacheFilePath() + if path == "" { + return + } + // Validate the final path before writing to prevent path traversal (CWE-73) + sanitizedPath, err := util.SanitizePath(path) + if err != nil { + return + } + dir := filepath.Dir(sanitizedPath) + if err := os.MkdirAll(dir, 0o700); err != nil { + return + } + data, err := json.Marshal(result) + if err != nil { + return + } + _ = os.WriteFile(sanitizedPath, data, 0o600) //nolint:gosec // path validated by SanitizePath +} + +// IsNewer returns true if latest is a newer version than current. +// Versions may optionally have a "v" prefix. +func IsNewer(current, latest string) bool { + current = strings.TrimPrefix(current, "v") + latest = strings.TrimPrefix(latest, "v") + + curParts := parseVersion(current) + latParts := parseVersion(latest) + + if curParts == nil || latParts == nil { + return false + } + + for i := 0; i < 3; i++ { + if latParts[i] > curParts[i] { + return true + } + if latParts[i] < curParts[i] { + return false + } + } + return false +} + +// parseVersion returns [major, minor, patch] or nil if invalid. +func parseVersion(v string) []int { + // Strip any pre-release suffix (e.g., "-rc1") + if idx := strings.IndexByte(v, '-'); idx >= 0 { + v = v[:idx] + } + parts := strings.SplitN(v, ".", 3) + if len(parts) != 3 { + return nil + } + result := make([]int, 3) + for i, p := range parts { + n, err := strconv.Atoi(p) + if err != nil || n < 0 { + return nil + } + result[i] = n + } + return result +} + +// FormatNotification builds the user-facing notification string. +func FormatNotification(current, latest string) string { + current = strings.TrimPrefix(current, "v") + latest = strings.TrimPrefix(latest, "v") + + updateCmd := getUpdateCommand() + + msg := fmt.Sprintf( + "\nNote: armis-cli v%s is available (you have v%s)\n", + latest, current, + ) + if updateCmd != "" { + msg += fmt.Sprintf(" Run '%s' to update\n", updateCmd) + } + return msg +} + +// getUpdateCommand returns the appropriate update command for the current OS. +func getUpdateCommand() string { + switch runtime.GOOS { + case "darwin": + return "brew upgrade armis-cli" + case "linux": + return "curl -sSL https://raw.githubusercontent.com/ArmisSecurity/armis-cli/main/scripts/install.sh | bash" + case "windows": + return "irm https://raw.githubusercontent.com/ArmisSecurity/armis-cli/main/scripts/install.ps1 | iex" + default: + return "" + } +} diff --git a/internal/update/update_test.go b/internal/update/update_test.go new file mode 100644 index 0000000..d645e80 --- /dev/null +++ b/internal/update/update_test.go @@ -0,0 +1,470 @@ +package update + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +const testLatestVersion = "v1.2.0" + +func TestIsNewer(t *testing.T) { + tests := []struct { + name string + current string + latest string + expected bool + }{ + {"newer major", "1.0.0", "2.0.0", true}, + {"newer minor", "1.0.0", "1.1.0", true}, + {"newer patch", "1.0.0", "1.0.1", true}, + {"same version", "1.0.7", "1.0.7", false}, + {"older major", "2.0.0", "1.0.0", false}, + {"older minor", "1.2.0", "1.1.0", false}, + {"older patch", "1.0.2", "1.0.1", false}, + {"with v prefix current", "v1.0.0", "1.1.0", true}, + {"with v prefix latest", "1.0.0", "v1.1.0", true}, + {"with v prefix both", "v1.0.0", "v1.1.0", true}, + {"pre-release stripped", "1.0.0", "1.1.0-rc1", true}, + {"pre-release current", "1.0.0-rc1", "1.0.0", false}, + {"dev current", "dev", "1.0.0", false}, + {"invalid current", "not-a-version", "1.0.0", false}, + {"invalid latest", "1.0.0", "not-a-version", false}, + {"empty current", "", "1.0.0", false}, + {"empty latest", "1.0.0", "", false}, + {"empty both", "", "", false}, + {"two part version", "1.0", "1.0.1", false}, + {"four part version", "1.0.0.0", "1.0.1", false}, + {"negative numbers", "1.0.0", "-1.0.0", false}, + {"large numbers", "1.0.7", "1.0.100", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsNewer(tt.current, tt.latest) + if result != tt.expected { + t.Errorf("IsNewer(%q, %q) = %v, want %v", + tt.current, tt.latest, result, tt.expected) + } + }) + } +} + +func TestParseVersion(t *testing.T) { + tests := []struct { + name string + version string + expected []int + }{ + {"valid version", "1.2.3", []int{1, 2, 3}}, + {"with pre-release", "1.2.3-rc1", []int{1, 2, 3}}, + {"zeros", "0.0.0", []int{0, 0, 0}}, + {"large numbers", "10.20.30", []int{10, 20, 30}}, + {"two parts", "1.2", nil}, + {"one part", "1", nil}, + {"empty", "", nil}, + {"non-numeric", "a.b.c", nil}, + {"mixed", "1.b.3", nil}, + {"negative", "-1.0.0", nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseVersion(tt.version) + if tt.expected == nil { + if result != nil { + t.Errorf("parseVersion(%q) = %v, want nil", tt.version, result) + } + return + } + if result == nil { + t.Errorf("parseVersion(%q) = nil, want %v", tt.version, tt.expected) + return + } + for i := 0; i < 3; i++ { + if result[i] != tt.expected[i] { + t.Errorf("parseVersion(%q)[%d] = %d, want %d", + tt.version, i, result[i], tt.expected[i]) + } + } + }) + } +} + +func TestFormatNotification(t *testing.T) { + result := FormatNotification("1.0.0", "1.1.0") + + if result == "" { + t.Error("FormatNotification returned empty string") + } + + // Check that it contains the version numbers + if !strings.Contains(result, "v1.1.0") { + t.Errorf("notification should contain latest version, got: %s", result) + } + if !strings.Contains(result, "v1.0.0") { + t.Errorf("notification should contain current version, got: %s", result) + } + + // Check with v prefix + result = FormatNotification("v1.0.0", "v1.1.0") + if !strings.Contains(result, "v1.1.0") { + t.Errorf("notification should normalize version, got: %s", result) + } +} + +func TestChecker_FetchLatestVersion(t *testing.T) { + tests := []struct { + name string + responseCode int + responseBody string + expectedTag string + expectError bool + }{ + { + name: "success", + responseCode: http.StatusOK, + responseBody: `{"tag_name": "` + testLatestVersion + `"}`, + expectedTag: testLatestVersion, + expectError: false, + }, + { + name: "not found", + responseCode: http.StatusNotFound, + responseBody: `{"message": "Not Found"}`, + expectedTag: "", + expectError: true, + }, + { + name: "rate limited", + responseCode: http.StatusForbidden, + responseBody: `{"message": "API rate limit exceeded"}`, + expectedTag: "", + expectError: true, + }, + { + name: "invalid json", + responseCode: http.StatusOK, + responseBody: `not json`, + expectedTag: "", + expectError: true, + }, + { + name: "empty tag", + responseCode: http.StatusOK, + responseBody: `{"tag_name": ""}`, + expectedTag: "", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify headers + if r.Header.Get("Accept") != "application/vnd.github.v3+json" { + t.Error("missing Accept header") + } + if r.Header.Get("User-Agent") != "armis-cli-update-check" { + t.Error("missing User-Agent header") + } + + w.WriteHeader(tt.responseCode) + _, _ = w.Write([]byte(tt.responseBody)) + })) + defer server.Close() + + checker := NewChecker("1.0.0") + checker.githubAPIURL = server.URL + + tag, err := checker.fetchLatestVersion(context.Background()) + + if tt.expectError { + if err == nil { + t.Error("expected error, got nil") + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if tag != tt.expectedTag { + t.Errorf("tag = %q, want %q", tag, tt.expectedTag) + } + } + }) + } +} + +func TestChecker_CacheReadWrite(t *testing.T) { + cacheDir := t.TempDir() + + checker := NewChecker("1.0.0") + checker.cacheDir = cacheDir + + // Initially no cache + cached := checker.readCache() + if cached != nil { + t.Error("expected nil cache initially") + } + + // Write cache + now := time.Now() + checker.writeCache(&cacheFile{ + LatestVersion: "v1.2.0", + CheckedAt: now, + }) + + // Read it back + cached = checker.readCache() + if cached == nil { + t.Fatal("expected non-nil cache after write") + return + } + if cached.LatestVersion != testLatestVersion { + t.Errorf("LatestVersion = %q, want %q", cached.LatestVersion, "v1.2.0") + } + + // Verify file was created + path := filepath.Join(cacheDir, cacheFileName) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Error("cache file was not created") + } +} + +func TestChecker_CacheExpiry(t *testing.T) { + cacheDir := t.TempDir() + + // Create a mock server that returns v2.0.0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"tag_name": "v2.0.0"}`)) + })) + defer server.Close() + + checker := NewChecker("1.0.0") + checker.cacheDir = cacheDir + checker.githubAPIURL = server.URL + checker.cacheTTL = time.Hour // 1 hour TTL + + // Write an old cache entry + oldTime := time.Now().Add(-2 * time.Hour) // 2 hours ago + checker.writeCache(&cacheFile{ + LatestVersion: "v1.5.0", + CheckedAt: oldTime, + }) + + // Check should fetch fresh because cache is expired + result := checker.check(context.Background()) + if result == nil { + t.Fatal("expected non-nil result") + return + } + if result.LatestVersion != "v2.0.0" { + t.Errorf("LatestVersion = %q, want %q (should have fetched fresh)", result.LatestVersion, "v2.0.0") + } +} + +func TestChecker_CacheFresh(t *testing.T) { + cacheDir := t.TempDir() + + // Create a mock server that should NOT be called + serverCalled := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + serverCalled = true + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"tag_name": "v2.0.0"}`)) + })) + defer server.Close() + + checker := NewChecker("1.0.0") + checker.cacheDir = cacheDir + checker.githubAPIURL = server.URL + checker.cacheTTL = time.Hour + + // Write a fresh cache entry + checker.writeCache(&cacheFile{ + LatestVersion: "v1.5.0", + CheckedAt: time.Now(), // fresh + }) + + // Check should use cache + result := checker.check(context.Background()) + if result == nil { + t.Fatal("expected non-nil result") + return + } + if result.LatestVersion != "v1.5.0" { + t.Errorf("LatestVersion = %q, want %q (should have used cache)", result.LatestVersion, "v1.5.0") + } + if serverCalled { + t.Error("server should not have been called when cache is fresh") + } +} + +func TestChecker_NetworkFailure(t *testing.T) { + cacheDir := t.TempDir() + + checker := NewChecker("1.0.0") + checker.cacheDir = cacheDir + checker.githubAPIURL = "http://localhost:1" // invalid port + + // Should not panic, should return nil + result := checker.check(context.Background()) + if result != nil { + t.Error("expected nil result on network failure") + } +} + +func TestChecker_CheckInBackground(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"tag_name": "` + testLatestVersion + `"}`)) + })) + defer server.Close() + + checker := NewChecker("1.0.0") + checker.cacheDir = t.TempDir() + checker.githubAPIURL = server.URL + + ctx := context.Background() + ch := checker.CheckInBackground(ctx) + + // Should receive a result + select { + case result, ok := <-ch: + if !ok { + t.Error("channel closed without result") + } + if result == nil { + t.Error("expected non-nil result") + } else if result.LatestVersion != testLatestVersion { + t.Errorf("LatestVersion = %q, want %q", result.LatestVersion, "v1.2.0") + } + case <-time.After(5 * time.Second): + t.Error("timed out waiting for result") + } + + // Channel should be closed + select { + case _, ok := <-ch: + if ok { + t.Error("expected channel to be closed") + } + case <-time.After(time.Second): + t.Error("channel not closed") + } +} + +func TestChecker_NoUpdateNeeded(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"tag_name": "v1.0.0"}`)) + })) + defer server.Close() + + checker := NewChecker("1.0.0") // same version + checker.cacheDir = t.TempDir() + checker.githubAPIURL = server.URL + + result := checker.check(context.Background()) + if result != nil { + t.Error("expected nil result when no update is needed") + } +} + +func TestChecker_NoCacheDir(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"tag_name": "` + testLatestVersion + `"}`)) + })) + defer server.Close() + + checker := NewChecker("1.0.0") + checker.cacheDir = "/nonexistent/path/that/cannot/be/created/\x00" + checker.githubAPIURL = server.URL + + // Should still work, just without caching + result := checker.check(context.Background()) + if result == nil { + t.Error("expected non-nil result even without cache") + } +} + +func TestChecker_CorruptCache(t *testing.T) { + cacheDir := t.TempDir() + + // Write corrupt cache file + cachePath := filepath.Join(cacheDir, cacheFileName) + err := os.WriteFile(cachePath, []byte("not valid json"), 0o600) + if err != nil { + t.Fatal(err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"tag_name": "` + testLatestVersion + `"}`)) + })) + defer server.Close() + + checker := NewChecker("1.0.0") + checker.cacheDir = cacheDir + checker.githubAPIURL = server.URL + + // Should fetch fresh due to corrupt cache + result := checker.check(context.Background()) + if result == nil { + t.Fatal("expected non-nil result") + } + if result.LatestVersion != testLatestVersion { + t.Errorf("LatestVersion = %q, want %q", result.LatestVersion, "v1.2.0") + } +} + +func TestGetUpdateCommand(t *testing.T) { + // Just verify it returns something (actual value depends on runtime.GOOS) + cmd := getUpdateCommand() + // We can't easily test platform-specific behavior, but we can ensure it doesn't panic + _ = cmd +} + +func TestNewChecker(t *testing.T) { + checker := NewChecker("1.0.7") + if checker.currentVersion != "1.0.7" { + t.Errorf("currentVersion = %q, want %q", checker.currentVersion, "1.0.7") + } + if checker.githubAPIURL != githubReleasesURL { + t.Errorf("githubAPIURL = %q, want %q", checker.githubAPIURL, githubReleasesURL) + } + if checker.cacheTTL != cacheTTL { + t.Errorf("cacheTTL = %v, want %v", checker.cacheTTL, cacheTTL) + } + if checker.httpClient == nil { + t.Error("httpClient should not be nil") + } +} + +func TestCacheFileJSON(t *testing.T) { + // Test that cacheFile serializes/deserializes correctly + original := &cacheFile{ + LatestVersion: "v1.2.3", + CheckedAt: time.Now().UTC().Truncate(time.Second), + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded cacheFile + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if decoded.LatestVersion != original.LatestVersion { + t.Errorf("LatestVersion = %q, want %q", decoded.LatestVersion, original.LatestVersion) + } +}