diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 2980b0ce..c7253f26 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -14,6 +14,8 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/delete" "github.com/brevdev/brev-cli/pkg/cmd/envvars" "github.com/brevdev/brev-cli/pkg/cmd/fu" + "github.com/brevdev/brev-cli/pkg/cmd/gpucreate" + "github.com/brevdev/brev-cli/pkg/cmd/gpusearch" "github.com/brevdev/brev-cli/pkg/cmd/healthcheck" "github.com/brevdev/brev-cli/pkg/cmd/hello" "github.com/brevdev/brev-cli/pkg/cmd/importideconfig" @@ -270,6 +272,8 @@ func createCmdTree(cmd *cobra.Command, t *terminal.Terminal, loginCmdStore *stor } cmd.AddCommand(workspacegroups.NewCmdWorkspaceGroups(t, loginCmdStore)) cmd.AddCommand(scale.NewCmdScale(t, noLoginCmdStore)) + cmd.AddCommand(gpusearch.NewCmdGPUSearch(t, noLoginCmdStore)) + cmd.AddCommand(gpucreate.NewCmdGPUCreate(t, loginCmdStore)) cmd.AddCommand(configureenvvars.NewCmdConfigureEnvVars(t, loginCmdStore)) cmd.AddCommand(importideconfig.NewCmdImportIDEConfig(t, noLoginCmdStore)) cmd.AddCommand(shell.NewCmdShell(t, loginCmdStore, noLoginCmdStore)) diff --git a/pkg/cmd/gpucreate/gpucreate.go b/pkg/cmd/gpucreate/gpucreate.go new file mode 100644 index 00000000..f6c5e14d --- /dev/null +++ b/pkg/cmd/gpucreate/gpucreate.go @@ -0,0 +1,471 @@ +// Package gpucreate provides a command to create GPU instances with retry logic +package gpucreate + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + "sync" + "time" + + "github.com/brevdev/brev-cli/pkg/cmd/util" + "github.com/brevdev/brev-cli/pkg/config" + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/featureflag" + "github.com/brevdev/brev-cli/pkg/store" + "github.com/brevdev/brev-cli/pkg/terminal" + "github.com/spf13/cobra" +) + +var ( + long = `Create GPU instances with automatic retry across multiple instance types. + +This command attempts to create GPU instances, trying different instance types +until the desired number of instances are successfully created. Instance types +can be specified directly or piped from 'brev gpus'. + +The command will: +1. Try to create instances using the provided instance types (in order) +2. Continue until the desired count is reached +3. Optionally try multiple instance types in parallel +4. Clean up any extra instances that were created beyond the requested count` + + example = ` + # Create a single instance with a specific GPU type + brev gpu-create --name my-instance --type g5.xlarge + + # Pipe instance types from brev gpus (tries each type until one succeeds) + brev gpus --min-vram 24 | brev gpu-create --name my-instance + + # Create 3 instances, trying types in parallel + brev gpus --gpu-name A100 | brev gpu-create --name my-cluster --count 3 --parallel 5 + + # Try multiple specific types in order + brev gpu-create --name my-instance --type g5.xlarge,g5.2xlarge,g4dn.xlarge +` +) + +// GPUCreateStore defines the interface for GPU create operations +type GPUCreateStore interface { + util.GetWorkspaceByNameOrIDErrStore + GetActiveOrganizationOrDefault() (*entity.Organization, error) + GetCurrentUser() (*entity.User, error) + GetWorkspace(workspaceID string) (*entity.Workspace, error) + CreateWorkspace(organizationID string, options *store.CreateWorkspacesOptions) (*entity.Workspace, error) + DeleteWorkspace(workspaceID string) (*entity.Workspace, error) +} + +// CreateResult holds the result of a workspace creation attempt +type CreateResult struct { + Workspace *entity.Workspace + InstanceType string + Error error +} + +// NewCmdGPUCreate creates the gpu-create command +func NewCmdGPUCreate(t *terminal.Terminal, gpuCreateStore GPUCreateStore) *cobra.Command { + var name string + var instanceTypes string + var count int + var parallel int + var detached bool + var timeout int + + cmd := &cobra.Command{ + Annotations: map[string]string{"workspace": ""}, + Use: "gpu-create", + Aliases: []string{"gpu-retry", "gcreate", "provision"}, + DisableFlagsInUseLine: true, + Short: "Create GPU instances with automatic retry", + Long: long, + Example: example, + RunE: func(cmd *cobra.Command, args []string) error { + // Parse instance types from flag or stdin + types, err := parseInstanceTypes(instanceTypes) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if len(types) == 0 { + return breverrors.NewValidationError("no instance types provided. Use --type flag or pipe from 'brev gpus'") + } + + if name == "" { + return breverrors.NewValidationError("--name flag is required") + } + + if count < 1 { + return breverrors.NewValidationError("--count must be at least 1") + } + + if parallel < 1 { + parallel = 1 + } + + opts := GPUCreateOptions{ + Name: name, + InstanceTypes: types, + Count: count, + Parallel: parallel, + Detached: detached, + Timeout: time.Duration(timeout) * time.Second, + } + + err = RunGPUCreate(t, gpuCreateStore, opts) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil + }, + } + + cmd.Flags().StringVarP(&name, "name", "n", "", "Base name for the instances (required)") + cmd.Flags().StringVarP(&instanceTypes, "type", "t", "", "Comma-separated list of instance types to try") + cmd.Flags().IntVarP(&count, "count", "c", 1, "Number of instances to create") + cmd.Flags().IntVarP(¶llel, "parallel", "p", 1, "Number of parallel creation attempts") + cmd.Flags().BoolVarP(&detached, "detached", "d", false, "Don't wait for instances to be ready") + cmd.Flags().IntVar(&timeout, "timeout", 300, "Timeout in seconds for each instance to become ready") + + return cmd +} + +// GPUCreateOptions holds the options for GPU instance creation +type GPUCreateOptions struct { + Name string + InstanceTypes []string + Count int + Parallel int + Detached bool + Timeout time.Duration +} + +// parseInstanceTypes parses instance types from flag value or stdin +func parseInstanceTypes(flagValue string) ([]string, error) { + var types []string + + // First check if there's a flag value + if flagValue != "" { + parts := strings.Split(flagValue, ",") + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + types = append(types, p) + } + } + } + + // Check if there's piped input from stdin + stat, _ := os.Stdin.Stat() + if (stat.Mode() & os.ModeCharDevice) == 0 { + // Data is being piped to stdin + scanner := bufio.NewScanner(os.Stdin) + lineNum := 0 + for scanner.Scan() { + line := scanner.Text() + lineNum++ + + // Skip header line (first line typically contains column names) + if lineNum == 1 && (strings.Contains(line, "TYPE") || strings.Contains(line, "GPU")) { + continue + } + + // Skip empty lines + line = strings.TrimSpace(line) + if line == "" { + continue + } + + // Skip summary lines (e.g., "Found X GPU instance types") + if strings.HasPrefix(line, "Found ") { + continue + } + + // Extract the first column (TYPE) from the table output + // The format is: TYPE GPU COUNT VRAM/GPU TOTAL VRAM CAPABILITY VCPUs $/HR + fields := strings.Fields(line) + if len(fields) > 0 { + instanceType := fields[0] + // Validate it looks like an instance type (contains letters and possibly numbers/dots) + if isValidInstanceType(instanceType) { + types = append(types, instanceType) + } + } + } + + if err := scanner.Err(); err != nil { + return nil, breverrors.WrapAndTrace(err) + } + } + + return types, nil +} + +// isValidInstanceType checks if a string looks like a valid instance type +func isValidInstanceType(s string) bool { + // Instance types typically have formats like: + // g5.xlarge, p4d.24xlarge, n1-highmem-4:nvidia-tesla-t4:1 + if len(s) < 2 { + return false + } + + // Should contain alphanumeric characters + hasLetter := false + hasNumber := false + for _, c := range s { + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') { + hasLetter = true + } + if c >= '0' && c <= '9' { + hasNumber = true + } + } + + return hasLetter && hasNumber +} + +// RunGPUCreate executes the GPU create with retry logic +func RunGPUCreate(t *terminal.Terminal, gpuCreateStore GPUCreateStore, opts GPUCreateOptions) error { + user, err := gpuCreateStore.GetCurrentUser() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + org, err := gpuCreateStore.GetActiveOrganizationOrDefault() + if err != nil { + return breverrors.WrapAndTrace(err) + } + if org == nil { + return breverrors.NewValidationError("no organization found") + } + + t.Vprintf("Attempting to create %d instance(s) with %d parallel attempts\n", opts.Count, opts.Parallel) + t.Vprintf("Instance types to try: %s\n\n", strings.Join(opts.InstanceTypes, ", ")) + + // Track successful creations + var successfulWorkspaces []*entity.Workspace + var mu sync.Mutex + var wg sync.WaitGroup + + // Create a context for cancellation + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Channel to coordinate attempts + typesChan := make(chan string, len(opts.InstanceTypes)) + for _, it := range opts.InstanceTypes { + typesChan <- it + } + close(typesChan) + + // Results channel + resultsChan := make(chan CreateResult, len(opts.InstanceTypes)) + + // Track instance index for naming + instanceIndex := 0 + var indexMu sync.Mutex + + // Start parallel workers + workerCount := opts.Parallel + if workerCount > len(opts.InstanceTypes) { + workerCount = len(opts.InstanceTypes) + } + + for i := 0; i < workerCount; i++ { + wg.Add(1) + go func(workerID int) { + defer wg.Done() + + for instanceType := range typesChan { + // Check if we've already created enough + mu.Lock() + if len(successfulWorkspaces) >= opts.Count { + mu.Unlock() + return + } + mu.Unlock() + + // Check context + select { + case <-ctx.Done(): + return + default: + } + + // Get unique instance name + indexMu.Lock() + currentIndex := instanceIndex + instanceIndex++ + indexMu.Unlock() + + instanceName := opts.Name + if opts.Count > 1 || currentIndex > 0 { + instanceName = fmt.Sprintf("%s-%d", opts.Name, currentIndex+1) + } + + t.Vprintf("[Worker %d] Trying %s for instance '%s'...\n", workerID+1, instanceType, instanceName) + + // Attempt to create the workspace + workspace, err := createWorkspaceWithType(gpuCreateStore, org.ID, instanceName, instanceType, user) + + result := CreateResult{ + Workspace: workspace, + InstanceType: instanceType, + Error: err, + } + + if err != nil { + t.Vprintf("[Worker %d] %s Failed: %s\n", workerID+1, t.Yellow(instanceType), err.Error()) + } else { + t.Vprintf("[Worker %d] %s Success! Created instance '%s'\n", workerID+1, t.Green(instanceType), instanceName) + mu.Lock() + successfulWorkspaces = append(successfulWorkspaces, workspace) + if len(successfulWorkspaces) >= opts.Count { + cancel() // Signal other workers to stop + } + mu.Unlock() + } + + resultsChan <- result + } + }(i) + } + + // Wait for all workers to finish + go func() { + wg.Wait() + close(resultsChan) + }() + + // Collect results + for range resultsChan { + // Just drain the channel + } + + // Check if we created enough instances + if len(successfulWorkspaces) < opts.Count { + t.Vprintf("\n%s Only created %d/%d instances\n", t.Yellow("Warning:"), len(successfulWorkspaces), opts.Count) + + if len(successfulWorkspaces) > 0 { + t.Vprintf("Successfully created instances:\n") + for _, ws := range successfulWorkspaces { + t.Vprintf(" - %s (ID: %s)\n", ws.Name, ws.ID) + } + } + + return breverrors.NewValidationError(fmt.Sprintf("could only create %d/%d instances", len(successfulWorkspaces), opts.Count)) + } + + // If we created more than needed, clean up extras + if len(successfulWorkspaces) > opts.Count { + extras := successfulWorkspaces[opts.Count:] + t.Vprintf("\nCleaning up %d extra instance(s)...\n", len(extras)) + + for _, ws := range extras { + t.Vprintf(" Deleting %s...", ws.Name) + _, err := gpuCreateStore.DeleteWorkspace(ws.ID) + if err != nil { + t.Vprintf(" %s\n", t.Red("Failed")) + } else { + t.Vprintf(" %s\n", t.Green("Done")) + } + } + + successfulWorkspaces = successfulWorkspaces[:opts.Count] + } + + // Wait for instances to be ready (unless detached) + if !opts.Detached { + t.Vprintf("\nWaiting for instance(s) to be ready...\n") + t.Vprintf("You can safely ctrl+c to exit\n") + + for _, ws := range successfulWorkspaces { + err := pollUntilReady(t, ws.ID, gpuCreateStore, opts.Timeout) + if err != nil { + t.Vprintf(" %s: %s\n", ws.Name, t.Yellow("Timeout waiting for ready state")) + } + } + } + + // Print summary + fmt.Print("\n") + t.Vprint(t.Green(fmt.Sprintf("Successfully created %d instance(s)!\n\n", len(successfulWorkspaces)))) + + for _, ws := range successfulWorkspaces { + t.Vprintf("Instance: %s\n", t.Green(ws.Name)) + t.Vprintf(" ID: %s\n", ws.ID) + t.Vprintf(" Type: %s\n", ws.InstanceType) + displayConnectBreadCrumb(t, ws) + fmt.Print("\n") + } + + return nil +} + +// createWorkspaceWithType creates a workspace with the specified instance type +func createWorkspaceWithType(gpuCreateStore GPUCreateStore, orgID, name, instanceType string, user *entity.User) (*entity.Workspace, error) { + clusterID := config.GlobalConfig.GetDefaultClusterID() + cwOptions := store.NewCreateWorkspacesOptions(clusterID, name) + cwOptions.WithInstanceType(instanceType) + cwOptions = resolveWorkspaceUserOptions(cwOptions, user) + + workspace, err := gpuCreateStore.CreateWorkspace(orgID, cwOptions) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + return workspace, nil +} + +// resolveWorkspaceUserOptions sets workspace template and class based on user type +func resolveWorkspaceUserOptions(options *store.CreateWorkspacesOptions, user *entity.User) *store.CreateWorkspacesOptions { + if options.WorkspaceTemplateID == "" { + if featureflag.IsAdmin(user.GlobalUserType) { + options.WorkspaceTemplateID = store.DevWorkspaceTemplateID + } else { + options.WorkspaceTemplateID = store.UserWorkspaceTemplateID + } + } + if options.WorkspaceClassID == "" { + if featureflag.IsAdmin(user.GlobalUserType) { + options.WorkspaceClassID = store.DevWorkspaceClassID + } else { + options.WorkspaceClassID = store.UserWorkspaceClassID + } + } + return options +} + +// pollUntilReady waits for a workspace to reach the running state +func pollUntilReady(t *terminal.Terminal, wsID string, gpuCreateStore GPUCreateStore, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + + for time.Now().Before(deadline) { + ws, err := gpuCreateStore.GetWorkspace(wsID) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if ws.Status == entity.Running { + t.Vprintf(" %s: %s\n", ws.Name, t.Green("Ready")) + return nil + } + + if ws.Status == entity.Failure { + return breverrors.NewValidationError(fmt.Sprintf("instance %s failed", ws.Name)) + } + + time.Sleep(5 * time.Second) + } + + return breverrors.NewValidationError("timeout waiting for instance to be ready") +} + +// displayConnectBreadCrumb shows connection instructions +func displayConnectBreadCrumb(t *terminal.Terminal, workspace *entity.Workspace) { + t.Vprintf(" Connect:\n") + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("brev open %s", workspace.Name))) + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("brev shell %s", workspace.Name))) +} diff --git a/pkg/cmd/gpucreate/gpucreate_test.go b/pkg/cmd/gpucreate/gpucreate_test.go new file mode 100644 index 00000000..8c9b935d --- /dev/null +++ b/pkg/cmd/gpucreate/gpucreate_test.go @@ -0,0 +1,346 @@ +package gpucreate + +import ( + "testing" + + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/store" + "github.com/stretchr/testify/assert" +) + +// MockGPUCreateStore is a mock implementation of GPUCreateStore for testing +type MockGPUCreateStore struct { + User *entity.User + Org *entity.Organization + Workspaces map[string]*entity.Workspace + CreateError error + CreateErrorTypes map[string]error // Errors for specific instance types + DeleteError error + CreatedWorkspaces []*entity.Workspace + DeletedWorkspaceIDs []string +} + +func NewMockGPUCreateStore() *MockGPUCreateStore { + return &MockGPUCreateStore{ + User: &entity.User{ + ID: "user-123", + GlobalUserType: "Standard", + }, + Org: &entity.Organization{ + ID: "org-123", + Name: "test-org", + }, + Workspaces: make(map[string]*entity.Workspace), + CreateErrorTypes: make(map[string]error), + CreatedWorkspaces: []*entity.Workspace{}, + DeletedWorkspaceIDs: []string{}, + } +} + +func (m *MockGPUCreateStore) GetCurrentUser() (*entity.User, error) { + return m.User, nil +} + +func (m *MockGPUCreateStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { + return m.Org, nil +} + +func (m *MockGPUCreateStore) GetWorkspace(workspaceID string) (*entity.Workspace, error) { + if ws, ok := m.Workspaces[workspaceID]; ok { + return ws, nil + } + return &entity.Workspace{ + ID: workspaceID, + Status: entity.Running, + }, nil +} + +func (m *MockGPUCreateStore) CreateWorkspace(organizationID string, options *store.CreateWorkspacesOptions) (*entity.Workspace, error) { + // Check for type-specific errors first + if err, ok := m.CreateErrorTypes[options.InstanceType]; ok { + return nil, err + } + + if m.CreateError != nil { + return nil, m.CreateError + } + + ws := &entity.Workspace{ + ID: "ws-" + options.Name, + Name: options.Name, + InstanceType: options.InstanceType, + Status: entity.Running, + } + m.Workspaces[ws.ID] = ws + m.CreatedWorkspaces = append(m.CreatedWorkspaces, ws) + return ws, nil +} + +func (m *MockGPUCreateStore) DeleteWorkspace(workspaceID string) (*entity.Workspace, error) { + if m.DeleteError != nil { + return nil, m.DeleteError + } + + m.DeletedWorkspaceIDs = append(m.DeletedWorkspaceIDs, workspaceID) + ws := m.Workspaces[workspaceID] + delete(m.Workspaces, workspaceID) + return ws, nil +} + +func (m *MockGPUCreateStore) GetWorkspaceByNameOrID(orgID string, nameOrID string) ([]entity.Workspace, error) { + return []entity.Workspace{}, nil +} + +func TestIsValidInstanceType(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"Valid AWS instance type", "g5.xlarge", true}, + {"Valid AWS large instance", "p4d.24xlarge", true}, + {"Valid GCP instance type", "n1-highmem-4:nvidia-tesla-t4:1", true}, + {"Single letter", "a", false}, + {"No numbers", "xlarge", false}, + {"No letters", "12345", false}, + {"Empty string", "", false}, + {"Single character", "1", false}, + {"Valid with colon", "g5:xlarge", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isValidInstanceType(tt.input) + assert.Equal(t, tt.expected, result, "Validation failed for %s", tt.input) + }) + } +} + +func TestParseInstanceTypesFromFlag(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + {"Single type", "g5.xlarge", []string{"g5.xlarge"}}, + {"Multiple types comma separated", "g5.xlarge,g5.2xlarge,p3.2xlarge", []string{"g5.xlarge", "g5.2xlarge", "p3.2xlarge"}}, + {"With spaces", "g5.xlarge, g5.2xlarge, p3.2xlarge", []string{"g5.xlarge", "g5.2xlarge", "p3.2xlarge"}}, + {"Empty string", "", []string{}}, + {"Only spaces", " ", []string{}}, + {"Trailing comma", "g5.xlarge,", []string{"g5.xlarge"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseInstanceTypes(tt.input) + assert.NoError(t, err) + + // Handle nil vs empty slice + if len(tt.expected) == 0 { + assert.Empty(t, result) + } else { + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestGPUCreateOptions(t *testing.T) { + opts := GPUCreateOptions{ + Name: "my-instance", + InstanceTypes: []string{"g5.xlarge", "g5.2xlarge"}, + Count: 2, + Parallel: 3, + Detached: true, + } + + assert.Equal(t, "my-instance", opts.Name) + assert.Len(t, opts.InstanceTypes, 2) + assert.Equal(t, 2, opts.Count) + assert.Equal(t, 3, opts.Parallel) + assert.True(t, opts.Detached) +} + +func TestResolveWorkspaceUserOptionsStandard(t *testing.T) { + user := &entity.User{ + ID: "user-123", + GlobalUserType: "Standard", + } + + options := &store.CreateWorkspacesOptions{} + result := resolveWorkspaceUserOptions(options, user) + + assert.Equal(t, store.UserWorkspaceTemplateID, result.WorkspaceTemplateID) + assert.Equal(t, store.UserWorkspaceClassID, result.WorkspaceClassID) +} + +func TestResolveWorkspaceUserOptionsAdmin(t *testing.T) { + user := &entity.User{ + ID: "user-123", + GlobalUserType: "Admin", + } + + options := &store.CreateWorkspacesOptions{} + result := resolveWorkspaceUserOptions(options, user) + + assert.Equal(t, store.DevWorkspaceTemplateID, result.WorkspaceTemplateID) + assert.Equal(t, store.DevWorkspaceClassID, result.WorkspaceClassID) +} + +func TestResolveWorkspaceUserOptionsPreserveExisting(t *testing.T) { + user := &entity.User{ + ID: "user-123", + GlobalUserType: "Standard", + } + + options := &store.CreateWorkspacesOptions{ + WorkspaceTemplateID: "custom-template", + WorkspaceClassID: "custom-class", + } + result := resolveWorkspaceUserOptions(options, user) + + // Should preserve existing values + assert.Equal(t, "custom-template", result.WorkspaceTemplateID) + assert.Equal(t, "custom-class", result.WorkspaceClassID) +} + +func TestMockGPUCreateStoreBasics(t *testing.T) { + mock := NewMockGPUCreateStore() + + user, err := mock.GetCurrentUser() + assert.NoError(t, err) + assert.Equal(t, "user-123", user.ID) + + org, err := mock.GetActiveOrganizationOrDefault() + assert.NoError(t, err) + assert.Equal(t, "org-123", org.ID) +} + +func TestMockGPUCreateStoreCreateWorkspace(t *testing.T) { + mock := NewMockGPUCreateStore() + + options := store.NewCreateWorkspacesOptions("cluster-1", "test-instance") + options.WithInstanceType("g5.xlarge") + + ws, err := mock.CreateWorkspace("org-123", options) + assert.NoError(t, err) + assert.Equal(t, "test-instance", ws.Name) + assert.Equal(t, "g5.xlarge", ws.InstanceType) + assert.Len(t, mock.CreatedWorkspaces, 1) +} + +func TestMockGPUCreateStoreDeleteWorkspace(t *testing.T) { + mock := NewMockGPUCreateStore() + + // First create a workspace + options := store.NewCreateWorkspacesOptions("cluster-1", "test-instance") + ws, _ := mock.CreateWorkspace("org-123", options) + + // Then delete it + _, err := mock.DeleteWorkspace(ws.ID) + assert.NoError(t, err) + assert.Contains(t, mock.DeletedWorkspaceIDs, ws.ID) +} + +func TestMockGPUCreateStoreTypeSpecificError(t *testing.T) { + mock := NewMockGPUCreateStore() + mock.CreateErrorTypes["g5.xlarge"] = assert.AnError + + options := store.NewCreateWorkspacesOptions("cluster-1", "test-instance") + options.WithInstanceType("g5.xlarge") + + _, err := mock.CreateWorkspace("org-123", options) + assert.Error(t, err) + + // Different type should work + options2 := store.NewCreateWorkspacesOptions("cluster-1", "test-instance-2") + options2.WithInstanceType("g5.2xlarge") + + ws, err := mock.CreateWorkspace("org-123", options2) + assert.NoError(t, err) + assert.NotNil(t, ws) +} + +func TestParseInstanceTypesFromTableOutput(t *testing.T) { + // Simulated table output from brev gpus command + // Note: This tests the parsing logic, not actual stdin reading + tableLines := []string{ + "TYPE GPU COUNT VRAM/GPU TOTAL VRAM CAPABILITY VCPUs $/HR", + "g5.xlarge A10G 1 24 GB 24 GB 8.6 4 $1.01", + "g5.2xlarge A10G 1 24 GB 24 GB 8.6 8 $1.21", + "p4d.24xlarge A100 8 40 GB 320 GB 8.0 96 $32.77", + "", + "Found 3 GPU instance types", + } + + // Test parsing each line (simulating the scanner behavior) + var types []string + lineNum := 0 + for _, line := range tableLines { + lineNum++ + + // Skip header line + if lineNum == 1 && (contains(line, "TYPE") || contains(line, "GPU")) { + continue + } + + // Skip empty lines and summary + if line == "" || startsWith(line, "Found ") { + continue + } + + // Extract first column + fields := splitFields(line) + if len(fields) > 0 && isValidInstanceType(fields[0]) { + types = append(types, fields[0]) + } + } + + assert.Len(t, types, 3) + assert.Contains(t, types, "g5.xlarge") + assert.Contains(t, types, "g5.2xlarge") + assert.Contains(t, types, "p4d.24xlarge") +} + +// Helper functions for testing +func contains(s, substr string) bool { + return len(s) >= len(substr) && findSubstring(s, substr) >= 0 +} + +func findSubstring(s, substr string) int { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return i + } + } + return -1 +} + +func startsWith(s, prefix string) bool { + return len(s) >= len(prefix) && s[:len(prefix)] == prefix +} + +func splitFields(s string) []string { + var fields []string + var current string + inField := false + + for _, c := range s { + if c == ' ' || c == '\t' { + if inField { + fields = append(fields, current) + current = "" + inField = false + } + } else { + current += string(c) + inField = true + } + } + + if inField { + fields = append(fields, current) + } + + return fields +} diff --git a/pkg/cmd/gpusearch/gpusearch.go b/pkg/cmd/gpusearch/gpusearch.go new file mode 100644 index 00000000..54d1bd37 --- /dev/null +++ b/pkg/cmd/gpusearch/gpusearch.go @@ -0,0 +1,450 @@ +// Package gpusearch provides a command to search and filter GPU instance types +package gpusearch + +import ( + "fmt" + "os" + "regexp" + "sort" + "strconv" + "strings" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/terminal" + "github.com/jedib0t/go-pretty/v6/table" + "github.com/spf13/cobra" +) + +// MemoryBytes represents the memory size with value and unit +type MemoryBytes struct { + Value int64 `json:"value"` + Unit string `json:"unit"` +} + +// GPU represents a GPU configuration within an instance type +type GPU struct { + Count int `json:"count"` + Name string `json:"name"` + Manufacturer string `json:"manufacturer"` + Memory string `json:"memory"` + MemoryBytes MemoryBytes `json:"memory_bytes"` +} + +// BasePrice represents the pricing information +type BasePrice struct { + Currency string `json:"currency"` + Amount string `json:"amount"` +} + +// InstanceType represents an instance type from the API +type InstanceType struct { + Type string `json:"type"` + SupportedGPUs []GPU `json:"supported_gpus"` + SupportedStorage []interface{} `json:"supported_storage"` // Complex objects, not used in filtering + Memory string `json:"memory"` + VCPU int `json:"vcpu"` + BasePrice BasePrice `json:"base_price"` + Location string `json:"location"` + SubLocation string `json:"sub_location"` + AvailableLocations []string `json:"available_locations"` +} + +// InstanceTypesResponse represents the API response +type InstanceTypesResponse struct { + Items []InstanceType `json:"items"` +} + +// GPUSearchStore defines the interface for fetching instance types +type GPUSearchStore interface { + GetInstanceTypes() (*InstanceTypesResponse, error) +} + +var ( + long = `Search and filter GPU instance types available on Brev. + +Filter instances by GPU name, VRAM, total VRAM, and GPU compute capability. +Sort results by various columns to find the best instance for your needs.` + + example = ` + # List all GPU instances + brev gpu-search + + # Filter by GPU name (case-insensitive, partial match) + brev gpu-search --gpu-name A100 + brev gpu-search --gpu-name "L40S" + + # Filter by minimum VRAM per GPU (in GB) + brev gpu-search --min-vram 24 + + # Filter by minimum total VRAM (in GB) + brev gpu-search --min-total-vram 80 + + # Filter by minimum GPU compute capability + brev gpu-search --min-capability 8.0 + + # Sort by different columns (price, gpu-count, vram, total-vram, vcpu) + brev gpu-search --sort price + brev gpu-search --sort total-vram --desc + + # Combine filters + brev gpu-search --gpu-name A100 --min-vram 40 --sort price +` +) + +// NewCmdGPUSearch creates the gpu-search command +func NewCmdGPUSearch(t *terminal.Terminal, store GPUSearchStore) *cobra.Command { + var gpuName string + var minVRAM float64 + var minTotalVRAM float64 + var minCapability float64 + var sortBy string + var descending bool + + cmd := &cobra.Command{ + Annotations: map[string]string{"workspace": ""}, + Use: "gpu-search", + Aliases: []string{"gpu", "gpus", "gpu-list"}, + DisableFlagsInUseLine: true, + Short: "Search and filter GPU instance types", + Long: long, + Example: example, + RunE: func(cmd *cobra.Command, args []string) error { + err := RunGPUSearch(t, store, gpuName, minVRAM, minTotalVRAM, minCapability, sortBy, descending) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil + }, + } + + cmd.Flags().StringVarP(&gpuName, "gpu-name", "g", "", "Filter by GPU name (case-insensitive, partial match)") + cmd.Flags().Float64VarP(&minVRAM, "min-vram", "v", 0, "Minimum VRAM per GPU in GB") + cmd.Flags().Float64VarP(&minTotalVRAM, "min-total-vram", "t", 0, "Minimum total VRAM (GPU count * VRAM) in GB") + cmd.Flags().Float64VarP(&minCapability, "min-capability", "c", 0, "Minimum GPU compute capability (e.g., 8.0 for Ampere)") + cmd.Flags().StringVarP(&sortBy, "sort", "s", "price", "Sort by: price, gpu-count, vram, total-vram, vcpu, type") + cmd.Flags().BoolVarP(&descending, "desc", "d", false, "Sort in descending order") + + return cmd +} + +// GPUInstanceInfo holds processed GPU instance information for display +type GPUInstanceInfo struct { + Type string + GPUName string + GPUCount int + VRAMPerGPU float64 // in GB + TotalVRAM float64 // in GB + Capability float64 + VCPUs int + Memory string + PricePerHour float64 + Manufacturer string +} + +// RunGPUSearch executes the GPU search with filters and sorting +func RunGPUSearch(t *terminal.Terminal, store GPUSearchStore, gpuName string, minVRAM, minTotalVRAM, minCapability float64, sortBy string, descending bool) error { + response, err := store.GetInstanceTypes() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if response == nil || len(response.Items) == 0 { + t.Vprint(t.Yellow("No instance types found")) + return nil + } + + // Process and filter instances + instances := processInstances(response.Items) + + // Apply filters + filtered := filterInstances(instances, gpuName, minVRAM, minTotalVRAM, minCapability) + + if len(filtered) == 0 { + t.Vprint(t.Yellow("No GPU instances match the specified filters")) + return nil + } + + // Sort instances + sortInstances(filtered, sortBy, descending) + + // Display results + displayGPUTable(t, filtered) + + t.Vprintf("\n%s\n", t.Green(fmt.Sprintf("Found %d GPU instance types", len(filtered)))) + + return nil +} + +// parseMemoryToGB converts memory string like "22GiB360MiB" or "40GiB" to GB +func parseMemoryToGB(memory string) float64 { + // Handle memory_bytes if provided (in MiB) + // Otherwise parse the string format + + var totalGB float64 + + // Match GiB values + gibRe := regexp.MustCompile(`(\d+(?:\.\d+)?)\s*GiB`) + gibMatches := gibRe.FindAllStringSubmatch(memory, -1) + for _, match := range gibMatches { + if val, err := strconv.ParseFloat(match[1], 64); err == nil { + totalGB += val + } + } + + // Match MiB values and convert to GB + mibRe := regexp.MustCompile(`(\d+(?:\.\d+)?)\s*MiB`) + mibMatches := mibRe.FindAllStringSubmatch(memory, -1) + for _, match := range mibMatches { + if val, err := strconv.ParseFloat(match[1], 64); err == nil { + totalGB += val / 1024 + } + } + + // Match GB values (in case API uses GB instead of GiB) + gbRe := regexp.MustCompile(`(\d+(?:\.\d+)?)\s*GB`) + gbMatches := gbRe.FindAllStringSubmatch(memory, -1) + for _, match := range gbMatches { + if val, err := strconv.ParseFloat(match[1], 64); err == nil { + totalGB += val + } + } + + return totalGB +} + +// gpuCapabilityEntry represents a GPU pattern and its compute capability +type gpuCapabilityEntry struct { + pattern string + capability float64 +} + +// getGPUCapability returns the compute capability for known GPU types +func getGPUCapability(gpuName string) float64 { + gpuName = strings.ToUpper(gpuName) + + // Order matters: more specific patterns must come before less specific ones + // (e.g., "A100" before "A10", "L40S" before "L40") + capabilities := []gpuCapabilityEntry{ + // NVIDIA Professional (before other RTX patterns) + {"RTXPRO6000", 12.0}, + + // NVIDIA Blackwell + {"B200", 10.0}, + {"RTX5090", 10.0}, + + // NVIDIA Hopper + {"H100", 9.0}, + {"H200", 9.0}, + + // NVIDIA Ada Lovelace (L40S before L40, L4; RTX*Ada before RTX*) + {"L40S", 8.9}, + {"L40", 8.9}, + {"L4", 8.9}, + {"RTX6000ADA", 8.9}, + {"RTX4000ADA", 8.9}, + {"RTX4090", 8.9}, + {"RTX4080", 8.9}, + + // NVIDIA Ampere (A100 before A10G, A10) + {"A100", 8.0}, + {"A10G", 8.6}, + {"A10", 8.6}, + {"A40", 8.6}, + {"A6000", 8.6}, + {"A5000", 8.6}, + {"A4000", 8.6}, + {"A30", 8.0}, + {"A16", 8.6}, + {"RTX3090", 8.6}, + {"RTX3080", 8.6}, + + // NVIDIA Turing + {"T4", 7.5}, + {"RTX6000", 7.5}, + {"RTX2080", 7.5}, + + // NVIDIA Volta + {"V100", 7.0}, + + // NVIDIA Pascal (P100 before P40, P4) + {"P100", 6.0}, + {"P40", 6.1}, + {"P4", 6.1}, + + // NVIDIA Maxwell + {"M60", 5.2}, + + // NVIDIA Kepler + {"K80", 3.7}, + + // Gaudi (Habana) - not CUDA compatible + {"HL-205", 0}, + {"GAUDI3", 0}, + {"GAUDI2", 0}, + {"GAUDI", 0}, + } + + for _, entry := range capabilities { + if strings.Contains(gpuName, entry.pattern) { + return entry.capability + } + } + return 0 +} + +// processInstances converts raw instance types to GPUInstanceInfo +func processInstances(items []InstanceType) []GPUInstanceInfo { + var instances []GPUInstanceInfo + + for _, item := range items { + if len(item.SupportedGPUs) == 0 { + continue // Skip non-GPU instances + } + + for _, gpu := range item.SupportedGPUs { + vramPerGPU := parseMemoryToGB(gpu.Memory) + // Also check memory_bytes as fallback + if vramPerGPU == 0 && gpu.MemoryBytes.Value > 0 { + // Convert based on unit + if gpu.MemoryBytes.Unit == "MiB" { + vramPerGPU = float64(gpu.MemoryBytes.Value) / 1024 // MiB to GiB + } else if gpu.MemoryBytes.Unit == "GiB" { + vramPerGPU = float64(gpu.MemoryBytes.Value) + } + } + + totalVRAM := vramPerGPU * float64(gpu.Count) + capability := getGPUCapability(gpu.Name) + + price := 0.0 + if item.BasePrice.Amount != "" { + price, _ = strconv.ParseFloat(item.BasePrice.Amount, 64) + } + + instances = append(instances, GPUInstanceInfo{ + Type: item.Type, + GPUName: gpu.Name, + GPUCount: gpu.Count, + VRAMPerGPU: vramPerGPU, + TotalVRAM: totalVRAM, + Capability: capability, + VCPUs: item.VCPU, + Memory: item.Memory, + PricePerHour: price, + Manufacturer: gpu.Manufacturer, + }) + } + } + + return instances +} + +// filterInstances applies all filters to the instance list +func filterInstances(instances []GPUInstanceInfo, gpuName string, minVRAM, minTotalVRAM, minCapability float64) []GPUInstanceInfo { + var filtered []GPUInstanceInfo + + for _, inst := range instances { + // Filter out non-NVIDIA GPUs (AMD, Intel/Habana, etc.) + if !strings.Contains(strings.ToUpper(inst.Manufacturer), "NVIDIA") { + continue + } + + // Filter by GPU name (case-insensitive partial match) + if gpuName != "" && !strings.Contains(strings.ToLower(inst.GPUName), strings.ToLower(gpuName)) { + continue + } + + // Filter by minimum VRAM per GPU + if minVRAM > 0 && inst.VRAMPerGPU < minVRAM { + continue + } + + // Filter by minimum total VRAM + if minTotalVRAM > 0 && inst.TotalVRAM < minTotalVRAM { + continue + } + + // Filter by minimum GPU capability + if minCapability > 0 && inst.Capability < minCapability { + continue + } + + filtered = append(filtered, inst) + } + + return filtered +} + +// sortInstances sorts the instance list by the specified column +func sortInstances(instances []GPUInstanceInfo, sortBy string, descending bool) { + sort.Slice(instances, func(i, j int) bool { + var less bool + switch strings.ToLower(sortBy) { + case "price": + less = instances[i].PricePerHour < instances[j].PricePerHour + case "gpu-count": + less = instances[i].GPUCount < instances[j].GPUCount + case "vram": + less = instances[i].VRAMPerGPU < instances[j].VRAMPerGPU + case "total-vram": + less = instances[i].TotalVRAM < instances[j].TotalVRAM + case "vcpu": + less = instances[i].VCPUs < instances[j].VCPUs + case "type": + less = instances[i].Type < instances[j].Type + case "capability": + less = instances[i].Capability < instances[j].Capability + default: + less = instances[i].PricePerHour < instances[j].PricePerHour + } + + if descending { + return !less + } + return less + }) +} + +// getBrevTableOptions returns table styling options +func getBrevTableOptions() table.Options { + options := table.OptionsDefault + options.DrawBorder = false + options.SeparateColumns = false + options.SeparateRows = false + options.SeparateHeader = false + return options +} + +// displayGPUTable renders the GPU instances as a table +func displayGPUTable(t *terminal.Terminal, instances []GPUInstanceInfo) { + ta := table.NewWriter() + ta.SetOutputMirror(os.Stdout) + ta.Style().Options = getBrevTableOptions() + + header := table.Row{"TYPE", "GPU", "COUNT", "VRAM/GPU", "TOTAL VRAM", "CAPABILITY", "VCPUs", "$/HR"} + ta.AppendHeader(header) + + for _, inst := range instances { + vramStr := fmt.Sprintf("%.0f GB", inst.VRAMPerGPU) + totalVramStr := fmt.Sprintf("%.0f GB", inst.TotalVRAM) + capStr := "-" + if inst.Capability > 0 { + capStr = fmt.Sprintf("%.1f", inst.Capability) + } + priceStr := fmt.Sprintf("$%.2f", inst.PricePerHour) + + row := table.Row{ + inst.Type, + t.Green(inst.GPUName), + inst.GPUCount, + vramStr, + totalVramStr, + capStr, + inst.VCPUs, + priceStr, + } + ta.AppendRow(row) + } + + ta.Render() +} diff --git a/pkg/cmd/gpusearch/gpusearch_test.go b/pkg/cmd/gpusearch/gpusearch_test.go new file mode 100644 index 00000000..0714874f --- /dev/null +++ b/pkg/cmd/gpusearch/gpusearch_test.go @@ -0,0 +1,388 @@ +package gpusearch + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// MockGPUSearchStore is a mock implementation of GPUSearchStore for testing +type MockGPUSearchStore struct { + Response *InstanceTypesResponse + Err error +} + +func (m *MockGPUSearchStore) GetInstanceTypes() (*InstanceTypesResponse, error) { + if m.Err != nil { + return nil, m.Err + } + return m.Response, nil +} + +func createTestInstanceTypes() *InstanceTypesResponse { + return &InstanceTypesResponse{ + Items: []InstanceType{ + { + Type: "g5.xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "A10G", Manufacturer: "NVIDIA", Memory: "24GiB"}, + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "1.006"}, + }, + { + Type: "g5.2xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "A10G", Manufacturer: "NVIDIA", Memory: "24GiB"}, + }, + Memory: "32GiB", + VCPU: 8, + BasePrice: BasePrice{Currency: "USD", Amount: "1.212"}, + }, + { + Type: "p3.2xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "V100", Manufacturer: "NVIDIA", Memory: "16GiB"}, + }, + Memory: "61GiB", + VCPU: 8, + BasePrice: BasePrice{Currency: "USD", Amount: "3.06"}, + }, + { + Type: "p3.8xlarge", + SupportedGPUs: []GPU{ + {Count: 4, Name: "V100", Manufacturer: "NVIDIA", Memory: "16GiB"}, + }, + Memory: "244GiB", + VCPU: 32, + BasePrice: BasePrice{Currency: "USD", Amount: "12.24"}, + }, + { + Type: "p4d.24xlarge", + SupportedGPUs: []GPU{ + {Count: 8, Name: "A100", Manufacturer: "NVIDIA", Memory: "40GiB"}, + }, + Memory: "1152GiB", + VCPU: 96, + BasePrice: BasePrice{Currency: "USD", Amount: "32.77"}, + }, + { + Type: "g4dn.xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "T4", Manufacturer: "NVIDIA", Memory: "16GiB"}, + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "0.526"}, + }, + { + Type: "g6.xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "L4", Manufacturer: "NVIDIA", Memory: "24GiB"}, + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "0.805"}, + }, + }, + } +} + +func TestParseMemoryToGB(t *testing.T) { + tests := []struct { + name string + input string + expected float64 + }{ + {"Simple GiB", "24GiB", 24}, + {"GiB with MiB", "22GiB360MiB", 22.3515625}, + {"Simple GB", "16GB", 16}, + {"Large GiB", "1152GiB", 1152}, + {"Empty string", "", 0}, + {"MiB only", "512MiB", 0.5}, + {"With spaces", "24 GiB", 24}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseMemoryToGB(tt.input) + assert.InDelta(t, tt.expected, result, 0.01, "Memory parsing failed for %s", tt.input) + }) + } +} + +func TestGetGPUCapability(t *testing.T) { + tests := []struct { + name string + gpuName string + expected float64 + }{ + {"A100", "A100", 8.0}, + {"A10G", "A10G", 8.6}, + {"V100", "V100", 7.0}, + {"T4", "T4", 7.5}, + {"L4", "L4", 8.9}, + {"L40S", "L40S", 8.9}, + {"H100", "H100", 9.0}, + {"Unknown GPU", "Unknown", 0}, + {"Case insensitive", "a100", 8.0}, + {"Gaudi", "HL-205", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getGPUCapability(tt.gpuName) + assert.Equal(t, tt.expected, result, "GPU capability mismatch for %s", tt.gpuName) + }) + } +} + +func TestProcessInstances(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + assert.Len(t, instances, 7, "Expected 7 GPU instances") + + // Check specific instance + var a10gInstance *GPUInstanceInfo + for i := range instances { + if instances[i].Type == "g5.xlarge" { + a10gInstance = &instances[i] + break + } + } + + assert.NotNil(t, a10gInstance, "g5.xlarge instance should exist") + assert.Equal(t, "A10G", a10gInstance.GPUName) + assert.Equal(t, 1, a10gInstance.GPUCount) + assert.Equal(t, 24.0, a10gInstance.VRAMPerGPU) + assert.Equal(t, 24.0, a10gInstance.TotalVRAM) + assert.Equal(t, 8.6, a10gInstance.Capability) + assert.Equal(t, 4, a10gInstance.VCPUs) + assert.InDelta(t, 1.006, a10gInstance.PricePerHour, 0.001) +} + +func TestFilterInstancesByGPUName(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Filter by A10G + filtered := filterInstances(instances, "A10G", 0, 0, 0) + assert.Len(t, filtered, 2, "Should have 2 A10G instances") + + // Filter by V100 + filtered = filterInstances(instances, "V100", 0, 0, 0) + assert.Len(t, filtered, 2, "Should have 2 V100 instances") + + // Filter by lowercase (case-insensitive) + filtered = filterInstances(instances, "v100", 0, 0, 0) + assert.Len(t, filtered, 2, "Should have 2 V100 instances (case-insensitive)") + + // Filter by partial match + filtered = filterInstances(instances, "A1", 0, 0, 0) + assert.Len(t, filtered, 3, "Should have 3 instances matching 'A1' (A10G and A100)") +} + +func TestFilterInstancesByMinVRAM(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Filter by min VRAM 24GB + filtered := filterInstances(instances, "", 24, 0, 0) + assert.Len(t, filtered, 4, "Should have 4 instances with >= 24GB VRAM") + + // Filter by min VRAM 40GB + filtered = filterInstances(instances, "", 40, 0, 0) + assert.Len(t, filtered, 1, "Should have 1 instance with >= 40GB VRAM") + assert.Equal(t, "A100", filtered[0].GPUName) +} + +func TestFilterInstancesByMinTotalVRAM(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Filter by min total VRAM 60GB + filtered := filterInstances(instances, "", 0, 60, 0) + assert.Len(t, filtered, 2, "Should have 2 instances with >= 60GB total VRAM") + + // Filter by min total VRAM 300GB + filtered = filterInstances(instances, "", 0, 300, 0) + assert.Len(t, filtered, 1, "Should have 1 instance with >= 300GB total VRAM") + assert.Equal(t, "p4d.24xlarge", filtered[0].Type) +} + +func TestFilterInstancesByMinCapability(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Filter by capability >= 8.0 + filtered := filterInstances(instances, "", 0, 0, 8.0) + assert.Len(t, filtered, 4, "Should have 4 instances with capability >= 8.0") + + // Filter by capability >= 8.5 + filtered = filterInstances(instances, "", 0, 0, 8.5) + assert.Len(t, filtered, 3, "Should have 3 instances with capability >= 8.5") +} + +func TestFilterInstancesCombined(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Filter by GPU name and min VRAM + filtered := filterInstances(instances, "A10G", 24, 0, 0) + assert.Len(t, filtered, 2, "Should have 2 A10G instances with >= 24GB VRAM") + + // Filter by GPU name, min VRAM, and capability + filtered = filterInstances(instances, "", 24, 0, 8.5) + assert.Len(t, filtered, 3, "Should have 3 instances with >= 24GB VRAM and capability >= 8.5") +} + +func TestSortInstancesByPrice(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Sort by price ascending + sortInstances(instances, "price", false) + assert.Equal(t, "g4dn.xlarge", instances[0].Type, "Cheapest should be g4dn.xlarge") + assert.Equal(t, "p4d.24xlarge", instances[len(instances)-1].Type, "Most expensive should be p4d.24xlarge") + + // Sort by price descending + sortInstances(instances, "price", true) + assert.Equal(t, "p4d.24xlarge", instances[0].Type, "Most expensive should be first when descending") +} + +func TestSortInstancesByGPUCount(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Sort by GPU count ascending + sortInstances(instances, "gpu-count", false) + assert.Equal(t, 1, instances[0].GPUCount, "Instances with 1 GPU should be first") + + // Sort by GPU count descending + sortInstances(instances, "gpu-count", true) + assert.Equal(t, 8, instances[0].GPUCount, "Instance with 8 GPUs should be first when descending") +} + +func TestSortInstancesByVRAM(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Sort by VRAM ascending + sortInstances(instances, "vram", false) + assert.Equal(t, 16.0, instances[0].VRAMPerGPU, "Instances with 16GB VRAM should be first") + + // Sort by VRAM descending + sortInstances(instances, "vram", true) + assert.Equal(t, 40.0, instances[0].VRAMPerGPU, "Instance with 40GB VRAM should be first when descending") +} + +func TestSortInstancesByTotalVRAM(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Sort by total VRAM ascending + sortInstances(instances, "total-vram", false) + assert.Equal(t, 16.0, instances[0].TotalVRAM, "Instances with 16GB total VRAM should be first") + + // Sort by total VRAM descending + sortInstances(instances, "total-vram", true) + assert.Equal(t, 320.0, instances[0].TotalVRAM, "Instance with 320GB total VRAM should be first when descending") +} + +func TestSortInstancesByVCPU(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Sort by vCPU ascending + sortInstances(instances, "vcpu", false) + assert.Equal(t, 4, instances[0].VCPUs, "Instances with 4 vCPUs should be first") + + // Sort by vCPU descending + sortInstances(instances, "vcpu", true) + assert.Equal(t, 96, instances[0].VCPUs, "Instance with 96 vCPUs should be first when descending") +} + +func TestSortInstancesByCapability(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Sort by capability ascending + sortInstances(instances, "capability", false) + assert.Equal(t, 7.0, instances[0].Capability, "Instances with capability 7.0 should be first") + + // Sort by capability descending + sortInstances(instances, "capability", true) + assert.Equal(t, 8.9, instances[0].Capability, "Instance with capability 8.9 should be first when descending") +} + +func TestSortInstancesByType(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Sort by type ascending + sortInstances(instances, "type", false) + assert.Equal(t, "g4dn.xlarge", instances[0].Type, "g4dn.xlarge should be first alphabetically") + + // Sort by type descending + sortInstances(instances, "type", true) + assert.Equal(t, "p4d.24xlarge", instances[0].Type, "p4d.24xlarge should be first when descending") +} + +func TestEmptyInstanceTypes(t *testing.T) { + response := &InstanceTypesResponse{Items: []InstanceType{}} + instances := processInstances(response.Items) + + assert.Len(t, instances, 0, "Should have 0 instances") + + filtered := filterInstances(instances, "A100", 0, 0, 0) + assert.Len(t, filtered, 0, "Filtered should also be empty") +} + +func TestNonGPUInstancesAreFiltered(t *testing.T) { + response := &InstanceTypesResponse{ + Items: []InstanceType{ + { + Type: "m5.xlarge", + SupportedGPUs: []GPU{}, // No GPUs + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "0.192"}, + }, + { + Type: "g5.xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "A10G", Manufacturer: "NVIDIA", Memory: "24GiB"}, + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "1.006"}, + }, + }, + } + + instances := processInstances(response.Items) + assert.Len(t, instances, 1, "Should only have 1 GPU instance, non-GPU instances should be filtered") + assert.Equal(t, "g5.xlarge", instances[0].Type) +} + +func TestMemoryBytesAsFallback(t *testing.T) { + response := &InstanceTypesResponse{ + Items: []InstanceType{ + { + Type: "test.xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "TestGPU", Manufacturer: "NVIDIA", Memory: "", MemoryBytes: MemoryBytes{Value: 24576, Unit: "MiB"}}, // 24GB in MiB + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "1.00"}, + }, + }, + } + + instances := processInstances(response.Items) + assert.Len(t, instances, 1) + assert.Equal(t, 24.0, instances[0].VRAMPerGPU, "Should fall back to MemoryBytes when Memory string is empty") +} diff --git a/pkg/store/instancetypes.go b/pkg/store/instancetypes.go new file mode 100644 index 00000000..4f12710a --- /dev/null +++ b/pkg/store/instancetypes.go @@ -0,0 +1,48 @@ +package store + +import ( + "encoding/json" + + "github.com/brevdev/brev-cli/pkg/cmd/gpusearch" + breverrors "github.com/brevdev/brev-cli/pkg/errors" + resty "github.com/go-resty/resty/v2" +) + +const ( + instanceTypesAPIURL = "https://api.brev.dev" + instanceTypesAPIPath = "v1/instance/types" +) + +// GetInstanceTypes fetches all available instance types from the public API +func (s NoAuthHTTPStore) GetInstanceTypes() (*gpusearch.InstanceTypesResponse, error) { + return fetchInstanceTypes() +} + +// GetInstanceTypes fetches all available instance types from the public API +func (s AuthHTTPStore) GetInstanceTypes() (*gpusearch.InstanceTypesResponse, error) { + return fetchInstanceTypes() +} + +// fetchInstanceTypes fetches instance types from the public Brev API +func fetchInstanceTypes() (*gpusearch.InstanceTypesResponse, error) { + client := resty.New() + client.SetBaseURL(instanceTypesAPIURL) + + res, err := client.R(). + SetHeader("Accept", "application/json"). + Get(instanceTypesAPIPath) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if res.IsError() { + return nil, NewHTTPResponseError(res) + } + + var result gpusearch.InstanceTypesResponse + err = json.Unmarshal(res.Body(), &result) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + return &result, nil +}