diff --git a/cmd/gha.go b/cmd/gha.go index 3d146a1..f8aacf6 100644 --- a/cmd/gha.go +++ b/cmd/gha.go @@ -1,7 +1,6 @@ package cmd import ( - "context" "encoding/json" "fmt" "os" @@ -112,7 +111,7 @@ func runGHAUpload(cmd *cobra.Command, args []string) error { // List all artifacts fmt.Println("📦 Fetching workflow artifacts...") - ctx := context.Background() + ctx := cmd.Context() artifacts, err := collector.ListArtifacts(ctx) if err != nil { return fmt.Errorf("failed to list artifacts: %w", err) @@ -155,32 +154,34 @@ func runGHAUpload(cmd *cobra.Command, args []string) error { uploadResults := make([]map[string]string, 0, len(artifacts)) for i, artifact := range artifacts { - fmt.Printf(" [%d/%d] Uploading %s...\n", i+1, len(artifacts), artifact.Name) - - // Download and extract artifact - artifactDir, err := collector.DownloadArtifact(ctx, artifact) - if err != nil { - fmt.Printf(" ❌ Failed to download: %v\n", err) - continue - } - defer os.RemoveAll(artifactDir) + func() { + fmt.Printf(" [%d/%d] Uploading %s...\n", i+1, len(artifacts), artifact.Name) + + // Download and extract artifact + artifactDir, err := collector.DownloadArtifact(ctx, artifact) + if err != nil { + fmt.Printf(" ❌ Failed to download: %v\n", err) + return + } + defer os.RemoveAll(artifactDir) - // Upload to Vulnetix - uploadResp, err := uploader.UploadArtifact(txnResp.TxnID, artifact.Name, artifactDir) - if err != nil { - fmt.Printf(" ❌ Failed to upload: %v\n", err) - continue - } + // Upload to Vulnetix + uploadResp, err := uploader.UploadArtifact(txnResp.TxnID, artifact.Name, artifactDir) + if err != nil { + fmt.Printf(" ❌ Failed to upload: %v\n", err) + return + } - fmt.Printf(" ✅ Uploaded successfully\n") - fmt.Printf(" UUID: %s\n", uploadResp.UUID) - fmt.Printf(" Queue Path: %s\n", uploadResp.QueuePath) + fmt.Printf(" ✅ Uploaded successfully\n") + fmt.Printf(" UUID: %s\n", uploadResp.UUID) + fmt.Printf(" Queue Path: %s\n", uploadResp.QueuePath) - uploadResults = append(uploadResults, map[string]string{ - "name": artifact.Name, - "uuid": uploadResp.UUID, - "queue_path": uploadResp.QueuePath, - }) + uploadResults = append(uploadResults, map[string]string{ + "name": artifact.Name, + "uuid": uploadResp.UUID, + "queue_path": uploadResp.QueuePath, + }) + }() } fmt.Println() @@ -194,10 +195,13 @@ func runGHAUpload(cmd *cobra.Command, args []string) error { // Output JSON if requested if ghaOutputJSON { output := map[string]interface{}{ - "txnid": txnResp.TxnID, + "txnid": txnResp.TxnID, "artifacts": uploadResults, } - jsonData, _ := json.MarshalIndent(output, "", " ") + jsonData, err := json.MarshalIndent(output, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal JSON output: %w", err) + } fmt.Println() fmt.Println(string(jsonData)) } diff --git a/internal/github/artifact.go b/internal/github/artifact.go index b9be726..e4a7e25 100644 --- a/internal/github/artifact.go +++ b/internal/github/artifact.go @@ -9,9 +9,23 @@ import ( "net/http" "os" "path/filepath" + "regexp" + "strings" "time" ) +const ( + // maxArtifactSize is the maximum size for an artifact download (1GB) + maxArtifactSize = 1024 * 1024 * 1024 + // artifactDownloadTimeout is the timeout for downloading artifacts + artifactDownloadTimeout = 10 * time.Minute +) + +var ( + // artifactNameRegex matches safe characters for artifact names + artifactNameRegex = regexp.MustCompile(`[^a-zA-Z0-9\-_\.]`) +) + // Artifact represents a GitHub Actions artifact type Artifact struct { ID int64 `json:"id"` @@ -68,11 +82,20 @@ func NewArtifactCollector(token, apiURL, repository, runID string) *ArtifactColl repository: repository, runID: runID, client: &http.Client{ - Timeout: 60 * time.Second, + Timeout: artifactDownloadTimeout, }, } } +// sanitizeArtifactName sanitizes artifact names to prevent path traversal +func sanitizeArtifactName(name string) string { + // Remove any path separators and special characters + sanitized := artifactNameRegex.ReplaceAllString(name, "_") + // Remove leading dots to prevent hidden files + sanitized = strings.TrimLeft(sanitized, ".") + return sanitized +} + // CollectMetadata collects metadata from GitHub Actions environment func CollectMetadata(artifactNames []string) *ArtifactMetadata { // Collect standard GitHub Actions environment variables @@ -162,8 +185,19 @@ func (c *ArtifactCollector) DownloadArtifact(ctx context.Context, artifact Artif return "", fmt.Errorf("GitHub token is required") } + // Check artifact size + if artifact.SizeInBytes > maxArtifactSize { + return "", fmt.Errorf("artifact size (%d bytes) exceeds maximum allowed size (%d bytes)", artifact.SizeInBytes, maxArtifactSize) + } + + // Sanitize artifact name for directory pattern + sanitizedName := sanitizeArtifactName(artifact.Name) + if sanitizedName == "" { + sanitizedName = "artifact" + } + // Create temporary directory for extraction - tmpDir, err := os.MkdirTemp("", fmt.Sprintf("artifact-%s-*", artifact.Name)) + tmpDir, err := os.MkdirTemp("", fmt.Sprintf("artifact-%s-*", sanitizedName)) if err != nil { return "", fmt.Errorf("failed to create temp directory: %w", err) } @@ -199,7 +233,9 @@ func (c *ArtifactCollector) DownloadArtifact(ctx context.Context, artifact Artif return "", fmt.Errorf("failed to create zip file: %w", err) } - _, err = io.Copy(zipFile, resp.Body) + // Limit the reader to prevent resource exhaustion + limitedReader := io.LimitReader(resp.Body, maxArtifactSize) + _, err = io.Copy(zipFile, limitedReader) zipFile.Close() if err != nil { os.RemoveAll(tmpDir) @@ -226,19 +262,39 @@ func extractZip(zipPath, destDir string) error { } defer reader.Close() + // Clean the destination directory path for comparison + destDir = filepath.Clean(destDir) + for _, file := range reader.File { + // Prevent Zip Slip vulnerability by checking for path traversal + if strings.Contains(file.Name, "..") { + return fmt.Errorf("zip file contains potentially unsafe path: %s", file.Name) + } + + // Join and clean the path path := filepath.Join(destDir, file.Name) + path = filepath.Clean(path) + + // Verify the resulting path is within destDir + if !strings.HasPrefix(path, destDir+string(os.PathSeparator)) && path != destDir { + return fmt.Errorf("zip file contains entry outside destination directory: %s", file.Name) + } if file.FileInfo().IsDir() { - os.MkdirAll(path, file.Mode()) + // Use safe permissions for directories + if err := os.MkdirAll(path, 0755); err != nil { + return err + } continue } + // Create parent directory with safe permissions if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { return err } - destFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.Mode()) + // Use safe permissions for files (0644 for regular files) + destFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) if err != nil { return err } @@ -249,12 +305,24 @@ func extractZip(zipPath, destDir string) error { return err } - _, err = io.Copy(destFile, fileReader) - destFile.Close() - fileReader.Close() - - if err != nil { - return err + // Copy with proper error handling + _, copyErr := io.Copy(destFile, fileReader) + + // Close file reader first + readerCloseErr := fileReader.Close() + + // Close destination file and check for close errors + destCloseErr := destFile.Close() + + // Check all errors in order + if copyErr != nil { + return copyErr + } + if readerCloseErr != nil { + return readerCloseErr + } + if destCloseErr != nil { + return destCloseErr } } diff --git a/internal/github/artifact_test.go b/internal/github/artifact_test.go index 281566e..e54d96d 100644 --- a/internal/github/artifact_test.go +++ b/internal/github/artifact_test.go @@ -1,8 +1,15 @@ package github import ( + "archive/zip" + "context" + "net/http" + "net/http/httptest" "os" + "path/filepath" + "strings" "testing" + "time" ) func TestCollectMetadata(t *testing.T) { @@ -27,8 +34,7 @@ func TestCollectMetadata(t *testing.T) { // Set environment variables for key, value := range testEnvVars { - os.Setenv(key, value) - defer os.Unsetenv(key) + t.Setenv(key, value) } artifactNames := []string{"artifact1.zip", "artifact2.zip"} @@ -131,3 +137,295 @@ func TestNewArtifactCollector(t *testing.T) { t.Error("Expected client to be initialized") } } + +func TestSanitizeArtifactName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"normal name", "my-artifact", "my-artifact"}, + {"with spaces", "my artifact", "my_artifact"}, + {"with path separator", "path/to/artifact", "path_to_artifact"}, + {"with dots", "../artifact", "_artifact"}, + {"with special chars", "artifact@#$%", "artifact____"}, + {"empty after sanitization", "...", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := sanitizeArtifactName(tt.input) + if result != tt.expected { + t.Errorf("sanitizeArtifactName(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestListArtifacts(t *testing.T) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check auth header + if r.Header.Get("Authorization") != "Bearer test-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Return mock artifact list + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "total_count": 2, + "artifacts": [ + { + "id": 1, + "name": "artifact1", + "size_in_bytes": 1024, + "url": "https://api.github.com/repos/test/repo/actions/artifacts/1", + "archive_download_url": "https://api.github.com/repos/test/repo/actions/artifacts/1/zip" + }, + { + "id": 2, + "name": "artifact2", + "size_in_bytes": 2048, + "url": "https://api.github.com/repos/test/repo/actions/artifacts/2", + "archive_download_url": "https://api.github.com/repos/test/repo/actions/artifacts/2/zip" + } + ] + }`)) + })) + defer server.Close() + + collector := NewArtifactCollector("test-token", server.URL, "test/repo", "123") + + ctx := context.Background() + artifacts, err := collector.ListArtifacts(ctx) + if err != nil { + t.Fatalf("ListArtifacts failed: %v", err) + } + + if len(artifacts) != 2 { + t.Errorf("Expected 2 artifacts, got %d", len(artifacts)) + } + + if artifacts[0].Name != "artifact1" { + t.Errorf("Expected artifact name 'artifact1', got '%s'", artifacts[0].Name) + } + + if artifacts[1].SizeInBytes != 2048 { + t.Errorf("Expected artifact size 2048, got %d", artifacts[1].SizeInBytes) + } +} + +func TestListArtifacts_NoToken(t *testing.T) { + collector := NewArtifactCollector("", "https://api.github.com", "test/repo", "123") + + ctx := context.Background() + _, err := collector.ListArtifacts(ctx) + if err == nil { + t.Error("Expected error when token is missing, got nil") + } + + if !strings.Contains(err.Error(), "GitHub token is required") { + t.Errorf("Expected 'GitHub token is required' error, got: %v", err) + } +} + +func TestExtractZip_ZipSlipProtection(t *testing.T) { + // Create a malicious zip file with path traversal + tmpDir := t.TempDir() + zipPath := filepath.Join(tmpDir, "malicious.zip") + + // Create a zip with path traversal attempt + zipFile, err := os.Create(zipPath) + if err != nil { + t.Fatalf("Failed to create zip file: %v", err) + } + + w := zip.NewWriter(zipFile) + + // Try to create entry with ".." in path + _, err = w.Create("../../etc/passwd") + if err != nil { + zipFile.Close() + t.Fatalf("Failed to create zip entry: %v", err) + } + + w.Close() + zipFile.Close() + + // Attempt to extract + destDir := filepath.Join(tmpDir, "extracted") + err = extractZip(zipPath, destDir) + + // Should fail due to path traversal protection + if err == nil { + t.Error("Expected error for zip slip attempt, got nil") + } + + if !strings.Contains(err.Error(), "unsafe path") && !strings.Contains(err.Error(), "outside destination") { + t.Errorf("Expected path traversal error, got: %v", err) + } +} + +func TestExtractZip_ValidZip(t *testing.T) { + tmpDir := t.TempDir() + zipPath := filepath.Join(tmpDir, "valid.zip") + + // Create a valid zip file + zipFile, err := os.Create(zipPath) + if err != nil { + t.Fatalf("Failed to create zip file: %v", err) + } + + w := zip.NewWriter(zipFile) + + // Add a test file + fileWriter, err := w.Create("test.txt") + if err != nil { + zipFile.Close() + t.Fatalf("Failed to create zip entry: %v", err) + } + + _, err = fileWriter.Write([]byte("test content")) + if err != nil { + zipFile.Close() + t.Fatalf("Failed to write zip entry: %v", err) + } + + w.Close() + zipFile.Close() + + // Extract + destDir := filepath.Join(tmpDir, "extracted") + err = extractZip(zipPath, destDir) + if err != nil { + t.Fatalf("extractZip failed: %v", err) + } + + // Verify extracted file + extractedFile := filepath.Join(destDir, "test.txt") + content, err := os.ReadFile(extractedFile) + if err != nil { + t.Fatalf("Failed to read extracted file: %v", err) + } + + if string(content) != "test content" { + t.Errorf("Expected 'test content', got '%s'", string(content)) + } + + // Check file permissions are safe + info, err := os.Stat(extractedFile) + if err != nil { + t.Fatalf("Failed to stat extracted file: %v", err) + } + + mode := info.Mode() + if mode != 0644 { + t.Errorf("Expected file mode 0644, got %v", mode) + } +} + +func TestDownloadArtifact_SizeLimit(t *testing.T) { + // Create a test server that returns artifact data + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/zip") + w.WriteHeader(http.StatusOK) + // Write some data + w.Write(make([]byte, 1024)) + })) + defer server.Close() + + collector := NewArtifactCollector("test-token", server.URL, "test/repo", "123") + + // Create artifact with size exceeding limit + artifact := Artifact{ + ID: 1, + Name: "large-artifact", + SizeInBytes: maxArtifactSize + 1, + ArchiveDownloadURL: server.URL, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + ExpiresAt: time.Now().Add(24 * time.Hour), + } + + ctx := context.Background() + _, err := collector.DownloadArtifact(ctx, artifact) + + if err == nil { + t.Error("Expected error for artifact exceeding size limit, got nil") + } + + if !strings.Contains(err.Error(), "exceeds maximum allowed size") { + t.Errorf("Expected size limit error, got: %v", err) + } +} + +func TestDownloadArtifact_Success(t *testing.T) { + tmpDir := t.TempDir() + + // Create a valid zip file to serve + zipPath := filepath.Join(tmpDir, "artifact.zip") + zipFile, err := os.Create(zipPath) + if err != nil { + t.Fatalf("Failed to create zip file: %v", err) + } + + w := zip.NewWriter(zipFile) + fileWriter, err := w.Create("test.txt") + if err != nil { + zipFile.Close() + t.Fatalf("Failed to create zip entry: %v", err) + } + fileWriter.Write([]byte("test content")) + w.Close() + zipFile.Close() + + // Read the zip file + zipData, err := os.ReadFile(zipPath) + if err != nil { + t.Fatalf("Failed to read zip file: %v", err) + } + + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer test-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.Header().Set("Content-Type", "application/zip") + w.WriteHeader(http.StatusOK) + w.Write(zipData) + })) + defer server.Close() + + collector := NewArtifactCollector("test-token", server.URL, "test/repo", "123") + + artifact := Artifact{ + ID: 1, + Name: "test-artifact", + SizeInBytes: int64(len(zipData)), + ArchiveDownloadURL: server.URL, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + ExpiresAt: time.Now().Add(24 * time.Hour), + } + + ctx := context.Background() + extractDir, err := collector.DownloadArtifact(ctx, artifact) + if err != nil { + t.Fatalf("DownloadArtifact failed: %v", err) + } + defer os.RemoveAll(extractDir) + + // Verify extracted file + extractedFile := filepath.Join(extractDir, "test.txt") + content, err := os.ReadFile(extractedFile) + if err != nil { + t.Fatalf("Failed to read extracted file: %v", err) + } + + if string(content) != "test content" { + t.Errorf("Expected 'test content', got '%s'", string(content)) + } +} diff --git a/internal/github/uploader.go b/internal/github/uploader.go index c4a97db..0d4f656 100644 --- a/internal/github/uploader.go +++ b/internal/github/uploader.go @@ -9,9 +9,15 @@ import ( "net/http" "os" "path/filepath" + "regexp" "time" ) +var ( + // txnIDRegex validates transaction ID format + txnIDRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) +) + // TransactionRequest represents the initial transaction creation request type TransactionRequest struct { Meta *ArtifactMetadata `json:"_meta"` @@ -55,20 +61,50 @@ type ArtifactStatusDetail struct { type ArtifactUploader struct { baseURL string orgID string + apiKey string client *http.Client } // NewArtifactUploader creates a new artifact uploader func NewArtifactUploader(baseURL, orgID string) *ArtifactUploader { + // Try to get API key from environment + apiKey := os.Getenv("VULNETIX_API_KEY") + return &ArtifactUploader{ baseURL: baseURL, orgID: orgID, + apiKey: apiKey, client: &http.Client{ Timeout: 120 * time.Second, }, } } +// validateTxnID validates transaction ID format +func validateTxnID(txnID string) error { + if txnID == "" { + return fmt.Errorf("transaction ID cannot be empty") + } + + // Transaction ID should be alphanumeric with hyphens and underscores + if !txnIDRegex.MatchString(txnID) { + return fmt.Errorf("invalid transaction ID format: must contain only alphanumeric characters, hyphens, and underscores") + } + + return nil +} + +// addAuthHeaders adds authentication headers to the request +func (u *ArtifactUploader) addAuthHeaders(req *http.Request) { + req.Header.Set("User-Agent", "Vulnetix-CLI/1.0") + + // Add API key if available + if u.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+u.apiKey) + req.Header.Set("X-API-Key", u.apiKey) + } +} + // InitiateTransaction initiates a new artifact upload transaction func (u *ArtifactUploader) InitiateTransaction(metadata *ArtifactMetadata, artifactNames []string) (*TransactionResponse, error) { url := fmt.Sprintf("%s/%s/github/artifact-upload", u.baseURL, u.orgID) @@ -89,7 +125,7 @@ func (u *ArtifactUploader) InitiateTransaction(metadata *ArtifactMetadata, artif } req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", "Vulnetix-CLI/1.0") + u.addAuthHeaders(req) resp, err := u.client.Do(req) if err != nil { @@ -120,6 +156,11 @@ func (u *ArtifactUploader) InitiateTransaction(metadata *ArtifactMetadata, artif // UploadArtifact uploads a single artifact file to the specified transaction func (u *ArtifactUploader) UploadArtifact(txnID, artifactName, artifactDir string) (*ArtifactUploadResponse, error) { + // Validate transaction ID + if err := validateTxnID(txnID); err != nil { + return nil, fmt.Errorf("invalid transaction ID: %w", err) + } + url := fmt.Sprintf("%s/%s/github/artifact-upload/%s", u.baseURL, u.orgID, txnID) // Find all files in the artifact directory @@ -184,7 +225,7 @@ func (u *ArtifactUploader) UploadArtifact(txnID, artifactName, artifactDir strin } req.Header.Set("Content-Type", contentType) - req.Header.Set("User-Agent", "Vulnetix-CLI/1.0") + u.addAuthHeaders(req) resp, err := u.client.Do(req) if err != nil { @@ -215,6 +256,11 @@ func (u *ArtifactUploader) UploadArtifact(txnID, artifactName, artifactDir strin // GetTransactionStatus retrieves the status of a transaction func (u *ArtifactUploader) GetTransactionStatus(txnID string) (*StatusResponse, error) { + // Validate transaction ID + if err := validateTxnID(txnID); err != nil { + return nil, fmt.Errorf("invalid transaction ID: %w", err) + } + url := fmt.Sprintf("%s/%s/github/artifact-upload/%s/status", u.baseURL, u.orgID, txnID) req, err := http.NewRequest("GET", url, nil) @@ -222,7 +268,7 @@ func (u *ArtifactUploader) GetTransactionStatus(txnID string) (*StatusResponse, return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("User-Agent", "Vulnetix-CLI/1.0") + u.addAuthHeaders(req) resp, err := u.client.Do(req) if err != nil { @@ -249,6 +295,10 @@ func (u *ArtifactUploader) GetTransactionStatus(txnID string) (*StatusResponse, // GetArtifactStatus retrieves the status of a specific artifact by UUID func (u *ArtifactUploader) GetArtifactStatus(artifactUUID string) (*StatusResponse, error) { + if artifactUUID == "" { + return nil, fmt.Errorf("artifact UUID cannot be empty") + } + url := fmt.Sprintf("%s/%s/github/artifact/%s/status", u.baseURL, u.orgID, artifactUUID) req, err := http.NewRequest("GET", url, nil) @@ -256,7 +306,7 @@ func (u *ArtifactUploader) GetArtifactStatus(artifactUUID string) (*StatusRespon return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("User-Agent", "Vulnetix-CLI/1.0") + u.addAuthHeaders(req) resp, err := u.client.Do(req) if err != nil { diff --git a/internal/github/uploader_test.go b/internal/github/uploader_test.go index ec5b91e..a2e609a 100644 --- a/internal/github/uploader_test.go +++ b/internal/github/uploader_test.go @@ -1,6 +1,13 @@ package github import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" "testing" ) @@ -75,3 +82,321 @@ func TestArtifactStatusDetail(t *testing.T) { t.Errorf("Expected queue path '/queue/path', got '%s'", status.QueuePath) } } + +func TestValidateTxnID(t *testing.T) { + tests := []struct { + name string + txnID string + expectErr bool + }{ + {"valid alphanumeric", "abc123", false}, + {"valid with hyphens", "abc-123-def", false}, + {"valid with underscores", "abc_123_def", false}, + {"empty string", "", true}, + {"with special chars", "abc@123", true}, + {"with spaces", "abc 123", true}, + {"with path separator", "abc/123", true}, + {"with dots", "abc.123", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateTxnID(tt.txnID) + if tt.expectErr && err == nil { + t.Errorf("Expected error for txnID '%s', got nil", tt.txnID) + } + if !tt.expectErr && err != nil { + t.Errorf("Expected no error for txnID '%s', got: %v", tt.txnID, err) + } + }) + } +} + +func TestNewArtifactUploader_WithAPIKey(t *testing.T) { + // Set API key environment variable + t.Setenv("VULNETIX_API_KEY", "test-api-key") + + uploader := NewArtifactUploader("https://api.vulnetix.com", "test-org-id") + + if uploader.apiKey != "test-api-key" { + t.Errorf("Expected API key 'test-api-key', got '%s'", uploader.apiKey) + } +} + +func TestInitiateTransaction(t *testing.T) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method + if r.Method != "POST" { + t.Errorf("Expected POST request, got %s", r.Method) + } + + // Verify content type + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) + } + + // Read and verify request body + body, _ := io.ReadAll(r.Body) + var req TransactionRequest + if err := json.Unmarshal(body, &req); err != nil { + t.Errorf("Failed to unmarshal request: %v", err) + } + + // Return mock response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + resp := TransactionResponse{ + TxnID: "test-txn-123", + Success: true, + Message: "Transaction initiated", + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + uploader := &ArtifactUploader{ + baseURL: server.URL, + orgID: "test-org", + client: &http.Client{}, + } + + metadata := &ArtifactMetadata{ + Repository: "test/repo", + RunID: "123", + Artifacts: []string{"artifact1"}, + } + + txnResp, err := uploader.InitiateTransaction(metadata, []string{"artifact1"}) + if err != nil { + t.Fatalf("InitiateTransaction failed: %v", err) + } + + if txnResp.TxnID != "test-txn-123" { + t.Errorf("Expected TxnID 'test-txn-123', got '%s'", txnResp.TxnID) + } + + if !txnResp.Success { + t.Error("Expected Success to be true") + } +} + +func TestInitiateTransaction_WithAuth(t *testing.T) { + apiKey := "test-api-key" + + // Create a test server that checks auth + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify auth headers + if r.Header.Get("Authorization") != "Bearer "+apiKey { + t.Errorf("Expected Authorization header 'Bearer %s', got '%s'", apiKey, r.Header.Get("Authorization")) + } + if r.Header.Get("X-API-Key") != apiKey { + t.Errorf("Expected X-API-Key header '%s', got '%s'", apiKey, r.Header.Get("X-API-Key")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + resp := TransactionResponse{ + TxnID: "test-txn-123", + Success: true, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + uploader := &ArtifactUploader{ + baseURL: server.URL, + orgID: "test-org", + apiKey: apiKey, + client: &http.Client{}, + } + + metadata := &ArtifactMetadata{ + Repository: "test/repo", + RunID: "123", + } + + _, err := uploader.InitiateTransaction(metadata, []string{"artifact1"}) + if err != nil { + t.Fatalf("InitiateTransaction failed: %v", err) + } +} + +func TestUploadArtifact(t *testing.T) { + // Create temporary artifact directory with test files + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + if err := os.WriteFile(testFile, []byte("test content"), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method + if r.Method != "POST" { + t.Errorf("Expected POST request, got %s", r.Method) + } + + // Verify content type is multipart + contentType := r.Header.Get("Content-Type") + if !strings.HasPrefix(contentType, "multipart/form-data") { + t.Errorf("Expected multipart/form-data content type, got %s", contentType) + } + + // Return mock response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + resp := ArtifactUploadResponse{ + UUID: "artifact-uuid-123", + QueuePath: "/queue/path", + Success: true, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + uploader := &ArtifactUploader{ + baseURL: server.URL, + orgID: "test-org", + client: &http.Client{}, + } + + uploadResp, err := uploader.UploadArtifact("test-txn-123", "test-artifact", tmpDir) + if err != nil { + t.Fatalf("UploadArtifact failed: %v", err) + } + + if uploadResp.UUID != "artifact-uuid-123" { + t.Errorf("Expected UUID 'artifact-uuid-123', got '%s'", uploadResp.UUID) + } + + if !uploadResp.Success { + t.Error("Expected Success to be true") + } +} + +func TestUploadArtifact_InvalidTxnID(t *testing.T) { + tmpDir := t.TempDir() + + uploader := &ArtifactUploader{ + baseURL: "https://api.vulnetix.com", + orgID: "test-org", + client: &http.Client{}, + } + + _, err := uploader.UploadArtifact("invalid/txnid", "test-artifact", tmpDir) + if err == nil { + t.Error("Expected error for invalid transaction ID, got nil") + } + + if !strings.Contains(err.Error(), "invalid transaction ID") { + t.Errorf("Expected invalid transaction ID error, got: %v", err) + } +} + +func TestGetTransactionStatus(t *testing.T) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method + if r.Method != "GET" { + t.Errorf("Expected GET request, got %s", r.Method) + } + + // Verify URL path + expectedPath := "/test-org/github/artifact-upload/test-txn-123/status" + if r.URL.Path != expectedPath { + t.Errorf("Expected path %s, got %s", expectedPath, r.URL.Path) + } + + // Return mock response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + resp := StatusResponse{ + Status: "completed", + TxnID: "test-txn-123", + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + uploader := &ArtifactUploader{ + baseURL: server.URL, + orgID: "test-org", + client: &http.Client{}, + } + + statusResp, err := uploader.GetTransactionStatus("test-txn-123") + if err != nil { + t.Fatalf("GetTransactionStatus failed: %v", err) + } + + if statusResp.Status != "completed" { + t.Errorf("Expected status 'completed', got '%s'", statusResp.Status) + } + + if statusResp.TxnID != "test-txn-123" { + t.Errorf("Expected TxnID 'test-txn-123', got '%s'", statusResp.TxnID) + } +} + +func TestGetArtifactStatus(t *testing.T) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method + if r.Method != "GET" { + t.Errorf("Expected GET request, got %s", r.Method) + } + + // Return mock response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + resp := StatusResponse{ + Status: "processing", + Artifacts: []ArtifactStatusDetail{ + { + UUID: "artifact-uuid-123", + Name: "test-artifact", + Status: "processing", + }, + }, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + uploader := &ArtifactUploader{ + baseURL: server.URL, + orgID: "test-org", + client: &http.Client{}, + } + + statusResp, err := uploader.GetArtifactStatus("artifact-uuid-123") + if err != nil { + t.Fatalf("GetArtifactStatus failed: %v", err) + } + + if statusResp.Status != "processing" { + t.Errorf("Expected status 'processing', got '%s'", statusResp.Status) + } + + if len(statusResp.Artifacts) != 1 { + t.Errorf("Expected 1 artifact, got %d", len(statusResp.Artifacts)) + } +} + +func TestGetArtifactStatus_EmptyUUID(t *testing.T) { + uploader := &ArtifactUploader{ + baseURL: "https://api.vulnetix.com", + orgID: "test-org", + client: &http.Client{}, + } + + _, err := uploader.GetArtifactStatus("") + if err == nil { + t.Error("Expected error for empty UUID, got nil") + } + + if !strings.Contains(err.Error(), "UUID cannot be empty") { + t.Errorf("Expected empty UUID error, got: %v", err) + } +}