diff --git a/.golangci.yml b/.golangci.yml index a245d98..f4845d8 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -14,6 +14,8 @@ linters: - ineffassign - staticcheck - unused + - goimports + - gofmt # Linter-specific settings linters-settings: @@ -25,7 +27,11 @@ linters-settings: # Issues configuration issues: + exclude-dirs: + - vendor max-issues-per-linter: 0 max-same-issues: 0 exclude: - "Error return value is not checked.*\\.Close\\(" + + diff --git a/cmd/hardn/main.go b/cmd/hardn/main.go index e62d26a..2a0051b 100644 --- a/cmd/hardn/main.go +++ b/cmd/hardn/main.go @@ -9,17 +9,12 @@ import ( "github.com/spf13/cobra" "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/dns" - "github.com/abbott/hardn/pkg/firewall" + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/infrastructure" + "github.com/abbott/hardn/pkg/interfaces" "github.com/abbott/hardn/pkg/logging" - "github.com/abbott/hardn/pkg/menu" "github.com/abbott/hardn/pkg/osdetect" - "github.com/abbott/hardn/pkg/packages" - "github.com/abbott/hardn/pkg/security" - "github.com/abbott/hardn/pkg/ssh" - "github.com/abbott/hardn/pkg/updates" - "github.com/abbott/hardn/pkg/user" - "github.com/abbott/hardn/pkg/utils" + "github.com/abbott/hardn/pkg/version" ) // Version information - populated by build flags @@ -30,24 +25,29 @@ var ( ) var ( - configFile string - username string - dryRun bool - createUser bool - disableRoot bool - installLinux bool - installPython bool - installAll bool - configureUfw bool - configureDns bool - runAll bool - updateSources bool - printLogs bool - showVersion bool // Flag to display version information - setupSudoEnv bool - cfg *config.Config + configFile string + username string + dryRun bool + createUser bool + disableRoot bool + installLinux bool + installPython bool + installAll bool + configureUfw bool + configureDns bool + runAll bool + updateSources bool + printLogs bool + showVersion bool // Flag to display version information + setupSudoEnv bool + debugUpdates bool + testUpdateAvailable bool + cfg *config.Config ) +// Create provider as a global for dependency injection +var provider = interfaces.NewProvider() + func main() { // Setup colors color.NoColor = false @@ -69,11 +69,18 @@ func main() { } func init() { + // Set version for help output + rootCmd.Version = Version + + if rootCmd.Version != "" { + logging.LogInfo("Current version :::: : %s", rootCmd.Version) + } + rootCmd.PersistentFlags().StringVarP(&configFile, "config", "f", "", "Specify configuration file path") rootCmd.AddCommand(setupSudoEnvCmd) - // "Specify configuration file path (optionally set HARDN_CONFIG as variable)") + rootCmd.PersistentFlags().StringVarP(&username, "username", "u", "", "Specify username to create") rootCmd.PersistentFlags().BoolVarP(&createUser, "create-user", "c", false, "Create non-root user with sudo access") rootCmd.PersistentFlags().BoolVarP(&disableRoot, "disable-root", "d", false, "Disable root SSH access") @@ -89,6 +96,8 @@ func init() { rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "Show version information") rootCmd.PersistentFlags().BoolVarP(&setupSudoEnv, "setup-sudo-env", "e", false, "Configure sudoers to preserve HARDN_CONFIG environment variable") + rootCmd.PersistentFlags().BoolVar(&debugUpdates, "debug-updates", false, "Enable debugging for update checks") + rootCmd.PersistentFlags().BoolVar(&testUpdateAvailable, "test-update", false, "Force update notification for testing") } var rootCmd = &cobra.Command{ @@ -96,16 +105,12 @@ var rootCmd = &cobra.Command{ Short: "Linux hardening tool", Long: `A simple hardening tool for Debian, Ubuntu, Proxmox and Alpine Linux.`, Run: func(cmd *cobra.Command, args []string) { + // Create version service + versionService := version.NewService(Version, BuildDate, GitCommit) + // Check if version flag is set and display version info if showVersion { - fmt.Println("hardn - Linux hardening tool") - fmt.Printf("Version: %s\n", Version) - if BuildDate != "" { - fmt.Printf("Build Date: %s\n", BuildDate) - } - if GitCommit != "" { - fmt.Printf("Git Commit: %s\n", GitCommit) - } + versionService.PrintVersionInfo() return } @@ -151,103 +156,235 @@ var rootCmd = &cobra.Command{ os.Exit(1) } - // If no flags provided, show menu + // Create service factory + serviceFactory := infrastructure.NewServiceFactory(provider, osInfo) + serviceFactory.SetConfig(cfg) + + // If no specific flags provided, show the interactive menu if !createUser && !disableRoot && !installLinux && !installPython && !installAll && !configureUfw && !configureDns && !runAll && - !updateSources && !printLogs { - menu.ShowMainMenu(cfg, osInfo) + !updateSources && !printLogs && !setupSudoEnv { + + // Create menu factory and main menu with version service + menuFactory := infrastructure.NewMenuFactory(serviceFactory, cfg, osInfo) + mainMenu := menuFactory.CreateMainMenu(versionService) + + if testUpdateAvailable { + // Force the update notification to appear with a hard-coded newer version + mainMenu.SetTestUpdateAvailable("99.0.0") + } + + // Show main menu with version info + mainMenu.ShowMainMenu(Version, BuildDate, GitCommit) return } - // Process command line flags + // Process command line options using the new architecture + + // Get required managers + sshManager := serviceFactory.CreateSSHManager() + firewallManager := serviceFactory.CreateFirewallManager() + dnsManager := serviceFactory.CreateDNSManager() + packageManager := serviceFactory.CreatePackageManager() + userManager := serviceFactory.CreateUserManager() + menuManager := serviceFactory.CreateMenuManager() + environmentManager := serviceFactory.CreateEnvironmentManager() + + // Handle a complete system hardening request if runAll { - runAllHardening(cfg, osInfo) + logging.LogInfo("Running complete system hardening...") + + // Create a comprehensive hardening configuration + hardeningConfig := &model.HardeningConfig{ + CreateUser: cfg.Username != "", + Username: cfg.Username, + SudoNoPassword: cfg.SudoNoPassword, + SshKeys: cfg.SshKeys, + SshPort: cfg.SshPort, + SshListenAddresses: []string{cfg.SshListenAddress}, + SshAllowedUsers: cfg.SshAllowedUsers, + EnableFirewall: cfg.EnableUfwSshPolicy, + AllowedPorts: []int{}, + FirewallProfiles: []model.FirewallProfile{}, + ConfigureDns: cfg.ConfigureDns, + Nameservers: cfg.Nameservers, + EnableAppArmor: cfg.EnableAppArmor, + EnableLynis: cfg.EnableLynis, + EnableUnattendedUpgrades: cfg.EnableUnattendedUpgrades, + } + + // Run all hardening steps + if err := menuManager.HardenSystem(hardeningConfig); err != nil { + logging.LogError("Failed to complete system hardening: %v", err) + } else { + logging.LogSuccess("System hardening completed successfully!") + fmt.Printf("Check the log file at %s for details.\n", cfg.LogFile) + } return } - // Handle individual operations - if updateSources || installLinux || installPython || installAll || createUser || runAll { - packages.WriteSources(cfg, osInfo) + // Handle individual operations based on flags + + // Update package sources + if updateSources { + if err := packageManager.UpdatePackageSources(); err != nil { + logging.LogError("Failed to update package sources: %v", err) + } else { + logging.LogSuccess("Package sources updated") + } + + // Handle Proxmox-specific sources if osInfo.OsType != "alpine" && osInfo.IsProxmox { - packages.WriteProxmoxRepos(cfg, osInfo) + if err := packageManager.UpdateProxmoxSources(); err != nil { + logging.LogError("Failed to update Proxmox sources: %v", err) + } else { + logging.LogSuccess("Proxmox sources updated") + } } } + // Disable root SSH access if disableRoot { - err := ssh.DisableRootSSHAccess(cfg, osInfo) - if err != nil { + if err := sshManager.DisableRootAccess(); err != nil { logging.LogError("Failed to disable root SSH access: %v", err) } else { - logging.LogSuccess("Disabled root SSH access") + logging.LogSuccess("Root SSH access disabled") } } - if installPython || installAll { - packages.InstallPythonPackages(cfg, osInfo) - } + // Install Linux packages + if installLinux || installAll { + logging.LogInfo("Installing Linux packages...") - if installLinux || installAll || runAll { - installLinuxPackages(cfg, osInfo) + if installAll { + // Use the enhanced method that handles all package types appropriately + if err := packageManager.InstallAllLinuxPackages(); err != nil { + logging.LogError("Failed to install Linux packages: %v", err) + } else { + logging.LogSuccess("All Linux packages installed successfully") + } + } else { + // Just install core packages when specifically requested + if osInfo.OsType == "alpine" && len(cfg.AlpineCorePackages) > 0 { + if err := packageManager.InstallLinuxPackages(cfg.AlpineCorePackages, "core"); err != nil { + logging.LogError("Failed to install Alpine core packages: %v", err) + } else { + logging.LogSuccess("Alpine core packages installed successfully") + } + } else if len(cfg.LinuxCorePackages) > 0 { + if err := packageManager.InstallLinuxPackages(cfg.LinuxCorePackages, "core"); err != nil { + logging.LogError("Failed to install Linux core packages: %v", err) + } else { + logging.LogSuccess("Linux core packages installed successfully") + } + } + } } - if createUser || runAll { - // Install sudo if needed - if osInfo.OsType == "alpine" { - if !packages.IsPackageInstalled("sudo") { - packages.InstallPackages([]string{"sudo"}, osInfo, cfg) + // Install Python packages + // Here's the update for installPython code path + if installPython || installAll { + logging.LogInfo("Installing Python packages...") + + if installAll { + // Use the enhanced method for all Python packages + if err := packageManager.InstallAllPythonPackages(cfg.UseUvPackageManager); err != nil { + logging.LogError("Failed to install Python packages: %v", err) + } else { + logging.LogSuccess("All Python packages installed successfully") } } else { - if !packages.IsPackageInstalled("sudo") { - packages.InstallPackages([]string{"sudo"}, osInfo, cfg) + // Handle specific Python package installation + if osInfo.OsType == "alpine" && len(cfg.AlpinePythonPackages) > 0 { + if err := packageManager.InstallPythonPackages( + cfg.AlpinePythonPackages, + cfg.PythonPipPackages, + cfg.UseUvPackageManager, + ); err != nil { + logging.LogError("Failed to install Alpine Python packages: %v", err) + } else { + logging.LogSuccess("Alpine Python packages installed successfully") + } + } else { + // For Debian/Ubuntu + pythonPackages := cfg.PythonPackages + // Add non-WSL packages if not in WSL + if os.Getenv("WSL") == "" && len(cfg.NonWslPythonPackages) > 0 { + pythonPackages = append(pythonPackages, cfg.NonWslPythonPackages...) + } + + if err := packageManager.InstallPythonPackages( + pythonPackages, + cfg.PythonPipPackages, + cfg.UseUvPackageManager, + ); err != nil { + logging.LogError("Failed to install Python packages: %v", err) + } else { + logging.LogSuccess("Python packages installed successfully") + } } } + } - err := user.CreateUser(cfg.Username, cfg, osInfo) - if err != nil { + // Create user + if createUser { + if err := userManager.CreateUser(cfg.Username, true, cfg.SudoNoPassword, cfg.SshKeys); err != nil { logging.LogError("Failed to create user: %v", err) + } else { + logging.LogSuccess("User '%s' created successfully", cfg.Username) } - ssh.WriteSSHConfig(cfg, osInfo) - } - - if configureUfw || runAll { - firewall.ConfigureUFW(cfg, osInfo) - } - - if configureDns || runAll { - dns.ConfigureDNS(cfg, osInfo) - } - if runAll && cfg.EnableAppArmor { - security.SetupAppArmor(cfg, osInfo) + // Configure SSH after user creation + // TODO: This might need to be refactored to avoid duplicating the SSH configuration + if err := sshManager.ConfigureSSH( + cfg.SshPort, + []string{cfg.SshListenAddress}, + cfg.PermitRootLogin, + cfg.SshAllowedUsers, + []string{cfg.SshKeyPath}, + ); err != nil { + logging.LogError("Failed to configure SSH: %v", err) + } } - if runAll && cfg.EnableLynis { - security.SetupLynis(cfg, osInfo) + // Configure firewall + if configureUfw { + if err := firewallManager.ConfigureSecureFirewall(cfg.SshPort, []int{}, []model.FirewallProfile{}); err != nil { + logging.LogError("Failed to configure firewall: %v", err) + } else { + logging.LogSuccess("Firewall configured successfully") + } } - if runAll && cfg.EnableUnattendedUpgrades { - updates.SetupUnattendedUpgrades(cfg, osInfo) + // Configure DNS + if configureDns { + if err := dnsManager.ConfigureDNS(cfg.Nameservers, "lan"); err != nil { + logging.LogError("Failed to configure DNS: %v", err) + } else { + logging.LogSuccess("DNS configured successfully") + } } + // Print logs if printLogs { logging.PrintLogs(cfg.LogFile) } - // Output completion message - if runAll { - logging.LogSuccess("Script completed all hardening operations.") - } else if createUser || disableRoot || installLinux || installPython || - installAll || configureUfw || configureDns || updateSources { - logging.LogSuccess("Script completed selected hardening operations.") - } - + // Setting up sudo environment preservation if setupSudoEnv { - if err := utils.SetupSudoEnvPreservation(); err != nil { + if err := environmentManager.SetupSudoPreservation(); err != nil { logging.LogError("Failed to configure sudoers: %v", err) os.Exit(1) } + logging.LogSuccess("Sudo environment configured to preserve HARDN_CONFIG") return } + + // Output completion message for operations other than the all-in-one run + if createUser || disableRoot || installLinux || installPython || + installAll || configureUfw || configureDns || updateSources { + logging.LogSuccess("Script completed selected hardening operations.") + } }, } @@ -264,128 +401,21 @@ This command must be run with sudo privileges. Example: sudo hardn setup-sudo-env`, Run: func(cmd *cobra.Command, args []string) { - if err := utils.SetupSudoEnvPreservation(); err != nil { - logging.LogError("Failed to configure sudoers: %v", err) - os.Exit(1) - } - }, -} - -// Run all hardening operations -func runAllHardening(cfg *config.Config, osInfo *osdetect.OSInfo) { - utils.PrintLogo() - logging.LogInfo("Running complete system hardening...") - - // Setup hushlogin - utils.SetupHushlogin(cfg) - - // Update package repositories - packages.WriteSources(cfg, osInfo) - if osInfo.OsType != "alpine" && osInfo.IsProxmox { - packages.WriteProxmoxRepos(cfg, osInfo) - } - - // Install packages - installLinuxPackages(cfg, osInfo) - - // Create user - if cfg.Username != "" { - err := user.CreateUser(cfg.Username, cfg, osInfo) + // Detect OS + osInfo, err := osdetect.DetectOS() if err != nil { - logging.LogError("Failed to create user: %v", err) - } - } - - // Configure SSH - ssh.WriteSSHConfig(cfg, osInfo) - - // Disable root SSH access if requested - if cfg.DisableRoot { - ssh.DisableRootSSHAccess(cfg, osInfo) - } - - // Configure UFW - if cfg.EnableUfwSshPolicy { - firewall.ConfigureUFW(cfg, osInfo) - } - - // Configure DNS - if cfg.ConfigureDns { - dns.ConfigureDNS(cfg, osInfo) - } - - // Setup AppArmor if enabled - if cfg.EnableAppArmor { - security.SetupAppArmor(cfg, osInfo) - } - - // Setup Lynis if enabled - if cfg.EnableLynis { - security.SetupLynis(cfg, osInfo) - } - - // Setup unattended upgrades if enabled - if cfg.EnableUnattendedUpgrades { - updates.SetupUnattendedUpgrades(cfg, osInfo) - } - - logging.LogSuccess("System hardening completed successfully!") - fmt.Printf("Check the log file at %s for details.\n", cfg.LogFile) -} - -// Install Linux packages based on OS type -func installLinuxPackages(cfg *config.Config, osInfo *osdetect.OSInfo) { - if osInfo.OsType == "alpine" { - fmt.Println("\nInstalling Alpine Linux packages...") - - // Install core Alpine packages first - if len(cfg.AlpineCorePackages) > 0 { - logging.LogInfo("Installing Alpine core packages...") - packages.InstallPackages(cfg.AlpineCorePackages, osInfo, cfg) + logging.LogError("Failed to detect OS: %v", err) + os.Exit(1) } - // Check subnet to determine which package sets to install - isDmz, _ := utils.CheckSubnet(cfg.DmzSubnet) - if isDmz { - if len(cfg.AlpineDmzPackages) > 0 { - logging.LogInfo("Installing Alpine DMZ packages...") - packages.InstallPackages(cfg.AlpineDmzPackages, osInfo, cfg) - } - } else { - // Install both - if len(cfg.AlpineDmzPackages) > 0 { - logging.LogInfo("Installing Alpine DMZ packages...") - packages.InstallPackages(cfg.AlpineDmzPackages, osInfo, cfg) - } - if len(cfg.AlpineLabPackages) > 0 { - logging.LogInfo("Installing Alpine LAB packages...") - packages.InstallPackages(cfg.AlpineLabPackages, osInfo, cfg) - } - } - } else { - // Install core Linux packages first - if len(cfg.LinuxCorePackages) > 0 { - logging.LogInfo("Installing Linux core packages...") - packages.InstallPackages(cfg.LinuxCorePackages, osInfo, cfg) - } + // Create service factory + serviceFactory := infrastructure.NewServiceFactory(provider, osInfo) + environmentManager := serviceFactory.CreateEnvironmentManager() - // Check subnet to determine which package sets to install - isDmz, _ := utils.CheckSubnet(cfg.DmzSubnet) - if isDmz { - if len(cfg.LinuxDmzPackages) > 0 { - logging.LogInfo("Installing Debian DMZ packages...") - packages.InstallPackages(cfg.LinuxDmzPackages, osInfo, cfg) - } - } else { - // Install both - if len(cfg.LinuxDmzPackages) > 0 { - logging.LogInfo("Installing Debian DMZ packages...") - packages.InstallPackages(cfg.LinuxDmzPackages, osInfo, cfg) - } - if len(cfg.LinuxLabPackages) > 0 { - logging.LogInfo("Installing Debian Lab packages...") - packages.InstallPackages(cfg.LinuxLabPackages, osInfo, cfg) - } + if err := environmentManager.SetupSudoPreservation(); err != nil { + logging.LogError("Failed to configure sudoers: %v", err) + os.Exit(1) } - } + logging.LogSuccess("Sudo environment configured to preserve HARDN_CONFIG") + }, } diff --git a/docs/configuration.md b/docs/configuration.md index b2d36a2..a0c056c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -147,7 +147,7 @@ sshAllowedUsers: # List of users allowed to access via SSH - "george" sshListenAddress: "0.0.0.0" # IP address to listen on sshKeyPath: ".ssh_%u" # Path to SSH keys (%u = username) -sshConfigFile: "/etc/ssh/sshd_config.d/manage.conf" # SSH config file location +sshConfigFile: "/etc/ssh/sshd_config.d/hardn.conf" # SSH config file location ``` **Important**: The `sshPort` setting is the single source of truth for SSH port configuration throughout the application. diff --git a/go.mod b/go.mod index 2967b35..a5f3509 100644 --- a/go.mod +++ b/go.mod @@ -1,18 +1,24 @@ module github.com/abbott/hardn -go 1.21 +go 1.23.0 + +toolchain go1.24.0 require ( github.com/fatih/color v1.18.0 github.com/spf13/cobra v1.9.1 - golang.org/x/text v0.22.0 + github.com/stretchr/testify v1.10.0 + golang.org/x/text v0.23.0 gopkg.in/yaml.v3 v3.0.1 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.6 // indirect + github.com/stretchr/objx v0.5.2 // indirect golang.org/x/sys v0.30.0 // indirect ) diff --git a/go.sum b/go.sum index f55533a..5d5a74b 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -7,16 +9,22 @@ github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHP github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= -golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/hardn.yml.example b/hardn.yml.example index 25052b5..9fbdaed 100644 --- a/hardn.yml.example +++ b/hardn.yml.example @@ -28,7 +28,7 @@ sshAllowedUsers: # List of users allowed to access via SSH - "george" sshListenAddress: "0.0.0.0" # IP address to listen on sshKeyPath: ".ssh_%u" # Path to SSH keys (use %u for username substitution) -sshConfigFile: "/etc/ssh/sshd_config.d/manage.conf" # SSH config file location +sshConfigFile: "/etc/ssh/sshd_config.d/hardn.conf" # SSH config file location ################################################# # User Configuration diff --git a/pkg/adapter/secondary/file_backup_repository.go b/pkg/adapter/secondary/file_backup_repository.go new file mode 100644 index 0000000..a75f78f --- /dev/null +++ b/pkg/adapter/secondary/file_backup_repository.go @@ -0,0 +1,259 @@ +// pkg/adapter/secondary/file_backup_repository.go +package secondary + +import ( + "fmt" + "os" + "path/filepath" + "time" + + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/interfaces" + "github.com/abbott/hardn/pkg/port/secondary" +) + +// FileBackupRepository implements BackupRepository using file operations +type FileBackupRepository struct { + fs interfaces.FileSystem + commander interfaces.Commander + config *model.BackupConfig +} + +// NewFileBackupRepository creates a new FileBackupRepository +func NewFileBackupRepository( + fs interfaces.FileSystem, + commander interfaces.Commander, + backupDir string, + enabled bool, +) secondary.BackupRepository { + return &FileBackupRepository{ + fs: fs, + commander: commander, + config: &model.BackupConfig{ + Enabled: enabled, + BackupDir: backupDir, + }, + } +} + +// BackupFile backs up a file with a timestamp +func (r *FileBackupRepository) BackupFile(filePath string) error { + if !r.config.Enabled { + return nil // Backups disabled, silently succeed + } + + // Create backup directory for today + backupDir := filepath.Join(r.config.BackupDir, time.Now().Format("2006-01-02")) + if err := r.fs.MkdirAll(backupDir, 0755); err != nil { + return fmt.Errorf("failed to create backup directory %s: %w", backupDir, err) + } + + // Get filename without path + fileName := filepath.Base(filePath) + + // Check if file exists + _, err := r.fs.Stat(filePath) + if os.IsNotExist(err) { + return nil // File doesn't exist, nothing to backup + } + + // Create backup with timestamp + backupFile := filepath.Join(backupDir, fmt.Sprintf("%s.%s.bak", fileName, time.Now().Format("150405"))) + + // Read original file + data, err := r.fs.ReadFile(filePath) + if err != nil { + return fmt.Errorf("failed to read file %s for backup: %w", filePath, err) + } + + // Write backup file + if err := r.fs.WriteFile(backupFile, data, 0644); err != nil { + return fmt.Errorf("failed to write backup file %s: %w", backupFile, err) + } + + return nil +} + +// ListBackups returns a list of all backups for a specific file +func (r *FileBackupRepository) ListBackups(filePath string) ([]model.BackupFile, error) { + var backups []model.BackupFile + + // Get filename without path + fileName := filepath.Base(filePath) + + // Walk through backup directory + if err := filepath.Walk(r.config.BackupDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return fmt.Errorf("error accessing path %s: %w", path, err) + } + + // Skip directories + if info.IsDir() { + return nil + } + + // Check if this is a backup of our file + if matched, err := filepath.Match(fmt.Sprintf("%s.*.bak", fileName), info.Name()); err != nil { + return fmt.Errorf("error matching pattern for file %s: %w", info.Name(), err) + } else if matched { + backup := model.BackupFile{ + OriginalPath: filePath, + BackupPath: path, + Created: info.ModTime(), + Size: info.Size(), + } + backups = append(backups, backup) + } + + return nil + }); err != nil { + return nil, fmt.Errorf("failed to list backups for %s: %w", filePath, err) + } + + return backups, nil +} + +// RestoreBackup restores a file from backup +func (r *FileBackupRepository) RestoreBackup(backupPath, originalPath string) error { + // Check if backup exists + fileInfo, err := r.fs.Stat(backupPath) + if os.IsNotExist(err) { + return fmt.Errorf("backup file %s does not exist", backupPath) + } + if err != nil { + return fmt.Errorf("failed to access backup file %s: %w", backupPath, err) + } + + // Make sure it's not a directory + if fileInfo.IsDir() { + return fmt.Errorf("backup path %s is a directory, not a file", backupPath) + } + + // Read backup file + data, err := r.fs.ReadFile(backupPath) + if err != nil { + return fmt.Errorf("failed to read backup file %s: %w", backupPath, err) + } + + // Create directory for restored file if needed + targetDir := filepath.Dir(originalPath) + if err := r.fs.MkdirAll(targetDir, 0755); err != nil { + return fmt.Errorf("failed to create directory %s for restored file: %w", targetDir, err) + } + + // Write restored file + if err := r.fs.WriteFile(originalPath, data, 0644); err != nil { + return fmt.Errorf("failed to write restored file %s: %w", originalPath, err) + } + + return nil +} + +// CleanupOldBackups removes backups older than specified date +func (r *FileBackupRepository) CleanupOldBackups(before time.Time) error { + // Check if backup directory exists + backupDirInfo, err := r.fs.Stat(r.config.BackupDir) + if err != nil { + if os.IsNotExist(err) { + // Backup path doesn't exist yet - nothing to clean + return nil + } + return fmt.Errorf("failed to access backup directory %s: %w", r.config.BackupDir, err) + } + + // Make sure it's a directory + if !backupDirInfo.IsDir() { + return fmt.Errorf("backup path %s is not a directory", r.config.BackupDir) + } + + // Since we don't have ReadDir in our interface, we'll use a different approach + // We'll have the repository implementation check known date directories directly + + // Get current and past dates to check (e.g., past 90 days) + var datesToCheck []string + for i := 0; i < 90; i++ { + date := time.Now().AddDate(0, 0, -i) + dateStr := date.Format("2006-01-02") + datesToCheck = append(datesToCheck, dateStr) + } + + // Check each possible date directory + for _, dateStr := range datesToCheck { + dirPath := filepath.Join(r.config.BackupDir, dateStr) + + // Check if directory exists + dirInfo, err := r.fs.Stat(dirPath) + if err != nil { + if os.IsNotExist(err) { + // Directory doesn't exist, skip + continue + } + // Other error, log and continue + fmt.Printf("Warning: Error checking directory %s: %v\n", dirPath, err) + continue + } + + // Skip if not a directory + if !dirInfo.IsDir() { + continue + } + + // Parse date from directory name + dirDate, err := time.Parse("2006-01-02", dateStr) + if err != nil { + // Should never happen since we're generating these dates + continue + } + + // If directory is older than cutoff, remove it + if dirDate.Before(before) { + if err := r.fs.RemoveAll(dirPath); err != nil { + return fmt.Errorf("failed to remove old backup directory %s: %w", dirPath, err) + } + } + } + + return nil +} + +// VerifyBackupDirectory ensures the backup directory exists and is writable +func (r *FileBackupRepository) VerifyBackupDirectory() error { + // Create backup directory if it doesn't exist + if err := r.fs.MkdirAll(r.config.BackupDir, 0755); err != nil { + return fmt.Errorf("failed to create backup directory %s: %w", r.config.BackupDir, err) + } + + // Check if directory is writable by writing a test file + testFile := filepath.Join(r.config.BackupDir, ".write_test") + if err := r.fs.WriteFile(testFile, []byte("test"), 0644); err != nil { + return fmt.Errorf("backup directory %s is not writable: %w", r.config.BackupDir, err) + } + + // Clean up test file + if err := r.fs.Remove(testFile); err != nil { + // Log warning but don't fail the operation since this is just cleanup + fmt.Printf("Warning: Failed to remove test file %s: %v\n", testFile, err) + } + return nil +} + +// GetBackupConfig retrieves the current backup configuration +func (r *FileBackupRepository) GetBackupConfig() (*model.BackupConfig, error) { + // Return a copy to prevent direct modification + config := *r.config + return &config, nil +} + +// SetBackupConfig updates the backup configuration +func (r *FileBackupRepository) SetBackupConfig(config model.BackupConfig) error { + // Update the configuration + r.config.Enabled = config.Enabled + r.config.BackupDir = config.BackupDir + + // If enabling backups, verify the directory exists and is writable + if r.config.Enabled { + return r.VerifyBackupDirectory() + } + + return nil +} diff --git a/pkg/adapter/secondary/file_dns_repository.go b/pkg/adapter/secondary/file_dns_repository.go new file mode 100644 index 0000000..382fdd1 --- /dev/null +++ b/pkg/adapter/secondary/file_dns_repository.go @@ -0,0 +1,190 @@ +// pkg/adapter/secondary/file_dns_repository.go +package secondary + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/interfaces" + "github.com/abbott/hardn/pkg/port/secondary" +) + +// FileDNSRepository implements DNSRepository using file operations +type FileDNSRepository struct { + fs interfaces.FileSystem + commander interfaces.Commander + osType string +} + +// NewFileDNSRepository creates a new FileDNSRepository +func NewFileDNSRepository( + fs interfaces.FileSystem, + commander interfaces.Commander, + osType string, +) secondary.DNSRepository { + return &FileDNSRepository{ + fs: fs, + commander: commander, + osType: osType, + } +} + +// SaveDNSConfig persists the DNS configuration +func (r *FileDNSRepository) SaveDNSConfig(config model.DNSConfig) error { + // Check if systemd-resolved is active + systemdActive := false + if _, err := r.commander.Execute("systemctl", "is-active", "systemd-resolved"); err == nil { + systemdActive = true + } + + // Check if resolvconf is installed + resolvconfInstalled := false + if _, err := r.commander.Execute("which", "resolvconf"); err == nil { + resolvconfInstalled = true + } + + if systemdActive { + return r.configureSystemdResolved(config) + } else if resolvconfInstalled { + return r.configureResolvconf(config) + } else { + return r.configureDirectResolv(config) + } +} + +// GetDNSConfig retrieves the current DNS configuration +func (r *FileDNSRepository) GetDNSConfig() (*model.DNSConfig, error) { + // Read /etc/resolv.conf to get current configuration + data, err := r.fs.ReadFile("/etc/resolv.conf") + if err != nil { + return nil, fmt.Errorf("failed to read resolv.conf: %w", err) + } + + config := model.DNSConfig{} + + // Parse file + lines := strings.Split(string(data), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + + directive := fields[0] + value := fields[1] + + switch directive { + case "nameserver": + config.Nameservers = append(config.Nameservers, value) + case "domain": + config.Domain = value + case "search": + config.Search = fields[1:] + } + } + + return &config, nil +} + +// configureSystemdResolved configures DNS using systemd-resolved +func (r *FileDNSRepository) configureSystemdResolved(config model.DNSConfig) error { + // Create resolved.conf content + var content strings.Builder + + content.WriteString("[Resolve]\n") + content.WriteString(fmt.Sprintf("DNS=%s\n", strings.Join(config.Nameservers, " "))) + + if config.Domain != "" { + content.WriteString(fmt.Sprintf("Domains=%s\n", config.Domain)) + } + + // Write resolved.conf + if err := r.fs.WriteFile("/etc/systemd/resolved.conf", []byte(content.String()), 0644); err != nil { + return fmt.Errorf("failed to write systemd-resolved config: %w", err) + } + + // Restart systemd-resolved + if _, err := r.commander.Execute("systemctl", "restart", "systemd-resolved"); err != nil { + return fmt.Errorf("failed to restart systemd-resolved: %w", err) + } + + return nil +} + +// configureResolvconf configures DNS using resolvconf +func (r *FileDNSRepository) configureResolvconf(config model.DNSConfig) error { + var content strings.Builder + + // Add domain if specified + if config.Domain != "" { + content.WriteString(fmt.Sprintf("domain %s\n", config.Domain)) + } + + // Add search domains + if len(config.Search) > 0 { + content.WriteString(fmt.Sprintf("search %s\n", strings.Join(config.Search, " "))) + } else if config.Domain != "" { + content.WriteString(fmt.Sprintf("search %s\n", config.Domain)) + } + + // Add nameservers + for _, nameserver := range config.Nameservers { + content.WriteString(fmt.Sprintf("nameserver %s\n", nameserver)) + } + + // Create resolvconf directory if it doesn't exist + resolvconfDir := "/etc/resolvconf/resolv.conf.d" + if err := r.fs.MkdirAll(resolvconfDir, 0755); err != nil { + return fmt.Errorf("failed to create resolvconf directory: %w", err) + } + + // Write head file + headPath := filepath.Join(resolvconfDir, "head") + if err := r.fs.WriteFile(headPath, []byte(content.String()), 0644); err != nil { + return fmt.Errorf("failed to write resolvconf head file: %w", err) + } + + // Update resolvconf + if _, err := r.commander.Execute("resolvconf", "-u"); err != nil { + return fmt.Errorf("failed to update resolvconf: %w", err) + } + + return nil +} + +// configureDirectResolv configures DNS by directly writing to resolv.conf +func (r *FileDNSRepository) configureDirectResolv(config model.DNSConfig) error { + var content strings.Builder + + // Add domain if specified + if config.Domain != "" { + content.WriteString(fmt.Sprintf("domain %s\n", config.Domain)) + } + + // Add search domains + if len(config.Search) > 0 { + content.WriteString(fmt.Sprintf("search %s\n", strings.Join(config.Search, " "))) + } else if config.Domain != "" { + content.WriteString(fmt.Sprintf("search %s\n", config.Domain)) + } + + // Add nameservers + for _, nameserver := range config.Nameservers { + content.WriteString(fmt.Sprintf("nameserver %s\n", nameserver)) + } + + // Write resolv.conf + if err := r.fs.WriteFile("/etc/resolv.conf", []byte(content.String()), 0644); err != nil { + return fmt.Errorf("failed to write resolv.conf: %w", err) + } + + return nil +} diff --git a/pkg/adapter/secondary/file_environment_repository.go b/pkg/adapter/secondary/file_environment_repository.go new file mode 100644 index 0000000..f634c7e --- /dev/null +++ b/pkg/adapter/secondary/file_environment_repository.go @@ -0,0 +1,150 @@ +// pkg/adapter/secondary/file_environment_repository.go +package secondary + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/interfaces" + "github.com/abbott/hardn/pkg/port/secondary" +) + +// FileEnvironmentRepository implements EnvironmentRepository using file operations +type FileEnvironmentRepository struct { + fs interfaces.FileSystem + commander interfaces.Commander +} + +// NewFileEnvironmentRepository creates a new FileEnvironmentRepository +func NewFileEnvironmentRepository( + fs interfaces.FileSystem, + commander interfaces.Commander, +) secondary.EnvironmentRepository { + return &FileEnvironmentRepository{ + fs: fs, + commander: commander, + } +} + +// SetupSudoPreservation configures sudo to preserve the HARDN_CONFIG environment variable +func (r *FileEnvironmentRepository) SetupSudoPreservation(username string) error { + // Check if username is empty + if username == "" { + return fmt.Errorf("username cannot be empty") + } + + // Ensure sudoers.d directory exists + sudoersDir := "/etc/sudoers.d" + if _, err := r.fs.Stat(sudoersDir); os.IsNotExist(err) { + return fmt.Errorf("sudoers.d directory does not exist; your system may not support sudo drop-in configurations") + } + + // Create/modify sudoers file for the user + sudoersFile := filepath.Join(sudoersDir, username) + + // Check if file already exists + var content string + fileInfo, err := r.fs.Stat(sudoersFile) + if err == nil && fileInfo != nil { + // Read existing content + data, err := r.fs.ReadFile(sudoersFile) + if err != nil { + return fmt.Errorf("failed to read existing sudoers file %s: %w", sudoersFile, err) + } + content = string(data) + + // Check if HARDN_CONFIG is already in the file + if strings.Contains(content, "env_keep += \"HARDN_CONFIG\"") { + return nil // Already configured + } + + // Append to existing content + content = strings.TrimSpace(content) + "\n" + } + + // env_keep directive + content += fmt.Sprintf("Defaults:%s env_keep += \"HARDN_CONFIG\"\n", username) + + // Create a temporary file for validation + tempDir := os.TempDir() + tempFile := filepath.Join(tempDir, "hardn_sudoers_temp") + if err := r.fs.WriteFile(tempFile, []byte(content), 0440); err != nil { + return fmt.Errorf("failed to create temporary sudoers file at %s: %w", tempFile, err) + } + + // Validate the sudoers file + _, err = r.commander.Execute("visudo", "-c", "-f", tempFile) + if err != nil { + // Clean up temp file + if err := r.fs.Remove(tempFile); err != nil { + // Log warning but don't fail the operation since this is just cleanup + fmt.Printf("Warning: Failed to remove test file %s: %v\n", tempFile, err) + } + return fmt.Errorf("invalid sudoers configuration: %w", err) + } + + // Clean up temp file + if err := r.fs.Remove(tempFile); err != nil { + // Log warning but don't fail the operation since this is just cleanup + fmt.Printf("Warning: Failed to remove test file %s: %v\n", tempFile, err) + } + + // Write the validated content to the actual sudoers file + if err := r.fs.WriteFile(sudoersFile, []byte(content), 0440); err != nil { + return fmt.Errorf("failed to write sudoers file %s: %w", sudoersFile, err) + } + + return nil +} + +// IsSudoPreservationEnabled checks if the HARDN_CONFIG environment variable is preserved in sudo +func (r *FileEnvironmentRepository) IsSudoPreservationEnabled(username string) (bool, error) { + // Check if username is empty + if username == "" { + return false, fmt.Errorf("username cannot be empty") + } + + // Check if sudoers file exists + sudoersFile := filepath.Join("/etc/sudoers.d", username) + fileInfo, err := r.fs.Stat(sudoersFile) + if err != nil || fileInfo == nil { + return false, nil // File doesn't exist, preservation not enabled + } + + // Read file content + data, err := r.fs.ReadFile(sudoersFile) + if err != nil { + return false, fmt.Errorf("failed to read sudoers file %s: %w", sudoersFile, err) + } + + // Check if HARDN_CONFIG is preserved + return strings.Contains(string(data), "env_keep += \"HARDN_CONFIG\""), nil +} + +// GetEnvironmentConfig retrieves the current environment configuration +func (r *FileEnvironmentRepository) GetEnvironmentConfig() (*model.EnvironmentConfig, error) { + config := &model.EnvironmentConfig{ + ConfigPath: os.Getenv("HARDN_CONFIG"), + PreserveSudo: false, // Will be determined below + } + + // Get username + username := os.Getenv("SUDO_USER") + if username == "" { + username = os.Getenv("USER") + } + config.Username = username + + // Check sudo preservation if username is not empty + if username != "" { + isEnabled, err := r.IsSudoPreservationEnabled(username) + if err == nil { + config.PreserveSudo = isEnabled + } + } + + return config, nil +} diff --git a/pkg/adapter/secondary/file_logs_repository.go b/pkg/adapter/secondary/file_logs_repository.go new file mode 100644 index 0000000..81a897a --- /dev/null +++ b/pkg/adapter/secondary/file_logs_repository.go @@ -0,0 +1,99 @@ +// pkg/adapter/secondary/file_logs_repository.go +package secondary + +import ( + "bufio" + "fmt" + "strings" + + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/interfaces" + "github.com/abbott/hardn/pkg/port/secondary" +) + +// FileLogsRepository implements LogsRepository using file operations +type FileLogsRepository struct { + fs interfaces.FileSystem + logFilePath string +} + +// NewFileLogsRepository creates a new FileLogsRepository +func NewFileLogsRepository( + fs interfaces.FileSystem, + logFilePath string, +) secondary.LogsRepository { + return &FileLogsRepository{ + fs: fs, + logFilePath: logFilePath, + } +} + +// GetLogs retrieves logs from the configured log file +func (r *FileLogsRepository) GetLogs() ([]model.LogEntry, error) { + // Check if log file exists + _, err := r.fs.Stat(r.logFilePath) + if err != nil { + return nil, fmt.Errorf("failed to access log file: %w", err) + } + + // Read log file + data, err := r.fs.ReadFile(r.logFilePath) + if err != nil { + return nil, fmt.Errorf("failed to read log file: %w", err) + } + + // Parse log entries + var entries []model.LogEntry + scanner := bufio.NewScanner(strings.NewReader(string(data))) + for scanner.Scan() { + line := scanner.Text() + + // Parse log line (assuming a simple format of "TIME LEVEL: MESSAGE") + parts := strings.SplitN(line, " ", 3) + if len(parts) >= 3 { + // Extract time and message parts + timeStr := parts[0] + levelStr := strings.TrimSuffix(parts[1], ":") + messageStr := parts[2] + + // Create log entry + entry := model.LogEntry{ + Time: timeStr, + Level: levelStr, + Message: messageStr, + } + + entries = append(entries, entry) + } + } + + return entries, nil +} + +// GetLogConfig retrieves the current log configuration +func (r *FileLogsRepository) GetLogConfig() (*model.LogsConfig, error) { + return &model.LogsConfig{ + LogFilePath: r.logFilePath, + }, nil +} + +// PrintLogs prints the logs to the console +func (r *FileLogsRepository) PrintLogs() error { + // Check if log file exists + _, err := r.fs.Stat(r.logFilePath) + if err != nil { + return fmt.Errorf("failed to access log file %s: %w", r.logFilePath, err) + } + + // Read log file + data, err := r.fs.ReadFile(r.logFilePath) + if err != nil { + return fmt.Errorf("failed to read log file %s: %w", r.logFilePath, err) + } + + // Print log contents + fmt.Printf("\n# Contents of %s:\n\n", r.logFilePath) + fmt.Println(string(data)) + + return nil +} diff --git a/pkg/adapter/secondary/file_ssh_repository.go b/pkg/adapter/secondary/file_ssh_repository.go new file mode 100644 index 0000000..0cdaf53 --- /dev/null +++ b/pkg/adapter/secondary/file_ssh_repository.go @@ -0,0 +1,217 @@ +// pkg/adapter/secondary/file_ssh_repository.go +package secondary + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/interfaces" + "github.com/abbott/hardn/pkg/port/secondary" +) + +// FileSSHRepository implements SSHRepository using file operations +type FileSSHRepository struct { + fs interfaces.FileSystem + commander interfaces.Commander + osType string +} + +// NewFileSSHRepository creates a new FileSSHRepository +func NewFileSSHRepository( + fs interfaces.FileSystem, + commander interfaces.Commander, + osType string, +) secondary.SSHRepository { + return &FileSSHRepository{ + fs: fs, + commander: commander, + osType: osType, + } +} + +// SaveSSHConfig writes the SSH configuration to the appropriate file +func (r *FileSSHRepository) SaveSSHConfig(config model.SSHConfig) error { + // Determine config file path based on OS type + configFile := config.ConfigFilePath + if configFile == "" { + if r.osType == "alpine" { + configFile = "/etc/ssh/sshd_config" + } else { + configFile = "/etc/ssh/sshd_config.d/hardn.conf" + } + } + + // Format SSH configuration content + var content strings.Builder + + content.WriteString("# SSH configuration managed by Hardn\n\n") + content.WriteString("Protocol 2\n") + content.WriteString("StrictModes yes\n\n") + + // Port configuration + content.WriteString(fmt.Sprintf("Port %d\n", config.Port)) + + // Listen addresses + for _, addr := range config.ListenAddresses { + content.WriteString(fmt.Sprintf("ListenAddress %s\n", addr)) + } + content.WriteString("\n") + + // Authentication methods + if len(config.AuthMethods) > 0 { + content.WriteString(fmt.Sprintf("AuthenticationMethods %s\n", strings.Join(config.AuthMethods, ","))) + } else { + content.WriteString("AuthenticationMethods publickey\n") + } + content.WriteString("PubkeyAuthentication yes\n\n") + + // Root login setting + rootLoginValue := "no" + if config.PermitRootLogin { + rootLoginValue = "yes" + } + content.WriteString(fmt.Sprintf("PermitRootLogin %s\n", rootLoginValue)) + + // Allowed users + if len(config.AllowedUsers) > 0 { + content.WriteString(fmt.Sprintf("AllowUsers %s\n", strings.Join(config.AllowedUsers, " "))) + } + content.WriteString("\n") + + // Password authentication + content.WriteString("PasswordAuthentication no\n") + content.WriteString("PermitEmptyPasswords no\n\n") + + // Authorized keys + if len(config.KeyPaths) > 0 { + for _, path := range config.KeyPaths { + content.WriteString(fmt.Sprintf("AuthorizedKeysFile %s\n", path)) + } + } else { + content.WriteString("AuthorizedKeysFile .ssh/authorized_keys\n") + } + + // Create directory if it doesn't exist + dir := filepath.Dir(configFile) + if err := r.fs.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create directory for SSH config: %w", err) + } + + // Write the configuration file + if err := r.fs.WriteFile(configFile, []byte(content.String()), 0644); err != nil { + return fmt.Errorf("failed to write SSH config file: %w", err) + } + + // Restart SSH service based on OS type + var cmd string + var args []string + + if r.osType == "alpine" { + cmd = "rc-service" + args = []string{"sshd", "restart"} + } else { + cmd = "systemctl" + args = []string{"restart", "ssh"} + } + + if _, err := r.commander.Execute(cmd, args...); err != nil { + return fmt.Errorf("failed to restart SSH service: %w", err) + } + + return nil +} + +// GetSSHConfig reads the current SSH configuration +func (r *FileSSHRepository) GetSSHConfig() (*model.SSHConfig, error) { + // Implementation to parse SSH config file and return configuration + // ... (implementation details omitted for brevity) + return &model.SSHConfig{Port: 22}, nil +} + +// DisableRootAccess disables SSH access for the root user +func (r *FileSSHRepository) DisableRootAccess() error { + // Get current config + config, err := r.GetSSHConfig() + if err != nil { + return err + } + + // Disable root login + config.PermitRootLogin = false + + // Remove 'root' from AllowedUsers + var newAllowedUsers []string + for _, user := range config.AllowedUsers { + if user != "root" { + newAllowedUsers = append(newAllowedUsers, user) + } + } + config.AllowedUsers = newAllowedUsers + + // Save the modified configuration + return r.SaveSSHConfig(*config) +} + +// AddAuthorizedKey adds an SSH public key to a user's authorized_keys +func (r *FileSSHRepository) AddAuthorizedKey(username string, publicKey string) error { + var homeDir string + var sshDir string + var authKeysFile string + + // Determine paths based on user + if username == "root" { + homeDir = "/root" + } else { + homeDir = fmt.Sprintf("/home/%s", username) + } + + sshDir = filepath.Join(homeDir, ".ssh") + authKeysFile = filepath.Join(sshDir, "authorized_keys") + + // Create .ssh directory if it doesn't exist + if err := r.fs.MkdirAll(sshDir, 0700); err != nil { + return fmt.Errorf("failed to create SSH directory for user %s: %w", username, err) + } + + // Check if authorized_keys file exists + fileInfo, err := r.fs.Stat(authKeysFile) + var content string + + if err == nil && fileInfo != nil { + // File exists, read content and append + data, err := r.fs.ReadFile(authKeysFile) + if err != nil { + return fmt.Errorf("failed to read authorized_keys file: %w", err) + } + + content = string(data) + // Check if key already exists + if strings.Contains(content, publicKey) { + return nil // Key already exists + } + + // Ensure file ends with newline + if !strings.HasSuffix(content, "\n") { + content += "\n" + } + content += publicKey + "\n" + } else { + // File doesn't exist, create new + content = publicKey + "\n" + } + + // Write the file + if err := r.fs.WriteFile(authKeysFile, []byte(content), 0600); err != nil { + return fmt.Errorf("failed to write authorized_keys file: %w", err) + } + + // Set correct ownership + chownCmd := fmt.Sprintf("chown -R %s:%s %s", username, username, sshDir) + if _, err := r.commander.Execute("sh", "-c", chownCmd); err != nil { + return fmt.Errorf("failed to set ownership on SSH directory: %w", err) + } + + return nil +} diff --git a/pkg/adapter/secondary/os_package_respository.go b/pkg/adapter/secondary/os_package_respository.go new file mode 100644 index 0000000..fa0eddc --- /dev/null +++ b/pkg/adapter/secondary/os_package_respository.go @@ -0,0 +1,332 @@ +// pkg/adapter/secondary/os_package_repository.go +package secondary + +import ( + "fmt" + "strings" + + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/interfaces" + "github.com/abbott/hardn/pkg/port/secondary" +) + +// OSPackageRepository implements PackageRepository using OS operations +type OSPackageRepository struct { + fs interfaces.FileSystem + commander interfaces.Commander + osType string + osVersion string + osCodename string + isProxmox bool + config *model.PackageSources +} + +// NewOSPackageRepository creates a new OSPackageRepository +func NewOSPackageRepository( + fs interfaces.FileSystem, + commander interfaces.Commander, + osType string, + osVersion string, + osCodename string, + isProxmox bool, + config *model.PackageSources, +) secondary.PackageRepository { + return &OSPackageRepository{ + fs: fs, + commander: commander, + osType: osType, + osVersion: osVersion, + osCodename: osCodename, + isProxmox: isProxmox, + config: config, + } +} + +// InstallPackages installs packages based on the request +func (r *OSPackageRepository) InstallPackages(request model.PackageInstallRequest) error { + if len(request.Packages) == 0 && len(request.PipPackages) == 0 { + return nil + } + + if request.IsPython { + return r.installPythonPackages(request) + } + + // Standard Linux packages installation + var args []string + + if r.osType == "alpine" { + args = append([]string{"add", "--no-cache"}, request.Packages...) + _, err := r.commander.Execute("apk", args...) + if err != nil { + return fmt.Errorf("failed to install Alpine packages: %w", err) + } + } else { + // Hold Proxmox packages if necessary + if r.isProxmox { + if err := r.holdProxmoxPackages(); err != nil { + return err + } + } + + // Update package lists + _, err := r.commander.Execute("apt-get", "update") + if err != nil { + return fmt.Errorf("failed to update package lists: %w", err) + } + + // Install packages + args = append([]string{"install", "--yes"}, request.Packages...) + _, err = r.commander.Execute("apt-get", args...) + if err != nil { + return fmt.Errorf("failed to install Debian/Ubuntu packages: %w", err) + } + + // Clean up - check errors but don't fail the entire installation for cleanup issues + if _, err := r.commander.Execute("apt-get", "autoremove", "--yes"); err != nil { + fmt.Printf("Warning: Failed to autoremove packages: %v\n", err) + } + if _, err := r.commander.Execute("apt-get", "clean"); err != nil { + fmt.Printf("Warning: Failed to clean apt cache: %v\n", err) + } + if _, err := r.commander.Execute("rm", "-rf", "/var/lib/apt/lists/*"); err != nil { + fmt.Printf("Warning: Failed to remove apt lists: %v\n", err) + } + + // Unhold Proxmox packages + if r.isProxmox { + if err := r.unholdProxmoxPackages(); err != nil { + fmt.Printf("Warning: Failed to unhold Proxmox packages: %v\n", err) + } + } + } + + return nil +} + +// installPythonPackages handles Python package installation +func (r *OSPackageRepository) installPythonPackages(request model.PackageInstallRequest) error { + if r.osType == "alpine" { + // Use Alpine's package manager for Python packages + if len(request.Packages) > 0 { + args := append([]string{"add", "--no-cache"}, request.Packages...) + _, err := r.commander.Execute("apk", args...) + if err != nil { + return fmt.Errorf("failed to install Alpine Python packages: %w", err) + } + } + } else { + // For Debian/Ubuntu systems + if len(request.Packages) > 0 { + // Install system packages first + _, err := r.commander.Execute("apt-get", "update") + if err != nil { + return fmt.Errorf("failed to update package lists for Python installation: %w", err) + } + + args := append([]string{"install", "--yes"}, request.Packages...) + _, err = r.commander.Execute("apt-get", args...) + if err != nil { + return fmt.Errorf("failed to install Python system packages: %w", err) + } + } + } + + // Handle pip/UV packages + if len(request.PipPackages) > 0 { + if request.UseUv { + // Check if UV is installed + _, err := r.commander.Execute("which", "uv") + if err != nil { + // Install UV + _, err = r.commander.Execute("pip3", "install", "uv") + if err != nil { + return fmt.Errorf("failed to install UV package manager: %w", err) + } + } + + // Install packages using UV + args := append([]string{"pip", "install"}, request.PipPackages...) + _, err = r.commander.Execute("uv", args...) + if err != nil { + return fmt.Errorf("failed to install Python pip packages with UV: %w", err) + } + } else { + // Use standard pip + args := append([]string{"install"}, request.PipPackages...) + _, err := r.commander.Execute("pip3", args...) + if err != nil { + return fmt.Errorf("failed to install Python pip packages: %w", err) + } + } + } + + return nil +} + +// UpdatePackageSources updates package sources configuration +func (r *OSPackageRepository) UpdatePackageSources(sources model.PackageSources) error { + if r.osType == "alpine" { + return r.updateAlpineSources(sources) + } + + // Debian/Ubuntu + return r.updateDebianSources(sources) +} + +// updateAlpineSources updates Alpine Linux repository configuration +func (r *OSPackageRepository) updateAlpineSources(sources model.PackageSources) error { + // Format Alpine version for repositories + versionPrefix := r.osVersion + if idx := strings.LastIndex(versionPrefix, "."); idx != -1 { + versionPrefix = versionPrefix[:idx] + } + + // Create Alpine repository file content + content := fmt.Sprintf(`# Main repositories +https://dl-cdn.alpinelinux.org/alpine/v%s/main +https://dl-cdn.alpinelinux.org/alpine/v%s/community + +# Security updates +https://dl-cdn.alpinelinux.org/alpine/v%s/main +https://dl-cdn.alpinelinux.org/alpine/v%s/community +`, versionPrefix, versionPrefix, versionPrefix, versionPrefix) + + // testing repo if enabled + if sources.AlpineTestingRepo { + content += ` +# Testing repository (use with caution) +https://dl-cdn.alpinelinux.org/alpine/edge/testing +` + } + + // Write the file + if err := r.fs.WriteFile("/etc/apk/repositories", []byte(content), 0644); err != nil { + return fmt.Errorf("failed to write Alpine repositories: %w", err) + } + + // Update package index + _, err := r.commander.Execute("apk", "update") + if err != nil { + return fmt.Errorf("failed to update Alpine package index: %w", err) + } + + return nil +} + +// updateDebianSources updates Debian/Ubuntu repository configuration +func (r *OSPackageRepository) updateDebianSources(sources model.PackageSources) error { + // Prepare content by replacing CODENAME placeholder + var content strings.Builder + for _, repo := range sources.DebianRepos { + content.WriteString(strings.ReplaceAll(repo, "CODENAME", r.osCodename)) + content.WriteString("\n") + } + + // Backup original file + backupFile := "/etc/apt/sources.list.bak" + originalData, err := r.fs.ReadFile("/etc/apt/sources.list") + if err == nil { + if err := r.fs.WriteFile(backupFile, originalData, 0644); err != nil { + fmt.Printf("Warning: Failed to create backup of sources.list: %v\n", err) + } + } + + // Write the file + if err := r.fs.WriteFile("/etc/apt/sources.list", []byte(content.String()), 0644); err != nil { + return fmt.Errorf("failed to write Debian/Ubuntu sources list: %w", err) + } + + return nil +} + +// UpdateProxmoxSources updates Proxmox-specific sources +func (r *OSPackageRepository) UpdateProxmoxSources(sources model.PackageSources) error { + if !r.isProxmox { + return nil + } + + // Create directory if it doesn't exist + if err := r.fs.MkdirAll("/etc/apt/sources.list.d", 0755); err != nil { + return fmt.Errorf("failed to create sources.list.d directory: %w", err) + } + + // Write Ceph repository + var cephContent strings.Builder + for _, repo := range sources.ProxmoxCephRepo { + cephContent.WriteString(strings.ReplaceAll(repo, "CODENAME", r.osCodename)) + cephContent.WriteString("\n") + } + + if err := r.fs.WriteFile("/etc/apt/sources.list.d/ceph.list", []byte(cephContent.String()), 0644); err != nil { + return fmt.Errorf("failed to write Proxmox Ceph repository: %w", err) + } + + // Write Enterprise repository + var enterpriseContent strings.Builder + for _, repo := range sources.ProxmoxEnterpriseRepo { + enterpriseContent.WriteString(strings.ReplaceAll(repo, "CODENAME", r.osCodename)) + enterpriseContent.WriteString("\n") + } + + if err := r.fs.WriteFile("/etc/apt/sources.list.d/pve-enterprise.list", []byte(enterpriseContent.String()), 0644); err != nil { + return fmt.Errorf("failed to write Proxmox Enterprise repository: %w", err) + } + + return nil +} + +// IsPackageInstalled checks if a package is installed +func (r *OSPackageRepository) IsPackageInstalled(packageName string) (bool, error) { + if r.osType == "alpine" { + // Alpine method + _, err := r.commander.Execute("apk", "info", "-e", packageName) + if err != nil { + return false, nil // Package not installed + } + return true, nil + } else { + // Debian/Ubuntu method + _, err := r.commander.Execute("dpkg", "-l", packageName) + if err != nil { + return false, nil // Package not installed + } + return true, nil + } +} + +// GetPackageSources retrieves the current package sources configuration +func (r *OSPackageRepository) GetPackageSources() (*model.PackageSources, error) { + // Return the injected configuration + return r.config, nil +} + +// holdProxmoxPackages holds Proxmox packages to prevent accidental removal +func (r *OSPackageRepository) holdProxmoxPackages() error { + packages := []string{"proxmox-archive-keyring", "proxmox-backup-client", "proxmox-ve", "pve-kernel"} + + for _, pkg := range packages { + _, err := r.commander.Execute("apt-mark", "hold", pkg) + if err != nil { + // Non-fatal, just log and continue + fmt.Printf("Warning: Failed to hold package %s: %v\n", pkg, err) + } + } + + return nil +} + +// unholdProxmoxPackages releases held Proxmox packages +func (r *OSPackageRepository) unholdProxmoxPackages() error { + packages := []string{"proxmox-archive-keyring", "proxmox-backup-client", "proxmox-ve", "pve-kernel"} + + for _, pkg := range packages { + _, err := r.commander.Execute("apt-mark", "unhold", pkg) + if err != nil { + // Non-fatal, just log and continue + fmt.Printf("Warning: Failed to unhold package %s: %v\n", pkg, err) + } + } + + return nil +} diff --git a/pkg/adapter/secondary/os_user_repository.go b/pkg/adapter/secondary/os_user_repository.go new file mode 100644 index 0000000..0bd26e0 --- /dev/null +++ b/pkg/adapter/secondary/os_user_repository.go @@ -0,0 +1,208 @@ +// pkg/adapter/secondary/os_user_repository.go +package secondary + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/interfaces" + "github.com/abbott/hardn/pkg/port/secondary" +) + +// OSUserRepository implements UserRepository using OS operations +type OSUserRepository struct { + fs interfaces.FileSystem + commander interfaces.Commander + osType string // e.g., "alpine", "debian", etc. +} + +// NewOSUserRepository creates a new OSUserRepository +func NewOSUserRepository( + fs interfaces.FileSystem, + commander interfaces.Commander, + osType string, +) secondary.UserRepository { + return &OSUserRepository{ + fs: fs, + commander: commander, + osType: osType, + } +} + +// UserExists checks if a user exists +func (r *OSUserRepository) UserExists(username string) (bool, error) { + _, err := r.commander.Execute("id", username) + if err != nil { + // Command failed, user probably doesn't exist + return false, nil + } + return true, nil +} + +// CreateUser creates a new system user +func (r *OSUserRepository) CreateUser(user model.User) error { + // Check if user already exists + exists, err := r.UserExists(user.Username) + if err != nil { + return fmt.Errorf("error checking user existence: %w", err) + } + if exists { + return fmt.Errorf("user %s already exists", user.Username) + } + + // Create the user based on OS type + if r.osType == "alpine" { + // Alpine user creation + _, err := r.commander.Execute("adduser", "-D", "-g", "", user.Username) + if err != nil { + return fmt.Errorf("failed to create user %s on Alpine: %w", user.Username, err) + } + + // Add to wheel group for sudo + if user.HasSudo { + _, err := r.commander.Execute("addgroup", user.Username, "wheel") + if err != nil { + return fmt.Errorf("failed to add user %s to wheel group: %w", user.Username, err) + } + } + } else { + // Debian/Ubuntu user creation + _, err := r.commander.Execute("adduser", "--disabled-password", "--gecos", "", user.Username) + if err != nil { + return fmt.Errorf("failed to create user %s on Debian/Ubuntu: %w", user.Username, err) + } + + // Add to sudo group + if user.HasSudo { + _, err := r.commander.Execute("usermod", "-aG", "sudo", user.Username) + if err != nil { + return fmt.Errorf("failed to add user %s to sudo group: %w", user.Username, err) + } + } + } + + // Set up SSH keys + for _, key := range user.SshKeys { + if err := r.AddSSHKey(user.Username, key); err != nil { + return err + } + } + + // Configure sudo if needed + if user.HasSudo { + if err := r.ConfigureSudo(user.Username, user.SudoNoPassword); err != nil { + return err + } + } + + return nil +} + +// GetUser retrieves user information +func (r *OSUserRepository) GetUser(username string) (*model.User, error) { + // Implementation... + return nil, nil +} + +// AddSSHKey adds an SSH key for a user +func (r *OSUserRepository) AddSSHKey(username, publicKey string) error { + // Common path for SSH keys + var sshDir string + var homePath string + + if r.osType == "alpine" { + homePath = fmt.Sprintf("/home/%s", username) + sshDir = filepath.Join(homePath, ".ssh") + + // Create .ssh directory if it doesn't exist + if err := r.fs.MkdirAll(sshDir, 0700); err != nil { + return fmt.Errorf("failed to create SSH directory for user %s: %w", username, err) + } + + // Create authorized_keys file if it doesn't exist + authKeysPath := filepath.Join(sshDir, "authorized_keys") + authKeysExists := false + _, err := r.fs.Stat(authKeysPath) + if err == nil { + authKeysExists = true + } + + if authKeysExists { + // Read existing keys + existingContent, err := r.fs.ReadFile(authKeysPath) + if err != nil { + return fmt.Errorf("failed to read authorized_keys: %w", err) + } + + // Append new key if not already present + if !strings.Contains(string(existingContent), publicKey) { + newContent := string(existingContent) + if !strings.HasSuffix(newContent, "\n") { + newContent += "\n" + } + newContent += publicKey + "\n" + + if err := r.fs.WriteFile(authKeysPath, []byte(newContent), 0600); err != nil { + return fmt.Errorf("failed to update authorized_keys: %w", err) + } + } + } else { + // Create new file + if err := r.fs.WriteFile(authKeysPath, []byte(publicKey+"\n"), 0600); err != nil { + return fmt.Errorf("failed to create authorized_keys: %w", err) + } + } + + // Set correct ownership + _, err = r.commander.Execute("chown", "-R", fmt.Sprintf("%s:%s", username, username), sshDir) + if err != nil { + return fmt.Errorf("failed to set ownership for SSH directory: %w", err) + } + } else { + // Debian/Ubuntu - use su to run commands as the user + _, err := r.commander.Execute("su", "-", username, "-c", "mkdir -p ~/.ssh && chmod 700 ~/.ssh") + if err != nil { + return fmt.Errorf("failed to create SSH directory for user %s: %w", username, err) + } + + // Add the key using a here-document style input + _, err = r.commander.ExecuteWithInput(publicKey+"\n", "su", "-", username, "-c", "cat >> ~/.ssh/authorized_keys") + if err != nil { + return fmt.Errorf("failed to add SSH key for user %s: %w", username, err) + } + + _, err = r.commander.Execute("su", "-", username, "-c", "chmod 600 ~/.ssh/authorized_keys") + if err != nil { + return fmt.Errorf("failed to set permissions for authorized_keys: %w", err) + } + } + + return nil +} + +// ConfigureSudo configures sudo access for a user +func (r *OSUserRepository) ConfigureSudo(username string, noPassword bool) error { + // Create sudoers directory if needed + sudoersDir := "/etc/sudoers.d" + if err := r.fs.MkdirAll(sudoersDir, 0755); err != nil { + return fmt.Errorf("failed to create sudoers directory: %w", err) + } + + // Create user sudoers file + sudoersFile := filepath.Join(sudoersDir, username) + + var sudoersContent string + if noPassword { + sudoersContent = fmt.Sprintf("%s ALL=(ALL) NOPASSWD: ALL\n", username) + } else { + sudoersContent = fmt.Sprintf("%s ALL=(ALL) ALL\n", username) + } + + if err := r.fs.WriteFile(sudoersFile, []byte(sudoersContent), 0440); err != nil { + return fmt.Errorf("failed to write sudoers file: %w", err) + } + + return nil +} diff --git a/pkg/adapter/secondary/ufw_firewall_repository.go b/pkg/adapter/secondary/ufw_firewall_repository.go new file mode 100644 index 0000000..6c479ac --- /dev/null +++ b/pkg/adapter/secondary/ufw_firewall_repository.go @@ -0,0 +1,274 @@ +// pkg/adapter/secondary/ufw_firewall_repository.go +package secondary + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/interfaces" + "github.com/abbott/hardn/pkg/port/secondary" +) + +// UFWFirewallRepository implements FirewallRepository using UFW +type UFWFirewallRepository struct { + fs interfaces.FileSystem + commander interfaces.Commander +} + +// NewUFWFirewallRepository creates a new UFWFirewallRepository +func NewUFWFirewallRepository( + fs interfaces.FileSystem, + commander interfaces.Commander, +) secondary.FirewallRepository { + return &UFWFirewallRepository{ + fs: fs, + commander: commander, + } +} + +// IsUFWInstalled checks if UFW is installed +func (r *UFWFirewallRepository) IsUFWInstalled() bool { + _, err := r.commander.Execute("which", "ufw") + return err == nil +} + +// pkg/adapter/secondary/ufw_firewall_repository.go +// Add this method to the UFWFirewallRepository struct + +// GetFirewallStatus retrieves the current status of the firewall +func (r *UFWFirewallRepository) GetFirewallStatus() (bool, bool, bool, []string, error) { + // Check if UFW is installed + _, err := r.commander.Execute("which", "ufw") + isInstalled := (err == nil) + + // Default values if not installed + isEnabled := false + isConfigured := false + var rules []string + + if isInstalled { + // Check if UFW is enabled + statusOutput, err := r.commander.Execute("ufw", "status") + if err == nil { + statusText := string(statusOutput) + isEnabled = strings.Contains(statusText, "Status: active") + + // Extract rules (skip header lines) + lines := strings.Split(statusText, "\n") + ruleSection := false + for _, line := range lines { + line = strings.TrimSpace(line) + + // Skip empty lines + if line == "" { + continue + } + + // Skip header lines + if strings.Contains(line, "Status:") || + strings.Contains(line, "Logging:") || + strings.Contains(line, "Default:") || + strings.Contains(line, "New profiles:") || + strings.Contains(line, "To Action From") { + continue + } + + // Check if we've reached the rule section + if strings.Contains(line, "--") { + ruleSection = true + continue + } + + // Add rule lines + if ruleSection && line != "" { + rules = append(rules, line) + } + } + + // Check if we have default policies configured + isConfigured = strings.Contains(statusText, "deny (incoming)") && + strings.Contains(statusText, "allow (outgoing)") + } + } + + return isInstalled, isEnabled, isConfigured, rules, nil +} + +// SaveFirewallConfig applies the specified firewall configuration +func (r *UFWFirewallRepository) SaveFirewallConfig(config model.FirewallConfig) error { + // Ensure UFW is installed + if !r.IsUFWInstalled() { + return fmt.Errorf("UFW firewall is not installed") + } + + // Set default policies + if _, err := r.commander.Execute("ufw", "default", config.DefaultIncoming, "incoming"); err != nil { + return fmt.Errorf("failed to set incoming policy: %w", err) + } + + if _, err := r.commander.Execute("ufw", "default", config.DefaultOutgoing, "outgoing"); err != nil { + return fmt.Errorf("failed to set outgoing policy: %w", err) + } + + // Reset rules (disable and enable later) + if _, err := r.commander.Execute("ufw", "disable"); err != nil { + return fmt.Errorf("failed to disable UFW: %w", err) + } + + // Reset rules + if _, err := r.commander.Execute("ufw", "reset"); err != nil { + return fmt.Errorf("failed to reset UFW rules: %w", err) + } + + // Apply application profiles + if err := r.applyAppProfiles(config.ApplicationProfiles); err != nil { + return err + } + + // Add rules + for _, rule := range config.Rules { + if err := r.AddRule(rule); err != nil { + return err + } + } + + // Enable firewall if configured + if config.Enabled { + if err := r.EnableFirewall(); err != nil { + return err + } + } + + return nil +} + +// GetFirewallConfig retrieves the current firewall configuration +func (r *UFWFirewallRepository) GetFirewallConfig() (*model.FirewallConfig, error) { + // This would parse the output of 'ufw status verbose' + // Implementation details omitted for brevity + return &model.FirewallConfig{ + Enabled: true, + DefaultIncoming: "deny", + DefaultOutgoing: "allow", + }, nil +} + +// AddRule adds a firewall rule +func (r *UFWFirewallRepository) AddRule(rule model.FirewallRule) error { + var args []string + + // Build command arguments + args = append(args, rule.Action) + + // Add port specification + portSpec := fmt.Sprintf("%d/%s", rule.Port, rule.Protocol) + args = append(args, portSpec) + + // Add source IP if specified + if rule.SourceIP != "" { + args = append(args, "from", rule.SourceIP) + } + + // Add description if specified + if rule.Description != "" { + args = append(args, "comment", rule.Description) + } + + // Execute command + if _, err := r.commander.Execute("ufw", args...); err != nil { + return fmt.Errorf("failed to add rule %s %s: %w", rule.Action, portSpec, err) + } + + return nil +} + +// RemoveRule removes a firewall rule +func (r *UFWFirewallRepository) RemoveRule(rule model.FirewallRule) error { + var args []string + + // Build command arguments + args = append(args, "delete", rule.Action) + + // Add port specification + portSpec := fmt.Sprintf("%d/%s", rule.Port, rule.Protocol) + args = append(args, portSpec) + + // Add source IP if specified + if rule.SourceIP != "" { + args = append(args, "from", rule.SourceIP) + } + + // Execute command + if _, err := r.commander.Execute("ufw", args...); err != nil { + return fmt.Errorf("failed to remove rule %s %s: %w", rule.Action, portSpec, err) + } + + return nil +} + +// AddProfile adds a firewall application profile +func (r *UFWFirewallRepository) AddProfile(profile model.FirewallProfile) error { + // Apply a single profile + return r.applyAppProfiles([]model.FirewallProfile{profile}) +} + +// applyAppProfiles applies firewall application profiles +func (r *UFWFirewallRepository) applyAppProfiles(profiles []model.FirewallProfile) error { + if len(profiles) == 0 { + return nil + } + + // Create applications directory if it doesn't exist + appsDir := "/etc/ufw/applications.d" + if err := r.fs.MkdirAll(appsDir, 0755); err != nil { + return fmt.Errorf("failed to create UFW applications directory: %w", err) + } + + // Create profile file + profilesPath := filepath.Join(appsDir, "hardn") + + var content strings.Builder + for _, profile := range profiles { + content.WriteString(fmt.Sprintf("[%s]\n", profile.Name)) + content.WriteString(fmt.Sprintf("title=%s\n", profile.Title)) + content.WriteString(fmt.Sprintf("description=%s\n", profile.Description)) + content.WriteString(fmt.Sprintf("ports=%s\n\n", strings.Join(profile.Ports, ","))) + } + + // Write profiles file + if err := r.fs.WriteFile(profilesPath, []byte(content.String()), 0644); err != nil { + return fmt.Errorf("failed to write UFW application profiles: %w", err) + } + + // Apply each profile + for _, profile := range profiles { + args := []string{"allow", "from", "any", "to", "any", "app", profile.Name} + if _, err := r.commander.Execute("ufw", args...); err != nil { + return fmt.Errorf("failed to apply profile %s: %w", profile.Name, err) + } + } + + return nil +} + +// EnableFirewall enables the firewall +func (r *UFWFirewallRepository) EnableFirewall() error { + // Use non-interactive mode + // The 'yes | ufw enable' approach is replaced with a direct command + if _, err := r.commander.Execute("sh", "-c", "yes | ufw enable"); err != nil { + return fmt.Errorf("failed to enable UFW: %w", err) + } + + return nil +} + +// DisableFirewall disables the firewall +func (r *UFWFirewallRepository) DisableFirewall() error { + if _, err := r.commander.Execute("ufw", "disable"); err != nil { + return fmt.Errorf("failed to disable UFW: %w", err) + } + + return nil +} diff --git a/pkg/application/backup_manager.go b/pkg/application/backup_manager.go new file mode 100644 index 0000000..4c54184 --- /dev/null +++ b/pkg/application/backup_manager.go @@ -0,0 +1,102 @@ +// pkg/application/backup_manager.go +package application + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/domain/service" +) + +// BackupManager is an application service for backup operations +type BackupManager struct { + backupService service.BackupService +} + +// NewBackupManager creates a new BackupManager +func NewBackupManager(backupService service.BackupService) *BackupManager { + return &BackupManager{ + backupService: backupService, + } +} + +// BackupFile creates a backup of the specified file +func (m *BackupManager) BackupFile(filePath string) error { + return m.backupService.BackupFile(filePath) +} + +// GetBackupConfig retrieves the current backup configuration +func (m *BackupManager) GetBackupConfig() (*model.BackupConfig, error) { + return m.backupService.GetBackupConfig() +} + +// ToggleBackups enables or disables backups +func (m *BackupManager) ToggleBackups() error { + config, err := m.backupService.GetBackupConfig() + if err != nil { + return fmt.Errorf("failed to get backup config: %w", err) + } + + return m.backupService.EnableBackups(!config.Enabled) +} + +// SetBackupDirectory changes the backup directory +func (m *BackupManager) SetBackupDirectory(directory string) error { + // Expand path if it starts with ~ + if len(directory) > 0 && directory[:1] == "~" { + home, err := os.UserHomeDir() + if err == nil { + directory = filepath.Join(home, directory[1:]) + } + } + + return m.backupService.SetBackupDirectory(directory) +} + +// VerifyBackupDirectory ensures the backup directory exists and is writable +func (m *BackupManager) VerifyBackupDirectory() error { + return m.backupService.VerifyBackupDirectory() +} + +// CleanupOldBackups removes backups older than the specified number of days +func (m *BackupManager) CleanupOldBackups(days int) error { + return m.backupService.CleanupOldBackups(days) +} + +// GetBackupStatus returns a simple status indicating if backups are enabled +// and the current backup directory +func (m *BackupManager) GetBackupStatus() (bool, string, error) { + config, err := m.backupService.GetBackupConfig() + if err != nil { + return false, "", fmt.Errorf("failed to get backup status: %w", err) + } + + return config.Enabled, config.BackupDir, nil +} + +// VerifyBackupPath checks if the backup path exists and is writable +func (m *BackupManager) VerifyBackupPath() (bool, error) { + config, err := m.backupService.GetBackupConfig() + if err != nil { + return false, fmt.Errorf("failed to get backup config: %w", err) + } + + // Check if directory exists + if _, err := os.Stat(config.BackupDir); os.IsNotExist(err) { + return false, nil + } + + // Check if directory is writable by trying to create a test file + testFile := filepath.Join(config.BackupDir, ".write_test") + err = os.WriteFile(testFile, []byte("test"), 0644) + if err != nil { + return false, nil + } + + // Clean up test file + os.Remove(testFile) + + return true, nil +} diff --git a/pkg/application/dns_manager.go b/pkg/application/dns_manager.go new file mode 100644 index 0000000..8b70077 --- /dev/null +++ b/pkg/application/dns_manager.go @@ -0,0 +1,44 @@ +// pkg/application/dns_manager.go +package application + +import ( + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/domain/service" +) + +// DNSManager is an application service for DNS configuration +type DNSManager struct { + dnsService service.DNSService +} + +// NewDNSManager creates a new DNSManager +func NewDNSManager(dnsService service.DNSService) *DNSManager { + return &DNSManager{ + dnsService: dnsService, + } +} + +// ConfigureDNS applies DNS configuration with the specified nameservers +func (m *DNSManager) ConfigureDNS(nameservers []string, domain string) error { + // Create DNS config + config := model.DNSConfig{ + Nameservers: nameservers, + Domain: domain, + Search: []string{domain}, + } + + return m.dnsService.ConfigureDNS(config) +} + +// ConfigureSecureDNS applies DNS configuration with secure default nameservers +func (m *DNSManager) ConfigureSecureDNS() error { + // Use Cloudflare DNS by default (secure and privacy-focused) + cloudflareNameservers := []string{"1.1.1.1", "1.0.0.1"} + + return m.ConfigureDNS(cloudflareNameservers, "lan") +} + +// GetCurrentConfig retrieves the current DNS configuration +func (m *DNSManager) GetCurrentConfig() (*model.DNSConfig, error) { + return m.dnsService.GetCurrentConfig() +} diff --git a/pkg/application/environment_manager.go b/pkg/application/environment_manager.go new file mode 100644 index 0000000..ee81ed4 --- /dev/null +++ b/pkg/application/environment_manager.go @@ -0,0 +1,62 @@ +// pkg/application/environment_manager.go +package application + +import ( + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/domain/service" +) + +// EnvironmentManager is an application service for environment variable management +type EnvironmentManager struct { + environmentService service.EnvironmentService +} + +// NewEnvironmentManager creates a new EnvironmentManager +func NewEnvironmentManager( + environmentService service.EnvironmentService, +) *EnvironmentManager { + return &EnvironmentManager{ + environmentService: environmentService, + } +} + +// SetupSudoPreservation configures sudo to preserve environment variables +func (m *EnvironmentManager) SetupSudoPreservation() error { + return m.environmentService.SetupSudoPreservation() +} + +// IsSudoPreservationEnabled checks if sudo preservation is enabled +func (m *EnvironmentManager) IsSudoPreservationEnabled() (bool, error) { + return m.environmentService.IsSudoPreservationEnabled() +} + +// GetEnvironmentConfig retrieves the current environment configuration +func (m *EnvironmentManager) GetEnvironmentConfig() (*model.EnvironmentConfig, error) { + return m.environmentService.GetEnvironmentConfig() +} + +// GetConfigPath returns the path to the configuration file +func (m *EnvironmentManager) GetConfigPath() (string, error) { + config, err := m.environmentService.GetEnvironmentConfig() + if err != nil { + return "", err + } + + return config.ConfigPath, nil +} + +// IsEnvironmentVariableSet checks if a specific environment variable is set +func (m *EnvironmentManager) IsEnvironmentVariableSet(name string) (bool, string) { + value, exists := "", false + + // Currently only HARDN_CONFIG is supported + if name == "HARDN_CONFIG" { + config, err := m.environmentService.GetEnvironmentConfig() + if err == nil && config.ConfigPath != "" { + exists = true + value = config.ConfigPath + } + } + + return exists, value +} diff --git a/pkg/application/firewall_manager.go b/pkg/application/firewall_manager.go new file mode 100644 index 0000000..ba1c77e --- /dev/null +++ b/pkg/application/firewall_manager.go @@ -0,0 +1,103 @@ +// pkg/application/firewall_manager.go +package application + +import ( + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/domain/service" +) + +// FirewallManager is an application service for firewall configuration +type FirewallManager struct { + firewallService service.FirewallService +} + +// NewFirewallManager creates a new FirewallManager +func NewFirewallManager(firewallService service.FirewallService) *FirewallManager { + return &FirewallManager{ + firewallService: firewallService, + } +} + +// ConfigureFirewall applies a complete firewall configuration +func (m *FirewallManager) ConfigureFirewall( + defaultIncoming string, + defaultOutgoing string, + rules []model.FirewallRule, + profiles []model.FirewallProfile, +) error { + config := model.FirewallConfig{ + Enabled: true, + DefaultIncoming: defaultIncoming, + DefaultOutgoing: defaultOutgoing, + Rules: rules, + ApplicationProfiles: profiles, + } + + return m.firewallService.ConfigureFirewall(config) +} + +// ConfigureSecureFirewall sets up a firewall with secure defaults +func (m *FirewallManager) ConfigureSecureFirewall(sshPort int, allowedPorts []int, profiles []model.FirewallProfile) error { + // Create default SSH rule + sshRule := model.FirewallRule{ + Action: "allow", + Protocol: "tcp", + Port: sshPort, + SourceIP: "", + Description: "SSH access", + } + + // Create additional rules for allowed ports + var rules []model.FirewallRule + rules = append(rules, sshRule) + + for _, port := range allowedPorts { + rule := model.FirewallRule{ + Action: "allow", + Protocol: "tcp", + Port: port, + SourceIP: "", + Description: "Custom allowed port", + } + rules = append(rules, rule) + } + + // Create default configuration + config := model.FirewallConfig{ + Enabled: true, + DefaultIncoming: "deny", + DefaultOutgoing: "allow", + Rules: rules, + ApplicationProfiles: profiles, // Use the profiles parameter here + } + + return m.firewallService.ConfigureFirewall(config) +} + +// AddSSHRule adds a rule to allow SSH access +func (m *FirewallManager) AddSSHRule(port int) error { + rule := model.FirewallRule{ + Action: "allow", + Protocol: "tcp", + Port: port, + SourceIP: "", + Description: "SSH access", + } + + return m.firewallService.AddRule(rule) +} + +// EnableFirewall enables the firewall +func (m *FirewallManager) EnableFirewall() error { + return m.firewallService.EnableFirewall() +} + +// DisableFirewall disables the firewall +func (m *FirewallManager) DisableFirewall() error { + return m.firewallService.DisableFirewall() +} + +// GetFirewallStatus retrieves the current status of the firewall +func (m *FirewallManager) GetFirewallStatus() (bool, bool, bool, []string, error) { + return m.firewallService.GetFirewallStatus() +} diff --git a/pkg/application/log_manager.go b/pkg/application/log_manager.go new file mode 100644 index 0000000..21a937a --- /dev/null +++ b/pkg/application/log_manager.go @@ -0,0 +1,34 @@ +// pkg/application/logs_manager.go +package application + +import ( + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/domain/service" +) + +// LogsManager is an application service for log operations +type LogsManager struct { + logsService service.LogsService +} + +// NewLogsManager creates a new LogsManager +func NewLogsManager(logsService service.LogsService) *LogsManager { + return &LogsManager{ + logsService: logsService, + } +} + +// GetLogs retrieves logs from the system +func (m *LogsManager) GetLogs() ([]model.LogEntry, error) { + return m.logsService.GetLogs() +} + +// GetLogConfig retrieves the current log configuration +func (m *LogsManager) GetLogConfig() (*model.LogsConfig, error) { + return m.logsService.GetLogConfig() +} + +// PrintLogs prints the logs to the console +func (m *LogsManager) PrintLogs() error { + return m.logsService.PrintLogs() +} diff --git a/pkg/application/menu_manager.go b/pkg/application/menu_manager.go new file mode 100644 index 0000000..0db72ac --- /dev/null +++ b/pkg/application/menu_manager.go @@ -0,0 +1,175 @@ +// pkg/application/menu_manager.go +package application + +import ( + "fmt" + + "github.com/abbott/hardn/pkg/domain/model" +) + +// MenuManager orchestrates menu-related operations +type MenuManager struct { + userManager *UserManager + sshManager *SSHManager + firewallManager *FirewallManager + dnsManager *DNSManager + packageManager *PackageManager + backupManager *BackupManager + securityManager *SecurityManager + environmentManager *EnvironmentManager + logsManager *LogsManager +} + +// In the struct definition: +func NewMenuManager( + userManager *UserManager, + sshManager *SSHManager, + firewallManager *FirewallManager, + dnsManager *DNSManager, + packageManager *PackageManager, + backupManager *BackupManager, + securityManager *SecurityManager, + environmentManager *EnvironmentManager, + logsManager *LogsManager, +) *MenuManager { + return &MenuManager{ + userManager: userManager, + sshManager: sshManager, + firewallManager: firewallManager, + dnsManager: dnsManager, + packageManager: packageManager, + backupManager: backupManager, + securityManager: securityManager, + environmentManager: environmentManager, + logsManager: logsManager, + } +} + +// Methods for handling menu operations +// func (m *MenuManager) CreateUser(username string, hasSudo bool, sshKeys []string) error { +// return m.userManager.CreateUser(username, hasSudo, true, sshKeys) +// } + +// CreateUser creates a user with the specified settings +func (m *MenuManager) CreateUser(username string, hasSudo bool, sudoNoPassword bool, sshKeys []string) error { + // Create the user + err := m.userManager.CreateUser(username, hasSudo, sudoNoPassword, sshKeys) + if err != nil { + return err + } + + // If SSH keys are provided, ensure they're added + for _, key := range sshKeys { + if err := m.sshManager.AddSSHKey(username, key); err != nil { + return fmt.Errorf("error adding SSH key: %w", err) + } + } + + return nil +} + +// AddSSHKey adds an SSH key for the specified user +func (m *MenuManager) AddSSHKey(username, publicKey string) error { + return m.sshManager.AddSSHKey(username, publicKey) +} + +// DisableRootSsh disables SSH access for the root user +func (m *MenuManager) DisableRootSsh() error { + return m.sshManager.DisableRootAccess() +} + +// HardenSystem applies comprehensive system hardening +func (m *MenuManager) HardenSystem(config *model.HardeningConfig) error { + return m.securityManager.HardenSystem(config) +} + +// ConfigureDNS configures DNS with the specified nameservers +func (m *MenuManager) ConfigureDNS(nameservers []string, domain string) error { + return m.dnsManager.ConfigureDNS(nameservers, domain) +} + +// ConfigureSecureFirewall configures the firewall with secure settings +func (m *MenuManager) ConfigureSecureFirewall(sshPort int, allowedPorts []int, profiles []model.FirewallProfile) error { + return m.firewallManager.ConfigureSecureFirewall(sshPort, allowedPorts, profiles) +} + +// InstallLinuxPackages installs Linux packages based on the specified type +func (m *MenuManager) InstallLinuxPackages(packages []string, packageType string) error { + return m.packageManager.InstallLinuxPackages(packages, packageType) +} + +// InstallPythonPackages installs Python packages +func (m *MenuManager) InstallPythonPackages(systemPackages []string, pipPackages []string, useUv bool) error { + return m.packageManager.InstallPythonPackages(systemPackages, pipPackages, useUv) +} + +// UpdatePackageSources updates package sources configuration +func (m *MenuManager) UpdatePackageSources() error { + return m.packageManager.UpdatePackageSources() +} + +// UpdateProxmoxSources updates Proxmox-specific package sources +func (m *MenuManager) UpdateProxmoxSources() error { + return m.packageManager.UpdateProxmoxSources() +} + +// GetFirewallStatus retrieves the current status of the firewall +func (m *MenuManager) GetFirewallStatus() (bool, bool, bool, []string, error) { + return m.firewallManager.GetFirewallStatus() +} + +// GetBackupStatus returns the backup status and directory +func (m *MenuManager) GetBackupStatus() (bool, string, error) { + return m.backupManager.GetBackupStatus() +} + +// VerifyBackupPath checks if the backup path exists and is writable +func (m *MenuManager) VerifyBackupPath() (bool, error) { + return m.backupManager.VerifyBackupPath() +} + +// ToggleBackups enables or disables backups +func (m *MenuManager) ToggleBackups() error { + return m.backupManager.ToggleBackups() +} + +// SetBackupDirectory changes the backup directory +func (m *MenuManager) SetBackupDirectory(directory string) error { + return m.backupManager.SetBackupDirectory(directory) +} + +// VerifyBackupDirectory ensures the backup directory exists and is writable +func (m *MenuManager) VerifyBackupDirectory() error { + return m.backupManager.VerifyBackupDirectory() +} + +// Add these methods to pkg/application/menu_manager.go + +// Add these fields and methods to MenuManager + +// Replace the existing methods with these: + +// SetupSudoPreservation configures sudo to preserve the HARDN_CONFIG environment variable +func (m *MenuManager) SetupSudoPreservation() error { + return m.environmentManager.SetupSudoPreservation() +} + +// IsSudoPreservationEnabled checks if sudo is configured to preserve the HARDN_CONFIG environment variable +func (m *MenuManager) IsSudoPreservationEnabled() (bool, error) { + return m.environmentManager.IsSudoPreservationEnabled() +} + +// GetEnvironmentConfig retrieves the current environment configuration +func (m *MenuManager) GetEnvironmentConfig() (*model.EnvironmentConfig, error) { + return m.environmentManager.GetEnvironmentConfig() +} + +// PrintLogs prints the log file content to the console +func (m *MenuManager) PrintLogs() error { + return m.logsManager.PrintLogs() +} + +// GetLogConfig retrieves the current log configuration +func (m *MenuManager) GetLogConfig() (*model.LogsConfig, error) { + return m.logsManager.GetLogConfig() +} diff --git a/pkg/application/package_manager.go b/pkg/application/package_manager.go new file mode 100644 index 0000000..87c6633 --- /dev/null +++ b/pkg/application/package_manager.go @@ -0,0 +1,158 @@ +// pkg/application/package_manager.go +package application + +import ( + "os" + + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/domain/service" + "github.com/abbott/hardn/pkg/interfaces" +) + +// PackageManager is an application service for package management +// PackageManager is an application service for package management +type PackageManager struct { + packageService service.PackageService + config *model.PackageSources + osInfo *model.OSInfo + networkOps interfaces.NetworkOperations + dmzSubnet string +} + +// NewPackageManager creates a new PackageManager +func NewPackageManager( + packageService service.PackageService, + config *model.PackageSources, + osInfo *model.OSInfo, + networkOps interfaces.NetworkOperations, + dmzSubnet string, +) *PackageManager { + return &PackageManager{ + packageService: packageService, + config: config, + osInfo: osInfo, + networkOps: networkOps, + dmzSubnet: dmzSubnet, + } +} + +// InstallLinuxPackages installs system packages based on the specified type +func (m *PackageManager) InstallLinuxPackages(packages []string, packageType string) error { + // Create a package installation request + request := model.PackageInstallRequest{ + Packages: packages, + PackageType: packageType, + IsPython: false, + } + + // Call the domain service + return m.packageService.InstallPackages(request) +} + +// InstallPythonPackages installs Python packages +func (m *PackageManager) InstallPythonPackages( + systemPackages []string, + pipPackages []string, + useUv bool, +) error { + // Create a Python package installation request + request := model.PackageInstallRequest{ + Packages: systemPackages, + PipPackages: pipPackages, + UseUv: useUv, + IsPython: true, + } + + // Call the domain service + return m.packageService.InstallPackages(request) +} + +// UpdatePackageSources updates package sources configuration +func (m *PackageManager) UpdatePackageSources() error { + return m.packageService.UpdatePackageSources() +} + +// UpdateProxmoxSources updates Proxmox-specific package sources configuration +func (m *PackageManager) UpdateProxmoxSources() error { + return m.packageService.UpdateProxmoxSources() +} + +// InstallAllLinuxPackages installs all appropriate packages based on OS type and environment +func (m *PackageManager) InstallAllLinuxPackages() error { + // Check if we're in a DMZ subnet + isDMZ, _ := m.networkOps.CheckSubnet(m.dmzSubnet) + + var corePackages, dmzPackages, labPackages []string + + // Determine packages based on OS type + if m.osInfo.Type == "alpine" { + // Get Alpine packages from the configuration + if m.config != nil { + corePackages = m.config.AlpineCorePackages + dmzPackages = m.config.AlpineDmzPackages + labPackages = m.config.AlpineLabPackages + } + } else { + // Get Debian/Ubuntu packages from the configuration + if m.config != nil { + corePackages = m.config.DebianCorePackages + dmzPackages = m.config.DebianDmzPackages + labPackages = m.config.DebianLabPackages + } + } + + // Install core packages + if len(corePackages) > 0 { + if err := m.InstallLinuxPackages(corePackages, "core"); err != nil { + return err + } + } + + // Install DMZ packages + if len(dmzPackages) > 0 { + if err := m.InstallLinuxPackages(dmzPackages, "dmz"); err != nil { + return err + } + } + + // Install lab packages if not in DMZ + if !isDMZ && len(labPackages) > 0 { + if err := m.InstallLinuxPackages(labPackages, "lab"); err != nil { + return err + } + } + + return nil +} + +// InstallAllPythonPackages installs all appropriate Python packages based on OS type +func (m *PackageManager) InstallAllPythonPackages(useUv bool) error { + var systemPackages []string + var pipPackages []string + + if m.osInfo.Type == "alpine" { + // Get Alpine Python packages + if m.config != nil { + systemPackages = m.config.AlpinePythonPackages + } + } else { + // Get Debian/Ubuntu Python packages + if m.config != nil { + systemPackages = m.config.DebianPythonPackages + + // Add non-WSL packages if not in WSL + if os.Getenv("WSL") == "" && len(m.config.NonWslPythonPackages) > 0 { + systemPackages = append(systemPackages, m.config.NonWslPythonPackages...) + } + + pipPackages = m.config.PythonPipPackages + } + } + + // Install Python packages + if len(systemPackages) > 0 || len(pipPackages) > 0 { + return m.InstallPythonPackages(systemPackages, pipPackages, useUv) + } + + return nil +} diff --git a/pkg/application/security_manager.go b/pkg/application/security_manager.go new file mode 100644 index 0000000..ad97368 --- /dev/null +++ b/pkg/application/security_manager.go @@ -0,0 +1,78 @@ +// pkg/application/security_manager.go +package application + +import ( + "github.com/abbott/hardn/pkg/domain/model" +) + +// SecurityManager provides high-level security operations combining multiple services +type SecurityManager struct { + userManager *UserManager + sshManager *SSHManager + firewallManager *FirewallManager + dnsManager *DNSManager +} + +// NewSecurityManager creates a new SecurityManager +func NewSecurityManager( + userManager *UserManager, + sshManager *SSHManager, + firewallManager *FirewallManager, + dnsManager *DNSManager, +) *SecurityManager { + return &SecurityManager{ + userManager: userManager, + sshManager: sshManager, + firewallManager: firewallManager, + dnsManager: dnsManager, + } +} + +// HardenSystem applies comprehensive system hardening +func (m *SecurityManager) HardenSystem(config *model.HardeningConfig) error { + // Create non-root user if requested + if config.CreateUser && config.Username != "" { + if err := m.userManager.CreateUser( + config.Username, + true, + config.SudoNoPassword, + config.SshKeys, + ); err != nil { + return err + } + } + + // Configure SSH with secure settings + if err := m.sshManager.ConfigureSSH( + config.SshPort, + config.SshListenAddresses, + false, // Never allow root login + config.SshAllowedUsers, + config.SshKeyPaths, + ); err != nil { + return err + } + + // Configure firewall + if config.EnableFirewall { + if err := m.firewallManager.ConfigureSecureFirewall( + config.SshPort, + config.AllowedPorts, + config.FirewallProfiles, + ); err != nil { + return err + } + } + + // Configure DNS if enabled + if config.ConfigureDns { + if err := m.dnsManager.ConfigureDNS( + config.Nameservers, + "lan", + ); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/application/ssh_manager.go b/pkg/application/ssh_manager.go new file mode 100644 index 0000000..b2c6eae --- /dev/null +++ b/pkg/application/ssh_manager.go @@ -0,0 +1,67 @@ +// pkg/application/ssh_manager.go +package application + +import ( + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/domain/service" +) + +// SSHManager is an application service for SSH configuration +type SSHManager struct { + sshService service.SSHService +} + +// NewSSHManager creates a new SSHManager +func NewSSHManager(sshService service.SSHService) *SSHManager { + return &SSHManager{ + sshService: sshService, + } +} + +// ConfigureSSH applies SSH configuration with the specified settings +func (m *SSHManager) ConfigureSSH( + port int, + listenAddresses []string, + permitRootLogin bool, + allowedUsers []string, + keyPaths []string, +) error { + // Create SSH config object + config := model.SSHConfig{ + Port: port, + ListenAddresses: listenAddresses, + PermitRootLogin: permitRootLogin, + AllowedUsers: allowedUsers, + KeyPaths: keyPaths, + AuthMethods: []string{"publickey"}, + } + + // Call domain service + return m.sshService.ConfigureSSH(config) +} + +// SecureSSH applies recommended security settings to SSH +func (m *SSHManager) SecureSSH(port int, allowedUsers []string) error { + // Create SSH config with secure defaults + config := model.SSHConfig{ + Port: port, + ListenAddresses: []string{"0.0.0.0"}, + PermitRootLogin: true, + AllowedUsers: allowedUsers, + AuthMethods: []string{"publickey"}, + KeyPaths: []string{".ssh/authorized_keys"}, + } + + // Apply the configuration + return m.sshService.ConfigureSSH(config) +} + +// DisableRootAccess disables SSH access for the root user +func (m *SSHManager) DisableRootAccess() error { + return m.sshService.DisableRootAccess() +} + +// AddSSHKey adds an SSH public key for a user +func (m *SSHManager) AddSSHKey(username string, publicKey string) error { + return m.sshService.AddAuthorizedKey(username, publicKey) +} diff --git a/pkg/application/user_manager.go b/pkg/application/user_manager.go new file mode 100644 index 0000000..d47d859 --- /dev/null +++ b/pkg/application/user_manager.go @@ -0,0 +1,36 @@ +// pkg/application/user_manager.go +package application + +import ( + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/domain/service" +) + +// UserManager is an application service for user management +type UserManager struct { + userService service.UserService +} + +// NewUserManager creates a new UserManager +func NewUserManager(userService service.UserService) *UserManager { + return &UserManager{ + userService: userService, + } +} + +// CreateUser creates a new system user with the specified settings +func (m *UserManager) CreateUser(username string, hasSudo bool, sudoNoPassword bool, sshKeys []string) error { + user := model.User{ + Username: username, + HasSudo: hasSudo, + SudoNoPassword: sudoNoPassword, + SshKeys: sshKeys, + } + + return m.userService.CreateUser(user) +} + +// AddSSHKey adds an SSH key to an existing user +func (m *UserManager) AddSSHKey(username string, publicKey string) error { + return m.userService.AddSSHKey(username, publicKey) +} diff --git a/pkg/backup/backup.go b/pkg/backup/backup.go deleted file mode 100644 index de64124..0000000 --- a/pkg/backup/backup.go +++ /dev/null @@ -1,198 +0,0 @@ -package backup - -import ( - "fmt" - "os" - "path/filepath" - "time" - - "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/logging" -) - -// BackupFile backs up a file with a timestamp -func BackupFile(filePath string, cfg *config.Config) error { - if !cfg.EnableBackups { - logging.LogInfo("Backups disabled. Skipping backup of %s", filePath) - return nil - } - - if cfg.DryRun { - logging.LogInfo("[DRY-RUN] Backup %s to %s", filePath, cfg.BackupPath) - return nil - } - - // Create backup directory with date - backupDir := filepath.Join(cfg.BackupPath, time.Now().Format("2006-01-02")) - if err := os.MkdirAll(backupDir, 0755); err != nil { - return fmt.Errorf("failed to create backup directory %s: %w", backupDir, err) - } - - // Get filename without path - fileName := filepath.Base(filePath) - - // Check if file exists - if _, err := os.Stat(filePath); os.IsNotExist(err) { - logging.LogInfo("File %s does not exist, no backup needed", filePath) - return nil - } - - // Create backup with timestamp - backupFile := filepath.Join(backupDir, fmt.Sprintf("%s.%s.bak", fileName, time.Now().Format("150405"))) - - // Read original file - data, err := os.ReadFile(filePath) - if err != nil { - return fmt.Errorf("failed to read file %s for backup: %w", filePath, err) - } - - // Write backup file - if err := os.WriteFile(backupFile, data, 0644); err != nil { - return fmt.Errorf("failed to write backup file %s: %w", backupFile, err) - } - - logging.LogInfo("Backed up %s to %s", filePath, backupFile) - return nil -} - -// ListBackups returns a list of all backups for a specific file -func ListBackups(filePath string, cfg *config.Config) ([]string, error) { - // Get filename without path - fileName := filepath.Base(filePath) - var backups []string - - // Walk through backup directories - err := filepath.Walk(cfg.BackupPath, func(path string, info os.FileInfo, err error) error { - if err != nil { - return fmt.Errorf("error accessing path %s: %w", path, err) - } - - // Skip directories - if info.IsDir() { - return nil - } - - // Check if this is a backup of our file - if matched, err := filepath.Match(fmt.Sprintf("%s.*.bak", fileName), info.Name()); err != nil { - return fmt.Errorf("error matching pattern for file %s: %w", info.Name(), err) - } else if matched { - backups = append(backups, path) - } - - return nil - }) - - if err != nil { - return nil, fmt.Errorf("failed to list backups for %s: %w", filePath, err) - } - - return backups, nil -} - -// RestoreBackup restores a file from backup -func RestoreBackup(backupPath, originalPath string, cfg *config.Config) error { - if cfg.DryRun { - logging.LogInfo("[DRY-RUN] Restore backup %s to %s", backupPath, originalPath) - return nil - } - - // Check if backup exists - if _, err := os.Stat(backupPath); os.IsNotExist(err) { - return fmt.Errorf("backup file %s does not exist: %w", backupPath, err) - } - - // Read backup file - data, err := os.ReadFile(backupPath) - if err != nil { - return fmt.Errorf("failed to read backup file %s: %w", backupPath, err) - } - - // Make sure target directory exists - targetDir := filepath.Dir(originalPath) - if err := os.MkdirAll(targetDir, 0755); err != nil { - return fmt.Errorf("failed to create directory %s for restored file: %w", targetDir, err) - } - - // Write restored file - if err := os.WriteFile(originalPath, data, 0644); err != nil { - return fmt.Errorf("failed to write restored file %s: %w", originalPath, err) - } - - logging.LogSuccess("Restored %s from backup %s", originalPath, backupPath) - return nil -} - -// CleanupOldBackups removes backups older than specified days -func CleanupOldBackups(cfg *config.Config, daysToKeep int) error { - if cfg.DryRun { - logging.LogInfo("[DRY-RUN] Clean up backups older than %d days in %s", daysToKeep, cfg.BackupPath) - return nil - } - - // Calculate cutoff time - cutoff := time.Now().AddDate(0, 0, -daysToKeep) - - // Get all date-based directories in backup path - entries, err := os.ReadDir(cfg.BackupPath) - if err != nil { - if os.IsNotExist(err) { - // Backup path doesn't exist yet - nothing to clean - return nil - } - return fmt.Errorf("failed to read backup directory %s: %w", cfg.BackupPath, err) - } - - // Check each directory - for _, entry := range entries { - if !entry.IsDir() { - continue - } - - // Check if directory name is a date format - dirDate, err := time.Parse("2006-01-02", entry.Name()) - if err != nil { - // Not a date directory, skip - continue - } - - // If directory is older than cutoff, remove it - if dirDate.Before(cutoff) { - dirPath := filepath.Join(cfg.BackupPath, entry.Name()) - if err := os.RemoveAll(dirPath); err != nil { - logging.LogError("Failed to remove old backup directory %s: %v", dirPath, err) - } else { - logging.LogInfo("Removed old backup directory %s", dirPath) - } - } - } - - logging.LogSuccess("Backup cleanup completed") - return nil -} - -// VerifyBackupDirectory ensures the backup directory exists and is writable -func VerifyBackupDirectory(cfg *config.Config) error { - if cfg.DryRun { - logging.LogInfo("[DRY-RUN] Verify backup directory %s exists and is writable", cfg.BackupPath) - return nil - } - - // Create backup directory if it doesn't exist - if err := os.MkdirAll(cfg.BackupPath, 0755); err != nil { - return fmt.Errorf("failed to create backup directory %s: %w", cfg.BackupPath, err) - } - - // Check if directory is writable by writing a test file - testFile := filepath.Join(cfg.BackupPath, ".write_test") - if err := os.WriteFile(testFile, []byte("test"), 0644); err != nil { - return fmt.Errorf("backup directory %s is not writable: %w", cfg.BackupPath, err) - } - - // Clean up test file - if err := os.Remove(testFile); err != nil { - logging.LogError("Failed to remove test file %s: %v", testFile, err) - } - - logging.LogInfo("Backup directory %s verified", cfg.BackupPath) - return nil -} \ No newline at end of file diff --git a/pkg/config/config.go b/pkg/config/config.go index 85feee3..8db64ee 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -88,6 +88,11 @@ type Config struct { LcAll string `yaml:"lcAll"` Tz string `yaml:"tz"` PythonUnbuffered string `yaml:"pythonUnbuffered"` + + // Logs Configuration (embedded for easy access to LogFile) + LogsConfig struct { + LogFilePath string + } } // Default configuration @@ -110,7 +115,7 @@ func DefaultConfig() *Config { // SshAllowedUsers: []string{"george"}, SshListenAddress: "0.0.0.0", SshKeyPath: ".ssh_%u", - SshConfigFile: "/etc/ssh/sshd_config.d/manage.conf", + SshConfigFile: "/etc/ssh/sshd_config.d/hardn.conf", // User Configuration SudoNoPassword: true, @@ -192,9 +197,9 @@ func ConfigFileSearchPath(explicitPath string) []string { func FindConfigFile(explicitPath string) (string, bool) { // Log environment variable for debugging envPath := os.Getenv("HARDN_CONFIG") - if envPath != "" { - logging.LogInfo("HARDN_CONFIG environment variable is set to: %s", envPath) - } + // if envPath != "" { + // logging.LogInfo("HARDN_CONFIG environment variable is set to: %s", envPath) + // } // First priority: explicit path from command line if explicitPath != "" { @@ -327,7 +332,15 @@ func LoadConfig(filePath string) (*Config, error) { fmt.Println() } - return LoadConfigWithEnvPriority(filePath) + cfg, err := LoadConfigWithEnvPriority(filePath) + if err != nil { + return nil, err + } + + // Initialize LogsConfig + cfg.LogsConfig.LogFilePath = cfg.LogFile + + return cfg, nil } // GetDefaultConfigLocation returns the appropriate location for a new config file @@ -456,4 +469,4 @@ func SaveConfig(config *Config, filePath string) error { func isInteractive() bool { fileInfo, _ := os.Stdin.Stat() return (fileInfo.Mode() & os.ModeCharDevice) != 0 -} \ No newline at end of file +} diff --git a/pkg/config/embed.go b/pkg/config/embed.go index b8ea593..26937d3 100644 --- a/pkg/config/embed.go +++ b/pkg/config/embed.go @@ -40,7 +40,7 @@ sshAllowedUsers: # List of users allowed to access via SSH - "george" sshListenAddress: "0.0.0.0" # IP address to listen on sshKeyPath: ".ssh_%u" # Path to SSH keys (use %u for username substitution) -sshConfigFile: "/etc/ssh/sshd_config.d/manage.conf" # SSH config file location +sshConfigFile: "/etc/ssh/sshd_config.d/hardn.conf" # SSH config file location ################################################# # User Configuration diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go deleted file mode 100644 index 92b173c..0000000 --- a/pkg/dns/dns.go +++ /dev/null @@ -1,116 +0,0 @@ -package dns - -import ( - "fmt" - "os" - "os/exec" - "strings" - - "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/logging" - "github.com/abbott/hardn/pkg/osdetect" - "github.com/abbott/hardn/pkg/utils" -) - -// ConfigureDNS configures DNS settings based on the configuration -func ConfigureDNS(cfg *config.Config, osInfo *osdetect.OSInfo) error { - if cfg.DryRun { - logging.LogInfo("[DRY-RUN] Configure DNS with the following settings:") - logging.LogInfo("[DRY-RUN] - Domain: lan") - logging.LogInfo("[DRY-RUN] - Search: lan") - logging.LogInfo("[DRY-RUN] - Primary nameserver: %s", cfg.Nameservers[0]) - if len(cfg.Nameservers) > 1 { - logging.LogInfo("[DRY-RUN] - Secondary nameserver: %s", cfg.Nameservers[1]) - } - - // Check systemd-resolved - cmd := exec.Command("systemctl", "is-active", "systemd-resolved") - if err := cmd.Run(); err == nil { - logging.LogInfo("[DRY-RUN] systemd-resolved detected - Configure via /etc/systemd/resolved.conf") - logging.LogInfo("[DRY-RUN] Restart systemd-resolved service") - } else if _, err := exec.LookPath("resolvconf"); err == nil { - // Check resolvconf - logging.LogInfo("[DRY-RUN] resolvconf detected - Configure via /etc/resolvconf/resolv.conf.d/head") - logging.LogInfo("[DRY-RUN] Update resolvconf with 'resolvconf -u'") - } else { - // Direct configuration - logging.LogInfo("[DRY-RUN] Write DNS configuration directly to /etc/resolv.conf") - } - return nil - } - - logging.LogInfo("Configuring DNS settings...") - - if len(cfg.Nameservers) == 0 { - return fmt.Errorf("no nameservers configured in configuration") - } - - primaryNameserver := cfg.Nameservers[0] - - // Check if systemd-resolved is active - cmd := exec.Command("systemctl", "is-active", "systemd-resolved") - if err := cmd.Run(); err == nil { - logging.LogInfo("systemd-resolved detected, configuring via resolved.conf") - utils.BackupFile("/etc/systemd/resolved.conf", cfg) - - // Create resolved.conf content - content := "[Resolve]\n" - content += fmt.Sprintf("DNS=%s", strings.Join(cfg.Nameservers, " ")) - content += "\nDomains=lan\n" - - // Write resolved.conf - if err := os.WriteFile("/etc/systemd/resolved.conf", []byte(content), 0644); err != nil { - return fmt.Errorf("failed to write resolved.conf for nameserver %s: %w", primaryNameserver, err) - } - - // Restart systemd-resolved - restartCmd := exec.Command("systemctl", "restart", "systemd-resolved") - if err := restartCmd.Run(); err != nil { - return fmt.Errorf("failed to restart systemd-resolved for nameserver %s: %w", primaryNameserver, err) - } - } else if _, err := exec.LookPath("resolvconf"); err == nil { - // resolvconf is installed - logging.LogInfo("resolvconf detected, using resolvconf mechanism") - utils.BackupFile("/etc/resolvconf/resolv.conf.d/head", cfg) - - // Create head file content - content := "domain lan\nsearch lan\n" - content += fmt.Sprintf("nameserver %s\n", cfg.Nameservers[0]) - if len(cfg.Nameservers) > 1 { - content += fmt.Sprintf("nameserver %s\n", cfg.Nameservers[1]) - } - - // Write head file - if err := os.MkdirAll("/etc/resolvconf/resolv.conf.d", 0755); err != nil { - return fmt.Errorf("failed to create resolvconf directory for nameserver %s: %w", primaryNameserver, err) - } - if err := os.WriteFile("/etc/resolvconf/resolv.conf.d/head", []byte(content), 0644); err != nil { - return fmt.Errorf("failed to write resolvconf head file with nameserver %s: %w", primaryNameserver, err) - } - - // Update resolvconf - resolvCmd := exec.Command("resolvconf", "-u") - if err := resolvCmd.Run(); err != nil { - return fmt.Errorf("failed to update resolvconf with nameserver %s: %w", primaryNameserver, err) - } - } else { - // Direct approach - logging.LogInfo("Using direct DNS configuration") - utils.BackupFile("/etc/resolv.conf", cfg) - - // Create resolv.conf content - content := "domain lan\nsearch lan\n" - content += fmt.Sprintf("nameserver %s\n", cfg.Nameservers[0]) - if len(cfg.Nameservers) > 1 { - content += fmt.Sprintf("nameserver %s\n", cfg.Nameservers[1]) - } - - // Write resolv.conf - if err := os.WriteFile("/etc/resolv.conf", []byte(content), 0644); err != nil { - return fmt.Errorf("failed to write resolv.conf with nameserver %s: %w", primaryNameserver, err) - } - } - - logging.LogSuccess("DNS configured successfully with nameserver %s", primaryNameserver) - return nil -} \ No newline at end of file diff --git a/pkg/domain/model/backup.go b/pkg/domain/model/backup.go new file mode 100644 index 0000000..4ebea88 --- /dev/null +++ b/pkg/domain/model/backup.go @@ -0,0 +1,18 @@ +// pkg/domain/model/backup.go +package model + +import "time" + +// BackupConfig represents backup configuration settings +type BackupConfig struct { + Enabled bool // Whether backups are enabled + BackupDir string // Directory to store backups +} + +// BackupFile represents information about a backed up file +type BackupFile struct { + OriginalPath string // Path of the original file + BackupPath string // Full path to the backup + Created time.Time // When the backup was created + Size int64 // Size of the backup in bytes +} diff --git a/pkg/domain/model/dns_config.go b/pkg/domain/model/dns_config.go new file mode 100644 index 0000000..14964b1 --- /dev/null +++ b/pkg/domain/model/dns_config.go @@ -0,0 +1,9 @@ +// pkg/domain/model/dns_config.go +package model + +// DNSConfig represents DNS configuration settings +type DNSConfig struct { + Nameservers []string + Domain string + Search []string +} diff --git a/pkg/domain/model/environment.go b/pkg/domain/model/environment.go new file mode 100644 index 0000000..7db0e94 --- /dev/null +++ b/pkg/domain/model/environment.go @@ -0,0 +1,14 @@ +// pkg/domain/model/environment.go +package model + +// EnvironmentConfig represents environment variable configuration settings +type EnvironmentConfig struct { + // ConfigPath is the path to the configuration file specified by HARDN_CONFIG + ConfigPath string + + // PreserveSudo indicates whether HARDN_CONFIG should be preserved in sudo + PreserveSudo bool + + // Username of the current user for sudo configuration + Username string +} diff --git a/pkg/domain/model/firewall.go b/pkg/domain/model/firewall.go new file mode 100644 index 0000000..f5c4535 --- /dev/null +++ b/pkg/domain/model/firewall.go @@ -0,0 +1,28 @@ +// pkg/domain/model/firewall.go +package model + +// FirewallRule represents a firewall rule +type FirewallRule struct { + Action string // allow, deny + Protocol string // tcp, udp, icmp + Port int + SourceIP string // source IP or subnet + Description string +} + +// FirewallProfile represents a firewall application profile +type FirewallProfile struct { + Name string + Title string + Description string + Ports []string // formatted as "port/protocol" +} + +// FirewallConfig represents the full firewall configuration +type FirewallConfig struct { + Enabled bool + DefaultIncoming string // allow, deny + DefaultOutgoing string // allow, deny + Rules []FirewallRule + ApplicationProfiles []FirewallProfile +} diff --git a/pkg/domain/model/harderning_config.go b/pkg/domain/model/harderning_config.go new file mode 100644 index 0000000..f72d58f --- /dev/null +++ b/pkg/domain/model/harderning_config.go @@ -0,0 +1,35 @@ +// pkg/domain/model/hardening_config.go +package model + +// HardeningConfig represents a comprehensive system hardening configuration +type HardeningConfig struct { + // User settings + CreateUser bool + Username string + SudoNoPassword bool + SshKeys []string + + // SSH settings + SshPort int + SshListenAddresses []string + SshAllowedUsers []string + SshKeyPaths []string + + // Firewall settings + EnableFirewall bool + AllowedPorts []int + FirewallProfiles []FirewallProfile + + // DNS settings + ConfigureDns bool + Nameservers []string + + // Feature toggles + EnableAppArmor bool + EnableLynis bool + EnableUnattendedUpgrades bool + + UseUvPackageManager bool + // UpdateRepositories bool + InstallPackages bool +} diff --git a/pkg/domain/model/logs.go b/pkg/domain/model/logs.go new file mode 100644 index 0000000..29b86ab --- /dev/null +++ b/pkg/domain/model/logs.go @@ -0,0 +1,14 @@ +// pkg/domain/model/logs.go +package model + +// LogEntry represents a single log entry +type LogEntry struct { + Level string + Message string + Time string +} + +// LogsConfig represents log configuration settings +type LogsConfig struct { + LogFilePath string +} diff --git a/pkg/domain/model/package.go b/pkg/domain/model/package.go new file mode 100644 index 0000000..97744df --- /dev/null +++ b/pkg/domain/model/package.go @@ -0,0 +1,44 @@ +// pkg/domain/model/package.go +package model + +// PackageInstallRequest represents a request to install packages +type PackageInstallRequest struct { + Packages []string + PipPackages []string + PackageType string // Core, DMZ, Lab, etc. + UseUv bool // Whether to use UV for Python packages + IsPython bool // Whether this is a Python package install request + IsSystemPython bool // Whether to install system Python packages +} + +// RepositorySource represents a package repository source +type RepositorySource struct { + URL string + Distribution string + Components []string + Enabled bool +} + +// PackageSources represents package repository sources configuration +type PackageSources struct { + // Repository sources + DebianRepos []string + ProxmoxSrcRepos []string + ProxmoxCephRepo []string + ProxmoxEnterpriseRepo []string + AlpineTestingRepo bool + + // Package lists by OS and environment + DebianCorePackages []string + DebianDmzPackages []string + DebianLabPackages []string + AlpineCorePackages []string + AlpineDmzPackages []string + AlpineLabPackages []string + + // Python packages + DebianPythonPackages []string + NonWslPythonPackages []string + PythonPipPackages []string + AlpinePythonPackages []string +} diff --git a/pkg/domain/model/ssh_config.go b/pkg/domain/model/ssh_config.go new file mode 100644 index 0000000..373b7b0 --- /dev/null +++ b/pkg/domain/model/ssh_config.go @@ -0,0 +1,21 @@ +// pkg/domain/model/ssh_config.go +package model + +// SSHConfig represents SSH server configuration settings +type SSHConfig struct { + Port int + ListenAddresses []string + PermitRootLogin bool + AllowedUsers []string + KeyPaths []string + AuthMethods []string + ConfigFilePath string +} + +// SSHKey represents an SSH public key +type SSHKey struct { + User string + PublicKey string + KeyType string + Comment string +} diff --git a/pkg/domain/model/system.go b/pkg/domain/model/system.go new file mode 100644 index 0000000..c0fdf2c --- /dev/null +++ b/pkg/domain/model/system.go @@ -0,0 +1,10 @@ +// pkg/domain/model/system.go +package model + +// OSInfo represents operating system information +type OSInfo struct { + Type string // alpine, debian, ubuntu, etc. + Version string // version number + Codename string // release name + IsProxmox bool // whether this is a Proxmox installation +} diff --git a/pkg/domain/model/user.go b/pkg/domain/model/user.go new file mode 100644 index 0000000..405f541 --- /dev/null +++ b/pkg/domain/model/user.go @@ -0,0 +1,10 @@ +// pkg/domain/model/user.go +package model + +// User represents a system user +type User struct { + Username string + HasSudo bool + SshKeys []string + SudoNoPassword bool +} diff --git a/pkg/domain/service/backup_service.go b/pkg/domain/service/backup_service.go new file mode 100644 index 0000000..3cb4e17 --- /dev/null +++ b/pkg/domain/service/backup_service.go @@ -0,0 +1,106 @@ +// pkg/domain/service/backup_service.go +package service + +import ( + "fmt" + "time" + + "github.com/abbott/hardn/pkg/domain/model" +) + +// BackupService defines operations for backup functionality +type BackupService interface { + // BackupFile backs up a file with a timestamp + BackupFile(filePath string) error + + // ListBackups returns a list of all backups for a specific file + ListBackups(filePath string) ([]model.BackupFile, error) + + // RestoreBackup restores a file from backup + RestoreBackup(backupPath, originalPath string) error + + // CleanupOldBackups removes backups older than specified days + CleanupOldBackups(daysToKeep int) error + + // VerifyBackupDirectory ensures the backup directory exists and is writable + VerifyBackupDirectory() error + + // GetBackupConfig retrieves the current backup configuration + GetBackupConfig() (*model.BackupConfig, error) + + // EnableBackups enables backups + EnableBackups(enabled bool) error + + // SetBackupDirectory changes the backup directory + SetBackupDirectory(directory string) error +} + +// BackupServiceImpl implements BackupService +type BackupServiceImpl struct { + repository BackupRepository +} + +// NewBackupServiceImpl creates a new BackupServiceImpl +func NewBackupServiceImpl(repository BackupRepository) *BackupServiceImpl { + return &BackupServiceImpl{ + repository: repository, + } +} + +// BackupRepository defines the repository operations needed by BackupService +type BackupRepository interface { + BackupFile(filePath string) error + ListBackups(filePath string) ([]model.BackupFile, error) + RestoreBackup(backupPath, originalPath string) error + CleanupOldBackups(before time.Time) error + VerifyBackupDirectory() error + GetBackupConfig() (*model.BackupConfig, error) + SetBackupConfig(config model.BackupConfig) error +} + +// Implementation of BackupService methods +func (s *BackupServiceImpl) BackupFile(filePath string) error { + return s.repository.BackupFile(filePath) +} + +func (s *BackupServiceImpl) ListBackups(filePath string) ([]model.BackupFile, error) { + return s.repository.ListBackups(filePath) +} + +func (s *BackupServiceImpl) RestoreBackup(backupPath, originalPath string) error { + return s.repository.RestoreBackup(backupPath, originalPath) +} + +func (s *BackupServiceImpl) CleanupOldBackups(daysToKeep int) error { + // Convert days to a specific time + cutoffTime := time.Now().AddDate(0, 0, -daysToKeep) + return s.repository.CleanupOldBackups(cutoffTime) +} + +func (s *BackupServiceImpl) VerifyBackupDirectory() error { + return s.repository.VerifyBackupDirectory() +} + +func (s *BackupServiceImpl) GetBackupConfig() (*model.BackupConfig, error) { + return s.repository.GetBackupConfig() +} + +func (s *BackupServiceImpl) EnableBackups(enabled bool) error { + config, err := s.repository.GetBackupConfig() + if err != nil { + return fmt.Errorf("failed to get backup config: %w", err) + } + + config.Enabled = enabled + return s.repository.SetBackupConfig(*config) +} + +func (s *BackupServiceImpl) SetBackupDirectory(directory string) error { + config, err := s.repository.GetBackupConfig() + if err != nil { + return fmt.Errorf("failed to get backup config: %w", err) + } + + config.BackupDir = directory + return s.repository.SetBackupConfig(*config) +} diff --git a/pkg/domain/service/backup_service_test.go b/pkg/domain/service/backup_service_test.go new file mode 100644 index 0000000..d4dd171 --- /dev/null +++ b/pkg/domain/service/backup_service_test.go @@ -0,0 +1,676 @@ +package service + +import ( + "errors" + "reflect" + "testing" + "time" + + "github.com/abbott/hardn/pkg/domain/model" +) + +// MockBackupRepository implements BackupRepository interface for testing +type MockBackupRepository struct { + // BackupFile tracking + BackupFileCalled bool + BackupFilePath string + BackupFileError error + + // ListBackups tracking + ListBackupsCalled bool + ListBackupsPath string + ListBackupsResult []model.BackupFile + ListBackupsError error + + // RestoreBackup tracking + RestoreBackupCalled bool + BackupPath string + OriginalPath string + RestoreBackupError error + + // CleanupOldBackups tracking + CleanupCalled bool + CleanupBeforeTime time.Time + CleanupError error + + // VerifyBackupDirectory tracking + VerifyCalled bool + VerifyError error + + // GetBackupConfig tracking + GetConfigCalled bool + BackupConfig *model.BackupConfig + GetConfigError error + + // SetBackupConfig tracking + SetConfigCalled bool + SetConfigValue model.BackupConfig + SetConfigError error +} + +func (m *MockBackupRepository) BackupFile(filePath string) error { + m.BackupFileCalled = true + m.BackupFilePath = filePath + return m.BackupFileError +} + +func (m *MockBackupRepository) ListBackups(filePath string) ([]model.BackupFile, error) { + m.ListBackupsCalled = true + m.ListBackupsPath = filePath + return m.ListBackupsResult, m.ListBackupsError +} + +func (m *MockBackupRepository) RestoreBackup(backupPath, originalPath string) error { + m.RestoreBackupCalled = true + m.BackupPath = backupPath + m.OriginalPath = originalPath + return m.RestoreBackupError +} + +func (m *MockBackupRepository) CleanupOldBackups(before time.Time) error { + m.CleanupCalled = true + m.CleanupBeforeTime = before + return m.CleanupError +} + +func (m *MockBackupRepository) VerifyBackupDirectory() error { + m.VerifyCalled = true + return m.VerifyError +} + +func (m *MockBackupRepository) GetBackupConfig() (*model.BackupConfig, error) { + m.GetConfigCalled = true + return m.BackupConfig, m.GetConfigError +} + +func (m *MockBackupRepository) SetBackupConfig(config model.BackupConfig) error { + m.SetConfigCalled = true + m.SetConfigValue = config + return m.SetConfigError +} + +func TestNewBackupServiceImpl(t *testing.T) { + repo := &MockBackupRepository{} + + service := NewBackupServiceImpl(repo) + + if service == nil { + t.Fatal("Expected non-nil service") + } + + if service.repository != repo { + t.Error("Repository not properly set") + } +} + +func TestBackupServiceImpl_BackupFile(t *testing.T) { + tests := []struct { + name string + filePath string + mockError error + wantErr bool + }{ + { + name: "successful backup", + filePath: "/etc/hosts", + mockError: nil, + wantErr: false, + }, + { + name: "backup error", + filePath: "/etc/passwd", + mockError: errors.New("permission denied"), + wantErr: true, + }, + { + name: "empty path", + filePath: "", + mockError: errors.New("empty path"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup + repo := &MockBackupRepository{ + BackupFileError: tt.mockError, + } + + service := NewBackupServiceImpl(repo) + + // Execute + err := service.BackupFile(tt.filePath) + + // Verify + if (err != nil) != tt.wantErr { + t.Errorf("BackupServiceImpl.BackupFile() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !repo.BackupFileCalled { + t.Error("Expected BackupFile to be called") + } + + if repo.BackupFilePath != tt.filePath { + t.Errorf("Wrong file path passed to repository. Got %v, want %v", + repo.BackupFilePath, tt.filePath) + } + }) + } +} + +func TestBackupServiceImpl_ListBackups(t *testing.T) { + mockBackups := []model.BackupFile{ + { + OriginalPath: "/etc/hosts", + BackupPath: "/backup/2023-01-01/hosts.123456.bak", + Created: time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC), + Size: 1024, + }, + } + + tests := []struct { + name string + filePath string + mockBackups []model.BackupFile + mockError error + wantErr bool + wantBackups []model.BackupFile + }{ + { + name: "successful list", + filePath: "/etc/hosts", + mockBackups: mockBackups, + mockError: nil, + wantErr: false, + wantBackups: mockBackups, + }, + { + name: "list error", + filePath: "/etc/nonexistent", + mockBackups: nil, + mockError: errors.New("file not found"), + wantErr: true, + wantBackups: nil, + }, + { + name: "empty list", + filePath: "/etc/fstab", + mockBackups: []model.BackupFile{}, + mockError: nil, + wantErr: false, + wantBackups: []model.BackupFile{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup + repo := &MockBackupRepository{ + ListBackupsResult: tt.mockBackups, + ListBackupsError: tt.mockError, + } + + service := NewBackupServiceImpl(repo) + + // Execute + backups, err := service.ListBackups(tt.filePath) + + // Verify + if (err != nil) != tt.wantErr { + t.Errorf("BackupServiceImpl.ListBackups() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !repo.ListBackupsCalled { + t.Error("Expected ListBackups to be called") + } + + if repo.ListBackupsPath != tt.filePath { + t.Errorf("Wrong file path passed to repository. Got %v, want %v", + repo.ListBackupsPath, tt.filePath) + } + + if !reflect.DeepEqual(backups, tt.wantBackups) { + t.Errorf("BackupServiceImpl.ListBackups() = %v, want %v", backups, tt.wantBackups) + } + }) + } +} + +func TestBackupServiceImpl_RestoreBackup(t *testing.T) { + tests := []struct { + name string + backupPath string + originalPath string + mockError error + wantErr bool + }{ + { + name: "successful restore", + backupPath: "/backup/2023-01-01/hosts.123456.bak", + originalPath: "/etc/hosts", + mockError: nil, + wantErr: false, + }, + { + name: "restore error", + backupPath: "/backup/2023-01-01/hosts.123456.bak", + originalPath: "/etc/hosts", + mockError: errors.New("permission denied"), + wantErr: true, + }, + { + name: "empty backup path", + backupPath: "", + originalPath: "/etc/hosts", + mockError: errors.New("empty backup path"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup + repo := &MockBackupRepository{ + RestoreBackupError: tt.mockError, + } + + service := NewBackupServiceImpl(repo) + + // Execute + err := service.RestoreBackup(tt.backupPath, tt.originalPath) + + // Verify + if (err != nil) != tt.wantErr { + t.Errorf("BackupServiceImpl.RestoreBackup() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !repo.RestoreBackupCalled { + t.Error("Expected RestoreBackup to be called") + } + + if repo.BackupPath != tt.backupPath { + t.Errorf("Wrong backup path passed to repository. Got %v, want %v", + repo.BackupPath, tt.backupPath) + } + + if repo.OriginalPath != tt.originalPath { + t.Errorf("Wrong original path passed to repository. Got %v, want %v", + repo.OriginalPath, tt.originalPath) + } + }) + } +} + +func TestBackupServiceImpl_CleanupOldBackups(t *testing.T) { + tests := []struct { + name string + daysToKeep int + mockError error + wantErr bool + expectedDays int // roughly how many days before now the cutoff should be + }{ + { + name: "successful cleanup", + daysToKeep: 30, + mockError: nil, + wantErr: false, + expectedDays: 30, + }, + { + name: "cleanup error", + daysToKeep: 7, + mockError: errors.New("permission denied"), + wantErr: true, + expectedDays: 7, + }, + { + name: "zero days", + daysToKeep: 0, + mockError: nil, + wantErr: false, + expectedDays: 0, + }, + { + name: "negative days", + daysToKeep: -1, // should be treated as "keep everything" + mockError: nil, + wantErr: false, + expectedDays: -1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup + repo := &MockBackupRepository{ + CleanupError: tt.mockError, + } + + service := NewBackupServiceImpl(repo) + + // Get current time for comparison + now := time.Now() + + // Execute + err := service.CleanupOldBackups(tt.daysToKeep) + + // Verify + if (err != nil) != tt.wantErr { + t.Errorf("BackupServiceImpl.CleanupOldBackups() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !repo.CleanupCalled { + t.Error("Expected CleanupOldBackups to be called") + } + + // Verify the cutoff time is within reasonable bounds (allowing for test execution time) + if tt.expectedDays >= 0 { + expectedTime := now.AddDate(0, 0, -tt.expectedDays) + timeDiff := repo.CleanupBeforeTime.Sub(expectedTime) + if timeDiff < -2*time.Second || timeDiff > 2*time.Second { + t.Errorf("Wrong cutoff time. Got %v, expected close to %v (diff: %v)", + repo.CleanupBeforeTime, expectedTime, timeDiff) + } + } + }) + } +} + +func TestBackupServiceImpl_VerifyBackupDirectory(t *testing.T) { + tests := []struct { + name string + mockError error + wantErr bool + }{ + { + name: "successful verification", + mockError: nil, + wantErr: false, + }, + { + name: "verification error", + mockError: errors.New("directory not writable"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup + repo := &MockBackupRepository{ + VerifyError: tt.mockError, + } + + service := NewBackupServiceImpl(repo) + + // Execute + err := service.VerifyBackupDirectory() + + // Verify + if (err != nil) != tt.wantErr { + t.Errorf("BackupServiceImpl.VerifyBackupDirectory() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !repo.VerifyCalled { + t.Error("Expected VerifyBackupDirectory to be called") + } + }) + } +} + +func TestBackupServiceImpl_GetBackupConfig(t *testing.T) { + mockConfig := &model.BackupConfig{ + Enabled: true, + BackupDir: "/var/backups/hardn", + } + + tests := []struct { + name string + mockConfig *model.BackupConfig + mockError error + wantErr bool + wantConfig *model.BackupConfig + }{ + { + name: "successful get config", + mockConfig: mockConfig, + mockError: nil, + wantErr: false, + wantConfig: mockConfig, + }, + { + name: "get config error", + mockConfig: nil, + mockError: errors.New("configuration error"), + wantErr: true, + wantConfig: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup + repo := &MockBackupRepository{ + BackupConfig: tt.mockConfig, + GetConfigError: tt.mockError, + } + + service := NewBackupServiceImpl(repo) + + // Execute + config, err := service.GetBackupConfig() + + // Verify + if (err != nil) != tt.wantErr { + t.Errorf("BackupServiceImpl.GetBackupConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !repo.GetConfigCalled { + t.Error("Expected GetBackupConfig to be called") + } + + if !reflect.DeepEqual(config, tt.wantConfig) { + t.Errorf("BackupServiceImpl.GetBackupConfig() = %v, want %v", config, tt.wantConfig) + } + }) + } +} + +func TestBackupServiceImpl_EnableBackups(t *testing.T) { + mockConfig := &model.BackupConfig{ + Enabled: false, + BackupDir: "/var/backups/hardn", + } + + tests := []struct { + name string + enable bool + mockConfig *model.BackupConfig + getConfigError error + setConfigError error + wantErr bool + expectedEnabled bool + }{ + { + name: "enable backups", + enable: true, + mockConfig: mockConfig, + getConfigError: nil, + setConfigError: nil, + wantErr: false, + expectedEnabled: true, + }, + { + name: "disable backups", + enable: false, + mockConfig: &model.BackupConfig{Enabled: true, BackupDir: "/var/backups/hardn"}, + getConfigError: nil, + setConfigError: nil, + wantErr: false, + expectedEnabled: false, + }, + { + name: "get config error", + enable: true, + mockConfig: nil, + getConfigError: errors.New("failed to get config"), + wantErr: true, + }, + { + name: "set config error", + enable: true, + mockConfig: mockConfig, + getConfigError: nil, + setConfigError: errors.New("failed to set config"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup + repo := &MockBackupRepository{ + BackupConfig: tt.mockConfig, + GetConfigError: tt.getConfigError, + SetConfigError: tt.setConfigError, + } + + service := NewBackupServiceImpl(repo) + + // Execute + err := service.EnableBackups(tt.enable) + + // Verify + if (err != nil) != tt.wantErr { + t.Errorf("BackupServiceImpl.EnableBackups() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !repo.GetConfigCalled { + t.Error("Expected GetBackupConfig to be called") + } + + // If no error getting config, check that SetBackupConfig was called + if tt.getConfigError == nil && !tt.wantErr { + if !repo.SetConfigCalled { + t.Error("Expected SetBackupConfig to be called") + } + + // Check that the enabled state was properly updated + if repo.SetConfigValue.Enabled != tt.expectedEnabled { + t.Errorf("Wrong enabled value. Got %v, want %v", + repo.SetConfigValue.Enabled, tt.expectedEnabled) + } + + // Verify backup directory unchanged + if repo.SetConfigValue.BackupDir != tt.mockConfig.BackupDir { + t.Errorf("Backup directory changed unexpectedly. Got %v, want %v", + repo.SetConfigValue.BackupDir, tt.mockConfig.BackupDir) + } + } + }) + } +} + +func TestBackupServiceImpl_SetBackupDirectory(t *testing.T) { + mockConfig := &model.BackupConfig{ + Enabled: true, + BackupDir: "/var/backups/hardn", + } + + tests := []struct { + name string + directory string + mockConfig *model.BackupConfig + getConfigError error + setConfigError error + wantErr bool + expectedDir string + }{ + { + name: "set new directory", + directory: "/new/backup/path", + mockConfig: mockConfig, + getConfigError: nil, + setConfigError: nil, + wantErr: false, + expectedDir: "/new/backup/path", + }, + { + name: "set empty directory", + directory: "", + mockConfig: mockConfig, + getConfigError: nil, + setConfigError: nil, + wantErr: false, + expectedDir: "", + }, + { + name: "get config error", + directory: "/new/backup/path", + mockConfig: nil, + getConfigError: errors.New("failed to get config"), + wantErr: true, + }, + { + name: "set config error", + directory: "/new/backup/path", + mockConfig: mockConfig, + getConfigError: nil, + setConfigError: errors.New("failed to set config"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup + repo := &MockBackupRepository{ + BackupConfig: tt.mockConfig, + GetConfigError: tt.getConfigError, + SetConfigError: tt.setConfigError, + } + + service := NewBackupServiceImpl(repo) + + // Execute + err := service.SetBackupDirectory(tt.directory) + + // Verify + if (err != nil) != tt.wantErr { + t.Errorf("BackupServiceImpl.SetBackupDirectory() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !repo.GetConfigCalled { + t.Error("Expected GetBackupConfig to be called") + } + + // If no error getting config, check that SetBackupConfig was called + if tt.getConfigError == nil && !tt.wantErr { + if !repo.SetConfigCalled { + t.Error("Expected SetBackupConfig to be called") + } + + // Check that the directory was properly updated + if repo.SetConfigValue.BackupDir != tt.expectedDir { + t.Errorf("Wrong directory. Got %v, want %v", + repo.SetConfigValue.BackupDir, tt.expectedDir) + } + + // Verify enabled state unchanged + if repo.SetConfigValue.Enabled != tt.mockConfig.Enabled { + t.Errorf("Enabled state changed unexpectedly. Got %v, want %v", + repo.SetConfigValue.Enabled, tt.mockConfig.Enabled) + } + } + }) + } +} diff --git a/pkg/domain/service/dns_service.go b/pkg/domain/service/dns_service.go new file mode 100644 index 0000000..ad8d720 --- /dev/null +++ b/pkg/domain/service/dns_service.go @@ -0,0 +1,42 @@ +// pkg/domain/service/dns_service.go +package service + +import "github.com/abbott/hardn/pkg/domain/model" + +// DNSService defines operations for DNS configuration +type DNSService interface { + // ConfigureDNS applies DNS configuration settings + ConfigureDNS(config model.DNSConfig) error + + // GetCurrentConfig retrieves the current DNS configuration + GetCurrentConfig() (*model.DNSConfig, error) +} + +// DNSServiceImpl implements DNSService +type DNSServiceImpl struct { + repository DNSRepository + osInfo model.OSInfo +} + +// NewDNSServiceImpl creates a new DNSServiceImpl +func NewDNSServiceImpl(repository DNSRepository, osInfo model.OSInfo) *DNSServiceImpl { + return &DNSServiceImpl{ + repository: repository, + osInfo: osInfo, + } +} + +// DNSRepository defines the repository operations needed by DNSService +type DNSRepository interface { + SaveDNSConfig(config model.DNSConfig) error + GetDNSConfig() (*model.DNSConfig, error) +} + +// Implementation of DNSService methods +func (s *DNSServiceImpl) ConfigureDNS(config model.DNSConfig) error { + return s.repository.SaveDNSConfig(config) +} + +func (s *DNSServiceImpl) GetCurrentConfig() (*model.DNSConfig, error) { + return s.repository.GetDNSConfig() +} diff --git a/pkg/domain/service/dns_service_test.go b/pkg/domain/service/dns_service_test.go new file mode 100644 index 0000000..6d74701 --- /dev/null +++ b/pkg/domain/service/dns_service_test.go @@ -0,0 +1,239 @@ +package service + +import ( + "errors" + "reflect" + "testing" + + "github.com/abbott/hardn/pkg/domain/model" +) + +// MockDNSRepository implements DNSRepository interface for testing +type MockDNSRepository struct { + SavedConfig model.DNSConfig + SaveError error + ReturnedConfig *model.DNSConfig + GetConfigError error + SaveCallCount int + GetConfigCalled bool +} + +func (m *MockDNSRepository) SaveDNSConfig(config model.DNSConfig) error { + m.SavedConfig = config + m.SaveCallCount++ + return m.SaveError +} + +func (m *MockDNSRepository) GetDNSConfig() (*model.DNSConfig, error) { + m.GetConfigCalled = true + return m.ReturnedConfig, m.GetConfigError +} + +func TestNewDNSServiceImpl(t *testing.T) { + repo := &MockDNSRepository{} + osInfo := model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye"} + + service := NewDNSServiceImpl(repo, osInfo) + + if service == nil { + t.Fatal("Expected non-nil service") + } + + if service.repository != repo { + t.Error("Repository not properly set") + } + + if !reflect.DeepEqual(service.osInfo, osInfo) { + t.Error("OSInfo not properly set") + } +} + +func TestDNSServiceImpl_ConfigureDNS(t *testing.T) { + tests := []struct { + name string + config model.DNSConfig + mockSaveError error + expectError bool + expectSaveCall bool + }{ + { + name: "successful configuration", + config: model.DNSConfig{ + Nameservers: []string{"1.1.1.1", "1.0.0.1"}, + Domain: "example.com", + Search: []string{"example.com", "test.com"}, + }, + mockSaveError: nil, + expectError: false, + expectSaveCall: true, + }, + { + name: "empty nameservers", + config: model.DNSConfig{ + Nameservers: []string{}, + Domain: "example.com", + }, + mockSaveError: nil, + expectError: false, + expectSaveCall: true, + }, + { + name: "repository error", + config: model.DNSConfig{ + Nameservers: []string{"8.8.8.8"}, + }, + mockSaveError: errors.New("mock save error"), + expectError: true, + expectSaveCall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockDNSRepository{ + SaveError: tc.mockSaveError, + } + osInfo := model.OSInfo{Type: "debian", Version: "11"} + service := NewDNSServiceImpl(repo, osInfo) + + // Execute + err := service.ConfigureDNS(tc.config) + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if tc.expectSaveCall { + if repo.SaveCallCount != 1 { + t.Errorf("Expected SaveDNSConfig to be called once, got %d", repo.SaveCallCount) + } + + if !reflect.DeepEqual(repo.SavedConfig, tc.config) { + t.Errorf("Wrong config saved. Got %+v, expected %+v", repo.SavedConfig, tc.config) + } + } else if repo.SaveCallCount > 0 { + t.Error("SaveDNSConfig should not have been called") + } + }) + } +} + +func TestDNSServiceImpl_GetCurrentConfig(t *testing.T) { + tests := []struct { + name string + mockConfig *model.DNSConfig + mockError error + expectError bool + expectedConfig *model.DNSConfig + }{ + { + name: "successful retrieval", + mockConfig: &model.DNSConfig{ + Nameservers: []string{"1.1.1.1", "1.0.0.1"}, + Domain: "example.com", + Search: []string{"example.com"}, + }, + mockError: nil, + expectError: false, + expectedConfig: &model.DNSConfig{ + Nameservers: []string{"1.1.1.1", "1.0.0.1"}, + Domain: "example.com", + Search: []string{"example.com"}, + }, + }, + { + name: "repository error", + mockConfig: nil, + mockError: errors.New("mock retrieval error"), + expectError: true, + expectedConfig: nil, + }, + { + name: "empty nameservers", + mockConfig: &model.DNSConfig{ + Nameservers: []string{}, + Domain: "example.com", + }, + mockError: nil, + expectError: false, + expectedConfig: &model.DNSConfig{ + Nameservers: []string{}, + Domain: "example.com", + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockDNSRepository{ + ReturnedConfig: tc.mockConfig, + GetConfigError: tc.mockError, + } + osInfo := model.OSInfo{Type: "alpine", Version: "3.16"} + service := NewDNSServiceImpl(repo, osInfo) + + // Execute + config, err := service.GetCurrentConfig() + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if !repo.GetConfigCalled { + t.Error("Expected GetDNSConfig to be called") + } + + if tc.expectedConfig != nil { + if config == nil { + t.Fatal("Expected non-nil config but got nil") + } + if !reflect.DeepEqual(config, tc.expectedConfig) { + t.Errorf("Wrong config returned. Got %+v, expected %+v", config, tc.expectedConfig) + } + } else if config != nil { + t.Error("Expected nil config but got non-nil") + } + }) + } +} + +func TestDNSServiceImpl_OSTypes(t *testing.T) { + // Test with different OS types to ensure the service works consistently + osTypes := []string{"debian", "ubuntu", "alpine", "unknown"} + + for _, osType := range osTypes { + t.Run(osType+" OS type", func(t *testing.T) { + // Setup + repo := &MockDNSRepository{} + osInfo := model.OSInfo{Type: osType, Version: "1.0"} + service := NewDNSServiceImpl(repo, osInfo) + + // Test a simple configuration + config := model.DNSConfig{ + Nameservers: []string{"8.8.8.8"}, + } + + // Execute + err := service.ConfigureDNS(config) + + // Verify + if err != nil { + t.Errorf("Failed to configure DNS on %s: %v", osType, err) + } + + if !reflect.DeepEqual(repo.SavedConfig, config) { + t.Errorf("Wrong config saved for %s. Got %+v, expected %+v", osType, repo.SavedConfig, config) + } + }) + } +} diff --git a/pkg/domain/service/environment_service.go b/pkg/domain/service/environment_service.go new file mode 100644 index 0000000..eca93a2 --- /dev/null +++ b/pkg/domain/service/environment_service.go @@ -0,0 +1,70 @@ +// pkg/domain/service/environment_service.go +package service + +import "github.com/abbott/hardn/pkg/domain/model" + +// EnvironmentService defines operations for environment variable management +type EnvironmentService interface { + // SetupSudoPreservation configures sudo to preserve the HARDN_CONFIG environment variable + SetupSudoPreservation() error + + // IsSudoPreservationEnabled checks if the HARDN_CONFIG environment variable is preserved in sudo + IsSudoPreservationEnabled() (bool, error) + + // GetEnvironmentConfig retrieves the current environment configuration + GetEnvironmentConfig() (*model.EnvironmentConfig, error) +} + +// EnvironmentServiceImpl implements EnvironmentService +type EnvironmentServiceImpl struct { + repository EnvironmentRepository +} + +// NewEnvironmentServiceImpl creates a new EnvironmentServiceImpl +func NewEnvironmentServiceImpl(repository EnvironmentRepository) *EnvironmentServiceImpl { + return &EnvironmentServiceImpl{ + repository: repository, + } +} + +// EnvironmentRepository defines the repository operations needed by EnvironmentService +type EnvironmentRepository interface { + SetupSudoPreservation(username string) error + IsSudoPreservationEnabled(username string) (bool, error) + GetEnvironmentConfig() (*model.EnvironmentConfig, error) +} + +// SetupSudoPreservation configures sudo to preserve the HARDN_CONFIG environment variable +func (s *EnvironmentServiceImpl) SetupSudoPreservation() error { + // Get current config to obtain username + config, err := s.repository.GetEnvironmentConfig() + if err != nil { + return err + } + + if config.Username == "" { + return nil // No username, nothing to do + } + + return s.repository.SetupSudoPreservation(config.Username) +} + +// IsSudoPreservationEnabled checks if the HARDN_CONFIG environment variable is preserved in sudo +func (s *EnvironmentServiceImpl) IsSudoPreservationEnabled() (bool, error) { + // Get current config to obtain username + config, err := s.repository.GetEnvironmentConfig() + if err != nil { + return false, err + } + + if config.Username == "" { + return false, nil // No username, no preservation + } + + return s.repository.IsSudoPreservationEnabled(config.Username) +} + +// GetEnvironmentConfig retrieves the current environment configuration +func (s *EnvironmentServiceImpl) GetEnvironmentConfig() (*model.EnvironmentConfig, error) { + return s.repository.GetEnvironmentConfig() +} diff --git a/pkg/domain/service/environment_service_test.go b/pkg/domain/service/environment_service_test.go new file mode 100644 index 0000000..0c23a90 --- /dev/null +++ b/pkg/domain/service/environment_service_test.go @@ -0,0 +1,337 @@ +package service + +import ( + "errors" + "reflect" + "testing" + + "github.com/abbott/hardn/pkg/domain/model" +) + +// MockEnvironmentRepository implements EnvironmentRepository interface for testing +type MockEnvironmentRepository struct { + // SetupSudoPreservation tracking + PreservedUsername string + SetupError error + SetupCallCount int + + // IsSudoPreservationEnabled tracking + CheckedUsername string + PreservationEnabled bool + CheckError error + CheckCallCount int + + // GetEnvironmentConfig tracking + ReturnedConfig *model.EnvironmentConfig + GetConfigError error + GetConfigCallCount int +} + +func (m *MockEnvironmentRepository) SetupSudoPreservation(username string) error { + m.PreservedUsername = username + m.SetupCallCount++ + return m.SetupError +} + +func (m *MockEnvironmentRepository) IsSudoPreservationEnabled(username string) (bool, error) { + m.CheckedUsername = username + m.CheckCallCount++ + return m.PreservationEnabled, m.CheckError +} + +func (m *MockEnvironmentRepository) GetEnvironmentConfig() (*model.EnvironmentConfig, error) { + m.GetConfigCallCount++ + return m.ReturnedConfig, m.GetConfigError +} + +func TestNewEnvironmentServiceImpl(t *testing.T) { + repo := &MockEnvironmentRepository{} + + service := NewEnvironmentServiceImpl(repo) + + if service == nil { + t.Fatal("Expected non-nil service") + } + + if service.repository != repo { + t.Error("Repository not properly set") + } +} + +func TestEnvironmentServiceImpl_SetupSudoPreservation(t *testing.T) { + tests := []struct { + name string + configUsername string + getConfigError error + setupError error + expectError bool + expectSetupCalled bool + }{ + { + name: "successful setup", + configUsername: "testuser", + getConfigError: nil, + setupError: nil, + expectError: false, + expectSetupCalled: true, + }, + { + name: "empty username", + configUsername: "", + getConfigError: nil, + setupError: nil, + expectError: false, + expectSetupCalled: false, // Should not call setup if username is empty + }, + { + name: "get config error", + configUsername: "testuser", + getConfigError: errors.New("mock get config error"), + setupError: nil, + expectError: true, + expectSetupCalled: false, + }, + { + name: "setup error", + configUsername: "testuser", + getConfigError: nil, + setupError: errors.New("mock setup error"), + expectError: true, + expectSetupCalled: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockEnvironmentRepository{ + ReturnedConfig: &model.EnvironmentConfig{ + Username: tc.configUsername, + }, + GetConfigError: tc.getConfigError, + SetupError: tc.setupError, + } + + service := NewEnvironmentServiceImpl(repo) + + // Execute + err := service.SetupSudoPreservation() + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if tc.expectSetupCalled { + if repo.SetupCallCount != 1 { + t.Errorf("Expected SetupSudoPreservation to be called once, got %d", repo.SetupCallCount) + } + + if repo.PreservedUsername != tc.configUsername { + t.Errorf("Wrong username passed. Got %s, expected %s", repo.PreservedUsername, tc.configUsername) + } + } else { + if repo.SetupCallCount > 0 { + t.Errorf("Expected SetupSudoPreservation not to be called, but was called %d times", repo.SetupCallCount) + } + } + + // GetEnvironmentConfig should always be called + if repo.GetConfigCallCount != 1 { + t.Errorf("Expected GetEnvironmentConfig to be called once, got %d", repo.GetConfigCallCount) + } + }) + } +} + +func TestEnvironmentServiceImpl_IsSudoPreservationEnabled(t *testing.T) { + tests := []struct { + name string + configUsername string + getConfigError error + preservationEnabled bool + checkError error + expectError bool + expectEnabled bool + expectCheckCalled bool + }{ + { + name: "preservation enabled", + configUsername: "testuser", + getConfigError: nil, + preservationEnabled: true, + checkError: nil, + expectError: false, + expectEnabled: true, + expectCheckCalled: true, + }, + { + name: "preservation disabled", + configUsername: "testuser", + getConfigError: nil, + preservationEnabled: false, + checkError: nil, + expectError: false, + expectEnabled: false, + expectCheckCalled: true, + }, + { + name: "empty username", + configUsername: "", + getConfigError: nil, + preservationEnabled: false, + checkError: nil, + expectError: false, + expectEnabled: false, + expectCheckCalled: false, // Should not call check if username is empty + }, + { + name: "get config error", + configUsername: "testuser", + getConfigError: errors.New("mock get config error"), + preservationEnabled: false, + checkError: nil, + expectError: true, + expectEnabled: false, + expectCheckCalled: false, + }, + { + name: "check error", + configUsername: "testuser", + getConfigError: nil, + preservationEnabled: false, + checkError: errors.New("mock check error"), + expectError: true, + expectEnabled: false, + expectCheckCalled: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockEnvironmentRepository{ + ReturnedConfig: &model.EnvironmentConfig{ + Username: tc.configUsername, + }, + GetConfigError: tc.getConfigError, + PreservationEnabled: tc.preservationEnabled, + CheckError: tc.checkError, + } + + service := NewEnvironmentServiceImpl(repo) + + // Execute + enabled, err := service.IsSudoPreservationEnabled() + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if enabled != tc.expectEnabled { + t.Errorf("Wrong enabled status. Got %v, expected %v", enabled, tc.expectEnabled) + } + + if tc.expectCheckCalled { + if repo.CheckCallCount != 1 { + t.Errorf("Expected IsSudoPreservationEnabled to be called once, got %d", repo.CheckCallCount) + } + + if repo.CheckedUsername != tc.configUsername { + t.Errorf("Wrong username checked. Got %s, expected %s", repo.CheckedUsername, tc.configUsername) + } + } else { + if repo.CheckCallCount > 0 { + t.Errorf("Expected IsSudoPreservationEnabled not to be called, but was called %d times", repo.CheckCallCount) + } + } + + // GetEnvironmentConfig should always be called + if repo.GetConfigCallCount != 1 { + t.Errorf("Expected GetEnvironmentConfig to be called once, got %d", repo.GetConfigCallCount) + } + }) + } +} + +func TestEnvironmentServiceImpl_GetEnvironmentConfig(t *testing.T) { + tests := []struct { + name string + returnedConfig *model.EnvironmentConfig + getConfigError error + expectError bool + }{ + { + name: "successful config retrieval", + returnedConfig: &model.EnvironmentConfig{ + ConfigPath: "/path/to/config", + PreserveSudo: true, + Username: "testuser", + }, + getConfigError: nil, + expectError: false, + }, + { + name: "get config error", + returnedConfig: nil, + getConfigError: errors.New("mock get config error"), + expectError: true, + }, + { + name: "empty config", + returnedConfig: &model.EnvironmentConfig{ + ConfigPath: "", + PreserveSudo: false, + Username: "", + }, + getConfigError: nil, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockEnvironmentRepository{ + ReturnedConfig: tc.returnedConfig, + GetConfigError: tc.getConfigError, + } + + service := NewEnvironmentServiceImpl(repo) + + // Execute + config, err := service.GetEnvironmentConfig() + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if tc.returnedConfig != nil { + if config == nil { + t.Fatal("Expected non-nil config but got nil") + } + if !reflect.DeepEqual(config, tc.returnedConfig) { + t.Errorf("Wrong config returned. Got %+v, expected %+v", config, tc.returnedConfig) + } + } else if config != nil { + t.Error("Expected nil config but got non-nil") + } + + // GetEnvironmentConfig should be called + if repo.GetConfigCallCount != 1 { + t.Errorf("Expected GetEnvironmentConfig to be called once, got %d", repo.GetConfigCallCount) + } + }) + } +} diff --git a/pkg/domain/service/firewall_service.go b/pkg/domain/service/firewall_service.go new file mode 100644 index 0000000..d0226be --- /dev/null +++ b/pkg/domain/service/firewall_service.go @@ -0,0 +1,92 @@ +// pkg/domain/service/firewall_service.go +package service + +import "github.com/abbott/hardn/pkg/domain/model" + +// FirewallService defines operations for firewall configuration +type FirewallService interface { + + // GetFirewallStatus retrieves the current status of the firewall + GetFirewallStatus() (isInstalled bool, isEnabled bool, isConfigured bool, rules []string, err error) + + // ConfigureFirewall applies the firewall configuration + ConfigureFirewall(config model.FirewallConfig) error + + // AddRule adds a firewall rule + AddRule(rule model.FirewallRule) error + + // RemoveRule removes a firewall rule + RemoveRule(rule model.FirewallRule) error + + // AddProfile adds a firewall application profile + AddProfile(profile model.FirewallProfile) error + + // GetCurrentConfig retrieves the current firewall configuration + GetCurrentConfig() (*model.FirewallConfig, error) + + // EnableFirewall enables the firewall + EnableFirewall() error + + // DisableFirewall disables the firewall + DisableFirewall() error +} + +// FirewallServiceImpl implements FirewallService +type FirewallServiceImpl struct { + repository FirewallRepository + osInfo model.OSInfo +} + +// NewFirewallServiceImpl creates a new FirewallServiceImpl +func NewFirewallServiceImpl(repository FirewallRepository, osInfo model.OSInfo) *FirewallServiceImpl { + return &FirewallServiceImpl{ + repository: repository, + osInfo: osInfo, + } +} + +// FirewallRepository defines the repository operations needed by FirewallService +type FirewallRepository interface { + GetFirewallStatus() (bool, bool, bool, []string, error) + SaveFirewallConfig(config model.FirewallConfig) error + GetFirewallConfig() (*model.FirewallConfig, error) + AddRule(rule model.FirewallRule) error + RemoveRule(rule model.FirewallRule) error + AddProfile(profile model.FirewallProfile) error + EnableFirewall() error + DisableFirewall() error +} + +// GetFirewallStatus retrieves the current status of the firewall +func (s *FirewallServiceImpl) GetFirewallStatus() (bool, bool, bool, []string, error) { + return s.repository.GetFirewallStatus() +} + +// Implementation of FirewallService methods +func (s *FirewallServiceImpl) ConfigureFirewall(config model.FirewallConfig) error { + return s.repository.SaveFirewallConfig(config) +} + +func (s *FirewallServiceImpl) AddRule(rule model.FirewallRule) error { + return s.repository.AddRule(rule) +} + +func (s *FirewallServiceImpl) RemoveRule(rule model.FirewallRule) error { + return s.repository.RemoveRule(rule) +} + +func (s *FirewallServiceImpl) AddProfile(profile model.FirewallProfile) error { + return s.repository.AddProfile(profile) +} + +func (s *FirewallServiceImpl) GetCurrentConfig() (*model.FirewallConfig, error) { + return s.repository.GetFirewallConfig() +} + +func (s *FirewallServiceImpl) EnableFirewall() error { + return s.repository.EnableFirewall() +} + +func (s *FirewallServiceImpl) DisableFirewall() error { + return s.repository.DisableFirewall() +} diff --git a/pkg/domain/service/firewall_service_test.go b/pkg/domain/service/firewall_service_test.go new file mode 100644 index 0000000..444b3c9 --- /dev/null +++ b/pkg/domain/service/firewall_service_test.go @@ -0,0 +1,691 @@ +package service + +import ( + "errors" + "reflect" + "testing" + + "github.com/abbott/hardn/pkg/domain/model" +) + +// MockFirewallRepository implements FirewallRepository interface for testing +type MockFirewallRepository struct { + // Status information + Installed bool + Enabled bool + Configured bool + Rules []string + StatusError error + StatusCallCount int + + // Configuration information + SavedConfig model.FirewallConfig + SaveConfigError error + SaveConfigCallCount int + + ReturnedConfig *model.FirewallConfig + GetConfigError error + GetConfigCallCount int + + // Rule management + AddedRule model.FirewallRule + AddRuleError error + AddRuleCallCount int + + RemovedRule model.FirewallRule + RemoveRuleError error + RemoveRuleCallCount int + + // Profile management + AddedProfile model.FirewallProfile + AddProfileError error + AddProfileCallCount int + + // Firewall state + EnableError error + EnableCallCount int + + DisableError error + DisableCallCount int +} + +func (m *MockFirewallRepository) GetFirewallStatus() (bool, bool, bool, []string, error) { + m.StatusCallCount++ + return m.Installed, m.Enabled, m.Configured, m.Rules, m.StatusError +} + +func (m *MockFirewallRepository) SaveFirewallConfig(config model.FirewallConfig) error { + m.SavedConfig = config + m.SaveConfigCallCount++ + return m.SaveConfigError +} + +func (m *MockFirewallRepository) GetFirewallConfig() (*model.FirewallConfig, error) { + m.GetConfigCallCount++ + return m.ReturnedConfig, m.GetConfigError +} + +func (m *MockFirewallRepository) AddRule(rule model.FirewallRule) error { + m.AddedRule = rule + m.AddRuleCallCount++ + return m.AddRuleError +} + +func (m *MockFirewallRepository) RemoveRule(rule model.FirewallRule) error { + m.RemovedRule = rule + m.RemoveRuleCallCount++ + return m.RemoveRuleError +} + +func (m *MockFirewallRepository) AddProfile(profile model.FirewallProfile) error { + m.AddedProfile = profile + m.AddProfileCallCount++ + return m.AddProfileError +} + +func (m *MockFirewallRepository) EnableFirewall() error { + m.EnableCallCount++ + return m.EnableError +} + +func (m *MockFirewallRepository) DisableFirewall() error { + m.DisableCallCount++ + return m.DisableError +} + +func TestNewFirewallServiceImpl(t *testing.T) { + repo := &MockFirewallRepository{} + osInfo := model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye"} + + service := NewFirewallServiceImpl(repo, osInfo) + + if service == nil { + t.Fatal("Expected non-nil service") + } + + if service.repository != repo { + t.Error("Repository not properly set") + } + + if !reflect.DeepEqual(service.osInfo, osInfo) { + t.Error("OSInfo not properly set") + } +} + +func TestFirewallServiceImpl_GetFirewallStatus(t *testing.T) { + tests := []struct { + name string + installed bool + enabled bool + configured bool + rules []string + statusError error + expectError bool + }{ + { + name: "successful status retrieval", + installed: true, + enabled: true, + configured: true, + rules: []string{"22/tcp ALLOW", "80/tcp DENY"}, + statusError: nil, + expectError: false, + }, + { + name: "firewall not installed", + installed: false, + enabled: false, + configured: false, + rules: []string{}, + statusError: nil, + expectError: false, + }, + { + name: "repository error", + installed: false, + enabled: false, + configured: false, + rules: []string{}, + statusError: errors.New("mock status error"), + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockFirewallRepository{ + Installed: tc.installed, + Enabled: tc.enabled, + Configured: tc.configured, + Rules: tc.rules, + StatusError: tc.statusError, + } + + osInfo := model.OSInfo{Type: "debian", Version: "11"} + service := NewFirewallServiceImpl(repo, osInfo) + + // Execute + installed, enabled, configured, rules, err := service.GetFirewallStatus() + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if repo.StatusCallCount != 1 { + t.Errorf("Expected GetFirewallStatus to be called once, got %d", repo.StatusCallCount) + } + + if installed != tc.installed { + t.Errorf("Incorrect installed status. Got %v, expected %v", installed, tc.installed) + } + + if enabled != tc.enabled { + t.Errorf("Incorrect enabled status. Got %v, expected %v", enabled, tc.enabled) + } + + if configured != tc.configured { + t.Errorf("Incorrect configured status. Got %v, expected %v", configured, tc.configured) + } + + if !reflect.DeepEqual(rules, tc.rules) { + t.Errorf("Incorrect rules. Got %v, expected %v", rules, tc.rules) + } + }) + } +} + +func TestFirewallServiceImpl_ConfigureFirewall(t *testing.T) { + tests := []struct { + name string + config model.FirewallConfig + saveConfigError error + expectError bool + }{ + { + name: "successful configuration", + config: model.FirewallConfig{ + Enabled: true, + DefaultIncoming: "deny", + DefaultOutgoing: "allow", + Rules: []model.FirewallRule{ + {Action: "allow", Protocol: "tcp", Port: 22}, + }, + ApplicationProfiles: []model.FirewallProfile{ + {Name: "OpenSSH", Title: "Secure Shell", Ports: []string{"22/tcp"}}, + }, + }, + saveConfigError: nil, + expectError: false, + }, + { + name: "empty rules", + config: model.FirewallConfig{ + Enabled: true, + DefaultIncoming: "deny", + DefaultOutgoing: "allow", + Rules: []model.FirewallRule{}, + ApplicationProfiles: []model.FirewallProfile{}, + }, + saveConfigError: nil, + expectError: false, + }, + { + name: "repository error", + config: model.FirewallConfig{ + Enabled: true, + DefaultIncoming: "deny", + DefaultOutgoing: "allow", + }, + saveConfigError: errors.New("mock save config error"), + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockFirewallRepository{ + SaveConfigError: tc.saveConfigError, + } + + osInfo := model.OSInfo{Type: "debian", Version: "11"} + service := NewFirewallServiceImpl(repo, osInfo) + + // Execute + err := service.ConfigureFirewall(tc.config) + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if repo.SaveConfigCallCount != 1 { + t.Errorf("Expected SaveFirewallConfig to be called once, got %d", repo.SaveConfigCallCount) + } + + if !reflect.DeepEqual(repo.SavedConfig, tc.config) { + t.Errorf("Wrong config saved. Got %+v, expected %+v", repo.SavedConfig, tc.config) + } + }) + } +} + +func TestFirewallServiceImpl_AddRule(t *testing.T) { + tests := []struct { + name string + rule model.FirewallRule + addRuleError error + expectError bool + }{ + { + name: "successful rule addition", + rule: model.FirewallRule{ + Action: "allow", + Protocol: "tcp", + Port: 22, + SourceIP: "", + Description: "SSH", + }, + addRuleError: nil, + expectError: false, + }, + { + name: "rule with source IP", + rule: model.FirewallRule{ + Action: "allow", + Protocol: "tcp", + Port: 80, + SourceIP: "192.168.1.0/24", + Description: "Web from LAN", + }, + addRuleError: nil, + expectError: false, + }, + { + name: "repository error", + rule: model.FirewallRule{ + Action: "allow", + Protocol: "tcp", + Port: 443, + }, + addRuleError: errors.New("mock add rule error"), + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockFirewallRepository{ + AddRuleError: tc.addRuleError, + } + + osInfo := model.OSInfo{Type: "ubuntu", Version: "20.04"} + service := NewFirewallServiceImpl(repo, osInfo) + + // Execute + err := service.AddRule(tc.rule) + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if repo.AddRuleCallCount != 1 { + t.Errorf("Expected AddRule to be called once, got %d", repo.AddRuleCallCount) + } + + if !reflect.DeepEqual(repo.AddedRule, tc.rule) { + t.Errorf("Wrong rule added. Got %+v, expected %+v", repo.AddedRule, tc.rule) + } + }) + } +} + +func TestFirewallServiceImpl_RemoveRule(t *testing.T) { + tests := []struct { + name string + rule model.FirewallRule + removeRuleError error + expectError bool + }{ + { + name: "successful rule removal", + rule: model.FirewallRule{ + Action: "allow", + Protocol: "tcp", + Port: 22, + Description: "SSH", + }, + removeRuleError: nil, + expectError: false, + }, + { + name: "repository error", + rule: model.FirewallRule{ + Action: "allow", + Protocol: "tcp", + Port: 80, + }, + removeRuleError: errors.New("mock remove rule error"), + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockFirewallRepository{ + RemoveRuleError: tc.removeRuleError, + } + + osInfo := model.OSInfo{Type: "ubuntu", Version: "20.04"} + service := NewFirewallServiceImpl(repo, osInfo) + + // Execute + err := service.RemoveRule(tc.rule) + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if repo.RemoveRuleCallCount != 1 { + t.Errorf("Expected RemoveRule to be called once, got %d", repo.RemoveRuleCallCount) + } + + if !reflect.DeepEqual(repo.RemovedRule, tc.rule) { + t.Errorf("Wrong rule removed. Got %+v, expected %+v", repo.RemovedRule, tc.rule) + } + }) + } +} + +func TestFirewallServiceImpl_AddProfile(t *testing.T) { + tests := []struct { + name string + profile model.FirewallProfile + addProfileError error + expectError bool + }{ + { + name: "successful profile addition", + profile: model.FirewallProfile{ + Name: "OpenSSH", + Title: "Secure Shell", + Description: "SSH server", + Ports: []string{"22/tcp"}, + }, + addProfileError: nil, + expectError: false, + }, + { + name: "multiple ports", + profile: model.FirewallProfile{ + Name: "NGINX", + Title: "Web Server", + Description: "NGINX web server", + Ports: []string{"80/tcp", "443/tcp"}, + }, + addProfileError: nil, + expectError: false, + }, + { + name: "repository error", + profile: model.FirewallProfile{ + Name: "Invalid", + Ports: []string{}, + }, + addProfileError: errors.New("mock add profile error"), + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockFirewallRepository{ + AddProfileError: tc.addProfileError, + } + + osInfo := model.OSInfo{Type: "debian", Version: "11"} + service := NewFirewallServiceImpl(repo, osInfo) + + // Execute + err := service.AddProfile(tc.profile) + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if repo.AddProfileCallCount != 1 { + t.Errorf("Expected AddProfile to be called once, got %d", repo.AddProfileCallCount) + } + + if !reflect.DeepEqual(repo.AddedProfile, tc.profile) { + t.Errorf("Wrong profile added. Got %+v, expected %+v", repo.AddedProfile, tc.profile) + } + }) + } +} + +func TestFirewallServiceImpl_GetCurrentConfig(t *testing.T) { + tests := []struct { + name string + mockConfig *model.FirewallConfig + mockError error + expectError bool + expectedConfig *model.FirewallConfig + }{ + { + name: "successful retrieval", + mockConfig: &model.FirewallConfig{ + Enabled: true, + DefaultIncoming: "deny", + DefaultOutgoing: "allow", + Rules: []model.FirewallRule{ + {Action: "allow", Protocol: "tcp", Port: 22}, + }, + }, + mockError: nil, + expectError: false, + expectedConfig: &model.FirewallConfig{ + Enabled: true, + DefaultIncoming: "deny", + DefaultOutgoing: "allow", + Rules: []model.FirewallRule{ + {Action: "allow", Protocol: "tcp", Port: 22}, + }, + }, + }, + { + name: "repository error", + mockConfig: nil, + mockError: errors.New("mock get config error"), + expectError: true, + expectedConfig: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockFirewallRepository{ + ReturnedConfig: tc.mockConfig, + GetConfigError: tc.mockError, + } + + osInfo := model.OSInfo{Type: "alpine", Version: "3.16"} + service := NewFirewallServiceImpl(repo, osInfo) + + // Execute + config, err := service.GetCurrentConfig() + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if repo.GetConfigCallCount != 1 { + t.Errorf("Expected GetFirewallConfig to be called once, got %d", repo.GetConfigCallCount) + } + + if tc.expectedConfig != nil { + if config == nil { + t.Fatal("Expected non-nil config but got nil") + } + if !reflect.DeepEqual(config, tc.expectedConfig) { + t.Errorf("Wrong config returned. Got %+v, expected %+v", config, tc.expectedConfig) + } + } else if config != nil { + t.Error("Expected nil config but got non-nil") + } + }) + } +} + +func TestFirewallServiceImpl_EnableFirewall(t *testing.T) { + tests := []struct { + name string + enableError error + expectError bool + }{ + { + name: "successful enable", + enableError: nil, + expectError: false, + }, + { + name: "repository error", + enableError: errors.New("mock enable error"), + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockFirewallRepository{ + EnableError: tc.enableError, + } + + osInfo := model.OSInfo{Type: "debian", Version: "11"} + service := NewFirewallServiceImpl(repo, osInfo) + + // Execute + err := service.EnableFirewall() + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if repo.EnableCallCount != 1 { + t.Errorf("Expected EnableFirewall to be called once, got %d", repo.EnableCallCount) + } + }) + } +} + +func TestFirewallServiceImpl_DisableFirewall(t *testing.T) { + tests := []struct { + name string + disableError error + expectError bool + }{ + { + name: "successful disable", + disableError: nil, + expectError: false, + }, + { + name: "repository error", + disableError: errors.New("mock disable error"), + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockFirewallRepository{ + DisableError: tc.disableError, + } + + osInfo := model.OSInfo{Type: "debian", Version: "11"} + service := NewFirewallServiceImpl(repo, osInfo) + + // Execute + err := service.DisableFirewall() + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if repo.DisableCallCount != 1 { + t.Errorf("Expected DisableFirewall to be called once, got %d", repo.DisableCallCount) + } + }) + } +} + +func TestFirewallServiceImpl_OSTypes(t *testing.T) { + // Test with different OS types to ensure the service works consistently + osTypes := []string{"debian", "ubuntu", "alpine", "proxmox", "unknown"} + + for _, osType := range osTypes { + t.Run(osType+" OS type", func(t *testing.T) { + // Setup + repo := &MockFirewallRepository{} + osInfo := model.OSInfo{Type: osType, Version: "1.0"} + service := NewFirewallServiceImpl(repo, osInfo) + + // Test a simple rule addition + rule := model.FirewallRule{ + Action: "allow", + Protocol: "tcp", + Port: 22, + } + + // Execute + err := service.AddRule(rule) + + // Verify + if err != nil { + t.Errorf("Failed to add rule on %s: %v", osType, err) + } + + if !reflect.DeepEqual(repo.AddedRule, rule) { + t.Errorf("Wrong rule added for %s. Got %+v, expected %+v", osType, repo.AddedRule, rule) + } + }) + } +} diff --git a/pkg/domain/service/logs_service.go b/pkg/domain/service/logs_service.go new file mode 100644 index 0000000..6ab8dda --- /dev/null +++ b/pkg/domain/service/logs_service.go @@ -0,0 +1,48 @@ +// pkg/domain/service/logs_service.go +package service + +import "github.com/abbott/hardn/pkg/domain/model" + +// LogsService defines operations for log management +type LogsService interface { + // GetLogs retrieves logs from the configured log file + GetLogs() ([]model.LogEntry, error) + + // GetLogConfig retrieves the current log configuration + GetLogConfig() (*model.LogsConfig, error) + + // PrintLogs prints the logs to the console + PrintLogs() error +} + +// LogsServiceImpl implements LogsService +type LogsServiceImpl struct { + repository LogsRepository +} + +// NewLogsServiceImpl creates a new LogsServiceImpl +func NewLogsServiceImpl(repository LogsRepository) *LogsServiceImpl { + return &LogsServiceImpl{ + repository: repository, + } +} + +// LogsRepository defines the repository operations needed by LogsService +type LogsRepository interface { + GetLogs() ([]model.LogEntry, error) + GetLogConfig() (*model.LogsConfig, error) + PrintLogs() error +} + +// Implementation of LogsService methods +func (s *LogsServiceImpl) GetLogs() ([]model.LogEntry, error) { + return s.repository.GetLogs() +} + +func (s *LogsServiceImpl) GetLogConfig() (*model.LogsConfig, error) { + return s.repository.GetLogConfig() +} + +func (s *LogsServiceImpl) PrintLogs() error { + return s.repository.PrintLogs() +} diff --git a/pkg/domain/service/logs_service_test.go b/pkg/domain/service/logs_service_test.go new file mode 100644 index 0000000..d04e7c8 --- /dev/null +++ b/pkg/domain/service/logs_service_test.go @@ -0,0 +1,115 @@ +package service + +import ( + "errors" + "testing" + + "github.com/abbott/hardn/pkg/domain/model" + "github.com/stretchr/testify/assert" +) + +// mockLogsRepository implements the LogsRepository interface for testing +type mockLogsRepository struct { + logs []model.LogEntry + config *model.LogsConfig + err error + printCalled bool +} + +func (m *mockLogsRepository) GetLogs() ([]model.LogEntry, error) { + return m.logs, m.err +} + +func (m *mockLogsRepository) GetLogConfig() (*model.LogsConfig, error) { + return m.config, m.err +} + +func (m *mockLogsRepository) PrintLogs() error { + m.printCalled = true + return m.err +} + +func TestNewLogsServiceImpl(t *testing.T) { + mockRepo := &mockLogsRepository{} + service := NewLogsServiceImpl(mockRepo) + + assert.NotNil(t, service) + assert.Equal(t, mockRepo, service.repository) +} + +func TestLogsServiceImpl_GetLogs(t *testing.T) { + t.Run("success case", func(t *testing.T) { + expectedLogs := []model.LogEntry{ + {Level: "INFO", Message: "Test message 1", Time: "2023-01-01T12:00:00Z"}, + {Level: "ERROR", Message: "Test error", Time: "2023-01-01T12:01:00Z"}, + } + mockRepo := &mockLogsRepository{logs: expectedLogs, err: nil} + service := NewLogsServiceImpl(mockRepo) + + logs, err := service.GetLogs() + + assert.NoError(t, err) + assert.Equal(t, expectedLogs, logs) + }) + + t.Run("error case", func(t *testing.T) { + expectedErr := errors.New("failed to retrieve logs") + mockRepo := &mockLogsRepository{logs: nil, err: expectedErr} + service := NewLogsServiceImpl(mockRepo) + + logs, err := service.GetLogs() + + assert.Error(t, err) + assert.Equal(t, expectedErr, err) + assert.Nil(t, logs) + }) +} + +func TestLogsServiceImpl_GetLogConfig(t *testing.T) { + t.Run("success case", func(t *testing.T) { + expectedConfig := &model.LogsConfig{LogFilePath: "/var/log/app.log"} + mockRepo := &mockLogsRepository{config: expectedConfig, err: nil} + service := NewLogsServiceImpl(mockRepo) + + config, err := service.GetLogConfig() + + assert.NoError(t, err) + assert.Equal(t, expectedConfig, config) + }) + + t.Run("error case", func(t *testing.T) { + expectedErr := errors.New("failed to get log configuration") + mockRepo := &mockLogsRepository{config: nil, err: expectedErr} + service := NewLogsServiceImpl(mockRepo) + + config, err := service.GetLogConfig() + + assert.Error(t, err) + assert.Equal(t, expectedErr, err) + assert.Nil(t, config) + }) +} + +func TestLogsServiceImpl_PrintLogs(t *testing.T) { + t.Run("success case", func(t *testing.T) { + mockRepo := &mockLogsRepository{err: nil} + service := NewLogsServiceImpl(mockRepo) + + err := service.PrintLogs() + + assert.NoError(t, err) + assert.True(t, mockRepo.printCalled) + }) + + t.Run("error case", func(t *testing.T) { + expectedErr := errors.New("failed to print logs") + mockRepo := &mockLogsRepository{err: expectedErr} + service := NewLogsServiceImpl(mockRepo) + + err := service.PrintLogs() + + assert.Error(t, err) + assert.Equal(t, expectedErr, err) + assert.True(t, mockRepo.printCalled) + }) +} diff --git a/pkg/domain/service/package_service.go b/pkg/domain/service/package_service.go new file mode 100644 index 0000000..77c20c5 --- /dev/null +++ b/pkg/domain/service/package_service.go @@ -0,0 +1,73 @@ +// pkg/domain/service/package_service.go +package service + +import "github.com/abbott/hardn/pkg/domain/model" + +// PackageService defines operations for package management +type PackageService interface { + // InstallPackages installs the specified packages + InstallPackages(request model.PackageInstallRequest) error + + // UpdatePackageSources updates package repository sources + UpdatePackageSources() error + + // UpdateProxmoxSources updates Proxmox-specific package sources + UpdateProxmoxSources() error + + // IsPackageInstalled checks if a package is installed + IsPackageInstalled(packageName string) (bool, error) +} + +// PackageServiceImpl implements PackageService +type PackageServiceImpl struct { + repository PackageRepository + osInfo model.OSInfo +} + +// NewPackageServiceImpl creates a new PackageServiceImpl +func NewPackageServiceImpl(repository PackageRepository, osInfo model.OSInfo) *PackageServiceImpl { + return &PackageServiceImpl{ + repository: repository, + osInfo: osInfo, + } +} + +// PackageRepository defines the repository operations needed by PackageService +type PackageRepository interface { + InstallPackages(request model.PackageInstallRequest) error + UpdatePackageSources(sources model.PackageSources) error + UpdateProxmoxSources(sources model.PackageSources) error + IsPackageInstalled(packageName string) (bool, error) + GetPackageSources() (*model.PackageSources, error) +} + +// Implementation of PackageService methods +func (s *PackageServiceImpl) InstallPackages(request model.PackageInstallRequest) error { + // Skip calling repository for empty package requests + if len(request.Packages) == 0 && len(request.PipPackages) == 0 { + return nil + } + return s.repository.InstallPackages(request) +} + +func (s *PackageServiceImpl) UpdatePackageSources() error { + sources, err := s.repository.GetPackageSources() + if err != nil { + return err + } + + return s.repository.UpdatePackageSources(*sources) +} + +func (s *PackageServiceImpl) UpdateProxmoxSources() error { + sources, err := s.repository.GetPackageSources() + if err != nil { + return err + } + + return s.repository.UpdateProxmoxSources(*sources) +} + +func (s *PackageServiceImpl) IsPackageInstalled(packageName string) (bool, error) { + return s.repository.IsPackageInstalled(packageName) +} diff --git a/pkg/domain/service/package_service_test.go b/pkg/domain/service/package_service_test.go new file mode 100644 index 0000000..dcc12cc --- /dev/null +++ b/pkg/domain/service/package_service_test.go @@ -0,0 +1,511 @@ +package service + +import ( + "errors" + "reflect" + "testing" + + "github.com/abbott/hardn/pkg/domain/model" +) + +// MockPackageRepository implements PackageRepository interface for testing +type MockPackageRepository struct { + // Install Packages tracking + InstalledRequest model.PackageInstallRequest + InstallError error + InstallCallCount int + + // Update package sources tracking + UpdatedSources model.PackageSources + UpdateSourcesError error + UpdateSourcesCalled bool + + // Update Proxmox sources tracking + UpdatedProxmoxSources model.PackageSources + UpdateProxmoxError error + UpdateProxmoxCalled bool + + // Package installed check tracking + CheckedPackage string + PackageInstalledResult bool + PackageInstalledError error + PackageInstalledCalled bool + + // Package sources retrieval tracking + ReturnedSources *model.PackageSources + GetSourcesError error + GetSourcesCalled bool +} + +func (m *MockPackageRepository) InstallPackages(request model.PackageInstallRequest) error { + m.InstalledRequest = request + m.InstallCallCount++ + return m.InstallError +} + +func (m *MockPackageRepository) UpdatePackageSources(sources model.PackageSources) error { + m.UpdatedSources = sources + m.UpdateSourcesCalled = true + return m.UpdateSourcesError +} + +func (m *MockPackageRepository) UpdateProxmoxSources(sources model.PackageSources) error { + m.UpdatedProxmoxSources = sources + m.UpdateProxmoxCalled = true + return m.UpdateProxmoxError +} + +func (m *MockPackageRepository) IsPackageInstalled(packageName string) (bool, error) { + m.CheckedPackage = packageName + m.PackageInstalledCalled = true + return m.PackageInstalledResult, m.PackageInstalledError +} + +func (m *MockPackageRepository) GetPackageSources() (*model.PackageSources, error) { + m.GetSourcesCalled = true + return m.ReturnedSources, m.GetSourcesError +} + +func TestNewPackageServiceImpl(t *testing.T) { + repo := &MockPackageRepository{} + osInfo := model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye"} + + service := NewPackageServiceImpl(repo, osInfo) + + if service == nil { + t.Fatal("Expected non-nil service") + } + + if service.repository != repo { + t.Error("Repository not properly set") + } + + if !reflect.DeepEqual(service.osInfo, osInfo) { + t.Error("OSInfo not properly set") + } +} + +func TestPackageServiceImpl_InstallPackages(t *testing.T) { + tests := []struct { + name string + request model.PackageInstallRequest + installError error + osInfo model.OSInfo + expectError bool + }{ + { + name: "debian system packages", + request: model.PackageInstallRequest{ + Packages: []string{"ufw", "unattended-upgrades"}, + PackageType: "Core", + IsPython: false, + }, + installError: nil, + osInfo: model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye"}, + expectError: false, + }, + { + name: "alpine system packages", + request: model.PackageInstallRequest{ + Packages: []string{"ufw", "python3"}, + PackageType: "Core", + IsPython: false, + }, + installError: nil, + osInfo: model.OSInfo{Type: "alpine", Version: "3.16"}, + expectError: false, + }, + { + name: "debian python packages", + request: model.PackageInstallRequest{ + Packages: []string{"python3-pip"}, + PipPackages: []string{"requests", "paramiko"}, + PackageType: "Core", + IsPython: true, + IsSystemPython: true, + }, + installError: nil, + osInfo: model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye"}, + expectError: false, + }, + { + name: "proxmox packages", + request: model.PackageInstallRequest{ + Packages: []string{"ufw", "zfsutils-linux"}, + PackageType: "Core", + IsPython: false, + }, + installError: nil, + osInfo: model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye", IsProxmox: true}, + expectError: false, + }, + { + name: "repository error", + request: model.PackageInstallRequest{ + Packages: []string{"ufw", "fail2ban"}, + PackageType: "Core", + IsPython: false, + }, + installError: errors.New("mock installation error"), + osInfo: model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye"}, + expectError: true, + }, + { + name: "empty package request", + request: model.PackageInstallRequest{ + Packages: []string{}, + PipPackages: []string{}, + IsPython: false, + }, + installError: nil, + osInfo: model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye"}, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockPackageRepository{ + InstallError: tc.installError, + } + + service := NewPackageServiceImpl(repo, tc.osInfo) + + // Execute + err := service.InstallPackages(tc.request) + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + // Handle the empty package case differently + if len(tc.request.Packages) == 0 && len(tc.request.PipPackages) == 0 { + // For empty requests, expect repository method not to be called + if repo.InstallCallCount != 0 { + t.Errorf("Expected InstallPackages not to be called for empty request, but was called %d times", repo.InstallCallCount) + } + return // Skip further checks for empty requests + } + + // For non-empty requests, expect the method to be called once + if repo.InstallCallCount != 1 { + t.Errorf("Expected InstallPackages to be called once, got %d", repo.InstallCallCount) + } + + if !reflect.DeepEqual(repo.InstalledRequest, tc.request) { + t.Errorf("Wrong request passed to repository. Got %+v, expected %+v", repo.InstalledRequest, tc.request) + } + }) + } +} + +func TestPackageServiceImpl_UpdatePackageSources(t *testing.T) { + tests := []struct { + name string + sources *model.PackageSources + getSourcesError error + updateSourcesError error + osInfo model.OSInfo + expectError bool + }{ + { + name: "debian successful update", + sources: &model.PackageSources{ + DebianRepos: []string{ + "deb http://deb.debian.org/debian CODENAME main contrib non-free", + "deb http://security.debian.org/debian-security CODENAME-security main contrib non-free", + }, + }, + getSourcesError: nil, + updateSourcesError: nil, + osInfo: model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye"}, + expectError: false, + }, + { + name: "alpine successful update", + sources: &model.PackageSources{ + AlpineTestingRepo: true, + }, + getSourcesError: nil, + updateSourcesError: nil, + osInfo: model.OSInfo{Type: "alpine", Version: "3.16"}, + expectError: false, + }, + { + name: "get sources error", + sources: nil, + getSourcesError: errors.New("mock get sources error"), + osInfo: model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye"}, + expectError: true, + }, + { + name: "update sources error", + sources: &model.PackageSources{ + DebianRepos: []string{ + "deb http://deb.debian.org/debian CODENAME main contrib non-free", + }, + }, + getSourcesError: nil, + updateSourcesError: errors.New("mock update sources error"), + osInfo: model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye"}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockPackageRepository{ + ReturnedSources: tc.sources, + GetSourcesError: tc.getSourcesError, + UpdateSourcesError: tc.updateSourcesError, + } + + service := NewPackageServiceImpl(repo, tc.osInfo) + + // Execute + err := service.UpdatePackageSources() + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if !repo.GetSourcesCalled { + t.Error("Expected GetPackageSources to be called") + } + + // Check if the update method was called when expected + if tc.getSourcesError == nil { + if !repo.UpdateSourcesCalled { + t.Error("Expected UpdatePackageSources to be called") + } + + if tc.sources != nil && !reflect.DeepEqual(repo.UpdatedSources, *tc.sources) { + t.Errorf("Wrong sources passed to repository. Got %+v, expected %+v", repo.UpdatedSources, *tc.sources) + } + } else { + if repo.UpdateSourcesCalled { + t.Error("UpdatePackageSources should not have been called when GetPackageSources fails") + } + } + }) + } +} + +func TestPackageServiceImpl_UpdateProxmoxSources(t *testing.T) { + tests := []struct { + name string + sources *model.PackageSources + getSourcesError error + updateProxmoxError error + osInfo model.OSInfo + expectError bool + }{ + { + name: "proxmox successful update", + sources: &model.PackageSources{ + ProxmoxCephRepo: []string{ + "deb http://download.proxmox.com/debian/ceph-pacific CODENAME main", + }, + ProxmoxEnterpriseRepo: []string{ + "# deb https://enterprise.proxmox.com/debian/pve CODENAME pve-enterprise", + }, + }, + getSourcesError: nil, + updateProxmoxError: nil, + osInfo: model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye", IsProxmox: true}, + expectError: false, + }, + { + name: "get sources error", + sources: nil, + getSourcesError: errors.New("mock get sources error"), + osInfo: model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye", IsProxmox: true}, + expectError: true, + }, + { + name: "update proxmox error", + sources: &model.PackageSources{ + ProxmoxCephRepo: []string{ + "deb http://download.proxmox.com/debian/ceph-pacific CODENAME main", + }, + }, + getSourcesError: nil, + updateProxmoxError: errors.New("mock update proxmox error"), + osInfo: model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye", IsProxmox: true}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockPackageRepository{ + ReturnedSources: tc.sources, + GetSourcesError: tc.getSourcesError, + UpdateProxmoxError: tc.updateProxmoxError, + } + + service := NewPackageServiceImpl(repo, tc.osInfo) + + // Execute + err := service.UpdateProxmoxSources() + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if !repo.GetSourcesCalled { + t.Error("Expected GetPackageSources to be called") + } + + // Check if the update method was called when expected + if tc.getSourcesError == nil { + if !repo.UpdateProxmoxCalled { + t.Error("Expected UpdateProxmoxSources to be called") + } + + if tc.sources != nil && !reflect.DeepEqual(repo.UpdatedProxmoxSources, *tc.sources) { + t.Errorf("Wrong sources passed to repository. Got %+v, expected %+v", repo.UpdatedProxmoxSources, *tc.sources) + } + } else { + if repo.UpdateProxmoxCalled { + t.Error("UpdateProxmoxSources should not have been called when GetPackageSources fails") + } + } + }) + } +} + +func TestPackageServiceImpl_IsPackageInstalled(t *testing.T) { + tests := []struct { + name string + packageName string + isInstalled bool + checkError error + osInfo model.OSInfo + expectError bool + expectInstalled bool + }{ + { + name: "package is installed", + packageName: "ufw", + isInstalled: true, + checkError: nil, + osInfo: model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye"}, + expectError: false, + expectInstalled: true, + }, + { + name: "package not installed", + packageName: "nonexistent-pkg", + isInstalled: false, + checkError: nil, + osInfo: model.OSInfo{Type: "alpine", Version: "3.16"}, + expectError: false, + expectInstalled: false, + }, + { + name: "repository error", + packageName: "ufw", + isInstalled: false, + checkError: errors.New("mock check error"), + osInfo: model.OSInfo{Type: "debian", Version: "11", Codename: "bullseye"}, + expectError: true, + expectInstalled: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup + repo := &MockPackageRepository{ + PackageInstalledResult: tc.isInstalled, + PackageInstalledError: tc.checkError, + } + + service := NewPackageServiceImpl(repo, tc.osInfo) + + // Execute + installed, err := service.IsPackageInstalled(tc.packageName) + + // Verify + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if installed != tc.expectInstalled { + t.Errorf("Wrong installed status. Got %v, expected %v", installed, tc.expectInstalled) + } + + if !repo.PackageInstalledCalled { + t.Error("Expected IsPackageInstalled to be called") + } + + if repo.CheckedPackage != tc.packageName { + t.Errorf("Wrong package name passed. Got %s, expected %s", repo.CheckedPackage, tc.packageName) + } + }) + } +} + +func TestPackageServiceImpl_OSTypeHandling(t *testing.T) { + osTypes := []struct { + osType string + osVersion string + isProxmox bool + }{ + {osType: "debian", osVersion: "11", isProxmox: false}, + {osType: "ubuntu", osVersion: "20.04", isProxmox: false}, + {osType: "alpine", osVersion: "3.16", isProxmox: false}, + {osType: "debian", osVersion: "11", isProxmox: true}, // Proxmox case + } + + for _, os := range osTypes { + t.Run(os.osType+"_"+os.osVersion, func(t *testing.T) { + // Setup + repo := &MockPackageRepository{ + ReturnedSources: &model.PackageSources{ + DebianRepos: []string{"deb http://example.com CODENAME main"}, + }, + } + + osInfo := model.OSInfo{ + Type: os.osType, + Version: os.osVersion, + IsProxmox: os.isProxmox, + } + + service := NewPackageServiceImpl(repo, osInfo) + + // Execute + err := service.UpdatePackageSources() + + // Verify + if err != nil { + t.Errorf("Failed to update package sources on %s %s: %v", os.osType, os.osVersion, err) + } + + // Just a basic test to ensure service handles different OSes gracefully + if !repo.UpdateSourcesCalled { + t.Errorf("Expected UpdatePackageSources to be called for %s", os.osType) + } + }) + } +} diff --git a/pkg/domain/service/ssh_service.go b/pkg/domain/service/ssh_service.go new file mode 100644 index 0000000..3c5c675 --- /dev/null +++ b/pkg/domain/service/ssh_service.go @@ -0,0 +1,58 @@ +// pkg/domain/service/ssh_service.go +package service + +import "github.com/abbott/hardn/pkg/domain/model" + +// SSHService defines operations for SSH configuration +type SSHService interface { + // ConfigureSSH applies SSH configuration settings + ConfigureSSH(config model.SSHConfig) error + + // DisableRootAccess disables SSH access for the root user + DisableRootAccess() error + + // AddAuthorizedKey adds an SSH public key to a user's authorized_keys + AddAuthorizedKey(username string, publicKey string) error + + // GetCurrentConfig retrieves the current SSH configuration + GetCurrentConfig() (*model.SSHConfig, error) +} + +// SSHServiceImpl implements SSHService +type SSHServiceImpl struct { + repository SSHRepository + osInfo model.OSInfo +} + +// NewSSHServiceImpl creates a new SSHServiceImpl +func NewSSHServiceImpl(repository SSHRepository, osInfo model.OSInfo) *SSHServiceImpl { + return &SSHServiceImpl{ + repository: repository, + osInfo: osInfo, + } +} + +// SSHRepository defines the repository operations needed by SSHService +type SSHRepository interface { + SaveSSHConfig(config model.SSHConfig) error + GetSSHConfig() (*model.SSHConfig, error) + DisableRootAccess() error + AddAuthorizedKey(username string, publicKey string) error +} + +// Implement SSHService methods +func (s *SSHServiceImpl) ConfigureSSH(config model.SSHConfig) error { + return s.repository.SaveSSHConfig(config) +} + +func (s *SSHServiceImpl) DisableRootAccess() error { + return s.repository.DisableRootAccess() +} + +func (s *SSHServiceImpl) AddAuthorizedKey(username string, publicKey string) error { + return s.repository.AddAuthorizedKey(username, publicKey) +} + +func (s *SSHServiceImpl) GetCurrentConfig() (*model.SSHConfig, error) { + return s.repository.GetSSHConfig() +} diff --git a/pkg/domain/service/ssh_service_test.go b/pkg/domain/service/ssh_service_test.go new file mode 100644 index 0000000..e9976ef --- /dev/null +++ b/pkg/domain/service/ssh_service_test.go @@ -0,0 +1,354 @@ +package service + +import ( + "fmt" + "testing" + + "github.com/abbott/hardn/pkg/domain/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// MockSSHRepository is a mock implementation of the SSHRepository interface for testing +type MockSSHRepository struct { + mock.Mock +} + +func (m *MockSSHRepository) SaveSSHConfig(config model.SSHConfig) error { + args := m.Called(config) + return args.Error(0) +} + +func (m *MockSSHRepository) GetSSHConfig() (*model.SSHConfig, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + + // Safely perform type assertion + config, ok := args.Get(0).(*model.SSHConfig) + if !ok { + return nil, fmt.Errorf("invalid type assertion, expected *model.SSHConfig") + } + + return config, args.Error(1) +} + +func (m *MockSSHRepository) DisableRootAccess() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockSSHRepository) AddAuthorizedKey(username string, publicKey string) error { + args := m.Called(username, publicKey) + return args.Error(0) +} + +func TestSSHServiceImpl_ConfigureSSH(t *testing.T) { + // Setup + mockRepo := new(MockSSHRepository) + osInfo := model.OSInfo{ + Type: "debian", + Version: "11", + Codename: "bullseye", + } + service := NewSSHServiceImpl(mockRepo, osInfo) + + // Test data + config := model.SSHConfig{ + Port: 2222, + ListenAddresses: []string{"0.0.0.0"}, + PermitRootLogin: false, + AllowedUsers: []string{"user1", "user2"}, + KeyPaths: []string{".ssh/authorized_keys"}, + AuthMethods: []string{"publickey"}, + } + + // Setup expectations + mockRepo.On("SaveSSHConfig", config).Return(nil) + + // Execute + err := service.ConfigureSSH(config) + + // Assert + assert.NoError(t, err) + mockRepo.AssertExpectations(t) +} + +func TestSSHServiceImpl_ConfigureSSH_Error(t *testing.T) { + // Setup + mockRepo := new(MockSSHRepository) + osInfo := model.OSInfo{ + Type: "debian", + Version: "11", + Codename: "bullseye", + } + service := NewSSHServiceImpl(mockRepo, osInfo) + + // Test data + config := model.SSHConfig{ + Port: 2222, + ListenAddresses: []string{"0.0.0.0"}, + PermitRootLogin: false, + AllowedUsers: []string{"user1", "user2"}, + KeyPaths: []string{".ssh/authorized_keys"}, + AuthMethods: []string{"publickey"}, + } + + // Setup expectations with an error + expectedErr := fmt.Errorf("failed to save config") + mockRepo.On("SaveSSHConfig", config).Return(expectedErr) + + // Execute + err := service.ConfigureSSH(config) + + // Assert + assert.Error(t, err) + assert.Equal(t, expectedErr, err) + mockRepo.AssertExpectations(t) +} + +func TestSSHServiceImpl_DisableRootAccess(t *testing.T) { + // Setup + mockRepo := new(MockSSHRepository) + osInfo := model.OSInfo{ + Type: "debian", + Version: "11", + Codename: "bullseye", + } + service := NewSSHServiceImpl(mockRepo, osInfo) + + // Setup expectations + mockRepo.On("DisableRootAccess").Return(nil) + + // Execute + err := service.DisableRootAccess() + + // Assert + assert.NoError(t, err) + mockRepo.AssertExpectations(t) +} + +func TestSSHServiceImpl_DisableRootAccess_Error(t *testing.T) { + // Setup + mockRepo := new(MockSSHRepository) + osInfo := model.OSInfo{ + Type: "debian", + Version: "11", + Codename: "bullseye", + } + service := NewSSHServiceImpl(mockRepo, osInfo) + + // Setup expectations + expectedErr := fmt.Errorf("failed to disable root access") + mockRepo.On("DisableRootAccess").Return(expectedErr) + + // Execute + err := service.DisableRootAccess() + + // Assert + assert.Error(t, err) + assert.Equal(t, expectedErr, err) + mockRepo.AssertExpectations(t) +} + +func TestSSHServiceImpl_AddAuthorizedKey(t *testing.T) { + // Setup + mockRepo := new(MockSSHRepository) + osInfo := model.OSInfo{ + Type: "debian", + Version: "11", + Codename: "bullseye", + } + service := NewSSHServiceImpl(mockRepo, osInfo) + + // Test data + username := "testuser" + publicKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAI... testuser@example.com" + + // Setup expectations + mockRepo.On("AddAuthorizedKey", username, publicKey).Return(nil) + + // Execute + err := service.AddAuthorizedKey(username, publicKey) + + // Assert + assert.NoError(t, err) + mockRepo.AssertExpectations(t) +} + +func TestSSHServiceImpl_AddAuthorizedKey_Error(t *testing.T) { + // Setup + mockRepo := new(MockSSHRepository) + osInfo := model.OSInfo{ + Type: "debian", + Version: "11", + Codename: "bullseye", + } + service := NewSSHServiceImpl(mockRepo, osInfo) + + // Test data + username := "testuser" + publicKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAI... testuser@example.com" + + // Setup expectations + expectedErr := fmt.Errorf("failed to add authorized key") + mockRepo.On("AddAuthorizedKey", username, publicKey).Return(expectedErr) + + // Execute + err := service.AddAuthorizedKey(username, publicKey) + + // Assert + assert.Error(t, err) + assert.Equal(t, expectedErr, err) + mockRepo.AssertExpectations(t) +} + +func TestSSHServiceImpl_GetCurrentConfig(t *testing.T) { + // Setup + mockRepo := new(MockSSHRepository) + osInfo := model.OSInfo{ + Type: "debian", + Version: "11", + Codename: "bullseye", + } + service := NewSSHServiceImpl(mockRepo, osInfo) + + // Test data + expectedConfig := &model.SSHConfig{ + Port: 2222, + ListenAddresses: []string{"0.0.0.0"}, + PermitRootLogin: false, + AllowedUsers: []string{"user1", "user2"}, + KeyPaths: []string{".ssh/authorized_keys"}, + AuthMethods: []string{"publickey"}, + } + + // Setup expectations + mockRepo.On("GetSSHConfig").Return(expectedConfig, nil) + + // Execute + config, err := service.GetCurrentConfig() + + // Assert + assert.NoError(t, err) + assert.Equal(t, expectedConfig, config) + mockRepo.AssertExpectations(t) +} + +func TestSSHServiceImpl_GetCurrentConfig_Error(t *testing.T) { + // Setup + mockRepo := new(MockSSHRepository) + osInfo := model.OSInfo{ + Type: "debian", + Version: "11", + Codename: "bullseye", + } + service := NewSSHServiceImpl(mockRepo, osInfo) + + // Setup expectations + expectedErr := fmt.Errorf("failed to get config") + mockRepo.On("GetSSHConfig").Return(nil, expectedErr) + + // Execute + config, err := service.GetCurrentConfig() + + // Assert + assert.Error(t, err) + assert.Nil(t, config) + assert.Equal(t, expectedErr, err) + mockRepo.AssertExpectations(t) +} + +func TestSSHServiceImpl_ConfigureSSHWithDifferentOSTypes(t *testing.T) { + // Test cases for different OS types + testCases := []struct { + name string + osInfo model.OSInfo + config model.SSHConfig + }{ + { + name: "Debian", + osInfo: model.OSInfo{ + Type: "debian", + Version: "11", + Codename: "bullseye", + }, + config: model.SSHConfig{ + Port: 2222, + ListenAddresses: []string{"0.0.0.0"}, + PermitRootLogin: false, + AllowedUsers: []string{"debianuser"}, + KeyPaths: []string{".ssh/authorized_keys"}, + AuthMethods: []string{"publickey"}, + }, + }, + { + name: "Alpine", + osInfo: model.OSInfo{ + Type: "alpine", + Version: "3.15", + Codename: "3.15", + }, + config: model.SSHConfig{ + Port: 2222, + ListenAddresses: []string{"0.0.0.0"}, + PermitRootLogin: false, + AllowedUsers: []string{"alpineuser"}, + KeyPaths: []string{".ssh/authorized_keys"}, + AuthMethods: []string{"publickey"}, + }, + }, + { + name: "Ubuntu", + osInfo: model.OSInfo{ + Type: "ubuntu", + Version: "20.04", + Codename: "focal", + }, + config: model.SSHConfig{ + Port: 2222, + ListenAddresses: []string{"0.0.0.0"}, + PermitRootLogin: false, + AllowedUsers: []string{"ubuntuuser"}, + KeyPaths: []string{".ssh/authorized_keys"}, + AuthMethods: []string{"publickey"}, + }, + }, + { + name: "Proxmox", + osInfo: model.OSInfo{ + Type: "debian", + Version: "11", + Codename: "bullseye", + IsProxmox: true, + }, + config: model.SSHConfig{ + Port: 2222, + ListenAddresses: []string{"0.0.0.0"}, + PermitRootLogin: false, + AllowedUsers: []string{"proxmoxuser"}, + KeyPaths: []string{".ssh/authorized_keys"}, + AuthMethods: []string{"publickey"}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup + mockRepo := new(MockSSHRepository) + service := NewSSHServiceImpl(mockRepo, tc.osInfo) + + // Setup expectations + mockRepo.On("SaveSSHConfig", tc.config).Return(nil) + + // Execute + err := service.ConfigureSSH(tc.config) + + // Assert + assert.NoError(t, err) + mockRepo.AssertExpectations(t) + }) + } +} diff --git a/pkg/domain/service/user_service.go b/pkg/domain/service/user_service.go new file mode 100644 index 0000000..4404c40 --- /dev/null +++ b/pkg/domain/service/user_service.go @@ -0,0 +1,49 @@ +// pkg/domain/service/user_service.go +package service + +import "github.com/abbott/hardn/pkg/domain/model" + +// UserService defines operations for user management +type UserService interface { + CreateUser(user model.User) error + GetUser(username string) (*model.User, error) + AddSSHKey(username, publicKey string) error + ConfigureSudo(username string, noPassword bool) error +} + +// UserServiceImpl implements UserService +type UserServiceImpl struct { + repository UserRepository +} + +// NewUserServiceImpl creates a new UserServiceImpl +func NewUserServiceImpl(repository UserRepository) *UserServiceImpl { + return &UserServiceImpl{ + repository: repository, + } +} + +// UserRepository defines user data operations needed by UserService +type UserRepository interface { + CreateUser(user model.User) error + GetUser(username string) (*model.User, error) + AddSSHKey(username, publicKey string) error + ConfigureSudo(username string, noPassword bool) error +} + +// Implement UserService methods... +func (s *UserServiceImpl) CreateUser(user model.User) error { + return s.repository.CreateUser(user) +} + +func (s *UserServiceImpl) GetUser(username string) (*model.User, error) { + return s.repository.GetUser(username) +} + +func (s *UserServiceImpl) AddSSHKey(username, publicKey string) error { + return s.repository.AddSSHKey(username, publicKey) +} + +func (s *UserServiceImpl) ConfigureSudo(username string, noPassword bool) error { + return s.repository.ConfigureSudo(username, noPassword) +} diff --git a/pkg/domain/service/user_service_test.go b/pkg/domain/service/user_service_test.go new file mode 100644 index 0000000..5cbf089 --- /dev/null +++ b/pkg/domain/service/user_service_test.go @@ -0,0 +1,341 @@ +package service + +import ( + "fmt" + "testing" + + "github.com/abbott/hardn/pkg/domain/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// MockUserRepository is a mock implementation of the UserRepository interface for testing +type MockUserRepository struct { + mock.Mock +} + +func (m *MockUserRepository) CreateUser(user model.User) error { + args := m.Called(user) + return args.Error(0) +} + +func (m *MockUserRepository) GetUser(username string) (*model.User, error) { + args := m.Called(username) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + // Safely perform type assertion + user, ok := args.Get(0).(*model.User) + if !ok { + return nil, fmt.Errorf("invalid type assertion, expected *model.User") + } + + return user, args.Error(1) +} + +func (m *MockUserRepository) AddSSHKey(username, publicKey string) error { + args := m.Called(username, publicKey) + return args.Error(0) +} + +func (m *MockUserRepository) ConfigureSudo(username string, noPassword bool) error { + args := m.Called(username, noPassword) + return args.Error(0) +} + +func (m *MockUserRepository) UserExists(username string) (bool, error) { + args := m.Called(username) + return args.Bool(0), args.Error(1) +} + +func TestUserServiceImpl_CreateUser(t *testing.T) { + // Setup + mockRepo := new(MockUserRepository) + service := NewUserServiceImpl(mockRepo) + + // Test data + user := model.User{ + Username: "testuser", + HasSudo: true, + SshKeys: []string{"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAI... testuser@example.com"}, + SudoNoPassword: true, + } + + // Setup expectations + mockRepo.On("CreateUser", user).Return(nil) + + // Execute + err := service.CreateUser(user) + + // Assert + assert.NoError(t, err) + mockRepo.AssertExpectations(t) +} + +func TestUserServiceImpl_CreateUser_Error(t *testing.T) { + // Setup + mockRepo := new(MockUserRepository) + service := NewUserServiceImpl(mockRepo) + + // Test data + user := model.User{ + Username: "testuser", + HasSudo: true, + SshKeys: []string{"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAI... testuser@example.com"}, + SudoNoPassword: true, + } + + // Setup expectations + expectedErr := fmt.Errorf("failed to create user") + mockRepo.On("CreateUser", user).Return(expectedErr) + + // Execute + err := service.CreateUser(user) + + // Assert + assert.Error(t, err) + assert.Equal(t, expectedErr, err) + mockRepo.AssertExpectations(t) +} + +func TestUserServiceImpl_GetUser(t *testing.T) { + // Setup + mockRepo := new(MockUserRepository) + service := NewUserServiceImpl(mockRepo) + + // Test data + username := "testuser" + expectedUser := &model.User{ + Username: username, + HasSudo: true, + SshKeys: []string{"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAI... testuser@example.com"}, + SudoNoPassword: true, + } + + // Setup expectations + mockRepo.On("GetUser", username).Return(expectedUser, nil) + + // Execute + user, err := service.GetUser(username) + + // Assert + assert.NoError(t, err) + assert.Equal(t, expectedUser, user) + mockRepo.AssertExpectations(t) +} + +func TestUserServiceImpl_GetUser_Error(t *testing.T) { + // Setup + mockRepo := new(MockUserRepository) + service := NewUserServiceImpl(mockRepo) + + // Test data + username := "nonexistentuser" + + // Setup expectations + expectedErr := fmt.Errorf("user not found") + mockRepo.On("GetUser", username).Return(nil, expectedErr) + + // Execute + user, err := service.GetUser(username) + + // Assert + assert.Error(t, err) + assert.Nil(t, user) + assert.Equal(t, expectedErr, err) + mockRepo.AssertExpectations(t) +} + +func TestUserServiceImpl_AddSSHKey(t *testing.T) { + // Setup + mockRepo := new(MockUserRepository) + service := NewUserServiceImpl(mockRepo) + + // Test data + username := "testuser" + publicKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAI... testuser@example.com" + + // Setup expectations + mockRepo.On("AddSSHKey", username, publicKey).Return(nil) + + // Execute + err := service.AddSSHKey(username, publicKey) + + // Assert + assert.NoError(t, err) + mockRepo.AssertExpectations(t) +} + +func TestUserServiceImpl_AddSSHKey_Error(t *testing.T) { + // Setup + mockRepo := new(MockUserRepository) + service := NewUserServiceImpl(mockRepo) + + // Test data + username := "testuser" + publicKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAI... testuser@example.com" + + // Setup expectations + expectedErr := fmt.Errorf("failed to add SSH key") + mockRepo.On("AddSSHKey", username, publicKey).Return(expectedErr) + + // Execute + err := service.AddSSHKey(username, publicKey) + + // Assert + assert.Error(t, err) + assert.Equal(t, expectedErr, err) + mockRepo.AssertExpectations(t) +} + +func TestUserServiceImpl_ConfigureSudo(t *testing.T) { + // Setup + mockRepo := new(MockUserRepository) + service := NewUserServiceImpl(mockRepo) + + // Test data + username := "testuser" + noPassword := true + + // Setup expectations + mockRepo.On("ConfigureSudo", username, noPassword).Return(nil) + + // Execute + err := service.ConfigureSudo(username, noPassword) + + // Assert + assert.NoError(t, err) + mockRepo.AssertExpectations(t) +} + +func TestUserServiceImpl_ConfigureSudo_Error(t *testing.T) { + // Setup + mockRepo := new(MockUserRepository) + service := NewUserServiceImpl(mockRepo) + + // Test data + username := "testuser" + noPassword := true + + // Setup expectations + expectedErr := fmt.Errorf("failed to configure sudo") + mockRepo.On("ConfigureSudo", username, noPassword).Return(expectedErr) + + // Execute + err := service.ConfigureSudo(username, noPassword) + + // Assert + assert.Error(t, err) + assert.Equal(t, expectedErr, err) + mockRepo.AssertExpectations(t) +} + +func TestUserServiceImpl_WithSpecialCharacters(t *testing.T) { + // Setup + mockRepo := new(MockUserRepository) + service := NewUserServiceImpl(mockRepo) + + // Test data with special characters to check for escaping issues + userWithSpecialChars := model.User{ + Username: "user-with.special_chars", + HasSudo: true, + SshKeys: []string{"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAI... user@example.com"}, + SudoNoPassword: true, + } + + // Setup expectations + mockRepo.On("CreateUser", userWithSpecialChars).Return(nil) + + // Execute + err := service.CreateUser(userWithSpecialChars) + + // Assert + assert.NoError(t, err) + mockRepo.AssertExpectations(t) +} + +func TestUserServiceImpl_UserWithEmptyValues(t *testing.T) { + // Setup + mockRepo := new(MockUserRepository) + service := NewUserServiceImpl(mockRepo) + + // Test data with empty values + userWithEmptyValues := model.User{ + Username: "minimal-user", + HasSudo: false, + SshKeys: []string{}, + SudoNoPassword: false, + } + + // Setup expectations + mockRepo.On("CreateUser", userWithEmptyValues).Return(nil) + + // Execute + err := service.CreateUser(userWithEmptyValues) + + // Assert + assert.NoError(t, err) + mockRepo.AssertExpectations(t) +} + +func TestUserServiceImpl_UserWithMultipleSSHKeys(t *testing.T) { + // Setup + mockRepo := new(MockUserRepository) + service := NewUserServiceImpl(mockRepo) + + // Test data with multiple SSH keys + userWithMultipleKeys := model.User{ + Username: "user-with-keys", + HasSudo: true, + SshKeys: []string{ + "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAI... key1@example.com", + "ssh-rsa AAAAB3NzaC1yc2EAAAADA... key2@example.com", + "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTY... key3@example.com", + }, + SudoNoPassword: true, + } + + // Setup expectations + mockRepo.On("CreateUser", userWithMultipleKeys).Return(nil) + + // Execute + err := service.CreateUser(userWithMultipleKeys) + + // Assert + assert.NoError(t, err) + mockRepo.AssertExpectations(t) +} + +func TestUserServiceImpl_UserWithExistingUsername(t *testing.T) { + // This would be an integration test or would require more complex repository mocking + // that would check if the user exists first and then return an appropriate error + + // In a real implementation, the repository would check if the user exists before creating + // For this unit test example, we're just simulating the error returned when a duplicate user + // creation is attempted + + // Setup + mockRepo := new(MockUserRepository) + service := NewUserServiceImpl(mockRepo) + + // Test data + existingUser := model.User{ + Username: "existing-user", + HasSudo: true, + SshKeys: []string{"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAI... user@example.com"}, + SudoNoPassword: true, + } + + // Setup expectations + existsErr := fmt.Errorf("user 'existing-user' already exists") + mockRepo.On("CreateUser", existingUser).Return(existsErr) + + // Execute + err := service.CreateUser(existingUser) + + // Assert + assert.Error(t, err) + assert.Equal(t, existsErr, err) + mockRepo.AssertExpectations(t) +} diff --git a/pkg/firewall/firewall.go b/pkg/firewall/firewall.go deleted file mode 100644 index 130c2d5..0000000 --- a/pkg/firewall/firewall.go +++ /dev/null @@ -1,198 +0,0 @@ -package firewall - -import ( - "fmt" - "os" - "os/exec" - "strconv" - "strings" - - "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/logging" - "github.com/abbott/hardn/pkg/osdetect" - "github.com/abbott/hardn/pkg/utils" -) - -// ConfigureUFW sets up the Uncomplicated Firewall with the specified configuration -func ConfigureUFW(cfg *config.Config, osInfo *osdetect.OSInfo) error { - if cfg.DryRun { - logging.LogInfo("[DRY-RUN] Configure UFW firewall:") - logging.LogInfo("[DRY-RUN] - Enable UFW firewall with default policies (deny incoming, allow outgoing)") - - // Show SSH port policy - logging.LogInfo("[DRY-RUN] - Allow SSH on port %d/tcp", cfg.SshPort) - - if osInfo.OsType == "alpine" { - logging.LogInfo("[DRY-RUN] - Configure UFW to start on boot using OpenRC") - } - - // Show security recommendation if using default SSH port - if cfg.SshPort == 22 { - logging.LogInfo("[DRY-RUN] - SECURITY RECOMMENDATION: You are using the default SSH port (22)") - logging.LogInfo("[DRY-RUN] - Consider setting a non-standard SSH port (e.g., 2208) in your configuration file") - logging.LogInfo("[DRY-RUN] - This can help reduce automated SSH attacks targeting the default port") - } - - // Log application profiles - WriteUfwAppProfiles(cfg, osInfo) - return nil - } - - logging.LogInfo("Configuring UFW firewall...") - - // Show security recommendation if using the default SSH port - if cfg.SshPort == 22 { - logging.LogInfo("SECURITY RECOMMENDATION: You are using the default SSH port (22)") - logging.LogInfo("Consider setting a non-standard SSH port (e.g., 2208) in your configuration file") - logging.LogInfo("This can help reduce automated SSH attacks targeting the default port") - } - - // Install UFW if not already installed - ufwInstalled := false - if _, err := exec.LookPath("ufw"); err == nil { - ufwInstalled = true - } else { - if osInfo.OsType == "alpine" { - // Install UFW with apk - cmd := exec.Command("apk", "add", "ufw") - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to install UFW on Alpine: %w", err) - } - ufwInstalled = true - } else { - // Install UFW with apt - cmd := exec.Command("apt-get", "install", "-y", "ufw") - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to install UFW on Debian/Ubuntu: %w", err) - } - ufwInstalled = true - } - } - - if !ufwInstalled { - return fmt.Errorf("failed to install or find UFW") - } - - // Set default policies (always deny incoming, allow outgoing) - defaultInCmd := exec.Command("ufw", "default", "deny", "incoming") - output, err := defaultInCmd.CombinedOutput() // Capture both stdout and stderr - if err != nil { - logging.LogError("Failed to set default incoming policy: %v, output: %s", err, string(output)) - return fmt.Errorf("failed to set default incoming policy in UFW: %w", err) - } - logging.LogSuccess("Set default incoming policy to deny") - - defaultOutCmd := exec.Command("ufw", "default", "allow", "outgoing") - output, err = defaultOutCmd.CombinedOutput() // Capture both stdout and stderr - if err != nil { - logging.LogError("Failed to set default outgoing policy: %v, output: %s", err, string(output)) - return fmt.Errorf("failed to set default outgoing policy in UFW: %w", err) - } - logging.LogSuccess("Set default outgoing policy to allow") - - sshPortStr := strconv.Itoa(cfg.SshPort) - sshAllowCmd := exec.Command("ufw", "allow", sshPortStr+"/tcp", "comment", "SSH") - if err := sshAllowCmd.Run(); err != nil { - logging.LogError("Failed to allow SSH port %s/tcp: %v", sshPortStr, err) - return fmt.Errorf("failed to create UFW rule to allow SSH on port %s/tcp: %w", sshPortStr, err) - } else { - logging.LogSuccess("Configured UFW rule for SSH on port %s/tcp", sshPortStr) - } - - // Configure application profiles - if err := WriteUfwAppProfiles(cfg, osInfo); err != nil { - logging.LogError("Failed to configure UFW application profiles: %v", err) - // Continue with firewall setup even if app profiles fail - } - - // Enable UFW - enableCmd := exec.Command("ufw", "enable") - // Force non-interactive mode for 'ufw enable' - enableCmd.Env = append(enableCmd.Env, "DEBIAN_FRONTEND=noninteractive") - // The 'yes' command pipes "y" to ufw enable, which would normally prompt for confirmation - enableCmd = exec.Command("sh", "-c", "yes | ufw enable") - if err := enableCmd.Run(); err != nil { - return fmt.Errorf("failed to enable UFW: %w", err) - } - - // Configure boot service on Alpine - if osInfo.OsType == "alpine" { - bootCmd := exec.Command("rc-update", "add", "ufw", "default") - if err := bootCmd.Run(); err != nil { - logging.LogError("Failed to add UFW to Alpine boot services: %v", err) - } - - startCmd := exec.Command("rc-service", "ufw", "start") - if err := startCmd.Run(); err != nil { - logging.LogError("Failed to start UFW service on Alpine: %v", err) - } - } - - logging.LogSuccess("UFW configured and enabled with firewall rules") - return nil -} - -// WriteUfwAppProfiles writes user-defined UFW application profiles -func WriteUfwAppProfiles(cfg *config.Config, osInfo *osdetect.OSInfo) error { - if cfg.DryRun { - logging.LogInfo("[DRY-RUN] Write UFW application profiles to /etc/ufw/applications.d/hardn") - - // Check if we need to create a default SSH profile - if len(cfg.UfwAppProfiles) == 0 && cfg.SshPort != 0 { - logging.LogInfo("[DRY-RUN] - No application profiles defined, creating default SSH profile") - logging.LogInfo("[DRY-RUN] - SSH profile for port %d/tcp", cfg.SshPort) - } else if len(cfg.UfwAppProfiles) > 0 { - for _, profile := range cfg.UfwAppProfiles { - logging.LogInfo("[DRY-RUN] - Profile: %s (%s)", profile.Name, profile.Title) - logging.LogInfo("[DRY-RUN] Description: %s", profile.Description) - logging.LogInfo("[DRY-RUN] Ports: %s", strings.Join(profile.Ports, ", ")) - } - } - return nil - } - - // If there are no profiles to write, return - if len(cfg.UfwAppProfiles) == 0 { - logging.LogInfo("No UFW application profiles to configure") - return nil - } - - logging.LogInfo("Writing UFW application profiles...") - - // Create applications.d directory if it doesn't exist - if err := os.MkdirAll("/etc/ufw/applications.d", 0755); err != nil { - return fmt.Errorf("failed to create UFW applications directory: %w", err) - } - - // Backup existing profiles file if it exists - utils.BackupFile("/etc/ufw/applications.d/hardn", cfg) - - // Create content for UFW applications file - var content strings.Builder - for _, profile := range cfg.UfwAppProfiles { - content.WriteString(fmt.Sprintf("[%s]\n", profile.Name)) - content.WriteString(fmt.Sprintf("title=%s\n", profile.Title)) - content.WriteString(fmt.Sprintf("description=%s\n", profile.Description)) - content.WriteString(fmt.Sprintf("ports=%s\n", strings.Join(profile.Ports, ","))) - content.WriteString("\n") - } - - // Write the file - if err := os.WriteFile("/etc/ufw/applications.d/hardn", []byte(content.String()), 0644); err != nil { - return fmt.Errorf("failed to write UFW application profiles file: %w", err) - } - - // Apply the profiles - for _, profile := range cfg.UfwAppProfiles { - allowCmd := exec.Command("ufw", "allow", fmt.Sprintf("from any to any app '%s'", profile.Name)) - if err := allowCmd.Run(); err != nil { - logging.LogError("Failed to enable UFW application profile %s: %v", profile.Name, err) - return fmt.Errorf("failed to enable UFW application profile %s: %w", profile.Name, err) - } else { - logging.LogSuccess("Enabled UFW application profile: %s", profile.Name) - } - } - - logging.LogSuccess("UFW application profiles configured") - return nil -} \ No newline at end of file diff --git a/pkg/infrastructure/menu_factory.go b/pkg/infrastructure/menu_factory.go new file mode 100644 index 0000000..d0ac91e --- /dev/null +++ b/pkg/infrastructure/menu_factory.go @@ -0,0 +1,78 @@ +// pkg/infrastructure/menu_factory.go +package infrastructure + +import ( + "github.com/abbott/hardn/pkg/application" + "github.com/abbott/hardn/pkg/config" + "github.com/abbott/hardn/pkg/menu" + "github.com/abbott/hardn/pkg/osdetect" + "github.com/abbott/hardn/pkg/version" +) + +// MenuFactory creates menu components +type MenuFactory struct { + serviceFactory *ServiceFactory + config *config.Config + osInfo *osdetect.OSInfo +} + +func NewMenuFactory( + serviceFactory *ServiceFactory, + config *config.Config, + osInfo *osdetect.OSInfo, +) *MenuFactory { + // Set the config in the service factory + serviceFactory.SetConfig(config) + + return &MenuFactory{ + serviceFactory: serviceFactory, + config: config, + osInfo: osInfo, + } +} + +// CreateRunAllMenu creates a RunAllMenu with all dependencies wired up +func (f *MenuFactory) CreateRunAllMenu() *menu.RunAllMenu { + menuManager := f.serviceFactory.CreateMenuManager() + return menu.NewRunAllMenu(menuManager, f.config, f.osInfo) +} + +// CreateDryRunMenu creates a DryRunMenu with all dependencies wired up +func (f *MenuFactory) CreateDryRunMenu() *menu.DryRunMenu { + menuManager := f.serviceFactory.CreateMenuManager() + return menu.NewDryRunMenu(menuManager, f.config) +} + +func (f *MenuFactory) CreateHelpMenu() *menu.HelpMenu { + return menu.NewHelpMenu() +} + +// CreateMainMenu creates the main menu with all dependencies wired up +func (f *MenuFactory) CreateMainMenu(versionService *version.Service) *menu.MainMenu { + // Create required managers + userManager := f.serviceFactory.CreateUserManager() + sshManager := f.serviceFactory.CreateSSHManager() + firewallManager := f.serviceFactory.CreateFirewallManager() + dnsManager := f.serviceFactory.CreateDNSManager() + packageManager := f.serviceFactory.CreatePackageManager() + backupManager := f.serviceFactory.CreateBackupManager() + environmentManager := f.serviceFactory.CreateEnvironmentManager() + logsManager := f.serviceFactory.CreateLogsManager() + securityManager := application.NewSecurityManager( + userManager, sshManager, firewallManager, dnsManager) + + // Create menu manager (use := instead of = since we're not declaring it above anymore) + menuManager := application.NewMenuManager( + userManager, + sshManager, + firewallManager, + dnsManager, + packageManager, + backupManager, + securityManager, + environmentManager, + logsManager) + + // Create menu with all necessary fields initialized + return menu.NewMainMenu(menuManager, f.config, f.osInfo, versionService) +} diff --git a/pkg/infrastructure/service_factory.go b/pkg/infrastructure/service_factory.go new file mode 100644 index 0000000..c26100b --- /dev/null +++ b/pkg/infrastructure/service_factory.go @@ -0,0 +1,214 @@ +// pkg/infrastructure/service_factory.go +package infrastructure + +import ( + "github.com/abbott/hardn/pkg/adapter/secondary" + "github.com/abbott/hardn/pkg/application" + "github.com/abbott/hardn/pkg/config" + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/domain/service" + "github.com/abbott/hardn/pkg/interfaces" + "github.com/abbott/hardn/pkg/osdetect" +) + +// ServiceFactory creates and wires application components +type ServiceFactory struct { + provider *interfaces.Provider + osInfo *osdetect.OSInfo + config *config.Config +} + +// NewServiceFactory creates a new ServiceFactory +func NewServiceFactory(provider *interfaces.Provider, osInfo *osdetect.OSInfo) *ServiceFactory { + return &ServiceFactory{ + provider: provider, + osInfo: osInfo, + } +} + +// SetConfig sets the configuration +func (f *ServiceFactory) SetConfig(config *config.Config) { + f.config = config +} + +// CreateUserManager creates a UserManager with all required dependencies +func (f *ServiceFactory) CreateUserManager() *application.UserManager { + // Create repository + userRepo := secondary.NewOSUserRepository(f.provider.FS, f.provider.Commander, f.osInfo.OsType) + + // Create domain service + userService := service.NewUserServiceImpl(userRepo) + + // Create application service + return application.NewUserManager(userService) +} + +// CreateSSHManager creates an SSHManager with all required dependencies +func (f *ServiceFactory) CreateSSHManager() *application.SSHManager { + // Create repository + sshRepo := secondary.NewFileSSHRepository(f.provider.FS, f.provider.Commander, f.osInfo.OsType) + + // Create domain service + sshService := service.NewSSHServiceImpl(sshRepo, convertOSInfo(f.osInfo)) + + // Create application service + return application.NewSSHManager(sshService) +} + +// Helper to convert osdetect.OSInfo to domain model.OSInfo +func convertOSInfo(info *osdetect.OSInfo) model.OSInfo { + return model.OSInfo{ + Type: info.OsType, + Codename: info.OsCodename, + Version: info.OsVersion, + IsProxmox: info.IsProxmox, + } +} + +// CreateFirewallManager creates a FirewallManager +func (f *ServiceFactory) CreateFirewallManager() *application.FirewallManager { + // Create repository + firewallRepo := secondary.NewUFWFirewallRepository(f.provider.FS, f.provider.Commander) + + // Create domain service + firewallService := service.NewFirewallServiceImpl(firewallRepo, convertOSInfo(f.osInfo)) + + // Create application service + return application.NewFirewallManager(firewallService) +} + +// CreateDNSManager creates a DNSManager +func (f *ServiceFactory) CreateDNSManager() *application.DNSManager { + // Create repository + dnsRepo := secondary.NewFileDNSRepository(f.provider.FS, f.provider.Commander, f.osInfo.OsType) + + // Create domain service + dnsService := service.NewDNSServiceImpl(dnsRepo, convertOSInfo(f.osInfo)) + + // Create application service + return application.NewDNSManager(dnsService) +} + +// CreatePackageManager creates a PackageManager +func (f *ServiceFactory) CreatePackageManager() *application.PackageManager { + // Convert config to PackageSources model + sources := &model.PackageSources{ + // Standard repositories + DebianRepos: f.config.DebianRepos, + ProxmoxSrcRepos: f.config.ProxmoxSrcRepos, + ProxmoxCephRepo: f.config.ProxmoxCephRepo, + ProxmoxEnterpriseRepo: f.config.ProxmoxEnterpriseRepo, + AlpineTestingRepo: f.config.AlpineTestingRepo, + + // Package lists + DebianCorePackages: f.config.LinuxCorePackages, + DebianDmzPackages: f.config.LinuxDmzPackages, + DebianLabPackages: f.config.LinuxLabPackages, + AlpineCorePackages: f.config.AlpineCorePackages, + AlpineDmzPackages: f.config.AlpineDmzPackages, + AlpineLabPackages: f.config.AlpineLabPackages, + + // Python packages + DebianPythonPackages: f.config.PythonPackages, + NonWslPythonPackages: f.config.NonWslPythonPackages, + PythonPipPackages: f.config.PythonPipPackages, + AlpinePythonPackages: f.config.AlpinePythonPackages, + } + + // Create repository + packageRepo := secondary.NewOSPackageRepository( + f.provider.FS, + f.provider.Commander, + f.osInfo.OsType, + f.osInfo.OsVersion, + f.osInfo.OsCodename, + f.osInfo.IsProxmox, + sources, + ) + + // Create domain service + packageService := service.NewPackageServiceImpl(packageRepo, convertOSInfo(f.osInfo)) + + // Create application service with all required dependencies + return application.NewPackageManager( + packageService, + sources, + &model.OSInfo{ + Type: f.osInfo.OsType, + Version: f.osInfo.OsVersion, + Codename: f.osInfo.OsCodename, + IsProxmox: f.osInfo.IsProxmox, + }, + f.provider.Network, + f.config.DmzSubnet, + ) +} + +// CreateMenuManager creates a MenuManager with all required dependencies +func (f *ServiceFactory) CreateMenuManager() *application.MenuManager { + userManager := f.CreateUserManager() + sshManager := f.CreateSSHManager() + firewallManager := f.CreateFirewallManager() + dnsManager := f.CreateDNSManager() + packageManager := f.CreatePackageManager() + backupManager := f.CreateBackupManager() + environmentManager := f.CreateEnvironmentManager() + logsManager := f.CreateLogsManager() + securityManager := application.NewSecurityManager( + userManager, sshManager, firewallManager, dnsManager) + + return application.NewMenuManager( + userManager, + sshManager, + firewallManager, + dnsManager, + packageManager, + backupManager, + securityManager, + environmentManager, + logsManager) +} + +// CreateBackupManager creates a BackupManager +func (f *ServiceFactory) CreateBackupManager() *application.BackupManager { + // Create repository + backupRepo := secondary.NewFileBackupRepository( + f.provider.FS, + f.provider.Commander, + f.config.BackupPath, + f.config.EnableBackups, + ) + + // Create domain service + backupService := service.NewBackupServiceImpl(backupRepo) + + // Create application service + return application.NewBackupManager(backupService) +} + +// CreateEnvironmentManager creates an EnvironmentManager with all required dependencies +func (f *ServiceFactory) CreateEnvironmentManager() *application.EnvironmentManager { + // Create repository + environmentRepo := secondary.NewFileEnvironmentRepository(f.provider.FS, f.provider.Commander) + + // Create domain service + environmentService := service.NewEnvironmentServiceImpl(environmentRepo) + + // Create application service + return application.NewEnvironmentManager(environmentService) +} + +// CreateLogsManager creates a LogsManager +func (f *ServiceFactory) CreateLogsManager() *application.LogsManager { + // Create repository + logsRepo := secondary.NewFileLogsRepository( + f.provider.FS, + f.config.LogFile, + ) + + // Create domain service + logsService := service.NewLogsServiceImpl(logsRepo) + + // Create application service + return application.NewLogsManager(logsService) +} diff --git a/pkg/interfaces/interfaces.go b/pkg/interfaces/interfaces.go new file mode 100644 index 0000000..47ee33c --- /dev/null +++ b/pkg/interfaces/interfaces.go @@ -0,0 +1,29 @@ +// pkg/interfaces/interfaces.go +package interfaces + +import ( + "io/fs" + "os" +) + +// FileSystem abstracts filesystem operations +type FileSystem interface { + ReadFile(filename string) ([]byte, error) + WriteFile(filename string, data []byte, perm fs.FileMode) error + MkdirAll(path string, perm fs.FileMode) error + Stat(name string) (os.FileInfo, error) + Remove(name string) error + RemoveAll(path string) error +} + +// Commander abstracts command execution +type Commander interface { + Execute(command string, args ...string) ([]byte, error) + ExecuteWithInput(input string, command string, args ...string) ([]byte, error) +} + +// NetworkOperations abstracts network-related operations +type NetworkOperations interface { + GetInterfaces() ([]string, error) + CheckSubnet(subnet string) (bool, error) +} diff --git a/pkg/interfaces/mock_interfaces.go b/pkg/interfaces/mock_interfaces.go new file mode 100644 index 0000000..a3e6aaf --- /dev/null +++ b/pkg/interfaces/mock_interfaces.go @@ -0,0 +1,250 @@ +// pkg/interfaces/mock_interfaces.go +package interfaces + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + "time" +) + +// MockFileSystem provides a mock implementation of FileSystem for testing +type MockFileSystem struct { + // Maps to store mock data + Files map[string][]byte + FileInfos map[string]os.FileInfo + Directories map[string]bool + + // Error responses for specific operations + ReadFileError map[string]error + WriteFileError map[string]error + MkdirAllError map[string]error + StatError map[string]error + RemoveError map[string]error + RemoveAllError map[string]error +} + +// NewMockFileSystem creates a new initialized MockFileSystem +func NewMockFileSystem() *MockFileSystem { + return &MockFileSystem{ + Files: make(map[string][]byte), + FileInfos: make(map[string]os.FileInfo), + Directories: make(map[string]bool), + ReadFileError: make(map[string]error), + WriteFileError: make(map[string]error), + MkdirAllError: make(map[string]error), + StatError: make(map[string]error), + RemoveError: make(map[string]error), + RemoveAllError: make(map[string]error), + } +} + +func (m MockFileSystem) ReadFile(filename string) ([]byte, error) { + if err, ok := m.ReadFileError[filename]; ok && err != nil { + return nil, err + } + + data, exists := m.Files[filename] + if !exists { + return nil, fmt.Errorf("file not found: %s", filename) + } + return data, nil +} + +func (m MockFileSystem) WriteFile(filename string, data []byte, perm fs.FileMode) error { + if err, ok := m.WriteFileError[filename]; ok && err != nil { + return err + } + + // Ensure directory exists in our mock + dir := filepath.Dir(filename) + if dir != "." { + m.Directories[dir] = true + } + + m.Files[filename] = data + return nil +} + +func (m MockFileSystem) MkdirAll(path string, perm fs.FileMode) error { + if err, ok := m.MkdirAllError[path]; ok && err != nil { + return err + } + + m.Directories[path] = true + return nil +} + +func (m MockFileSystem) Stat(name string) (os.FileInfo, error) { + if err, ok := m.StatError[name]; ok && err != nil { + return nil, err + } + + info, exists := m.FileInfos[name] + if exists { + return info, nil + } + + if _, exists := m.Files[name]; exists { + return mockFileInfo{name: name, isDir: false}, nil + } + + if _, exists := m.Directories[name]; exists { + return mockFileInfo{name: name, isDir: true}, nil + } + + return nil, os.ErrNotExist +} + +func (m MockFileSystem) Remove(name string) error { + if err, ok := m.RemoveError[name]; ok && err != nil { + return err + } + + delete(m.Files, name) + delete(m.FileInfos, name) + delete(m.Directories, name) + return nil +} + +func (m MockFileSystem) RemoveAll(path string) error { + if err, ok := m.RemoveAllError[path]; ok && err != nil { + return err + } + + // Remove all files and directories with this prefix + for key := range m.Files { + if strings.HasPrefix(key, path) { + delete(m.Files, key) + } + } + + for key := range m.FileInfos { + if strings.HasPrefix(key, path) { + delete(m.FileInfos, key) + } + } + + for key := range m.Directories { + if strings.HasPrefix(key, path) { + delete(m.Directories, key) + } + } + + return nil +} + +// Mock implementation of os.FileInfo for testing +type mockFileInfo struct { + name string + size int64 + mode os.FileMode + isDir bool +} + +func (m mockFileInfo) Name() string { return filepath.Base(m.name) } +func (m mockFileInfo) Size() int64 { return m.size } +func (m mockFileInfo) Mode() os.FileMode { return m.mode } +func (m mockFileInfo) ModTime() time.Time { return time.Time{} } +func (m mockFileInfo) IsDir() bool { return m.isDir } +func (m mockFileInfo) Sys() interface{} { return nil } + +// MockCommander provides a mock implementation of Commander for testing +type MockCommander struct { + // Maps command + args string to output + CommandOutputs map[string][]byte + CommandErrors map[string]error + + // Track executed commands for verification + ExecutedCommands []string +} + +// NewMockCommander creates a new MockCommander +func NewMockCommander() *MockCommander { + return &MockCommander{ + CommandOutputs: make(map[string][]byte), + CommandErrors: make(map[string]error), + ExecutedCommands: []string{}, + } +} + +func (m *MockCommander) Execute(command string, args ...string) ([]byte, error) { + // Create command string for lookup + cmdString := command + for _, arg := range args { + cmdString += " " + arg + } + + // Record this command was executed + m.ExecutedCommands = append(m.ExecutedCommands, cmdString) + + // Return mock response + if err, ok := m.CommandErrors[cmdString]; ok && err != nil { + return nil, err + } + + if output, ok := m.CommandOutputs[cmdString]; ok { + return output, nil + } + + // Default empty response + return []byte{}, nil +} + +func (m *MockCommander) ExecuteWithInput(input string, command string, args ...string) ([]byte, error) { + // Create command string for lookup + cmdString := "INPUT:" + input + "|" + command + for _, arg := range args { + cmdString += " " + arg + } + + // Record this command was executed + m.ExecutedCommands = append(m.ExecutedCommands, cmdString) + + // Return mock response + if err, ok := m.CommandErrors[cmdString]; ok && err != nil { + return nil, err + } + + if output, ok := m.CommandOutputs[cmdString]; ok { + return output, nil + } + + // Default empty response + return []byte{}, nil +} + +// MockNetworkOperations provides a mock implementation of NetworkOperations +type MockNetworkOperations struct { + // Mock data + Interfaces []string + Subnets map[string]bool + + // Errors + GetInterfacesError error + CheckSubnetError error +} + +// NewMockNetworkOperations creates a new MockNetworkOperations +func NewMockNetworkOperations() *MockNetworkOperations { + return &MockNetworkOperations{ + Interfaces: []string{}, + Subnets: make(map[string]bool), + } +} + +func (m MockNetworkOperations) GetInterfaces() ([]string, error) { + if m.GetInterfacesError != nil { + return nil, m.GetInterfacesError + } + return m.Interfaces, nil +} + +func (m MockNetworkOperations) CheckSubnet(subnet string) (bool, error) { + if m.CheckSubnetError != nil { + return false, m.CheckSubnetError + } + return m.Subnets[subnet], nil +} diff --git a/pkg/interfaces/os_commander.go b/pkg/interfaces/os_commander.go new file mode 100644 index 0000000..206e155 --- /dev/null +++ b/pkg/interfaces/os_commander.go @@ -0,0 +1,24 @@ +// pkg/interfaces/os_commander.go +package interfaces + +import ( + "bytes" + "os/exec" +) + +// OSCommander is an implementation of Commander using os/exec +type OSCommander struct{} + +func (c OSCommander) Execute(command string, args ...string) ([]byte, error) { + cmd := exec.Command(command, args...) + return cmd.CombinedOutput() +} + +func (c OSCommander) ExecuteWithInput(input string, command string, args ...string) ([]byte, error) { + cmd := exec.Command(command, args...) + + stdin := bytes.NewBufferString(input) + cmd.Stdin = stdin + + return cmd.CombinedOutput() +} diff --git a/pkg/interfaces/os_filesystem.go b/pkg/interfaces/os_filesystem.go new file mode 100644 index 0000000..adb175a --- /dev/null +++ b/pkg/interfaces/os_filesystem.go @@ -0,0 +1,42 @@ +// pkg/interfaces/os_filesystem.go +package interfaces + +import ( + "io/fs" + "os" + "path/filepath" +) + +// OSFileSystem is an implementation of FileSystem using the os package +type OSFileSystem struct{} + +func (fs OSFileSystem) ReadFile(filename string) ([]byte, error) { + return os.ReadFile(filename) +} + +func (fs OSFileSystem) WriteFile(filename string, data []byte, perm fs.FileMode) error { + // Ensure directory exists + dir := filepath.Dir(filename) + if dir != "." { + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + } + return os.WriteFile(filename, data, perm) +} + +func (fs OSFileSystem) MkdirAll(path string, perm fs.FileMode) error { + return os.MkdirAll(path, perm) +} + +func (fs OSFileSystem) Stat(name string) (os.FileInfo, error) { + return os.Stat(name) +} + +func (fs OSFileSystem) Remove(name string) error { + return os.Remove(name) +} + +func (fs OSFileSystem) RemoveAll(path string) error { + return os.RemoveAll(path) +} diff --git a/pkg/interfaces/os_network.go b/pkg/interfaces/os_network.go new file mode 100644 index 0000000..c6b5838 --- /dev/null +++ b/pkg/interfaces/os_network.go @@ -0,0 +1,56 @@ +package interfaces + +import ( + "fmt" + "net" + "strings" +) + +// GetInterfaces returns a list of network interfaces +func (o OSNetworkOperations) GetInterfaces() ([]string, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("failed to get network interfaces: %w", err) + } + + var names []string + for _, iface := range interfaces { + names = append(names, iface.Name) + } + + return names, nil +} + +// CheckSubnet checks if the specified subnet is present in the system's interfaces +func (o OSNetworkOperations) CheckSubnet(subnet string) (bool, error) { + interfaces, err := net.Interfaces() + if err != nil { + return false, fmt.Errorf("failed to get network interfaces: %w", err) + } + + for _, iface := range interfaces { + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if !ok { + continue + } + + ip := ipNet.IP.To4() + if ip == nil { + continue + } + + // Check if IP matches subnet + if strings.HasPrefix(ip.String(), subnet+".") { + return true, nil + } + } + } + + return false, nil +} diff --git a/pkg/interfaces/provider.go b/pkg/interfaces/provider.go new file mode 100644 index 0000000..f176be8 --- /dev/null +++ b/pkg/interfaces/provider.go @@ -0,0 +1,34 @@ +// pkg/interfaces/provider.go +package interfaces + +// Provider holds interfaces for dependency injection +type Provider struct { + FS FileSystem + Commander Commander + Network NetworkOperations +} + +// NewProvider creates a new Provider with default implementations +func NewProvider() *Provider { + return &Provider{ + FS: OSFileSystem{}, + Commander: OSCommander{}, + Network: OSNetworkOperations{}, + } +} + +// MockProvider creates a Provider with mock implementations for testing +func MockProvider() *Provider { + return &Provider{ + FS: MockFileSystem{}, + Commander: &MockCommander{}, + Network: MockNetworkOperations{}, + } +} + +// OSNetworkOperations implements NetworkOperations +type OSNetworkOperations struct{} + +// Add implementations for OSNetworkOperations... + +// Mock implementations are defined in mocks.go diff --git a/pkg/logging/logging.go b/pkg/logging/logging.go index ab41170..d2f67d9 100644 --- a/pkg/logging/logging.go +++ b/pkg/logging/logging.go @@ -104,4 +104,4 @@ func PrintLogs(logPath string) { fmt.Printf("\n# Contents of %s:\n\n", logPath) fmt.Println(string(data)) -} \ No newline at end of file +} diff --git a/pkg/menu/backup.go b/pkg/menu/backup.go deleted file mode 100644 index a6116b4..0000000 --- a/pkg/menu/backup.go +++ /dev/null @@ -1,261 +0,0 @@ -// pkg/menu/backup.go - -package menu - -import ( - "fmt" - "os" - "path/filepath" - "strconv" - - "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/logging" - "github.com/abbott/hardn/pkg/style" - "github.com/abbott/hardn/pkg/utils" -) - -// BackupOptionsMenu displays and handles backup configuration options -func BackupOptionsMenu(cfg *config.Config) { - utils.PrintHeader() - fmt.Println(style.Bolded("Backup Settings", style.Blue)) - - // Display current settings - fmt.Println() - fmt.Println(style.Bolded("Current Backup Configuration:", style.Blue)) - - // Format backup status - backupStatus := "Disabled" - statusColor := style.Red - if cfg.EnableBackups { - backupStatus = "Enabled" - statusColor = style.Green - } - - // Display status with formatter - formatter := style.NewStatusFormatter([]string{"Backups", "Backup Path"}, 2) - - // Determine symbol and color based on backup status - symbol := style.SymCrossMark - color := style.Red - if cfg.EnableBackups { - symbol = style.SymEnabled - color = style.Green - } - - fmt.Println(formatter.FormatLine( - symbol, - color, - "Backups", - backupStatus, - statusColor, - "", - "bold")) - - // Display backup path - fmt.Println(formatter.FormatLine( - style.SymInfo, - style.Cyan, - "Backup Path", - cfg.BackupPath, - style.Cyan, - "", - "light")) - - // Check backup path status - if cfg.EnableBackups { - pathExists := checkBackupPath(cfg.BackupPath) - if pathExists { - fmt.Printf("%s Backup directory exists and is writable\n", - style.Colored(style.Green, style.SymCheckMark)) - } else { - fmt.Printf("%s Backup directory doesn't exist or isn't writable\n", - style.Colored(style.Yellow, style.SymWarning)) - fmt.Printf("%s Directory will be created when needed\n", style.BulletItem) - } - } - - // Create menu options - menuOptions := []style.MenuOption{ - { - Number: 1, - Title: fmt.Sprintf("Toggle backups (currently: %s)", backupStatus), - Description: "Enable or disable automatic backups of modified files", - }, - { - Number: 2, - Title: "Change backup path", - Description: fmt.Sprintf("Current: %s", cfg.BackupPath), - }, - } - - // Add option to test backup directory if backups are enabled - if cfg.EnableBackups { - menuOptions = append(menuOptions, style.MenuOption{ - Number: 3, - Title: "Verify backup directory", - Description: "Test if backup directory exists and is writable", - }) - } - - // Create menu - menu := style.NewMenu("Select an option", menuOptions) - menu.SetExitOption(style.MenuOption{ - Number: 0, - Title: "Return to main menu", - Description: "", - }) - - // Display menu - menu.Print() - - choiceStr := ReadInput() - choice, _ := strconv.Atoi(choiceStr) - - switch choice { - case 1: - // Toggle backups - cfg.EnableBackups = !cfg.EnableBackups - if cfg.EnableBackups { - fmt.Printf("\n%s Backups have been %s\n", - style.Colored(style.Green, style.SymCheckMark), - style.Bolded("enabled", style.Green)) - fmt.Printf("%s Modified files will be backed up to: %s\n", - style.BulletItem, - style.Colored(style.Cyan, cfg.BackupPath)) - } else { - fmt.Printf("\n%s Backups have been %s\n", - style.Colored(style.Yellow, style.SymInfo), - style.Bolded("disabled", style.Yellow)) - fmt.Printf("%s No automatic backups will be created\n", style.BulletItem) - } - - // Save config changes - saveBackupConfig(cfg) - - // Return to this menu after changing setting - fmt.Printf("\n%s Press any key to continue...", style.BulletItem) - ReadKey() - BackupOptionsMenu(cfg) - - case 2: - // Change backup path - fmt.Printf("\n%s Current backup path: %s\n", - style.BulletItem, - style.Colored(style.Cyan, cfg.BackupPath)) - fmt.Printf("%s Enter new backup path: ", style.BulletItem) - - newPath := ReadInput() - if newPath != "" { - // Expand path if it starts with ~ - if newPath[:1] == "~" { - home, err := os.UserHomeDir() - if err == nil { - newPath = filepath.Join(home, newPath[1:]) - } - } - - cfg.BackupPath = newPath - fmt.Printf("\n%s Backup path updated to: %s\n", - style.Colored(style.Green, style.SymCheckMark), - style.Colored(style.Cyan, cfg.BackupPath)) - - // Save config changes - saveBackupConfig(cfg) - } else { - fmt.Printf("\n%s Backup path unchanged\n", style.BulletItem) - } - - // Return to this menu after changing path - fmt.Printf("\n%s Press any key to continue...", style.BulletItem) - ReadKey() - BackupOptionsMenu(cfg) - - case 3: - // Verify backup directory (only available if backups are enabled) - if cfg.EnableBackups { - fmt.Printf("\n%s Verifying backup directory: %s\n", - style.BulletItem, - style.Colored(style.Cyan, cfg.BackupPath)) - - // Try to create directory and test write access - err := verifyBackupDirectory(cfg.BackupPath) - if err == nil { - fmt.Printf("\n%s Backup directory is valid and writable\n", - style.Colored(style.Green, style.SymCheckMark)) - } else { - fmt.Printf("\n%s Backup directory verification failed: %v\n", - style.Colored(style.Red, style.SymCrossMark), - err) - fmt.Printf("%s Please choose a different backup path\n", style.BulletItem) - } - } - - // Return to this menu after verification - fmt.Printf("\n%s Press any key to continue...", style.BulletItem) - ReadKey() - BackupOptionsMenu(cfg) - - case 0: - // Return to main menu - return - - default: - fmt.Printf("\n%s Invalid option. Please try again.\n", - style.Colored(style.Red, style.SymCrossMark)) - - fmt.Printf("\n%s Press any key to continue...", style.BulletItem) - ReadKey() - BackupOptionsMenu(cfg) - } -} - -// Helper function to save backup configuration -func saveBackupConfig(cfg *config.Config) { - // Save config changes - configFile := "hardn.yml" // Default config file - if err := config.SaveConfig(cfg, configFile); err != nil { - logging.LogError("Failed to save configuration: %v", err) - fmt.Printf("\n%s Failed to save configuration: %v\n", - style.Colored(style.Red, style.SymCrossMark), - err) - } -} - -// Helper function to check if backup path exists -func checkBackupPath(path string) bool { - // Check if directory exists - if _, err := os.Stat(path); os.IsNotExist(err) { - return false - } - - // Check if directory is writable by writing a test file - testFile := filepath.Join(path, ".write_test") - err := os.WriteFile(testFile, []byte("test"), 0644) - if err != nil { - return false - } - - // Clean up test file - os.Remove(testFile) - - return true -} - -// Helper function to verify backup directory -func verifyBackupDirectory(path string) error { - // Create backup directory if it doesn't exist - if err := os.MkdirAll(path, 0755); err != nil { - return fmt.Errorf("failed to create backup directory: %w", err) - } - - // Check if directory is writable by writing a test file - testFile := filepath.Join(path, ".write_test") - if err := os.WriteFile(testFile, []byte("test"), 0644); err != nil { - return fmt.Errorf("backup directory is not writable: %w", err) - } - - // Clean up test file - os.Remove(testFile) - - return nil -} \ No newline at end of file diff --git a/pkg/menu/backup_menu.go b/pkg/menu/backup_menu.go new file mode 100644 index 0000000..ceba7bc --- /dev/null +++ b/pkg/menu/backup_menu.go @@ -0,0 +1,258 @@ +// pkg/menu/backup_menu.go +package menu + +import ( + "fmt" + "strconv" + + "github.com/abbott/hardn/pkg/application" + "github.com/abbott/hardn/pkg/config" + "github.com/abbott/hardn/pkg/style" + "github.com/abbott/hardn/pkg/utils" +) + +// BackupMenu handles backup configuration +type BackupMenu struct { + menuManager *application.MenuManager + config *config.Config +} + +// NewBackupMenu creates a new BackupMenu +func NewBackupMenu( + menuManager *application.MenuManager, + config *config.Config, +) *BackupMenu { + return &BackupMenu{ + menuManager: menuManager, + config: config, + } +} + +// Show displays the backup menu and handles user input +func (m *BackupMenu) Show() { + utils.PrintHeader() + fmt.Println(style.Bolded("Backup Settings", style.Blue)) + + // Get backup status from application layer + enabled, backupPath, err := m.menuManager.GetBackupStatus() + if err != nil { + fmt.Printf("\n%s Error retrieving backup status: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } + + // Display current settings + fmt.Println() + fmt.Println(style.Bolded("Current Backup Configuration:", style.Blue)) + + // Format backup status + backupStatus := "Disabled" + statusColor := style.Red + if enabled { + backupStatus = "Enabled" + statusColor = style.Green + } + + // Display status with formatter + formatter := style.NewStatusFormatter([]string{"Backups", "Backup Path"}, 2) + + // Determine symbol and color based on backup status + symbol := style.SymCrossMark + color := style.Red + if enabled { + symbol = style.SymEnabled + color = style.Green + } + + fmt.Println(formatter.FormatLine( + symbol, + color, + "Backups", + backupStatus, + statusColor, + "", + "bold")) + + // Display backup path + fmt.Println(formatter.FormatLine( + style.SymInfo, + style.Cyan, + "Backup Path", + backupPath, + style.Cyan, + "", + "light")) + + // Check backup path status + if enabled { + // Use application layer to check path status + pathExists, err := m.menuManager.VerifyBackupPath() + if err != nil { + fmt.Printf("%s Error checking backup path: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } else if pathExists { + fmt.Printf("%s Backup directory exists and is writable\n", + style.Colored(style.Green, style.SymCheckMark)) + } else { + fmt.Printf("%s Backup directory doesn't exist or isn't writable\n", + style.Colored(style.Yellow, style.SymWarning)) + fmt.Printf("%s Directory will be created when needed\n", style.BulletItem) + } + } + + // Create menu options + menuOptions := []style.MenuOption{ + { + Number: 1, + Title: fmt.Sprintf("Toggle backups (currently: %s)", backupStatus), + Description: "Enable or disable automatic backups of modified files", + }, + { + Number: 2, + Title: "Change backup path", + Description: fmt.Sprintf("Current: %s", backupPath), + }, + } + + // Add option to test backup directory if backups are enabled + if enabled { + menuOptions = append(menuOptions, style.MenuOption{ + Number: 3, + Title: "Verify backup directory", + Description: "Test if backup directory exists and is writable", + }) + } + + // Create menu + menu := style.NewMenu("Select an option", menuOptions) + menu.SetExitOption(style.MenuOption{ + Number: 0, + Title: "Return to main menu", + Description: "", + }) + + // Display menu + menu.Print() + + choiceStr := ReadInput() + choice, _ := strconv.Atoi(choiceStr) + + switch choice { + case 1: + // Toggle backups using application layer + err := m.menuManager.ToggleBackups() + if err != nil { + fmt.Printf("\n%s Error toggling backups: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } else { + // Get the new status + enabled, backupPath, _ = m.menuManager.GetBackupStatus() + + if enabled { + fmt.Printf("\n%s Backups have been %s\n", + style.Colored(style.Green, style.SymCheckMark), + style.Bolded("enabled", style.Green)) + fmt.Printf("%s Modified files will be backed up to: %s\n", + style.BulletItem, + style.Colored(style.Cyan, backupPath)) + } else { + fmt.Printf("\n%s Backups have been %s\n", + style.Colored(style.Yellow, style.SymInfo), + style.Bolded("disabled", style.Yellow)) + fmt.Printf("%s No automatic backups will be created\n", style.BulletItem) + } + + // Update config to keep it in sync + m.config.EnableBackups = enabled + + // Save config changes + configFile := "hardn.yml" // Default config file + if err := config.SaveConfig(m.config, configFile); err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } + } + + // Return to this menu after changing setting + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) + ReadKey() + m.Show() + + case 2: + // Change backup path + fmt.Printf("\n%s Current backup path: %s\n", + style.BulletItem, + style.Colored(style.Cyan, backupPath)) + fmt.Printf("%s Enter new backup path: ", style.BulletItem) + + newPath := ReadInput() + if newPath != "" { + // Use application layer to set backup directory + err := m.menuManager.SetBackupDirectory(newPath) + if err != nil { + fmt.Printf("\n%s Failed to set backup path: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } else { + // Update local path for display + _, updatedPath, _ := m.menuManager.GetBackupStatus() + + fmt.Printf("\n%s Backup path updated to: %s\n", + style.Colored(style.Green, style.SymCheckMark), + style.Colored(style.Cyan, updatedPath)) + + // Update config to keep it in sync + m.config.BackupPath = updatedPath + + // Save config + configFile := "hardn.yml" // Default config file + if err := config.SaveConfig(m.config, configFile); err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } + } + } else { + fmt.Printf("\n%s Backup path unchanged\n", style.BulletItem) + } + + // Return to this menu after changing path + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) + ReadKey() + m.Show() + + case 3: + // Verify backup directory (only available if backups are enabled) + if enabled { + fmt.Printf("\n%s Verifying backup directory: %s\n", + style.BulletItem, + style.Colored(style.Cyan, backupPath)) + + // Use application layer to verify directory + err := m.menuManager.VerifyBackupDirectory() + if err == nil { + fmt.Printf("\n%s Backup directory is valid and writable\n", + style.Colored(style.Green, style.SymCheckMark)) + } else { + fmt.Printf("\n%s Backup directory verification failed: %v\n", + style.Colored(style.Red, style.SymCrossMark), + err) + fmt.Printf("%s Please choose a different backup path\n", style.BulletItem) + } + } + + // Return to this menu after verification + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) + ReadKey() + m.Show() + + case 0: + // Return to main menu + return + + default: + fmt.Printf("\n%s Invalid option. Please try again.\n", + style.Colored(style.Red, style.SymCrossMark)) + + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) + ReadKey() + m.Show() + } +} diff --git a/pkg/menu/disable_root_menu.go b/pkg/menu/disable_root_menu.go new file mode 100644 index 0000000..02fce8c --- /dev/null +++ b/pkg/menu/disable_root_menu.go @@ -0,0 +1,254 @@ +// pkg/menu/disable_root_menu.go +package menu + +import ( + "fmt" + "os" + "os/exec" + "strings" + + "github.com/abbott/hardn/pkg/application" + "github.com/abbott/hardn/pkg/config" + "github.com/abbott/hardn/pkg/osdetect" + "github.com/abbott/hardn/pkg/style" + "github.com/abbott/hardn/pkg/utils" +) + +// DisableRootMenu handles disabling root SSH access +type DisableRootMenu struct { + menuManager *application.MenuManager + config *config.Config + osInfo *osdetect.OSInfo +} + +// NewDisableRootMenu creates a new DisableRootMenu +func NewDisableRootMenu( + menuManager *application.MenuManager, + config *config.Config, + osInfo *osdetect.OSInfo, +) *DisableRootMenu { + return &DisableRootMenu{ + menuManager: menuManager, + config: config, + osInfo: osInfo, + } +} + +// Show displays the disable root menu and handles user input +func (m *DisableRootMenu) Show() { + utils.PrintHeader() + fmt.Println(style.Bolded("Disable Root SSH Access", style.Blue)) + + // Check current status of root SSH access + rootAccessEnabled, err := m.checkRootLoginEnabled() + if err != nil { + fmt.Printf("\n%s Error checking root SSH status: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + rootAccessEnabled = true // Assume vulnerable if can't check + } + + fmt.Println() + if rootAccessEnabled { + fmt.Printf("%s %s Root SSH access is currently %s\n", + style.Colored(style.Yellow, style.SymWarning), + style.Bolded("WARNING:"), + style.Bolded("ENABLED", style.Red)) + } else { + fmt.Printf("%s Root SSH access is already %s\n", + style.Colored(style.Green, style.SymCheckMark), + style.Bolded("DISABLED", style.Green)) + + // Display information about status but continue to show menu + fmt.Println(style.Colored(style.Green, "\nNo further action needed to disable root SSH access.")) + } + + // Security warning + fmt.Println(style.Colored(style.Yellow, "\nBefore disabling root SSH access, ensure that:")) + fmt.Printf("%s You have created at least one non-root user with sudo privileges\n", style.BulletItem) + fmt.Printf("%s You have tested SSH access with this non-root user\n", style.BulletItem) + fmt.Printf("%s You have a backup method to access this system if SSH fails\n", style.BulletItem) + + // Create menu options + menuOptions := []style.MenuOption{} + + // Always show option 1, but dim it when already disabled + if rootAccessEnabled { + menuOptions = append(menuOptions, style.MenuOption{ + Number: 1, + Title: "Disable root SSH access", + Description: "Modify SSH config to prevent root login", + }) + } else { + // For dimmed text, we need to store just the plain text in the Title field + // and then apply the dimming in the description to maintain proper spacing + menuOptions = append(menuOptions, style.MenuOption{ + Number: 1, + Title: "Disable root SSH access", + Description: "ALREADY DISABLED", + Style: "strike", + }) + } + + // Add options to view SSH configuration + menuOptions = append(menuOptions, style.MenuOption{ + Number: 2, + Title: "View current SSH configuration", + Description: "Show details of SSH security settings", + }) + + // Create menu + menu := style.NewMenu("Select an option", menuOptions) + menu.SetExitOption(style.MenuOption{ + Number: 0, + Title: "Return to main menu", + Description: "Keep current settings", + }) + + // Display menu + menu.Print() + + choice := ReadMenuInput() + + // Handle 'q' as a special exit case + if choice == "q" { + return + } + + switch choice { + case "1": + // Handle the case where root access is already disabled + if !rootAccessEnabled { + fmt.Printf("\n%s Root SSH access is already disabled\n", + style.Colored(style.Green, style.SymCheckMark)) + fmt.Println(style.Dimmed("\nNo action needed.")) + break + } + + // Confirmation step + fmt.Printf("\n%s Are you sure you want to disable root SSH access? (y/n): ", + style.Colored(style.Yellow, style.SymWarning)) + confirm := ReadInput() + + if strings.ToLower(confirm) != "y" && strings.ToLower(confirm) != "yes" { + fmt.Println("\nOperation cancelled. Root SSH access remains enabled.") + break + } + + fmt.Println("\nDisabling root SSH access...") + + if m.config.DryRun { + fmt.Printf("%s [DRY-RUN] Would disable root SSH access\n", style.BulletItem) + } else { + // Call application layer to disable root SSH access + err := m.menuManager.DisableRootSsh() + if err != nil { + fmt.Printf("\n%s Failed to disable root SSH access: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } else { + fmt.Printf("\n%s Root SSH access has been disabled\n", + style.Colored(style.Green, style.SymCheckMark)) + + // Restart SSH service + fmt.Println(style.Dimmed("Restarting SSH service...")) + var restartErr error + if m.osInfo.OsType == "alpine" { + restartErr = exec.Command("rc-service", "sshd", "restart").Run() + } else { + restartErr = exec.Command("systemctl", "restart", "ssh").Run() + } + + if restartErr != nil { + fmt.Printf("%s Failed to restart SSH service: %v\n", + style.Colored(style.Yellow, style.SymWarning), restartErr) + fmt.Println(style.Dimmed("You may need to restart the SSH service manually.")) + } else { + fmt.Printf("%s SSH service restarted successfully\n", + style.Colored(style.Green, style.SymCheckMark)) + } + } + } + case "2": + // View current SSH configuration + fmt.Println("\nCurrent SSH Configuration:") + fmt.Println(style.Dimmed("-------------------------------------")) + + // Display root login status + rootStatus := "Enabled" + if !rootAccessEnabled { + rootStatus = "Disabled" + } + + var color string + if rootAccessEnabled { + color = style.Red + } else { + color = style.Green + } + fmt.Printf("%s Root SSH login: %s\n", style.BulletItem, + style.Colored(color, rootStatus)) + + // Display SSH port + fmt.Printf("%s SSH port: %d\n", style.BulletItem, m.config.SshPort) + + // Display additional SSH settings if available + fmt.Printf("%s Allowed users: %s\n", style.BulletItem, + strings.Join(m.config.SshAllowedUsers, ", ")) + case "0": + return + default: + fmt.Printf("\n%s Invalid option. Please try again.\n", + style.Colored(style.Red, style.SymCrossMark)) + } + + fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) + ReadKey() +} + +// checkRootLoginEnabled checks if SSH root login is enabled by asking the application layer +func (m *DisableRootMenu) checkRootLoginEnabled() (bool, error) { + // In a full implementation, we would call through to the application layer + // For now, we'll use a simple file check similar to the old implementation + + // This is temporary and should be replaced with a proper call to the application layer + // as it becomes available + var rootLoginEnabled bool + + // Check SSH config file - THIS SHOULD BE REPLACED with app layer method + var sshConfigPaths []string + + if m.osInfo.OsType == "alpine" { + sshConfigPaths = []string{"/etc/ssh/sshd_config"} + } else { + // For Debian/Ubuntu, check both main config and config.d + sshConfigPaths = []string{ + "/etc/ssh/sshd_config.d/hardn.conf", + "/etc/ssh/sshd_config", + } + } + + // Check each potential config file + for _, configPath := range sshConfigPaths { + // Check if the file exists and parse it + content, err := os.ReadFile(configPath) + if err != nil { + continue // Try next config file if this one can't be read + } + + lines := strings.Split(string(content), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "PermitRootLogin") { + parts := strings.Fields(line) + if len(parts) >= 2 && (parts[1] == "no" || parts[1] == "No") { + rootLoginEnabled = false + return rootLoginEnabled, nil + } + rootLoginEnabled = true + return rootLoginEnabled, nil + } + } + } + + // If not explicitly set, assume it's enabled + return true, nil +} diff --git a/pkg/menu/dns.go b/pkg/menu/dns_menu.go similarity index 52% rename from pkg/menu/dns.go rename to pkg/menu/dns_menu.go index 12dc757..a9cb264 100644 --- a/pkg/menu/dns.go +++ b/pkg/menu/dns_menu.go @@ -1,5 +1,4 @@ -// pkg/menu/dns.go - +// pkg/menu/dns_menu.go package menu import ( @@ -8,51 +7,71 @@ import ( "os/exec" "strings" + "github.com/abbott/hardn/pkg/application" "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/dns" - "github.com/abbott/hardn/pkg/logging" "github.com/abbott/hardn/pkg/osdetect" "github.com/abbott/hardn/pkg/style" "github.com/abbott/hardn/pkg/utils" ) -// ConfigureDnsMenu handles DNS configuration options -func ConfigureDnsMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { +// DNSMenu handles DNS configuration +type DNSMenu struct { + menuManager *application.MenuManager + config *config.Config + osInfo *osdetect.OSInfo +} + +// NewDNSMenu creates a new DNSMenu +func NewDNSMenu( + menuManager *application.MenuManager, + config *config.Config, + osInfo *osdetect.OSInfo, +) *DNSMenu { + return &DNSMenu{ + menuManager: menuManager, + config: config, + osInfo: osInfo, + } +} + +// Show displays the DNS configuration menu and handles user input +func (m *DNSMenu) Show() { utils.PrintHeader() fmt.Println(style.Bolded("DNS Configuration", style.Blue)) - // Check current DNS status + // Check current DNS status - this would ideally come from the application layer + // but for now we'll reuse the existing code until it's refactored currentNameservers, dnsImplementation := getCurrentDnsSettings() // Display current configuration fmt.Println() fmt.Println(style.Bolded("Current DNS Configuration:", style.Blue)) - + // Create formatter for status display formatter := style.NewStatusFormatter([]string{"DNS Implementation", "Nameservers"}, 2) - + // Show DNS implementation if dnsImplementation != "" { - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "DNS Implementation", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "DNS Implementation", dnsImplementation, style.Cyan, "", "light")) } else { fmt.Println(formatter.FormatWarning("DNS Implementation", "Unknown", "Could not detect DNS setup")) } - + // Show current nameservers if len(currentNameservers) > 0 { - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "Nameservers", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "Nameservers", strings.Join(currentNameservers, ", "), style.Cyan, "", "light")) } else { fmt.Println(formatter.FormatWarning("Nameservers", "None detected", "DNS resolution may not work")) } - + // Show configured nameservers fmt.Println() fmt.Println(style.Bolded("Configured Nameservers:", style.Blue)) - - if len(cfg.Nameservers) > 0 { - for i, ns := range cfg.Nameservers { + + if len(m.config.Nameservers) > 0 { + for i, ns := range m.config.Nameservers { fmt.Printf("%s Nameserver %d: %s\n", style.BulletItem, i+1, style.Colored(style.Cyan, ns)) } } else { @@ -64,35 +83,35 @@ func ConfigureDnsMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { {Number: 1, Title: "Configure DNS", Description: "Apply nameserver settings from configuration"}, {Number: 2, Title: "Add nameserver", Description: "Add a new DNS server to configuration"}, } - + // Add remove option if nameservers exist - if len(cfg.Nameservers) > 0 { + if len(m.config.Nameservers) > 0 { menuOptions = append(menuOptions, style.MenuOption{ - Number: 3, - Title: "Remove nameserver", + Number: 3, + Title: "Remove nameserver", Description: "Remove a DNS server from configuration", }) } - + // Add popular DNS provider options menuOptions = append(menuOptions, style.MenuOption{ - Number: 4, - Title: "Use Cloudflare DNS", + Number: 4, + Title: "Use Cloudflare DNS", Description: "Set nameservers to 1.1.1.1, 1.0.0.1", }) - + menuOptions = append(menuOptions, style.MenuOption{ - Number: 5, - Title: "Use Google DNS", + Number: 5, + Title: "Use Google DNS", Description: "Set nameservers to 8.8.8.8, 8.8.4.4", }) - + menuOptions = append(menuOptions, style.MenuOption{ - Number: 6, - Title: "Use Quad9 DNS", + Number: 6, + Title: "Use Quad9 DNS", Description: "Set nameservers to 9.9.9.9, 149.112.112.112", }) - + // Create menu menu := style.NewMenu("Select an option", menuOptions) menu.SetExitOption(style.MenuOption{ @@ -100,210 +119,241 @@ func ConfigureDnsMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { Title: "Return to main menu", Description: "", }) - + // Display menu menu.Print() - - choice := ReadInput() - + + choice := ReadMenuInput() + + // Handle 'q' as a special exit case + if choice == "q" { + return + } + switch choice { case "1": // Configure DNS with current settings - if len(cfg.Nameservers) == 0 { - fmt.Printf("\n%s No nameservers configured. Please add nameservers first.\n", + if len(m.config.Nameservers) == 0 { + fmt.Printf("\n%s No nameservers configured. Please add nameservers first.\n", style.Colored(style.Yellow, style.SymWarning)) - + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - ConfigureDnsMenu(cfg, osInfo) + m.Show() return } - + fmt.Println("\nConfiguring DNS settings...") - - if cfg.DryRun { - fmt.Printf("%s [DRY-RUN] Would configure DNS with nameservers: %s\n", - style.BulletItem, strings.Join(cfg.Nameservers, ", ")) + + if m.config.DryRun { + fmt.Printf("%s [DRY-RUN] Would configure DNS with nameservers: %s\n", + style.BulletItem, strings.Join(m.config.Nameservers, ", ")) } else { - if err := dns.ConfigureDNS(cfg, osInfo); err != nil { - fmt.Printf("\n%s Failed to configure DNS: %v\n", + err := m.menuManager.ConfigureDNS(m.config.Nameservers, "lan") + if err != nil { + fmt.Printf("\n%s Failed to configure DNS: %v\n", style.Colored(style.Red, style.SymCrossMark), err) - logging.LogError("Failed to configure DNS: %v", err) } else { - fmt.Printf("\n%s DNS configured successfully\n", + fmt.Printf("\n%s DNS configured successfully\n", style.Colored(style.Green, style.SymCheckMark)) - fmt.Printf("%s Nameservers: %s\n", - style.BulletItem, strings.Join(cfg.Nameservers, ", ")) + fmt.Printf("%s Nameservers: %s\n", + style.BulletItem, strings.Join(m.config.Nameservers, ", ")) } } - + case "2": // Add nameserver - fmt.Printf("\n%s Enter nameserver IP address: ", style.BulletItem) - newNameserver := ReadInput() - - if newNameserver == "" { - fmt.Printf("\n%s Nameserver cannot be empty\n", - style.Colored(style.Red, style.SymCrossMark)) - } else { - // Validate IP format (basic check) - parts := strings.Split(newNameserver, ".") - if len(parts) != 4 { - fmt.Printf("\n%s Invalid IP address format\n", - style.Colored(style.Red, style.SymCrossMark)) - } else { - // Check for duplicate - isDuplicate := false - for _, ns := range cfg.Nameservers { - if ns == newNameserver { - isDuplicate = true - break - } - } - - if isDuplicate { - fmt.Printf("\n%s Nameserver %s is already configured\n", - style.Colored(style.Yellow, style.SymWarning), newNameserver) - } else { - // Add new nameserver - cfg.Nameservers = append(cfg.Nameservers, newNameserver) - - // Save config - saveDnsConfig(cfg) - - fmt.Printf("\n%s Nameserver %s added to configuration\n", - style.Colored(style.Green, style.SymCheckMark), newNameserver) - } - } - } - - fmt.Printf("\n%s Press any key to continue...", style.BulletItem) - ReadKey() - ConfigureDnsMenu(cfg, osInfo) - + m.addNameserver() + m.Show() + return + case "3": - // Remove nameserver - if len(cfg.Nameservers) == 0 { - fmt.Printf("\n%s No nameservers to remove\n", + // Remove nameserver (only if nameservers exist) + if len(m.config.Nameservers) == 0 { + fmt.Printf("\n%s No nameservers to remove\n", style.Colored(style.Yellow, style.SymWarning)) } else { - fmt.Println() - for i, ns := range cfg.Nameservers { - fmt.Printf("%s %d: %s\n", style.BulletItem, i+1, ns) - } - - fmt.Printf("\n%s Enter nameserver number to remove (1-%d): ", - style.BulletItem, len(cfg.Nameservers)) - numStr := ReadInput() - - // Parse number - num := 0 - fmt.Sscanf(numStr, "%d", &num) - - if num < 1 || num > len(cfg.Nameservers) { - fmt.Printf("\n%s Invalid nameserver number\n", - style.Colored(style.Red, style.SymCrossMark)) - } else { - // Remove nameserver (adjust for 0-based index) - removedNs := cfg.Nameservers[num-1] - cfg.Nameservers = append(cfg.Nameservers[:num-1], cfg.Nameservers[num:]...) - - // Save config - saveDnsConfig(cfg) - - fmt.Printf("\n%s Nameserver %s removed from configuration\n", - style.Colored(style.Green, style.SymCheckMark), removedNs) - } + m.removeNameserver() } - - fmt.Printf("\n%s Press any key to continue...", style.BulletItem) - ReadKey() - ConfigureDnsMenu(cfg, osInfo) - + + m.Show() + return + case "4": // Use Cloudflare DNS fmt.Println("\nSetting Cloudflare DNS servers...") - cfg.Nameservers = []string{"1.1.1.1", "1.0.0.1"} - + m.config.Nameservers = []string{"1.1.1.1", "1.0.0.1"} + // Save config - saveDnsConfig(cfg) - - fmt.Printf("\n%s Nameservers set to Cloudflare DNS: 1.1.1.1, 1.0.0.1\n", + if err := config.SaveConfig(m.config, "hardn.yml"); err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } + + fmt.Printf("\n%s Nameservers set to Cloudflare DNS: 1.1.1.1, 1.0.0.1\n", style.Colored(style.Green, style.SymCheckMark)) - + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - ConfigureDnsMenu(cfg, osInfo) - + m.Show() + return + case "5": // Use Google DNS fmt.Println("\nSetting Google DNS servers...") - cfg.Nameservers = []string{"8.8.8.8", "8.8.4.4"} - + m.config.Nameservers = []string{"8.8.8.8", "8.8.4.4"} + // Save config - saveDnsConfig(cfg) - - fmt.Printf("\n%s Nameservers set to Google DNS: 8.8.8.8, 8.8.4.4\n", + if err := config.SaveConfig(m.config, "hardn.yml"); err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } + + fmt.Printf("\n%s Nameservers set to Google DNS: 8.8.8.8, 8.8.4.4\n", style.Colored(style.Green, style.SymCheckMark)) - + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - ConfigureDnsMenu(cfg, osInfo) - + m.Show() + return + case "6": // Use Quad9 DNS fmt.Println("\nSetting Quad9 DNS servers...") - cfg.Nameservers = []string{"9.9.9.9", "149.112.112.112"} - + m.config.Nameservers = []string{"9.9.9.9", "149.112.112.112"} + // Save config - saveDnsConfig(cfg) - - fmt.Printf("\n%s Nameservers set to Quad9 DNS: 9.9.9.9, 149.112.112.112\n", + if err := config.SaveConfig(m.config, "hardn.yml"); err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } + + fmt.Printf("\n%s Nameservers set to Quad9 DNS: 9.9.9.9, 149.112.112.112\n", style.Colored(style.Green, style.SymCheckMark)) - + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - ConfigureDnsMenu(cfg, osInfo) - + m.Show() + return + case "0": // Return to main menu return - + default: - fmt.Printf("\n%s Invalid option. Please try again.\n", + fmt.Printf("\n%s Invalid option. Please try again.\n", style.Colored(style.Red, style.SymCrossMark)) - + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - ConfigureDnsMenu(cfg, osInfo) + m.Show() return } - + fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) ReadKey() } -// Helper function to save DNS configuration -func saveDnsConfig(cfg *config.Config) { - // Save config changes - configFile := "hardn.yml" // Default config file - if err := config.SaveConfig(cfg, configFile); err != nil { - logging.LogError("Failed to save configuration: %v", err) - fmt.Printf("\n%s Failed to save configuration: %v\n", +// addNameserver handles adding a new nameserver +func (m *DNSMenu) addNameserver() { + fmt.Printf("\n%s Enter nameserver IP address: ", style.BulletItem) + newNameserver := ReadInput() + + if newNameserver == "" { + fmt.Printf("\n%s Nameserver cannot be empty\n", + style.Colored(style.Red, style.SymCrossMark)) + return + } + + // Validate IP format (basic check) + parts := strings.Split(newNameserver, ".") + if len(parts) != 4 { + fmt.Printf("\n%s Invalid IP address format\n", + style.Colored(style.Red, style.SymCrossMark)) + return + } + + // Check for duplicate + isDuplicate := false + for _, ns := range m.config.Nameservers { + if ns == newNameserver { + isDuplicate = true + break + } + } + + if isDuplicate { + fmt.Printf("\n%s Nameserver %s is already configured\n", + style.Colored(style.Yellow, style.SymWarning), newNameserver) + return + } + + // Add new nameserver + m.config.Nameservers = append(m.config.Nameservers, newNameserver) + + // Save config + if err := config.SaveConfig(m.config, "hardn.yml"); err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + return + } + + fmt.Printf("\n%s Nameserver %s added to configuration\n", + style.Colored(style.Green, style.SymCheckMark), newNameserver) +} + +// removeNameserver handles removing a nameserver +func (m *DNSMenu) removeNameserver() { + fmt.Println() + for i, ns := range m.config.Nameservers { + fmt.Printf("%s %d: %s\n", style.BulletItem, i+1, ns) + } + + fmt.Printf("\n%s Enter nameserver number to remove (1-%d): ", + style.BulletItem, len(m.config.Nameservers)) + numStr := ReadInput() + + // Parse number + num := 0 + n, err := fmt.Sscanf(numStr, "%d", &num) + if err != nil || n != 1 { + fmt.Printf("\n%s Invalid nameserver number: not a valid number\n", + style.Colored(style.Red, style.SymCrossMark)) + return + } + + if num < 1 || num > len(m.config.Nameservers) { + fmt.Printf("\n%s Invalid nameserver number: out of range\n", + style.Colored(style.Red, style.SymCrossMark)) + return + } + + // Remove nameserver (adjust for 0-based index) + removedNs := m.config.Nameservers[num-1] + m.config.Nameservers = append(m.config.Nameservers[:num-1], m.config.Nameservers[num:]...) + + // Save config + if err := config.SaveConfig(m.config, "hardn.yml"); err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", style.Colored(style.Red, style.SymCrossMark), err) + return } + + fmt.Printf("\n%s Nameserver %s removed from configuration\n", + style.Colored(style.Green, style.SymCheckMark), removedNs) } -// Helper function to get current DNS settings +// getCurrentDnsSettings retrieves the current DNS settings +// This is a temporary function that will be replaced by application layer calls later func getCurrentDnsSettings() ([]string, string) { var nameservers []string dnsImplementation := "" - + // Check if systemd-resolved is active systemdCmd := exec.Command("systemctl", "is-active", "systemd-resolved") if err := systemdCmd.Run(); err == nil { dnsImplementation = "systemd-resolved" - + // Get nameservers from resolved resolvectlCmd := exec.Command("resolvectl", "dns") output, err := resolvectlCmd.CombinedOutput() @@ -327,7 +377,7 @@ func getCurrentDnsSettings() ([]string, string) { } else { dnsImplementation = "direct (/etc/resolv.conf)" } - + // If we couldn't get nameservers from implementation-specific means, // try to parse /etc/resolv.conf directly if len(nameservers) == 0 { @@ -344,6 +394,6 @@ func getCurrentDnsSettings() ([]string, string) { } } } - + return nameservers, dnsImplementation -} \ No newline at end of file +} diff --git a/pkg/menu/dry_run.go b/pkg/menu/dry_run_menu.go similarity index 76% rename from pkg/menu/dry_run.go rename to pkg/menu/dry_run_menu.go index bf5edfb..5295f58 100644 --- a/pkg/menu/dry_run.go +++ b/pkg/menu/dry_run_menu.go @@ -1,35 +1,51 @@ -// pkg/menu/dry_run.go - +// pkg/menu/dry_run_menu.go package menu import ( "fmt" + "github.com/abbott/hardn/pkg/application" "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/logging" "github.com/abbott/hardn/pkg/style" "github.com/abbott/hardn/pkg/utils" ) -// ToggleDryRunMenu handles toggling the dry-run mode setting -func ToggleDryRunMenu(cfg *config.Config) { +// DryRunMenu handles the dry-run mode configuration +type DryRunMenu struct { + menuManager *application.MenuManager + config *config.Config +} + +// NewDryRunMenu creates a new DryRunMenu +func NewDryRunMenu( + menuManager *application.MenuManager, + config *config.Config, +) *DryRunMenu { + return &DryRunMenu{ + menuManager: menuManager, + config: config, + } +} + +// Show displays the dry-run mode menu and handles user input +func (m *DryRunMenu) Show() { utils.PrintHeader() fmt.Println(style.Bolded("Dry-Run Mode Settings", style.Blue)) - + // Create a formatter with just the label we need formatter := style.NewStatusFormatter([]string{"Dry-run Mode"}, 2) // Display current status fmt.Println() - if cfg.DryRun { + if m.config.DryRun { fmt.Println(formatter.FormatLine(style.SymInfo, style.BrightCyan, "Dry-run Mode", "Enabled", style.Green, "", "bold")) fmt.Println(style.Dimmed("\nIn this mode, the script will preview changes without applying them.")) - + // Create menu options menuOptions := []style.MenuOption{ {Number: 1, Title: "Disable dry-run mode", Description: "Apply changes to the system for real"}, } - + // Create and customize menu menu := style.NewMenu("Select an option", menuOptions) menu.SetExitOption(style.MenuOption{ @@ -37,15 +53,15 @@ func ToggleDryRunMenu(cfg *config.Config) { Title: "Return to main menu", Description: "Keep dry-run mode enabled", }) - + // Display the menu menu.Print() - + choiceStr := ReadInput() - + switch choiceStr { case "1": - cfg.DryRun = false + m.config.DryRun = false fmt.Println("\n" + formatter.FormatLine(style.SymInfo, style.BrightCyan, "Dry-run Mode", "Disabled", style.Yellow, "", "bold")) fmt.Println(style.Dimmed("\nChanges will now be applied to the system. Proceed with caution.")) case "0": @@ -56,12 +72,12 @@ func ToggleDryRunMenu(cfg *config.Config) { } else { fmt.Println(formatter.FormatLine(style.SymInfo, style.BrightCyan, "Dry-run Mode", "Disabled", style.Yellow, "", "bold")) fmt.Println(style.Dimmed("\nIn this mode, changes will be applied to the system. Proceed with caution.")) - + // Create menu options menuOptions := []style.MenuOption{ {Number: 1, Title: "Enable dry-run mode", Description: "Preview changes without applying them"}, } - + // Create and customize menu menu := style.NewMenu("Select an option", menuOptions) menu.SetExitOption(style.MenuOption{ @@ -69,15 +85,15 @@ func ToggleDryRunMenu(cfg *config.Config) { Title: "Return to main menu", Description: "Keep dry-run mode disabled", }) - + // Display the menu menu.Print() - + choiceStr := ReadInput() - + switch choiceStr { case "1": - cfg.DryRun = true + m.config.DryRun = true fmt.Println("\n" + formatter.FormatLine(style.SymInfo, style.BrightCyan, "Dry-run Mode", "Enabled", style.Green, "", "bold")) fmt.Println(style.Dimmed("\nChanges will be simulated without affecting the system.")) case "0": @@ -88,11 +104,14 @@ func ToggleDryRunMenu(cfg *config.Config) { } // Save config changes + // In a future iteration, this could use m.menuManager.SaveConfiguration() + // For now, we'll use the direct approach configFile := "hardn.yml" // Default config file - if err := config.SaveConfig(cfg, configFile); err != nil { - logging.LogError("Failed to save configuration: %v", err) + if err := config.SaveConfig(m.config, configFile); err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) } fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) ReadKey() -} \ No newline at end of file +} diff --git a/pkg/menu/env.go b/pkg/menu/env.go deleted file mode 100644 index d899e79..0000000 --- a/pkg/menu/env.go +++ /dev/null @@ -1,142 +0,0 @@ -// pkg/menu/env.go - -package menu - -import ( - "fmt" - "os" - "os/user" - "path/filepath" - "strings" - - "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/style" - "github.com/abbott/hardn/pkg/utils" -) - -// EnvironmentSettingsMenu displays and handles environment variable configuration -func EnvironmentSettingsMenu(cfg *config.Config) { - utils.PrintHeader() - fmt.Println(style.Bolded("Environment Variable Settings", style.Blue)) - - // Check if HARDN_CONFIG is set - configEnv := os.Getenv("HARDN_CONFIG") - if configEnv != "" { - fmt.Printf("\n%s Current HARDN_CONFIG: %s\n", style.BulletItem, style.Colored(style.Green, configEnv)) - } else { - fmt.Printf("\n%s HARDN_CONFIG environment variable is not set\n", style.BulletItem) - } - - // Check sudo preservation status - sudoPreservation := checkSudoEnvPreservation() - if sudoPreservation { - fmt.Printf("%s Sudo preservation: %s\n", style.BulletItem, style.Colored(style.Green, "Enabled")) - } else { - fmt.Printf("%s Sudo preservation: %s\n", style.BulletItem, style.Colored(style.Red, "Disabled")) - } - - // Create menu options - menuOptions := []style.MenuOption{ - {Number: 1, Title: "Setup sudo environment preservation", Description: "Configure sudo to preserve HARDN_CONFIG"}, - {Number: 2, Title: "Show environment variables guide", Description: "Learn how to set up environment variables"}, - } - - // Create and customize menu - menu := style.NewMenu("Select an option", menuOptions) - menu.SetExitOption(style.MenuOption{ - Number: 0, - Title: "Return to main menu", - Description: "", - }) - - // Display the menu - menu.Print() - - choice := ReadInput() - - switch choice { - case "1": - // Run sudo env setup - fmt.Printf("\n%s Setting up sudo environment preservation...\n", style.BulletItem) - - // Check if running as root - if os.Geteuid() != 0 { - fmt.Printf("\n%s This operation requires sudo privileges.\n", style.Colored(style.Red, style.SymWarning)) - fmt.Printf("%s Please run: sudo hardn setup-sudo-env\n", style.BulletItem) - } else { - err := utils.SetupSudoEnvPreservation() - if err != nil { - fmt.Printf("\n%s Failed to configure sudo: %v\n", style.Colored(style.Red, style.SymCrossMark), err) - } - } - - fmt.Printf("\n%s Press any key to continue...", style.BulletItem) - ReadKey() - EnvironmentSettingsMenu(cfg) - - case "2": - // Show environment guide - utils.PrintHeader() - fmt.Println(style.Bolded("Environment Variables Guide", style.Blue)) - - fmt.Printf("\n%s HARDN_CONFIG Environment Variable\n", style.Bolded("", style.Blue)) - fmt.Println(style.Dimmed("------------------------------------")) - fmt.Println("Set this variable to specify a custom config file location:") - fmt.Println(style.Colored(style.Cyan, " export HARDN_CONFIG=/path/to/your/config.yml")) - - fmt.Printf("\n%s Using with sudo\n", style.Bolded("", style.Blue)) - fmt.Println(style.Dimmed("------------------------------------")) - fmt.Println("To preserve the variable when using sudo, run:") - fmt.Println(style.Colored(style.Cyan, " sudo hardn setup-sudo-env")) - - fmt.Printf("\n%s For persistent configuration:\n", style.Bolded("", style.Blue)) - fmt.Println(style.Colored(style.Cyan, " echo 'export HARDN_CONFIG=/path/to/config.yml' >> ~/.bashrc")) - - fmt.Printf("\n%s Press any key to continue...", style.BulletItem) - ReadKey() - EnvironmentSettingsMenu(cfg) - - case "0": - return - - default: - fmt.Printf("\n%s Invalid option. Please try again.\n", style.Colored(style.Red, style.SymCrossMark)) - fmt.Printf("\n%s Press any key to continue...", style.BulletItem) - ReadKey() - EnvironmentSettingsMenu(cfg) - } -} - -// Helper function to check if sudo preservation is enabled -func checkSudoEnvPreservation() bool { - // First check for SUDO_USER which is the original user when using sudo - username := os.Getenv("SUDO_USER") - - // If that's empty, fall back to USER - if username == "" { - username = os.Getenv("USER") - - // If that's still empty, try to get username another way - if username == "" { - currentUser, err := user.Current() - if err != nil { - return false - } - username = currentUser.Username - } - } - - // Check if sudoers file exists - sudoersFile := filepath.Join("/etc/sudoers.d", username) - if _, err := os.Stat(sudoersFile); os.IsNotExist(err) { - return false - } - - // Check file content - data, err := os.ReadFile(sudoersFile) - if err != nil { - return false - } - - return strings.Contains(string(data), "env_keep += \"HARDN_CONFIG\"") -} \ No newline at end of file diff --git a/pkg/menu/environment_settings_menu.go b/pkg/menu/environment_settings_menu.go new file mode 100644 index 0000000..c3b9d8c --- /dev/null +++ b/pkg/menu/environment_settings_menu.go @@ -0,0 +1,150 @@ +// pkg/menu/environment_settings_menu.go +package menu + +import ( + "fmt" + "os" + + "github.com/abbott/hardn/pkg/application" + "github.com/abbott/hardn/pkg/config" + "github.com/abbott/hardn/pkg/style" + "github.com/abbott/hardn/pkg/utils" +) + +// EnvironmentSettingsMenu handles environment variable configuration +type EnvironmentSettingsMenu struct { + menuManager *application.MenuManager + config *config.Config +} + +// NewEnvironmentSettingsMenu creates a new EnvironmentSettingsMenu +func NewEnvironmentSettingsMenu( + menuManager *application.MenuManager, + config *config.Config, +) *EnvironmentSettingsMenu { + return &EnvironmentSettingsMenu{ + menuManager: menuManager, + config: config, + } +} + +// Show displays the environment settings menu and handles user input +func (m *EnvironmentSettingsMenu) Show() { + utils.PrintHeader() + fmt.Println(style.Bolded("Environment Variable Settings", style.Blue)) + + // Check if HARDN_CONFIG is set + configEnv := os.Getenv("HARDN_CONFIG") + if configEnv != "" { + fmt.Printf("\n%s Current HARDN_CONFIG: %s\n", style.BulletItem, style.Colored(style.Green, configEnv)) + } else { + fmt.Printf("\n%s HARDN_CONFIG environment variable is not set\n", style.BulletItem) + } + + // Check sudo preservation status + sudoPreservation := m.checkSudoEnvPreservation() + if sudoPreservation { + fmt.Printf("%s Sudo preservation: %s\n", style.BulletItem, style.Colored(style.Green, "Enabled")) + } else { + fmt.Printf("%s Sudo preservation: %s\n", style.BulletItem, style.Colored(style.Red, "Disabled")) + } + + // Create menu options + menuOptions := []style.MenuOption{ + {Number: 1, Title: "Setup sudo environment preservation", Description: "Configure sudo to preserve HARDN_CONFIG"}, + {Number: 2, Title: "Show environment variables guide", Description: "Learn how to set up environment variables"}, + } + + // Create and customize menu + menu := style.NewMenu("Select an option", menuOptions) + menu.SetExitOption(style.MenuOption{ + Number: 0, + Title: "Return to main menu", + Description: "", + }) + + // Display the menu + menu.Print() + + choice := ReadMenuInput() + + // Handle 'q' as a special exit case + if choice == "q" { + return + } + + switch choice { + case "1": + // Run sudo env setup + fmt.Printf("\n%s Setting up sudo environment preservation...\n", style.BulletItem) + + // Check if running as root + if os.Geteuid() != 0 { + fmt.Printf("\n%s This operation requires sudo privileges.\n", style.Colored(style.Red, style.SymWarning)) + fmt.Printf("%s Please run: sudo hardn setup-sudo-env\n", style.BulletItem) + } else { + if m.config.DryRun { + fmt.Printf("%s [DRY-RUN] Would configure sudo to preserve HARDN_CONFIG environment variable\n", style.BulletItem) + } else { + // Use application layer through menuManager + err := m.menuManager.SetupSudoPreservation() + if err != nil { + fmt.Printf("\n%s Failed to configure sudo: %v\n", style.Colored(style.Red, style.SymCrossMark), err) + } else { + fmt.Printf("\n%s Successfully configured sudo to preserve HARDN_CONFIG\n", style.Colored(style.Green, style.SymCheckMark)) + } + } + } + + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) + ReadKey() + m.Show() + + case "2": + // Show environment guide + m.showEnvironmentGuide() + m.Show() + + case "0": + return + + default: + fmt.Printf("\n%s Invalid option. Please try again.\n", style.Colored(style.Red, style.SymCrossMark)) + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) + ReadKey() + m.Show() + } +} + +// showEnvironmentGuide displays a guide on how to set up environment variables +func (m *EnvironmentSettingsMenu) showEnvironmentGuide() { + utils.PrintHeader() + fmt.Println(style.Bolded("Environment Variables Guide", style.Blue)) + + fmt.Printf("\n%s HARDN_CONFIG Environment Variable\n", style.Bolded("", style.Blue)) + fmt.Println(style.Dimmed("------------------------------------")) + fmt.Println("Set this variable to specify a custom config file location:") + fmt.Println(style.Colored(style.Cyan, " export HARDN_CONFIG=/path/to/your/config.yml")) + + fmt.Printf("\n%s Using with sudo\n", style.Bolded("", style.Blue)) + fmt.Println(style.Dimmed("------------------------------------")) + fmt.Println("To preserve the variable when using sudo, run:") + fmt.Println(style.Colored(style.Cyan, " sudo hardn setup-sudo-env")) + + fmt.Printf("\n%s For persistent configuration:\n", style.Bolded("", style.Blue)) + fmt.Println(style.Colored(style.Cyan, " echo 'export HARDN_CONFIG=/path/to/config.yml' >> ~/.bashrc")) + + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) + ReadKey() +} + +// checkSudoEnvPreservation checks if sudo preservation is enabled +func (m *EnvironmentSettingsMenu) checkSudoEnvPreservation() bool { + // Use application layer through menuManager + isEnabled, err := m.menuManager.IsSudoPreservationEnabled() + if err != nil { + // If there's an error checking, assume disabled + return false + } + return isEnabled +} diff --git a/pkg/menu/firewall.go b/pkg/menu/firewall_menu.go similarity index 56% rename from pkg/menu/firewall.go rename to pkg/menu/firewall_menu.go index 143ecb0..7f97f27 100644 --- a/pkg/menu/firewall.go +++ b/pkg/menu/firewall_menu.go @@ -1,60 +1,86 @@ -// pkg/menu/firewall.go - +// pkg/menu/firewall_menu.go package menu import ( "fmt" - "os/exec" "strconv" "strings" + "github.com/abbott/hardn/pkg/application" "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/firewall" - "github.com/abbott/hardn/pkg/logging" + "github.com/abbott/hardn/pkg/domain/model" "github.com/abbott/hardn/pkg/osdetect" "github.com/abbott/hardn/pkg/style" "github.com/abbott/hardn/pkg/utils" ) -// UfwMenu handles UFW firewall configuration -func UfwMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { +// FirewallMenu handles UFW firewall configuration +type FirewallMenu struct { + menuManager *application.MenuManager + config *config.Config + osInfo *osdetect.OSInfo +} + +// NewFirewallMenu creates a new FirewallMenu +func NewFirewallMenu( + menuManager *application.MenuManager, + config *config.Config, + osInfo *osdetect.OSInfo, +) *FirewallMenu { + return &FirewallMenu{ + menuManager: menuManager, + config: config, + osInfo: osInfo, + } +} + +// Show displays the firewall menu and handles user input +func (m *FirewallMenu) Show() { utils.PrintHeader() fmt.Println(style.Bolded("UFW Firewall Configuration", style.Blue)) - // Check current UFW status - isInstalled, isEnabled, isConfigured, rules := checkUfwStatus() + // Check current UFW status - this would ideally come from the application layer + isInstalled, isEnabled, isConfigured, rules, err := m.menuManager.GetFirewallStatus() + if err != nil { + fmt.Printf("\n%s Error getting firewall status: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + isInstalled = false + isEnabled = false + isConfigured = false + rules = []string{} + } // Display current status fmt.Println() fmt.Println(style.Bolded("Current Firewall Status:", style.Blue)) - + // Create formatter for status display formatter := style.NewStatusFormatter([]string{"UFW Installed", "UFW Status", "SSH Port"}, 2) - + // Installation status if isInstalled { fmt.Println(formatter.FormatSuccess("UFW Installed", "Yes", "Uncomplicated Firewall is available")) } else { fmt.Println(formatter.FormatWarning("UFW Installed", "No", "Firewall package not found")) } - + // Enabled status if isEnabled { fmt.Println(formatter.FormatSuccess("UFW Status", "Active", "Firewall is running")) } else { fmt.Println(formatter.FormatWarning("UFW Status", "Inactive", "Firewall is not running")) } - + // SSH port status - sshPortStr := strconv.Itoa(cfg.SshPort) + sshPortStr := strconv.Itoa(m.config.SshPort) sshPortDisplay := fmt.Sprintf("Port %s/tcp", sshPortStr) - - if cfg.SshPort == 22 { + + if m.config.SshPort == 22 { fmt.Println(formatter.FormatWarning("SSH Port", sshPortDisplay, "Using default port (consider changing)")) } else { fmt.Println(formatter.FormatSuccess("SSH Port", sshPortDisplay, "Using non-standard port (good security)")) } - + // Display configuration information fmt.Println() if isConfigured && len(rules) > 0 { @@ -71,13 +97,13 @@ func UfwMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { } else if isInstalled { fmt.Printf("%s No firewall rules configured\n", style.Colored(style.Yellow, style.SymWarning)) } - + // Display app profiles if defined - if len(cfg.UfwAppProfiles) > 0 { + if len(m.config.UfwAppProfiles) > 0 { fmt.Println() fmt.Println(style.Bolded("Configured Application Profiles:", style.Blue)) - for _, profile := range cfg.UfwAppProfiles { - fmt.Printf("%s %s: %s (%s)\n", + for _, profile := range m.config.UfwAppProfiles { + fmt.Printf("%s %s: %s (%s)\n", style.BulletItem, style.Bolded(profile.Name, style.Cyan), profile.Title, @@ -87,47 +113,47 @@ func UfwMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { // Create menu options var menuOptions []style.MenuOption - + // Install UFW if not installed if !isInstalled { menuOptions = append(menuOptions, style.MenuOption{ - Number: 1, - Title: "Install UFW", + Number: 1, + Title: "Install UFW", Description: "Install Uncomplicated Firewall package", }) } else { // Standard options when UFW is installed - + // Enable/disable option if !isEnabled { menuOptions = append(menuOptions, style.MenuOption{ - Number: 1, - Title: "Enable firewall", + Number: 1, + Title: "Enable firewall", Description: "Start UFW and set to run at boot", }) } else { menuOptions = append(menuOptions, style.MenuOption{ - Number: 1, - Title: "Disable firewall", + Number: 1, + Title: "Disable firewall", Description: "Stop UFW (not recommended)", }) } - + // Configure option menuOptions = append(menuOptions, style.MenuOption{ - Number: 2, - Title: "Configure firewall", + Number: 2, + Title: "Configure firewall", Description: "Set up default policies and SSH rules", }) - + // Manage application profiles menuOptions = append(menuOptions, style.MenuOption{ - Number: 3, - Title: "Manage application profiles", + Number: 3, + Title: "Manage application profiles", Description: "Configure custom application rules", }) } - + // Create menu menu := style.NewMenu("Select an option", menuOptions) menu.SetExitOption(style.MenuOption{ @@ -135,161 +161,163 @@ func UfwMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { Title: "Return to main menu", Description: "", }) - + // Display menu menu.Print() - - choice := ReadInput() - + + choice := ReadMenuInput() + + // Handle 'q' as a special exit case + if choice == "q" { + return + } + switch choice { case "1": if !isInstalled { - // Install UFW + // Install UFW - this should call through to an application service fmt.Println("\nInstalling UFW...") - - if cfg.DryRun { + + if m.config.DryRun { fmt.Printf("%s [DRY-RUN] Would install UFW package\n", style.BulletItem) } else { - var installCmd *exec.Cmd - if osInfo.OsType == "alpine" { - installCmd = exec.Command("apk", "add", "ufw") - } else { - installCmd = exec.Command("apt-get", "install", "-y", "ufw") - } - - if err := installCmd.Run(); err != nil { - fmt.Printf("\n%s Failed to install UFW: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - logging.LogError("Failed to install UFW: %v", err) - } else { - fmt.Printf("\n%s UFW installed successfully\n", - style.Colored(style.Green, style.SymCheckMark)) - } + // TODO: This should go through the application layer + // For now, we'll just provide a message + fmt.Printf("%s This operation isn't yet implemented in the new architecture\n", + style.Colored(style.Yellow, style.SymWarning)) } } else if isEnabled { - // Disable UFW - fmt.Printf("\n%s WARNING: Disabling the firewall will remove protection from your system.\n", + // Disable firewall through application layer + fmt.Printf("\n%s WARNING: Disabling the firewall will remove protection from your system.\n", style.Colored(style.Red, style.SymWarning)) fmt.Printf("%s Are you sure you want to disable UFW? (y/n): ", style.BulletItem) - + confirm := ReadInput() if strings.ToLower(confirm) == "y" || strings.ToLower(confirm) == "yes" { - if cfg.DryRun { + if m.config.DryRun { fmt.Printf("%s [DRY-RUN] Would disable UFW\n", style.BulletItem) } else { - disableCmd := exec.Command("ufw", "disable") - if err := disableCmd.Run(); err != nil { - fmt.Printf("\n%s Failed to disable UFW: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - logging.LogError("Failed to disable UFW: %v", err) - } else { - fmt.Printf("\n%s UFW disabled\n", - style.Colored(style.Yellow, style.SymInfo)) - } + // Call to application layer to disable firewall + // TODO: Implement this in MenuManager and FirewallManager + fmt.Printf("%s This operation isn't yet implemented in the new architecture\n", + style.Colored(style.Yellow, style.SymWarning)) } } else { fmt.Println("\nOperation cancelled. UFW remains enabled.") } } else { - // Enable UFW + // Enable firewall through application layer fmt.Println("\nEnabling UFW...") - - if cfg.DryRun { + + if m.config.DryRun { fmt.Printf("%s [DRY-RUN] Would enable UFW\n", style.BulletItem) } else { - // First ensure there's an SSH rule to prevent lockout - sshPort := strconv.Itoa(cfg.SshPort) - allowCmd := exec.Command("ufw", "allow", sshPort+"/tcp", "comment", "SSH") - if err := allowCmd.Run(); err != nil { - fmt.Printf("%s Warning: Failed to add SSH rule before enabling UFW\n", - style.Colored(style.Yellow, style.SymWarning)) - logging.LogWarning("Failed to add SSH rule before enabling UFW: %v", err) + // Convert app profiles to domain model format + var profiles []model.FirewallProfile + for _, profile := range m.config.UfwAppProfiles { + profiles = append(profiles, model.FirewallProfile{ + Name: profile.Name, + Title: profile.Title, + Description: profile.Description, + Ports: profile.Ports, + }) } - - // Enable UFW non-interactively - enableCmd := exec.Command("sh", "-c", "yes | ufw enable") - if err := enableCmd.Run(); err != nil { - fmt.Printf("\n%s Failed to enable UFW: %v\n", + + // Call application layer to configure firewall with profiles + err := m.menuManager.ConfigureSecureFirewall(m.config.SshPort, []int{}, profiles) + if err != nil { + fmt.Printf("\n%s Failed to enable and configure firewall: %v\n", style.Colored(style.Red, style.SymCrossMark), err) - logging.LogError("Failed to enable UFW: %v", err) } else { - fmt.Printf("\n%s UFW enabled successfully\n", + fmt.Printf("\n%s Firewall enabled and configured successfully\n", style.Colored(style.Green, style.SymCheckMark)) } } } - + // Return to firewall menu fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - UfwMenu(cfg, osInfo) - + m.Show() + case "2": // Configure UFW fmt.Println("\nConfiguring UFW firewall...") - - if cfg.DryRun { + + if m.config.DryRun { fmt.Printf("%s [DRY-RUN] Would configure UFW with default policies and SSH rules\n", style.BulletItem) - fmt.Printf("%s [DRY-RUN] SSH port: %d/tcp\n", style.BulletItem, cfg.SshPort) + fmt.Printf("%s [DRY-RUN] SSH port: %d/tcp\n", style.BulletItem, m.config.SshPort) } else { - if err := firewall.ConfigureUFW(cfg, osInfo); err != nil { - fmt.Printf("\n%s Failed to configure UFW: %v\n", + // Convert app profiles to domain model format + var profiles []model.FirewallProfile + for _, profile := range m.config.UfwAppProfiles { + profiles = append(profiles, model.FirewallProfile{ + Name: profile.Name, + Title: profile.Title, + Description: profile.Description, + Ports: profile.Ports, + }) + } + + // Call application layer to configure firewall + err := m.menuManager.ConfigureSecureFirewall(m.config.SshPort, []int{}, profiles) + if err != nil { + fmt.Printf("\n%s Failed to configure firewall: %v\n", style.Colored(style.Red, style.SymCrossMark), err) - logging.LogError("Failed to configure UFW: %v", err) } else { - fmt.Printf("\n%s UFW configured successfully\n", + fmt.Printf("\n%s Firewall configured successfully\n", style.Colored(style.Green, style.SymCheckMark)) - + // Show important rules fmt.Printf("%s Default policy: deny (incoming), allow (outgoing)\n", style.BulletItem) - fmt.Printf("%s SSH allowed on port %d/tcp\n", style.BulletItem, cfg.SshPort) - + fmt.Printf("%s SSH allowed on port %d/tcp\n", style.BulletItem, m.config.SshPort) + // Show app profiles if configured - if len(cfg.UfwAppProfiles) > 0 { - fmt.Printf("%s Application profiles: %d configured\n", - style.BulletItem, len(cfg.UfwAppProfiles)) + if len(m.config.UfwAppProfiles) > 0 { + fmt.Printf("%s Application profiles: %d configured\n", + style.BulletItem, len(m.config.UfwAppProfiles)) } } } - + case "3": // Manage application profiles - manageUfwAppProfilesMenu(cfg, osInfo) - UfwMenu(cfg, osInfo) + m.manageAppProfiles() + m.Show() return - + case "0": // Return to main menu return - + default: - fmt.Printf("\n%s Invalid option. Please try again.\n", + fmt.Printf("\n%s Invalid option. Please try again.\n", style.Colored(style.Red, style.SymCrossMark)) - + // Return to firewall menu fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - UfwMenu(cfg, osInfo) + m.Show() return } - + fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) ReadKey() } -// Helper function to manage UFW application profiles -func manageUfwAppProfilesMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { +// manageAppProfiles handles the application profiles management submenu +func (m *FirewallMenu) manageAppProfiles() { utils.PrintHeader() fmt.Println(style.Bolded("Manage UFW Application Profiles", style.Blue)) - + // Display current profiles fmt.Println() fmt.Println(style.Bolded("Configured Application Profiles:", style.Blue)) - - if len(cfg.UfwAppProfiles) == 0 { + + if len(m.config.UfwAppProfiles) == 0 { fmt.Printf("%s No application profiles configured\n", style.BulletItem) } else { - for i, profile := range cfg.UfwAppProfiles { + for i, profile := range m.config.UfwAppProfiles { fmt.Printf("%s %d: %s\n", style.BulletItem, i+1, style.Bolded(profile.Name, style.Cyan)) fmt.Printf(" Title: %s\n", profile.Title) fmt.Printf(" Description: %s\n", profile.Description) @@ -297,27 +325,27 @@ func manageUfwAppProfilesMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { fmt.Println() } } - + // Create menu options menuOptions := []style.MenuOption{ {Number: 1, Title: "Add application profile", Description: "Create a new UFW application profile"}, } - + // Only add remove option if profiles exist - if len(cfg.UfwAppProfiles) > 0 { + if len(m.config.UfwAppProfiles) > 0 { menuOptions = append(menuOptions, style.MenuOption{ - Number: 2, - Title: "Remove application profile", + Number: 2, + Title: "Remove application profile", Description: "Delete an existing UFW application profile", }) - + menuOptions = append(menuOptions, style.MenuOption{ - Number: 3, - Title: "Apply profiles", + Number: 3, + Title: "Apply profiles", Description: "Enable configured application profiles in UFW", }) } - + // Create menu menu := style.NewMenu("Select an option", menuOptions) menu.SetExitOption(style.MenuOption{ @@ -325,111 +353,116 @@ func manageUfwAppProfilesMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { Title: "Return to firewall menu", Description: "", }) - + // Display menu menu.Print() - - choice := ReadInput() - + + choice := ReadMenuInput() + + // Handle 'q' as a special exit case + if choice == "q" { + return + } + switch choice { case "1": // Add application profile - addUfwAppProfile(cfg, osInfo) - manageUfwAppProfilesMenu(cfg, osInfo) + m.addAppProfile() + m.manageAppProfiles() return - + case "2": // Remove application profile (only if profiles exist) - if len(cfg.UfwAppProfiles) == 0 { - fmt.Printf("\n%s No profiles to remove\n", + if len(m.config.UfwAppProfiles) == 0 { + fmt.Printf("\n%s No profiles to remove\n", style.Colored(style.Yellow, style.SymWarning)) } else { - removeUfwAppProfile(cfg, osInfo) + m.removeAppProfile() } - - manageUfwAppProfilesMenu(cfg, osInfo) + + m.manageAppProfiles() return - + case "3": // Apply profiles (only if profiles exist) - if len(cfg.UfwAppProfiles) == 0 { - fmt.Printf("\n%s No profiles to apply\n", + if len(m.config.UfwAppProfiles) == 0 { + fmt.Printf("\n%s No profiles to apply\n", style.Colored(style.Yellow, style.SymWarning)) } else { - applyUfwAppProfiles(cfg, osInfo) + m.applyAppProfiles() } - - manageUfwAppProfilesMenu(cfg, osInfo) + + m.manageAppProfiles() return - + case "0": // Return to firewall menu return - + default: - fmt.Printf("\n%s Invalid option. Please try again.\n", + fmt.Printf("\n%s Invalid option. Please try again.\n", style.Colored(style.Red, style.SymCrossMark)) - + // Return to app profiles menu fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - manageUfwAppProfilesMenu(cfg, osInfo) + m.manageAppProfiles() return } } -// Helper function to add a UFW application profile -func addUfwAppProfile(cfg *config.Config, osInfo *osdetect.OSInfo) { +// addAppProfile handles adding a new application profile +func (m *FirewallMenu) addAppProfile() { fmt.Println() fmt.Println(style.Bolded("Add UFW Application Profile:", style.Blue)) - + // Get profile details fmt.Printf("%s Enter profile name (e.g., 'WebServer'): ", style.BulletItem) name := ReadInput() - + if name == "" { - fmt.Printf("\n%s Profile name cannot be empty\n", + fmt.Printf("\n%s Profile name cannot be empty\n", style.Colored(style.Red, style.SymCrossMark)) return } - + // Check for duplicate name - for _, profile := range cfg.UfwAppProfiles { + for _, profile := range m.config.UfwAppProfiles { if strings.EqualFold(profile.Name, name) { - fmt.Printf("\n%s A profile with this name already exists\n", + fmt.Printf("\n%s A profile with this name already exists\n", style.Colored(style.Red, style.SymCrossMark)) return } } - + fmt.Printf("%s Enter profile title (e.g., 'Web Server'): ", style.BulletItem) title := ReadInput() - + fmt.Printf("%s Enter profile description: ", style.BulletItem) description := ReadInput() - + fmt.Printf("%s Enter ports (e.g., '80/tcp,443/tcp'): ", style.BulletItem) portsStr := ReadInput() - + if portsStr == "" { - fmt.Printf("\n%s Ports cannot be empty\n", + fmt.Printf("\n%s Ports cannot be empty\n", style.Colored(style.Red, style.SymCrossMark)) return } - + // Split ports by comma ports := strings.Split(portsStr, ",") - + // Validate port format for i, port := range ports { ports[i] = strings.TrimSpace(port) if !strings.Contains(ports[i], "/") { - fmt.Printf("\n%s Invalid port format '%s'. Must include protocol (e.g., '80/tcp')\n", + fmt.Printf("\n%s Invalid port format '%s'. Must include protocol (e.g., '80/tcp')\n", style.Colored(style.Red, style.SymCrossMark), ports[i]) return } } - + // Create new profile newProfile := config.UfwAppProfile{ Name: name, @@ -437,157 +470,88 @@ func addUfwAppProfile(cfg *config.Config, osInfo *osdetect.OSInfo) { Description: description, Ports: ports, } - + // Add to configuration - cfg.UfwAppProfiles = append(cfg.UfwAppProfiles, newProfile) - + m.config.UfwAppProfiles = append(m.config.UfwAppProfiles, newProfile) + // Save configuration - saveFirewallConfig(cfg) - - fmt.Printf("\n%s Application profile '%s' added successfully\n", + if err := config.SaveConfig(m.config, "hardn.yml"); err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + return + } + + fmt.Printf("\n%s Application profile '%s' added successfully\n", style.Colored(style.Green, style.SymCheckMark), name) } -// Helper function to remove a UFW application profile -func removeUfwAppProfile(cfg *config.Config, osInfo *osdetect.OSInfo) { +// removeAppProfile handles removing an application profile +func (m *FirewallMenu) removeAppProfile() { fmt.Println() fmt.Println(style.Bolded("Remove UFW Application Profile:", style.Blue)) - + // Display numbered list of profiles - for i, profile := range cfg.UfwAppProfiles { - fmt.Printf("%s %d: %s (%s)\n", + for i, profile := range m.config.UfwAppProfiles { + fmt.Printf("%s %d: %s (%s)\n", style.BulletItem, i+1, profile.Name, strings.Join(profile.Ports, ", ")) } - + // Get profile to remove - fmt.Printf("\n%s Enter profile number to remove (1-%d): ", - style.BulletItem, len(cfg.UfwAppProfiles)) + fmt.Printf("\n%s Enter profile number to remove (1-%d): ", + style.BulletItem, len(m.config.UfwAppProfiles)) numStr := ReadInput() - + // Parse number num, err := strconv.Atoi(numStr) - if err != nil || num < 1 || num > len(cfg.UfwAppProfiles) { - fmt.Printf("\n%s Invalid profile number\n", + if err != nil || num < 1 || num > len(m.config.UfwAppProfiles) { + fmt.Printf("\n%s Invalid profile number\n", style.Colored(style.Red, style.SymCrossMark)) return } - + // Get profile name for confirmation - profileName := cfg.UfwAppProfiles[num-1].Name - + profileName := m.config.UfwAppProfiles[num-1].Name + // Confirm removal - fmt.Printf("%s Are you sure you want to remove profile '%s'? (y/n): ", + fmt.Printf("%s Are you sure you want to remove profile '%s'? (y/n): ", style.BulletItem, profileName) confirm := ReadInput() - + if strings.ToLower(confirm) == "y" || strings.ToLower(confirm) == "yes" { // Remove profile (adjust for 0-based index) - cfg.UfwAppProfiles = append( - cfg.UfwAppProfiles[:num-1], - cfg.UfwAppProfiles[num:]... + m.config.UfwAppProfiles = append( + m.config.UfwAppProfiles[:num-1], + m.config.UfwAppProfiles[num:]..., ) - + // Save configuration - saveFirewallConfig(cfg) - - fmt.Printf("\n%s Application profile '%s' removed successfully\n", + if err := config.SaveConfig(m.config, "hardn.yml"); err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + return + } + + fmt.Printf("\n%s Application profile '%s' removed successfully\n", style.Colored(style.Green, style.SymCheckMark), profileName) } else { fmt.Println("\nRemoval cancelled.") } } -// Helper function to apply UFW application profiles -func applyUfwAppProfiles(cfg *config.Config, osInfo *osdetect.OSInfo) { +// applyAppProfiles handles applying application profiles +func (m *FirewallMenu) applyAppProfiles() { fmt.Println() fmt.Println(style.Bolded("Apply UFW Application Profiles:", style.Blue)) - - if cfg.DryRun { + + if m.config.DryRun { fmt.Printf("%s [DRY-RUN] Would write profiles to /etc/ufw/applications.d/hardn\n", style.BulletItem) - for _, profile := range cfg.UfwAppProfiles { - fmt.Printf("%s [DRY-RUN] Profile: %s (%s)\n", + for _, profile := range m.config.UfwAppProfiles { + fmt.Printf("%s [DRY-RUN] Profile: %s (%s)\n", style.BulletItem, profile.Name, strings.Join(profile.Ports, ", ")) } } else { - // Use firewall package to write and apply profiles - if err := firewall.WriteUfwAppProfiles(cfg, osInfo); err != nil { - fmt.Printf("\n%s Failed to apply application profiles: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - logging.LogError("Failed to apply UFW application profiles: %v", err) - } else { - fmt.Printf("\n%s Application profiles applied successfully\n", - style.Colored(style.Green, style.SymCheckMark)) - } + // This should call the application layer, but for now we'll just provide a message + fmt.Printf("\n%s This operation isn't yet implemented in the new architecture\n", + style.Colored(style.Yellow, style.SymWarning)) } } - -// Helper function to save firewall configuration -func saveFirewallConfig(cfg *config.Config) { - // Save config changes - configFile := "hardn.yml" // Default config file - if err := config.SaveConfig(cfg, configFile); err != nil { - logging.LogError("Failed to save configuration: %v", err) - fmt.Printf("\n%s Failed to save configuration: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - } -} - -// Helper function to check UFW status -func checkUfwStatus() (bool, bool, bool, []string) { - // Check if UFW is installed - _, err := exec.LookPath("ufw") - isInstalled := (err == nil) - - // Default values if not installed - isEnabled := false - isConfigured := false - var rules []string - - if isInstalled { - // Check if UFW is enabled - statusCmd := exec.Command("ufw", "status") - statusOutput, err := statusCmd.CombinedOutput() - if err == nil { - statusText := string(statusOutput) - isEnabled = strings.Contains(statusText, "Status: active") - - // Extract rules (skip header lines) - lines := strings.Split(statusText, "\n") - ruleSection := false - for _, line := range lines { - line = strings.TrimSpace(line) - - // Skip empty lines - if line == "" { - continue - } - - // Skip header lines - if strings.Contains(line, "Status:") || - strings.Contains(line, "Logging:") || - strings.Contains(line, "Default:") || - strings.Contains(line, "New profiles:") || - strings.Contains(line, "To Action From") { - continue - } - - // Check if we've reached the rule section - if strings.Contains(line, "--") { - ruleSection = true - continue - } - - // Add rule lines - if ruleSection && line != "" { - rules = append(rules, line) - } - } - - // Check if we have default policies configured - isConfigured = strings.Contains(statusText, "deny (incoming)") && - strings.Contains(statusText, "allow (outgoing)") - } - } - - return isInstalled, isEnabled, isConfigured, rules -} \ No newline at end of file diff --git a/pkg/menu/help.go b/pkg/menu/help_menu.go similarity index 80% rename from pkg/menu/help.go rename to pkg/menu/help_menu.go index 08a7460..04a6fe0 100644 --- a/pkg/menu/help.go +++ b/pkg/menu/help_menu.go @@ -1,5 +1,4 @@ -// pkg/menu/help.go - +// pkg/menu/help_menu.go package menu import ( @@ -9,60 +8,73 @@ import ( "github.com/abbott/hardn/pkg/utils" ) -// HelpMenu displays usage information and command-line options -func HelpMenu() { +// HelpMenu provides usage information and examples +type HelpMenu struct { + // The help menu doesn't need many dependencies + // since it just displays information +} + +// NewHelpMenu creates a new HelpMenu +func NewHelpMenu() *HelpMenu { + return &HelpMenu{} +} + +// Show displays the help menu with command line options and examples +func (m *HelpMenu) Show() { utils.PrintLogo() fmt.Println(style.Bolded("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~", style.BrightGreen)) - + fmt.Println(style.Bolded("\nCommand Line Usage:", style.Blue)) fmt.Println(style.Dimmed("-----------------------------------------------------")) fmt.Println(" hardn [options]") - + fmt.Println(style.Bolded("\nCommand Line Options:", style.Blue)) fmt.Println(style.Dimmed("-----------------------------------------------------")) - + // Create a formatter with appropriate column widths formatter := style.NewStatusFormatter([]string{"Option", "Description"}, 4) // Display all command line options - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-f, --config-file string", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-f, --config-file string", "Configuration file path", style.Cyan, "", "light")) - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-u, --username string", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-u, --username string", "Specify username to create", style.Cyan, "", "light")) - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-c, --create-user", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-c, --create-user", "Create user", style.Cyan, "", "light")) - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-d, --disable-root", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-d, --disable-root", "Disable root SSH access", style.Cyan, "", "light")) - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-g, --configure-dns", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-g, --configure-dns", "Configure DNS resolvers", style.Cyan, "", "light")) - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-w, --configure-ufw", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-w, --configure-ufw", "Configure UFW", style.Cyan, "", "light")) - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-r, --run-all", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-r, --run-all", "Run all hardening operations", style.Cyan, "", "light")) - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-n, --dry-run", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-n, --dry-run", "Preview changes without applying them", style.Cyan, "", "light")) - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-l, --install-linux", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-l, --install-linux", "Install specified Linux packages", style.Cyan, "", "light")) - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-i, --install-python", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-i, --install-python", "Install specified Python packages", style.Cyan, "", "light")) - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-a, --install-all", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-a, --install-all", "Install all specified packages", style.Cyan, "", "light")) - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-s, --configure-sources", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-s, --configure-sources", "Configure package sources", style.Cyan, "", "light")) - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-p, --print-logs", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-p, --print-logs", "View logs", style.Cyan, "", "light")) - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-h, --help", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-h, --help", "View usage information", style.Cyan, "", "light")) - + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "-e, --setup-sudo-env", + "Configure sudoers for HARDN_CONFIG", style.Cyan, "", "light")) + // Usage examples fmt.Println(style.Bolded("\nExamples:", style.Blue)) fmt.Println(style.Dimmed("-----------------------------------------------------")) fmt.Printf("%s Run all hardening operations:\n", style.BulletItem) fmt.Printf(" %s\n", style.Colored(style.Cyan, "sudo hardn -r")) - + fmt.Printf("\n%s Create a non-root user with SSH access:\n", style.BulletItem) fmt.Printf(" %s\n", style.Colored(style.Cyan, "sudo hardn -u george -c")) - + fmt.Printf("\n%s Using a custom configuration file:\n", style.BulletItem) fmt.Printf(" %s\n", style.Colored(style.Cyan, "sudo hardn -f /path/to/config.yml")) @@ -72,4 +84,4 @@ func HelpMenu() { fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) ReadKey() -} \ No newline at end of file +} diff --git a/pkg/menu/input.go b/pkg/menu/input.go new file mode 100644 index 0000000..1cb8b84 --- /dev/null +++ b/pkg/menu/input.go @@ -0,0 +1,144 @@ +// pkg/menu/input.go +package menu + +import ( + "bufio" + "fmt" + "os" + "os/exec" + "strings" +) + +// Shared reader for all menus +var reader = bufio.NewReader(os.Stdin) + +// ReadInput reads a line of input from the user +func ReadInput() string { + input, _ := reader.ReadString('\n') + return strings.TrimSpace(input) +} + +// ReadKey reads a single key pressed by the user +func ReadKey() string { + // Configure terminal for raw input + if err := exec.Command("stty", "-F", "/dev/tty", "cbreak", "min", "1").Run(); err != nil { + fmt.Printf("Warning: Failed to configure terminal: %v\n", err) + // Try to continue anyway + } + defer func() { + if err := exec.Command("stty", "-F", "/dev/tty", "-cbreak").Run(); err != nil { + fmt.Printf("Warning: Failed to restore terminal: %v\n", err) + } + }() + + // Read the first byte + var firstByte = make([]byte, 1) + n, err := os.Stdin.Read(firstByte) + if err != nil || n != 1 { + return "" // Return empty on read error + } + + // If it's an escape character (27), read and discard the sequence + if firstByte[0] == 27 { + // Read and discard the next two bytes (common for arrow keys) + var discardBytes = make([]byte, 2) + _, err := os.Stdin.Read(discardBytes) + if err != nil { + // Just log and continue if this fails + fmt.Printf("Warning: Failed to read escape sequence: %v\n", err) + } + // Return empty to indicate a special key was pressed + return "" + } + + return string(firstByte) +} + +// ReadMenuInput reads input for a menu, supporting both immediate 'q' exit and +// normal buffered input with backspace support for other entries +func ReadMenuInput() string { + // fmt.Print("> ") + + var buffer strings.Builder + var displayedChars int + + for { + // Read a single key in raw mode + key := ReadRawKey() + + // Handle Enter (return the result) + if key == "\r" || key == "\n" { + fmt.Println() // Move to next line + return buffer.String() + } + + // Handle immediate 'q' exit if it's the first key + if buffer.Len() == 0 && (key == "q" || key == "Q") { + fmt.Println("q") + return "q" + } + + // Handle backspace/delete + if key == "\b" || key == "\x7f" { // \b = backspace, \x7f = delete + if buffer.Len() > 0 { + // Remove last character from our buffer + str := buffer.String() + buffer.Reset() + buffer.WriteString(str[:len(str)-1]) + + // Update display (backspace, space, backspace) + fmt.Print("\b \b") + displayedChars-- + } + continue + } + + // Only accept digits, q/Q and control characters + if (key >= "0" && key <= "9") || key == "q" || key == "Q" { + buffer.WriteString(key) + fmt.Print(key) // Echo the character + displayedChars++ + } + } +} + +// ReadRawKey reads a single key in raw mode +func ReadRawKey() string { + // Configure terminal for raw input + if err := exec.Command("stty", "-F", "/dev/tty", "raw", "-echo").Run(); err != nil { + fmt.Printf("Warning: Failed to configure terminal: %v\n", err) + // Try to continue anyway + } + defer func() { + if err := exec.Command("stty", "-F", "/dev/tty", "sane").Run(); err != nil { + fmt.Printf("Warning: Failed to restore terminal: %v\n", err) + } + }() + + var b = make([]byte, 1) + n, err := os.Stdin.Read(b) + if err != nil || n != 1 { + return "" // Return empty on read error + } + + // Convert control characters to strings + if b[0] == 13 { + return "\r" // Return/Enter key + } else if b[0] == 127 { + return "\x7f" // Delete key + } else if b[0] == 8 { + return "\b" // Backspace key + } else if b[0] == 27 { + // Possibly an arrow key or other escape sequence + // Read and discard two more bytes + var seq = make([]byte, 2) + _, err := os.Stdin.Read(seq) + if err != nil { + // Just log and continue if this fails + fmt.Printf("Warning: Failed to read escape sequence: %v\n", err) + } + return "" // Ignore escape sequences + } + + return string(b) +} diff --git a/pkg/menu/linux.go b/pkg/menu/linux.go deleted file mode 100644 index afd75aa..0000000 --- a/pkg/menu/linux.go +++ /dev/null @@ -1,237 +0,0 @@ -// pkg/menu/linux.go - -package menu - -import ( - "fmt" - "strings" - - "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/logging" - "github.com/abbott/hardn/pkg/osdetect" - "github.com/abbott/hardn/pkg/packages" - "github.com/abbott/hardn/pkg/style" - "github.com/abbott/hardn/pkg/utils" -) - -// LinuxPackagesMenu handles installation of Linux packages -func LinuxPackagesMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { - utils.PrintHeader() - fmt.Println(style.Bolded("Linux Packages Installation", style.Blue)) - - // Display current packages - fmt.Println() - fmt.Println(style.Bolded("Configured Packages:", style.Blue)) - - if osInfo.OsType == "alpine" { - // Alpine packages - if len(cfg.AlpineCorePackages) > 0 { - fmt.Printf("%s Core packages: %s\n", style.BulletItem, - style.Colored(style.Cyan, strings.Join(cfg.AlpineCorePackages, ", "))) - } - - if len(cfg.AlpineDmzPackages) > 0 { - fmt.Printf("%s DMZ packages: %s\n", style.BulletItem, - style.Colored(style.Cyan, strings.Join(cfg.AlpineDmzPackages, ", "))) - } - - if len(cfg.AlpineLabPackages) > 0 { - fmt.Printf("%s Lab packages: %s\n", style.BulletItem, - style.Colored(style.Cyan, strings.Join(cfg.AlpineLabPackages, ", "))) - } - } else { - // Debian/Ubuntu packages - if len(cfg.LinuxCorePackages) > 0 { - fmt.Printf("%s Core packages: %s\n", style.BulletItem, - style.Colored(style.Cyan, strings.Join(cfg.LinuxCorePackages, ", "))) - } - - if len(cfg.LinuxDmzPackages) > 0 { - fmt.Printf("%s DMZ packages: %s\n", style.BulletItem, - style.Colored(style.Cyan, strings.Join(cfg.LinuxDmzPackages, ", "))) - } - - if len(cfg.LinuxLabPackages) > 0 { - fmt.Printf("%s Lab packages: %s\n", style.BulletItem, - style.Colored(style.Cyan, strings.Join(cfg.LinuxLabPackages, ", "))) - } - } - - // Check subnet status for package selection - isDmz, _ := utils.CheckSubnet(cfg.DmzSubnet) - if isDmz { - fmt.Printf("\n%s DMZ subnet detected: %s\n", - style.Colored(style.Yellow, style.SymInfo), - style.Colored(style.Yellow, cfg.DmzSubnet)) - fmt.Printf("%s Only DMZ packages will be installed\n", style.BulletItem) - } else { - fmt.Printf("\n%s Not in DMZ subnet\n", - style.Colored(style.Green, style.SymInfo)) - fmt.Printf("%s Both DMZ and Lab packages will be installed\n", style.BulletItem) - } - - // Create menu options - menuOptions := []style.MenuOption{ - {Number: 1, Title: "Install Core Packages", Description: "Install essential system packages"}, - {Number: 2, Title: "Install DMZ Packages", Description: "Install packages for DMZ environments"}, - {Number: 3, Title: "Install Lab Packages", Description: "Install packages for development/lab environments"}, - {Number: 4, Title: "Install All Packages", Description: "Install all configured Linux packages"}, - } - - // Create menu - menu := style.NewMenu("Select an option", menuOptions) - menu.SetExitOption(style.MenuOption{ - Number: 0, - Title: "Return to main menu", - Description: "", - }) - - // Display menu - menu.Print() - - choice := ReadInput() - - switch choice { - case "1": - // Install core packages - fmt.Println("\nInstalling Core Linux packages...") - - if osInfo.OsType == "alpine" { - if len(cfg.AlpineCorePackages) > 0 { - installPackages(cfg.AlpineCorePackages, "Core", osInfo, cfg) - } else { - fmt.Printf("\n%s No Alpine Core packages configured\n", - style.Colored(style.Yellow, style.SymWarning)) - } - } else { - if len(cfg.LinuxCorePackages) > 0 { - installPackages(cfg.LinuxCorePackages, "Core", osInfo, cfg) - } else { - fmt.Printf("\n%s No Linux Core packages configured\n", - style.Colored(style.Yellow, style.SymWarning)) - } - } - - case "2": - // Install DMZ packages - fmt.Println("\nInstalling DMZ Linux packages...") - - if osInfo.OsType == "alpine" { - if len(cfg.AlpineDmzPackages) > 0 { - installPackages(cfg.AlpineDmzPackages, "DMZ", osInfo, cfg) - } else { - fmt.Printf("\n%s No Alpine DMZ packages configured\n", - style.Colored(style.Yellow, style.SymWarning)) - } - } else { - if len(cfg.LinuxDmzPackages) > 0 { - installPackages(cfg.LinuxDmzPackages, "DMZ", osInfo, cfg) - } else { - fmt.Printf("\n%s No Linux DMZ packages configured\n", - style.Colored(style.Yellow, style.SymWarning)) - } - } - - case "3": - // Install Lab packages - fmt.Println("\nInstalling Lab Linux packages...") - - if osInfo.OsType == "alpine" { - if len(cfg.AlpineLabPackages) > 0 { - installPackages(cfg.AlpineLabPackages, "Lab", osInfo, cfg) - } else { - fmt.Printf("\n%s No Alpine Lab packages configured\n", - style.Colored(style.Yellow, style.SymWarning)) - } - } else { - if len(cfg.LinuxLabPackages) > 0 { - installPackages(cfg.LinuxLabPackages, "Lab", osInfo, cfg) - } else { - fmt.Printf("\n%s No Linux Lab packages configured\n", - style.Colored(style.Yellow, style.SymWarning)) - } - } - - case "4": - // Install all packages - fmt.Println("\nInstalling All Linux packages...") - fmt.Println(style.Dimmed("This may take some time. Please wait...")) - - if osInfo.OsType == "alpine" { - // Install Alpine packages - if len(cfg.AlpineCorePackages) > 0 { - installPackages(cfg.AlpineCorePackages, "Core", osInfo, cfg) - } - - if isDmz { - if len(cfg.AlpineDmzPackages) > 0 { - installPackages(cfg.AlpineDmzPackages, "DMZ", osInfo, cfg) - } - } else { - if len(cfg.AlpineDmzPackages) > 0 { - installPackages(cfg.AlpineDmzPackages, "DMZ", osInfo, cfg) - } - - if len(cfg.AlpineLabPackages) > 0 { - installPackages(cfg.AlpineLabPackages, "Lab", osInfo, cfg) - } - } - } else { - // Install Debian/Ubuntu packages - if len(cfg.LinuxCorePackages) > 0 { - installPackages(cfg.LinuxCorePackages, "Core", osInfo, cfg) - } - - if isDmz { - if len(cfg.LinuxDmzPackages) > 0 { - installPackages(cfg.LinuxDmzPackages, "DMZ", osInfo, cfg) - } - } else { - if len(cfg.LinuxDmzPackages) > 0 { - installPackages(cfg.LinuxDmzPackages, "DMZ", osInfo, cfg) - } - - if len(cfg.LinuxLabPackages) > 0 { - installPackages(cfg.LinuxLabPackages, "Lab", osInfo, cfg) - } - } - } - - fmt.Printf("\n%s All Linux packages installed successfully!\n", - style.Colored(style.Green, style.SymCheckMark)) - - case "0": - return - - default: - fmt.Printf("\n%s Invalid option. No changes were made.\n", - style.Colored(style.Yellow, style.SymWarning)) - } - - fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) - ReadKey() -} - -// Helper function to install packages with nice formatting -func installPackages(pkgs []string, pkgType string, osInfo *osdetect.OSInfo, cfg *config.Config) { - if len(pkgs) == 0 { - return - } - - fmt.Printf("\n%s Installing %s packages: %s\n", - style.BulletItem, - pkgType, - style.Dimmed(strings.Join(pkgs, ", "))) - - if err := packages.InstallPackages(pkgs, osInfo, cfg); err != nil { - fmt.Printf("\n%s Failed to install %s packages: %v\n", - style.Colored(style.Red, style.SymCrossMark), - pkgType, - err) - logging.LogError("Failed to install %s packages: %v", pkgType, err) - } else { - fmt.Printf("\n%s %s packages installed successfully!\n", - style.Colored(style.Green, style.SymCheckMark), - pkgType) - } -} \ No newline at end of file diff --git a/pkg/menu/linux_packages_menu.go b/pkg/menu/linux_packages_menu.go new file mode 100644 index 0000000..af5ea54 --- /dev/null +++ b/pkg/menu/linux_packages_menu.go @@ -0,0 +1,271 @@ +// pkg/menu/linux_packages_menu.go +package menu + +import ( + "fmt" + "strings" + + "github.com/abbott/hardn/pkg/application" + "github.com/abbott/hardn/pkg/config" + "github.com/abbott/hardn/pkg/interfaces" + "github.com/abbott/hardn/pkg/osdetect" + "github.com/abbott/hardn/pkg/style" + "github.com/abbott/hardn/pkg/utils" +) + +// LinuxPackagesMenu handles Linux packages installation +type LinuxPackagesMenu struct { + menuManager *application.MenuManager + config *config.Config + osInfo *osdetect.OSInfo +} + +// NewLinuxPackagesMenu creates a new LinuxPackagesMenu +func NewLinuxPackagesMenu( + menuManager *application.MenuManager, + config *config.Config, + osInfo *osdetect.OSInfo, +) *LinuxPackagesMenu { + return &LinuxPackagesMenu{ + menuManager: menuManager, + config: config, + osInfo: osInfo, + } +} + +// Show displays the Linux packages menu and handles user input +func (m *LinuxPackagesMenu) Show() { + utils.PrintHeader() + fmt.Println(style.Bolded("Linux Packages Installation", style.Blue)) + + // Display current packages + fmt.Println() + fmt.Println(style.Bolded("Configured Packages:", style.Blue)) + + if m.osInfo.OsType == "alpine" { + // Alpine packages + if len(m.config.AlpineCorePackages) > 0 { + fmt.Printf("%sCore packages: %s\n", style.BulletItem, + style.Colored(style.Cyan, strings.Join(m.config.AlpineCorePackages, ", "))) + } + + if len(m.config.AlpineDmzPackages) > 0 { + fmt.Printf("%sDMZ packages: %s\n", style.BulletItem, + style.Colored(style.Cyan, strings.Join(m.config.AlpineDmzPackages, ", "))) + } + + if len(m.config.AlpineLabPackages) > 0 { + fmt.Printf("%sLab packages: %s\n", style.BulletItem, + style.Colored(style.Cyan, strings.Join(m.config.AlpineLabPackages, ", "))) + } + } else { + // Debian/Ubuntu packages + if len(m.config.LinuxCorePackages) > 0 { + fmt.Printf("%sCore packages: %s\n", style.BulletItem, + style.Colored(style.Cyan, strings.Join(m.config.LinuxCorePackages, ", "))) + } + + if len(m.config.LinuxDmzPackages) > 0 { + fmt.Printf("%sDMZ packages: %s\n", style.BulletItem, + style.Colored(style.Cyan, strings.Join(m.config.LinuxDmzPackages, ", "))) + } + + if len(m.config.LinuxLabPackages) > 0 { + fmt.Printf("%sLab packages: %s\n", style.BulletItem, + style.Colored(style.Cyan, strings.Join(m.config.LinuxLabPackages, ", "))) + } + } + + // Check subnet status for package selection + provider := interfaces.NewProvider() + isDmz, _ := utils.CheckSubnet(m.config.DmzSubnet, provider.Network) + if isDmz { + fmt.Printf("\n%s DMZ subnet detected: %s\n", + style.Colored(style.Yellow, style.SymInfo), + style.Colored(style.Yellow, m.config.DmzSubnet)) + fmt.Printf("%sOnly Core and DMZ packages can be installed\n", style.BulletItem) + } else { + fmt.Printf("\n%s Not in DMZ subnet\n", + style.Colored(style.Green, style.SymInfo)) + fmt.Printf("%sCore, DMZ and Lab packages can be installed\n", style.BulletItem) + } + + // Create menu options + menuOptions := []style.MenuOption{ + {Number: 1, Title: "Install Core Packages", Description: "Install essential system packages"}, + {Number: 2, Title: "Install DMZ Packages", Description: "Install packages for DMZ environments"}, + {Number: 3, Title: "Install Lab Packages", Description: "Install packages for development/lab environments"}, + {Number: 4, Title: "Install All Packages", Description: "Install all configured Linux packages"}, + } + + // Create menu + menu := style.NewMenu("Select an option", menuOptions) + menu.SetExitOption(style.MenuOption{ + Number: 0, + Title: "Return to main menu", + Description: "", + }) + + // Display menu + menu.Print() + + choice := ReadMenuInput() + + // Handle 'q' as a special exit case + if choice == "q" { + return + } + + switch choice { + case "1": + // Install core packages + fmt.Println("\nInstalling Core Linux packages...") + + if m.osInfo.OsType == "alpine" { + if len(m.config.AlpineCorePackages) > 0 { + m.installPackages(m.config.AlpineCorePackages, "Core") + } else { + fmt.Printf("\n%s No Alpine Core packages configured\n", + style.Colored(style.Yellow, style.SymWarning)) + } + } else { + if len(m.config.LinuxCorePackages) > 0 { + m.installPackages(m.config.LinuxCorePackages, "Core") + } else { + fmt.Printf("\n%s No Linux Core packages configured\n", + style.Colored(style.Yellow, style.SymWarning)) + } + } + + case "2": + // Install DMZ packages + fmt.Println("\nInstalling DMZ Linux packages...") + + if m.osInfo.OsType == "alpine" { + if len(m.config.AlpineDmzPackages) > 0 { + m.installPackages(m.config.AlpineDmzPackages, "DMZ") + } else { + fmt.Printf("\n%s No Alpine DMZ packages configured\n", + style.Colored(style.Yellow, style.SymWarning)) + } + } else { + if len(m.config.LinuxDmzPackages) > 0 { + m.installPackages(m.config.LinuxDmzPackages, "DMZ") + } else { + fmt.Printf("\n%s No Linux DMZ packages configured\n", + style.Colored(style.Yellow, style.SymWarning)) + } + } + + case "3": + // Install Lab packages + fmt.Println("\nInstalling Lab Linux packages...") + + if m.osInfo.OsType == "alpine" { + if len(m.config.AlpineLabPackages) > 0 { + m.installPackages(m.config.AlpineLabPackages, "Lab") + } else { + fmt.Printf("\n%s No Alpine Lab packages configured\n", + style.Colored(style.Yellow, style.SymWarning)) + } + } else { + if len(m.config.LinuxLabPackages) > 0 { + m.installPackages(m.config.LinuxLabPackages, "Lab") + } else { + fmt.Printf("\n%s No Linux Lab packages configured\n", + style.Colored(style.Yellow, style.SymWarning)) + } + } + + case "4": + // Install all packages + fmt.Println("\nInstalling All Linux packages...") + fmt.Println(style.Dimmed("This may take some time. Please wait...")) + + if m.osInfo.OsType == "alpine" { + // Install Alpine packages + if len(m.config.AlpineCorePackages) > 0 { + m.installPackages(m.config.AlpineCorePackages, "Core") + } + + if isDmz { + if len(m.config.AlpineDmzPackages) > 0 { + m.installPackages(m.config.AlpineDmzPackages, "DMZ") + } + } else { + if len(m.config.AlpineDmzPackages) > 0 { + m.installPackages(m.config.AlpineDmzPackages, "DMZ") + } + + if len(m.config.AlpineLabPackages) > 0 { + m.installPackages(m.config.AlpineLabPackages, "Lab") + } + } + } else { + // Install Debian/Ubuntu packages + if len(m.config.LinuxCorePackages) > 0 { + m.installPackages(m.config.LinuxCorePackages, "Core") + } + + if isDmz { + if len(m.config.LinuxDmzPackages) > 0 { + m.installPackages(m.config.LinuxDmzPackages, "DMZ") + } + } else { + if len(m.config.LinuxDmzPackages) > 0 { + m.installPackages(m.config.LinuxDmzPackages, "DMZ") + } + + if len(m.config.LinuxLabPackages) > 0 { + m.installPackages(m.config.LinuxLabPackages, "Lab") + } + } + } + + fmt.Printf("\n%s All Linux packages installed successfully!\n", + style.Colored(style.Green, style.SymCheckMark)) + + case "0": + return + + default: + fmt.Printf("\n%s Invalid option. No changes were made.\n", + style.Colored(style.Yellow, style.SymWarning)) + } + + fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) + ReadKey() +} + +// installPackages handles installing packages with nice formatting +func (m *LinuxPackagesMenu) installPackages(pkgs []string, pkgType string) { + if len(pkgs) == 0 { + return + } + + fmt.Printf("\n%s Installing %s packages: %s\n", + style.BulletItem, + pkgType, + style.Dimmed(strings.Join(pkgs, ", "))) + + if m.config.DryRun { + fmt.Printf("\n%s [DRY-RUN] Would install %s packages: %s\n", + style.Colored(style.Green, style.SymInfo), + pkgType, + strings.Join(pkgs, ", ")) + return + } + + // Use the application layer through menuManager + err := m.menuManager.InstallLinuxPackages(pkgs, pkgType) + if err != nil { + fmt.Printf("\n%s Failed to install %s packages: %v\n", + style.Colored(style.Red, style.SymCrossMark), + pkgType, + err) + } else { + fmt.Printf("\n%s %s packages installed successfully\n", + style.Colored(style.Green, style.SymCheckMark), + pkgType) + } +} diff --git a/pkg/menu/logs.go b/pkg/menu/logs.go deleted file mode 100644 index a7dd7ca..0000000 --- a/pkg/menu/logs.go +++ /dev/null @@ -1,32 +0,0 @@ -// pkg/menu/logs.go - -package menu - -import ( - "fmt" - - "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/logging" - "github.com/abbott/hardn/pkg/style" - "github.com/abbott/hardn/pkg/utils" -) - -// ViewLogsMenu displays the contents of the log file -func ViewLogsMenu(cfg *config.Config) { - utils.PrintHeader() - fmt.Println(style.Bolded("View Logs", style.Blue)) - - // Display log file path - fmt.Printf("\n%s Log file: %s\n", - style.BulletItem, style.Colored(style.Cyan, cfg.LogFile)) - - // Print separator before log content - fmt.Println(style.Bolded("\nLog Contents:", style.Blue)) - fmt.Println(style.Dimmed("-----------------------------------------------------")) - - // Use the logging package to print the logs - logging.PrintLogs(cfg.LogFile) - - fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) - ReadKey() -} \ No newline at end of file diff --git a/pkg/menu/logs_menu.go b/pkg/menu/logs_menu.go new file mode 100644 index 0000000..e5b7204 --- /dev/null +++ b/pkg/menu/logs_menu.go @@ -0,0 +1,64 @@ +// pkg/menu/logs_menu.go +package menu + +import ( + "fmt" + + "github.com/abbott/hardn/pkg/application" + "github.com/abbott/hardn/pkg/config" + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/style" + "github.com/abbott/hardn/pkg/utils" +) + +// LogsMenu handles viewing log information +type LogsMenu struct { + menuManager *application.MenuManager + config *config.Config +} + +// NewLogsMenu creates a new LogsMenu +func NewLogsMenu( + menuManager *application.MenuManager, + config *config.Config, +) *LogsMenu { + return &LogsMenu{ + menuManager: menuManager, + config: config, + } +} + +// Show displays the logs menu and handles user input +func (m *LogsMenu) Show() { + utils.PrintHeader() + fmt.Println(style.Bolded("View Logs", style.Blue)) + + // Get log configuration + logConfig, err := m.menuManager.GetLogConfig() + if err != nil { + fmt.Printf("\n%s Error getting log configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + // Create a domain model LogsConfig from the application config + logConfig = &model.LogsConfig{ + LogFilePath: m.config.LogFile, + } + } + + // Display log file path + fmt.Printf("\n%s Log file: %s\n", + style.BulletItem, style.Colored(style.Cyan, logConfig.LogFilePath)) + + // Print separator before log content + fmt.Println(style.Bolded("\nLog Contents:", style.Blue)) + fmt.Println(style.Dimmed("-----------------------------------------------------")) + + // Use the menu manager to print the logs + err = m.menuManager.PrintLogs() + if err != nil { + fmt.Printf("\n%s Error displaying logs: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } + + fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) + ReadKey() +} diff --git a/pkg/menu/main_menu.go b/pkg/menu/main_menu.go new file mode 100644 index 0000000..41baf27 --- /dev/null +++ b/pkg/menu/main_menu.go @@ -0,0 +1,472 @@ +// pkg/menu/main_menu.go +package menu + +import ( + "fmt" + "os" + "strings" + + "github.com/abbott/hardn/pkg/application" + "github.com/abbott/hardn/pkg/config" + "github.com/abbott/hardn/pkg/osdetect" + "github.com/abbott/hardn/pkg/status" + "github.com/abbott/hardn/pkg/style" + "github.com/abbott/hardn/pkg/utils" + "github.com/abbott/hardn/pkg/version" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +// MainMenu is the main menu of the application +type MainMenu struct { + menuManager *application.MenuManager + config *config.Config + osInfo *osdetect.OSInfo + + // Version service for update checks + versionService *version.Service + + // Update state fields + updateAvailable bool + latestVersion string + updateURL string +} + +// NewMainMenu creates a new MainMenu +func NewMainMenu( + menuManager *application.MenuManager, + config *config.Config, + osInfo *osdetect.OSInfo, + versionService *version.Service, +) *MainMenu { + return &MainMenu{ + menuManager: menuManager, + config: config, + osInfo: osInfo, + versionService: versionService, + } +} + +// refreshConfig refreshes any configuration values that might have been changed +// by sub-menus like RunAllMenu or DryRunMenu +func (m *MainMenu) refreshConfig() { + // If we added ways for sub-menus to notify the main menu of changes, + // we would handle them here + + // For now, we're using a shared config pointer, so changes are automatically visible + // This method is a placeholder for future extensibility +} + +// CheckForUpdates checks for new versions and updates the menu state +func (m *MainMenu) CheckForUpdates() { + if m.versionService == nil || m.versionService.CurrentVersion == "" { + return + } + + // Run in a goroutine to avoid blocking the menu display + go func() { + // Use the unified version service + result := m.versionService.CheckForUpdates(&version.UpdateOptions{ + Debug: os.Getenv("HARDN_DEBUG") != "", + }) + + if result.Error != nil { + return // Silently fail for menu updates + } + + if result.UpdateAvailable { + m.updateAvailable = true + m.latestVersion = result.LatestVersion + m.updateURL = result.ReleaseURL + } + }() +} + +// showDryRunMenu creates and displays the dry-run configuration menu +func (m *MainMenu) showDryRunMenu() { + // Display contextual information about dry-run mode + utils.PrintHeader() + fmt.Println(style.Bolded("Dry-Run Mode Configuration", style.Blue)) + + fmt.Println() + fmt.Println(style.Dimmed("Dry-run mode allows you to preview changes without applying them to your system.")) + fmt.Println(style.Dimmed("This is useful for testing and understanding what actions will be performed.")) + + // Check if any critical operations have been performed + // This is just an example - you'd need to track this information + criticalChanges := false // Placeholder for tracking if changes have been made + + if criticalChanges && m.config.DryRun { + fmt.Printf("\n%s You've already performed operations in dry-run mode.\n", + style.Colored(style.Yellow, style.SymInfo)) + fmt.Printf("%s Disabling dry-run mode will apply future changes for real.\n", + style.BulletItem) + } + + fmt.Println() + fmt.Printf("%s Press any key to continue to dry-run configuration...", style.BulletItem) + ReadKey() + + // Create and show the dry-run menu + dryRunMenu := NewDryRunMenu(m.menuManager, m.config) + dryRunMenu.Show() + + // After returning from the dry-run menu, inform about the status + utils.PrintHeader() + + // Quick feedback on the configuration change before returning to main menu + fmt.Printf("\n%s Dry-run mode is now %s\n", + style.Colored(style.Cyan, style.SymInfo), + style.Bolded(map[bool]string{ + true: "ENABLED - Changes will only be simulated", + false: "DISABLED - Changes will be applied to the system", + }[m.config.DryRun], map[bool]string{ + true: style.Green, + false: style.Yellow, + }[m.config.DryRun])) + + fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) + ReadKey() +} + +// ShowMainMenu displays the main menu and handles user input +func (m *MainMenu) ShowMainMenu(currentVersion, buildDate, gitCommit string) { + // Initialize version service if not already done + if m.versionService == nil && currentVersion != "" { + m.versionService = version.NewService(currentVersion, buildDate, gitCommit) + } + + // Check for updates when the menu starts + if m.versionService != nil { + // See if we should force an update notification for testing + if os.Getenv("HARDN_FORCE_UPDATE") != "" { + m.updateAvailable = true + m.latestVersion = "99.0.0" + m.updateURL = "https://github.com/abbott/hardn/releases/latest" + } else { + m.CheckForUpdates() + } + } + + for { + + // Refresh any configuration that might have been changed + m.refreshConfig() + + utils.PrintLogo() + + // Define separator line + separator := "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" + sepWidth := len(separator) + + // Get security status - this would need to be adapted to use the new architecture + securityStatus, err := status.CheckSecurityStatus(m.config, m.osInfo) + var riskLevel, riskDescription, riskColor string + if err == nil { + riskLevel, riskDescription, riskColor = status.GetSecurityRiskLevel(securityStatus) + } + + // Prepare OS display information + var osDisplay string + if m.osInfo != nil { + if m.osInfo.IsProxmox { + osDisplay = " Proxmox " + } else { + osName := cases.Title(language.English).String(m.osInfo.OsType) + osCodename := cases.Title(language.English).String(m.osInfo.OsCodename) + + if m.osInfo.OsType == "alpine" { + osDisplay = fmt.Sprintf(" %s Linux %s ", osName, m.osInfo.OsVersion) + } else { + osDisplay = fmt.Sprintf(" %s %s ", osName, osCodename) + } + } + + // Remove ANSI codes for accurate length calculation + osDisplayStripped := style.StripAnsi(osDisplay) + osDisplayWidth := len(osDisplayStripped) + + // Calculate padding for centering OS display, accounting for spaces + leftPadding := (sepWidth - osDisplayWidth) / 2 + rightPadding := sepWidth - osDisplayWidth - leftPadding + + // Print centered OS display within the separator line + var envLine = separator[:leftPadding] + osDisplay + separator[:rightPadding] + + fmt.Println(style.Colored(style.Green, envLine)) + } else { + // Print separator without OS info + fmt.Println(style.Bolded(separator, style.Green)) + } + + // Display version information after the OS display + if m.versionService != nil && m.versionService.CurrentVersion != "" { + versionDisplay := fmt.Sprintf(" Version %s ", m.versionService.CurrentVersion) + + // Center version information just like OS display + versionDisplayStripped := style.StripAnsi(versionDisplay) + versionDisplayWidth := len(versionDisplayStripped) + + leftPadding := (sepWidth - versionDisplayWidth) / 2 + rightPadding := sepWidth - versionDisplayWidth - leftPadding + + // Print centered version within the separator line + versionLine := separator[:leftPadding] + versionDisplay + separator[:rightPadding] + fmt.Println(style.Colored(style.BrightCyan, versionLine)) + + // Show build information if available + if m.versionService.BuildDate != "" || m.versionService.GitCommit != "" { + fmt.Println() + if m.versionService.BuildDate != "" { + fmt.Printf("%s Build Date: %s\n", style.BulletItem, style.Dimmed(m.versionService.BuildDate)) + } + if m.versionService.GitCommit != "" { + fmt.Printf("%s Git Commit: %s\n", style.BulletItem, style.Dimmed(m.versionService.GitCommit)) + } + } + } + + // Display update notification if a newer version is available + if m.updateAvailable { + fmt.Println() + updateMsg := fmt.Sprintf(" Update available: %s → %s ", m.versionService.CurrentVersion, m.latestVersion) + + // Center the update message + updateMsgStripped := style.StripAnsi(updateMsg) + msgWidth := len(updateMsgStripped) + + leftPadding := (sepWidth - msgWidth) / 2 + rightPadding := sepWidth - msgWidth - leftPadding + + updateLine := separator[:leftPadding] + updateMsg + separator[:rightPadding] + fmt.Println(style.Colored(style.Yellow, updateLine)) + + // Show update instructions + fmt.Printf("%s Visit: %s\n", + style.BulletItem, + style.Colored(style.BrightCyan, m.updateURL)) + } + + fmt.Println() + // 2 spaces buffer + + // Format and print risk status using the same formatter, with bold label + if riskLevel != "" { + + // Create a formatter that includes all labels (including Risk) + formatter := style.NewStatusFormatter([]string{ + "Risk", + "SSH Root Login", + "Firewall", + "Users", + "SSH Port", + "SSH Auth", + "AppArmor", + "Auto Updates", + }, 2) + + boldRiskLabel := style.Bold + "Risk Level" + style.Reset + riskDescription = style.SymApprox + " " + riskDescription + fmt.Println(formatter.FormatLine(style.SymDotTri, riskColor, boldRiskLabel, riskLevel, riskColor, riskDescription, "light")) + } + + fmt.Println() + + // Display detailed security status if available + if err == nil { + // Create formatter here since it wasn't created in the risk level section + formatter := style.NewStatusFormatter([]string{ + "Risk", + "SSH Root Login", + "Firewall", + "Users", + "SSH Port", + "SSH Auth", + "AppArmor", + "Auto Updates", + }, 2) // 2 spaces buffer + + status.DisplaySecurityStatus(m.config, securityStatus, formatter) + } + + // Display dry-run mode if active + fmt.Println() + + // Format the dry-run mode status like other status lines + formatter := style.NewStatusFormatter([]string{ + "Dry-run Mode", + }, 2) + + if m.config.DryRun { + fmt.Println(formatter.FormatLine(style.SymAsterisk, style.BrightGreen, "Dry-run Mode", "Enabled", style.BrightGreen, "", "light")) + } else { + fmt.Println(formatter.FormatLine(style.SymAsterisk, style.BrightYellow, "Dry-run Mode", "Disabled", style.BrightYellow, "", "light")) + } + + // Create menu options + menuOptions := []style.MenuOption{ + {Number: 1, Title: "Sudo User", Description: "Create non-root user with sudo access"}, + {Number: 2, Title: "Root SSH", Description: "Disable SSH access for root user"}, + {Number: 3, Title: "DNS", Description: "Configure DNS settings"}, + {Number: 4, Title: "Firewall", Description: "Configure UFW rules"}, + {Number: 5, Title: "Run All", Description: "Run all hardening operations"}, + {Number: 6, Title: "Dry-Run", Description: "Preview changes without applying them"}, + {Number: 7, Title: "Linux Packages", Description: "Install specified Linux packages"}, + {Number: 8, Title: "Python Packages", Description: "Install specified Python packages"}, + {Number: 9, Title: "Package Sources", Description: "Configure package source"}, + {Number: 10, Title: "Backup", Description: "Configure backup settings"}, + {Number: 11, Title: "Environment", Description: "Configure environment variable support"}, + {Number: 12, Title: "Logs", Description: "View log file"}, + {Number: 13, Title: "Help", Description: "View usage information"}, + } + + // Create and customize menu + menu := style.NewMenu("Select an option", menuOptions) + // Set custom exit option + menu.SetExitOption(style.MenuOption{ + Number: 0, + Title: "Exit", + Description: "Tip: Press 'q' to exit immediately", + }) + + // Display the menu + menu.Print() + + choice := ReadMenuInput() + + // Handle the special exit case for main menu + if choice == "q" { + utils.PrintHeader() + fmt.Println("Hardn has exited.") + fmt.Println() + return + } + + // Process the menu choice - using menuManager instead of direct calls + switch choice { + case "1": // Sudo User + // Create and show user menu + userMenu := NewUserMenu(m.menuManager, m.config, m.osInfo) + userMenu.Show() + + case "2": // Root SSH + // Create and show disable root menu + disableRootMenu := NewDisableRootMenu(m.menuManager, m.config, m.osInfo) + disableRootMenu.Show() + + case "3": // DNS + // ConfigureDnsMenu(m.config, m.osInfo) + dnsMenu := NewDNSMenu(m.menuManager, m.config, m.osInfo) + dnsMenu.Show() + + case "4": // Firewall + // UfwMenu(m.config, m.osInfo) + firewallMenu := NewFirewallMenu(m.menuManager, m.config, m.osInfo) + firewallMenu.Show() + + case "5": // Run All + // Check for prerequisites + if m.config.Username == "" && !m.config.DryRun { + // For actual runs (not dry-run), having a username is essential + fmt.Printf("\n%s No username defined for user creation\n", + style.Colored(style.Yellow, style.SymWarning)) + fmt.Printf("%s Would you like to set a username now? (y/n): ", style.BulletItem) + + confirm := ReadInput() + if strings.ToLower(confirm) == "y" || strings.ToLower(confirm) == "yes" { + // Launch the user menu to set a username first + userMenu := NewUserMenu(m.menuManager, m.config, m.osInfo) + userMenu.Show() + + // If still no username, abort Run All + if m.config.Username == "" { + fmt.Printf("\n%s Run All requires a username for user creation. Operation cancelled.\n", + style.Colored(style.Red, style.SymCrossMark)) + fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) + ReadKey() + break + } + } else { + // User chose not to set a username, continue with warning + fmt.Printf("\n%s Continuing without user creation\n", + style.Colored(style.Yellow, style.SymWarning)) + } + } + + // Create and show the Run All menu + runAllMenu := NewRunAllMenu(m.menuManager, m.config, m.osInfo) + runAllMenu.Show() + + // After returning from Run All menu, check if the dry-run mode was toggled + // This affects how the main menu status is displayed + // Note: This would automatically be handled on the next menu refresh + + case "6": // Dry-Run + m.showDryRunMenu() + + case "7": // Linux Packages + // LinuxPackagesMenu(m.config, m.osInfo) + // This needs a packages manager in application layer + linuxMenu := NewLinuxPackagesMenu(m.menuManager, m.config, m.osInfo) + linuxMenu.Show() + + case "8": // Python Packages + // PythonPackagesMenu(m.config, m.osInfo) + pythonMenu := NewPythonPackagesMenu(m.menuManager, m.config, m.osInfo) + pythonMenu.Show() + + case "9": // Package Sources + // UpdateSourcesMenu(m.config, m.osInfo) + sourcesMenu := NewSourcesMenu(m.menuManager, m.config, m.osInfo) + sourcesMenu.Show() + + case "10": // Backup + // BackupOptionsMenu(m.config) + backupMenu := NewBackupMenu(m.menuManager, m.config) + backupMenu.Show() + + case "11": // Environment + // EnvironmentSettingsMenu(m.config) + envMenu := NewEnvironmentSettingsMenu(m.menuManager, m.config) + envMenu.Show() + + case "12": // Logs + // Viewing logs doesn't need to go through menuManager + // ViewLogsMenu(m.config) + logsMenu := NewLogsMenu(m.menuManager, m.config) + logsMenu.Show() + + case "13": // Helpcase "13": // Help + helpMenu := NewHelpMenu() + helpMenu.Show() + // helpMenu := menuFactory.CreateHelpMenu() + // helpMenu.Show() + + case "0": // Exit + utils.PrintHeader() + fmt.Println("Hardn has exited.") + fmt.Println() + return + + default: + utils.PrintHeader() + fmt.Printf("%s Invalid option. Please try again.\n", + style.Colored(style.Red, style.SymCrossMark)) + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) + ReadKey() + } + } +} + +func (m *MainMenu) SetTestUpdateAvailable(testVersion string) { + if m.versionService != nil { + result := m.versionService.CheckForUpdates(&version.UpdateOptions{ + ForceUpdate: true, + ForcedVersion: testVersion, + }) + + m.updateAvailable = result.UpdateAvailable + m.latestVersion = result.LatestVersion + m.updateURL = result.ReleaseURL + } +} diff --git a/pkg/menu/menu.go b/pkg/menu/menu.go deleted file mode 100644 index fbd1524..0000000 --- a/pkg/menu/menu.go +++ /dev/null @@ -1,254 +0,0 @@ -package menu - -import ( - "bufio" - "fmt" - "os" - "os/exec" - "strings" - - "golang.org/x/text/cases" - "golang.org/x/text/language" - - "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/osdetect" - "github.com/abbott/hardn/pkg/status" - "github.com/abbott/hardn/pkg/style" - "github.com/abbott/hardn/pkg/utils" -) - -// Shared reader for all menus -var reader = bufio.NewReader(os.Stdin) - -// ReadInput reads a line of input from the user -func ReadInput() string { - input, _ := reader.ReadString('\n') - return strings.TrimSpace(input) -} - -// ReadKey reads a single key pressed by the user -func ReadKey() string { - // Configure terminal for raw input - exec.Command("stty", "-F", "/dev/tty", "cbreak", "min", "1").Run() - defer exec.Command("stty", "-F", "/dev/tty", "-cbreak").Run() - - // Read the first byte - var firstByte = make([]byte, 1) - os.Stdin.Read(firstByte) - - // If it's an escape character (27), read and discard the sequence - if firstByte[0] == 27 { - // Read and discard the next two bytes (common for arrow keys) - var discardBytes = make([]byte, 2) - os.Stdin.Read(discardBytes) - - // Return empty to indicate a special key was pressed - return "" - } - - return string(firstByte) -} - - -func RiskStatus(symbol string, color string, label string, status string, description string) string { - padding := strings.Repeat(" ", 6) // "Risk" is short, so hardcode reasonable padding - - return style.Colored(color, symbol) + " " + label + - padding + style.Bolded(status, color) + " " + style.Dimmed(description) -} - -// ShowMainMenu displays the main menu and handles user input -func ShowMainMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { - for { - utils.PrintLogo() - - // Define separator line - separator := "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" - sepWidth := len(separator) - - securityStatus, err := status.CheckSecurityStatus(cfg, osInfo) - var riskLevel, riskDescription, riskColor string - if err == nil { - riskLevel, riskDescription, riskColor = status.GetSecurityRiskLevel(securityStatus) - } - - // Prepare OS display information - var osDisplay string - if osInfo != nil { - titleCaser := cases.Title(language.English) // Title case formatter - - // If Proxmox is detected, treat it as the OS - if osInfo.IsProxmox { - osDisplay = " Proxmox " - } else { - // Convert OS type and codename to title case - osName := titleCaser.String(osInfo.OsType) - osCodename := titleCaser.String(osInfo.OsCodename) - - if osInfo.OsType == "alpine" { - osDisplay = fmt.Sprintf(" %s Linux %s ", osName, osInfo.OsVersion) - } else { - osDisplay = fmt.Sprintf(" %s %s ", osName, osCodename) - } - } - - // Remove ANSI codes for accurate length calculation - osDisplayStripped := style.StripAnsi(osDisplay) - osDisplayWidth := len(osDisplayStripped) - - // Calculate padding for centering OS display, accounting for spaces - leftPadding := (sepWidth - osDisplayWidth) / 2 - rightPadding := sepWidth - osDisplayWidth - leftPadding - - // Print centered OS display within the separator line - var envLine = separator[:leftPadding] + osDisplay + separator[:rightPadding] - - fmt.Println(style.Colored(style.Green, envLine)) - - } else { - // Print separator without OS info - fmt.Println(style.Bolded(separator, style.Green)) - } - - fmt.Println() - - // Create a formatter that includes all labels (including Risk) - formatter := style.NewStatusFormatter([]string{ - "Risk", - "SSH Root Login", - "Firewall", - "Users", - "SSH Port", - "SSH Auth", - "AppArmor", - "Auto Updates", - }, 2) // 2 spaces buffer - - // Format and print risk status using the same formatter, with bold label - if riskLevel != "" { - // Use bold for "Risk" label - boldRiskLabel := style.Bold + "Risk Level" + style.Reset - - riskDescription = style.SymApprox + " " + riskDescription - - // Display risk status with appropriate formatting - fmt.Println(formatter.FormatLine(style.SymDotTri, riskColor, boldRiskLabel, riskLevel, riskColor, riskDescription, "light")) - } - - fmt.Println() - - // Display detailed security status if available - if err == nil { - // Pass the formatter to the security status display to ensure consistent formatting - status.DisplaySecurityStatus(cfg, securityStatus, formatter) - } - - // Display dry-run mode if active - fmt.Println() - - // Format the dry-run mode status like other status lines - if cfg.DryRun { - fmt.Println(formatter.FormatLine(style.SymAsterisk, style.BrightGreen, "Dry-run Mode", "Enabled", style.BrightGreen, "", "light")) - } else { - fmt.Println(formatter.FormatLine(style.SymAsterisk, style.BrightYellow, "Dry-run Mode", "Disabled", style.BrightYellow, "", "light")) - } - - // Create menu options - menuOptions := []style.MenuOption{ - {Number: 1, Title: "Sudo User", Description: "Create non-root user with sudo access"}, - {Number: 2, Title: "Root SSH", Description: "Disable SSH access for root user"}, - {Number: 3, Title: "DNS", Description: "Configure DNS settings"}, - {Number: 4, Title: "Firewall", Description: "Configure UFW rules"}, - {Number: 5, Title: "Run All", Description: "Run all hardening operations"}, - {Number: 6, Title: "Dry-Run", Description: "Preview changes without applying them"}, - {Number: 7, Title: "Linux Packages", Description: "Install specified Linux packages"}, - {Number: 8, Title: "Python Packages", Description: "Install specified Python packages"}, - {Number: 9, Title: "Package Sources", Description: "Configure package source"}, - {Number: 10, Title: "Backup", Description: "Configure backup settings"}, - {Number: 11, Title: "Environment", Description: "Configure environment variable support"}, - {Number: 12, Title: "Logs", Description: "View log file"}, - {Number: 13, Title: "Help", Description: "View usage information"}, - } - - // Create and customize menu - menu := style.NewMenu("Select an option", menuOptions) - // Set custom exit option - menu.SetExitOption(style.MenuOption{ - Number: 0, - Title: "Exit", - Description: "Tip: Press 'q' to exit immediately", - }) - - // Display the menu - menu.Print() - - // First check if q is pressed immediately without Enter - firstKey := ReadKey() - if firstKey == "q" || firstKey == "Q" { - fmt.Println("q") - utils.PrintHeader() - fmt.Println("Hardn has exited.") - fmt.Println() - return - } - - // If firstKey is empty (like from an arrow key), try reading again - if firstKey == "" { - firstKey = ReadKey() - if firstKey == "" || firstKey == "q" || firstKey == "Q" { - fmt.Println("q") - utils.PrintHeader() - fmt.Println("Hardn has exited.") - fmt.Println() - return - } - } - - // Read the rest of the line with standard input - restKey := ReadInput() - - // Combine the inputs for the complete choice - choice := firstKey + restKey - - // Process the menu choice - switch choice { - case "1": - UserCreationMenu(cfg, osInfo) - case "2": - DisableRootMenu(cfg, osInfo) - case "3": - ConfigureDnsMenu(cfg, osInfo) - case "4": - UfwMenu(cfg, osInfo) - case "5": - RunAllHardeningMenu(cfg, osInfo) - case "6": - ToggleDryRunMenu(cfg) - case "7": - LinuxPackagesMenu(cfg, osInfo) - case "8": - PythonPackagesMenu(cfg, osInfo) - case "9": - UpdateSourcesMenu(cfg, osInfo) - case "10": - BackupOptionsMenu(cfg) - case "11": - EnvironmentSettingsMenu(cfg) - case "12": - ViewLogsMenu(cfg) - case "13": - HelpMenu() - case "0": - utils.PrintHeader() - fmt.Println("Hardn has exited.") - fmt.Println() - return - default: - utils.PrintHeader() - fmt.Printf("%s Invalid option. Please try again.\n", - style.Colored(style.Red, style.SymCrossMark)) - fmt.Printf("\n%s Press any key to continue...", style.BulletItem) - ReadKey() - } - } -} diff --git a/pkg/menu/python.go b/pkg/menu/python.go deleted file mode 100644 index 7e3f6de..0000000 --- a/pkg/menu/python.go +++ /dev/null @@ -1,155 +0,0 @@ -// pkg/menu/python.go - -package menu - -import ( - "fmt" - "os" - "strings" - - "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/logging" - "github.com/abbott/hardn/pkg/osdetect" - "github.com/abbott/hardn/pkg/packages" - "github.com/abbott/hardn/pkg/style" - "github.com/abbott/hardn/pkg/utils" -) - -// PythonPackagesMenu handles Python package installation and configuration -func PythonPackagesMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { - utils.PrintHeader() - fmt.Println(style.Bolded("Python Packages Installation", style.Blue)) - - // Get OS-specific package information - var packageDisplay string - if osInfo.OsType == "alpine" { - packageDisplay = fmt.Sprintf("Alpine Python packages: %s", - style.Colored(style.Cyan, strings.Join(cfg.AlpinePythonPackages, ", "))) - } else { - // For Debian/Ubuntu - allPackages := append([]string{}, cfg.PythonPackages...) - if os.Getenv("WSL") == "" { - allPackages = append(allPackages, cfg.NonWslPythonPackages...) - } - - packageDisplay = fmt.Sprintf("System Python packages: %s", - style.Colored(style.Cyan, strings.Join(allPackages, ", "))) - } - - // Display pip packages if available - pipPackageDisplay := "" - if len(cfg.PythonPipPackages) > 0 { - pipPackageDisplay = fmt.Sprintf("\n%s Pip packages: %s", - style.BulletItem, - style.Colored(style.Cyan, strings.Join(cfg.PythonPipPackages, ", "))) - } - - // Display current Python package management settings - fmt.Println() - fmt.Println(style.Bolded("Current Package Management Settings:", style.Blue)) - - // Create a formatter with the label we need - formatter := style.NewStatusFormatter([]string{"UV Package Manager"}, 2) - - // Show UV package manager status - if cfg.UseUvPackageManager { - fmt.Println(formatter.FormatSuccess("UV Package Manager", "Enabled", "Modern, fast package manager")) - } else { - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "UV Package Manager", "Disabled", - style.Yellow, "Using standard pip", "light")) - } - - // Show package information - fmt.Printf("\n%s %s", style.BulletItem, packageDisplay) - if pipPackageDisplay != "" { - fmt.Print(pipPackageDisplay) - } - - // Create menu options - var menuOptions []style.MenuOption - - // Toggle UV option - if cfg.UseUvPackageManager { - menuOptions = append(menuOptions, style.MenuOption{ - Number: 1, - Title: "Disable UV Package Manager", - Description: "Revert to standard pip for Python packages", - }) - } else { - menuOptions = append(menuOptions, style.MenuOption{ - Number: 1, - Title: "Enable UV Package Manager", - Description: "Use UV for faster Python package installation", - }) - } - - // Install packages option - menuOptions = append(menuOptions, style.MenuOption{ - Number: 2, - Title: "Install Python Packages", - Description: "Install all configured Python packages", - }) - - // Create menu - menu := style.NewMenu("Select an option", menuOptions) - menu.SetExitOption(style.MenuOption{ - Number: 0, - Title: "Return to main menu", - Description: "", - }) - - // Display menu - menu.Print() - - choice := ReadInput() - - switch choice { - case "1": - // Toggle UV package manager - if cfg.UseUvPackageManager { - cfg.UseUvPackageManager = false - fmt.Printf("\n%s UV package manager has been %s. Will use standard pip.\n", - style.Colored(style.Green, style.SymCheckMark), - style.Bolded("disabled", style.Yellow)) - } else { - cfg.UseUvPackageManager = true - fmt.Printf("\n%s UV package manager has been %s. Will use UV for Python packages.\n", - style.Colored(style.Green, style.SymCheckMark), - style.Bolded("enabled", style.Green)) - } - - // Save config changes - configFile := "hardn.yml" // Default config file - if err := config.SaveConfig(cfg, configFile); err != nil { - logging.LogError("Failed to save configuration: %v", err) - } - - // Return to this menu after toggling - fmt.Printf("\n%s Press any key to continue...", style.BulletItem) - ReadKey() - PythonPackagesMenu(cfg, osInfo) - - case "2": - // Install packages - fmt.Println("\nInstalling Python packages...") - fmt.Println(style.Dimmed("This may take some time. Please wait...")) - - if err := packages.InstallPythonPackages(cfg, osInfo); err != nil { - fmt.Printf("\n%s Failed to install Python packages: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - } else { - fmt.Printf("\n%s Python packages installed successfully!\n", - style.Colored(style.Green, style.SymCheckMark)) - } - - case "0": - return - - default: - fmt.Printf("\n%s Invalid option. No changes were made.\n", - style.Colored(style.Yellow, style.SymWarning)) - } - - fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) - ReadKey() -} \ No newline at end of file diff --git a/pkg/menu/python_packages_menu.go b/pkg/menu/python_packages_menu.go new file mode 100644 index 0000000..8052c27 --- /dev/null +++ b/pkg/menu/python_packages_menu.go @@ -0,0 +1,222 @@ +// pkg/menu/python_packages_menu.go +package menu + +import ( + "fmt" + "os" + "strings" + + "github.com/abbott/hardn/pkg/application" + "github.com/abbott/hardn/pkg/config" + "github.com/abbott/hardn/pkg/osdetect" + "github.com/abbott/hardn/pkg/style" + "github.com/abbott/hardn/pkg/utils" +) + +// PythonPackagesMenu handles Python packages installation +type PythonPackagesMenu struct { + menuManager *application.MenuManager + config *config.Config + osInfo *osdetect.OSInfo +} + +// NewPythonPackagesMenu creates a new PythonPackagesMenu +func NewPythonPackagesMenu( + menuManager *application.MenuManager, + config *config.Config, + osInfo *osdetect.OSInfo, +) *PythonPackagesMenu { + return &PythonPackagesMenu{ + menuManager: menuManager, + config: config, + osInfo: osInfo, + } +} + +// Show displays the Python packages menu and handles user input +func (m *PythonPackagesMenu) Show() { + utils.PrintHeader() + fmt.Println(style.Bolded("Python Packages Installation", style.Blue)) + + // Get OS-specific package information + var packageDisplay string + if m.osInfo.OsType == "alpine" { + packageDisplay = fmt.Sprintf("Alpine Python packages: %s", + style.Colored(style.Cyan, strings.Join(m.config.AlpinePythonPackages, ", "))) + } else { + // For Debian/Ubuntu + allPackages := append([]string{}, m.config.PythonPackages...) + if os.Getenv("WSL") == "" { + allPackages = append(allPackages, m.config.NonWslPythonPackages...) + } + + packageDisplay = fmt.Sprintf("System Python packages: %s", + style.Colored(style.Cyan, strings.Join(allPackages, ", "))) + } + + // Display pip packages if available + pipPackageDisplay := "" + if len(m.config.PythonPipPackages) > 0 { + pipPackageDisplay = fmt.Sprintf("\n%s Pip packages: %s", + style.BulletItem, + style.Colored(style.Cyan, strings.Join(m.config.PythonPipPackages, ", "))) + } + + // Display current Python package management settings + fmt.Println() + fmt.Println(style.Bolded("Current Package Management Settings:", style.Blue)) + + // Create a formatter with the label we need + formatter := style.NewStatusFormatter([]string{"UV Package Manager"}, 2) + + // Show UV package manager status + if m.config.UseUvPackageManager { + fmt.Println(formatter.FormatSuccess("UV Package Manager", "Enabled", "Modern, fast package manager")) + } else { + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "UV Package Manager", "Disabled", + style.Yellow, "Using standard pip", "light")) + } + + // Show package information + fmt.Printf("\n%s %s", style.BulletItem, packageDisplay) + if pipPackageDisplay != "" { + fmt.Print(pipPackageDisplay) + } + + // Create menu options + var menuOptions []style.MenuOption + + // Toggle UV option + if m.config.UseUvPackageManager { + menuOptions = append(menuOptions, style.MenuOption{ + Number: 1, + Title: "Disable UV Package Manager", + Description: "Revert to standard pip for Python packages", + }) + } else { + menuOptions = append(menuOptions, style.MenuOption{ + Number: 1, + Title: "Enable UV Package Manager", + Description: "Use UV for faster Python package installation", + }) + } + + // Install packages option + menuOptions = append(menuOptions, style.MenuOption{ + Number: 2, + Title: "Install Python Packages", + Description: "Install all configured Python packages", + }) + + // Create menu + menu := style.NewMenu("Select an option", menuOptions) + menu.SetExitOption(style.MenuOption{ + Number: 0, + Title: "Return to main menu", + Description: "", + }) + + // Display menu + menu.Print() + + choice := ReadMenuInput() + + // Handle 'q' as a special exit case + if choice == "q" { + return + } + + switch choice { + case "1": + // Toggle UV package manager + m.config.UseUvPackageManager = !m.config.UseUvPackageManager + + if m.config.UseUvPackageManager { + fmt.Printf("\n%s UV package manager has been %s. Will use UV for Python packages.\n", + style.Colored(style.Green, style.SymCheckMark), + style.Bolded("enabled", style.Green)) + } else { + fmt.Printf("\n%s UV package manager has been %s. Will use standard pip.\n", + style.Colored(style.Green, style.SymCheckMark), + style.Bolded("disabled", style.Yellow)) + } + + // Save config changes + configFile := "hardn.yml" // Default config file + if err := config.SaveConfig(m.config, configFile); err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } + + // Return to this menu after toggling + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) + ReadKey() + m.Show() + + case "2": + // Install packages + fmt.Println("\nInstalling Python packages...") + fmt.Println(style.Dimmed("This may take some time. Please wait...")) + + if m.config.DryRun { + if m.osInfo.OsType == "alpine" { + fmt.Printf("\n%s [DRY-RUN] Would install Alpine Python packages: %s\n", + style.Colored(style.Green, style.SymInfo), + strings.Join(m.config.AlpinePythonPackages, ", ")) + } else { + allPackages := append([]string{}, m.config.PythonPackages...) + if os.Getenv("WSL") == "" { + allPackages = append(allPackages, m.config.NonWslPythonPackages...) + } + + fmt.Printf("\n%s [DRY-RUN] Would install Python packages: %s\n", + style.Colored(style.Green, style.SymInfo), + strings.Join(allPackages, ", ")) + + if len(m.config.PythonPipPackages) > 0 { + packageManager := "pip" + if m.config.UseUvPackageManager { + packageManager = "UV" + } + fmt.Printf("\n%s [DRY-RUN] Would install Pip packages using %s: %s\n", + style.Colored(style.Green, style.SymInfo), + packageManager, + strings.Join(m.config.PythonPipPackages, ", ")) + } + } + } else { + // Use the application layer through menuManager + var systemPackages []string + if m.osInfo.OsType == "alpine" { + systemPackages = m.config.AlpinePythonPackages + } else { + systemPackages = m.config.PythonPackages + if os.Getenv("WSL") == "" { + systemPackages = append(systemPackages, m.config.NonWslPythonPackages...) + } + } + + err := m.menuManager.InstallPythonPackages( + systemPackages, + m.config.PythonPipPackages, + m.config.UseUvPackageManager) + + if err != nil { + fmt.Printf("\n%s Failed to install Python packages: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } else { + fmt.Printf("\n%s Python packages installed successfully\n", + style.Colored(style.Green, style.SymCheckMark)) + } + } + case "0": + return + + default: + fmt.Printf("\n%s Invalid option. No changes were made.\n", + style.Colored(style.Yellow, style.SymWarning)) + } + + fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) + ReadKey() +} diff --git a/pkg/menu/root.go b/pkg/menu/root.go deleted file mode 100644 index 2a5a773..0000000 --- a/pkg/menu/root.go +++ /dev/null @@ -1,129 +0,0 @@ -// pkg/menu/root.go - -package menu - -import ( - "bufio" - "fmt" - "os" - "os/exec" - "strings" - - "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/osdetect" - "github.com/abbott/hardn/pkg/ssh" - "github.com/abbott/hardn/pkg/style" - "github.com/abbott/hardn/pkg/utils" -) - -// DisableRootMenu handles disabling root SSH access -func DisableRootMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { - utils.PrintHeader() - fmt.Println(style.Bolded("Disable Root SSH Access", style.Blue)) - - // Check current status of root SSH access - rootAccessEnabled := CheckRootLoginEnabled(osInfo) - - fmt.Println() - if rootAccessEnabled { - fmt.Printf("%s %s Root SSH access is currently %s\n", - style.Colored(style.Yellow, style.SymWarning), - style.Bolded("WARNING:"), - style.Bolded("ENABLED", style.Red)) - } else { - fmt.Printf("%s Root SSH access is already %s\n", - style.Colored(style.Green, style.SymCheckMark), - style.Bolded("DISABLED", style.Green)) - - fmt.Printf("\n%s Nothing to do. Press any key to return to the main menu...", style.BulletItem) - ReadKey() - return - } - - // Security warning - fmt.Println(style.Colored(style.Yellow, "\nBefore proceeding, ensure that:")) - fmt.Printf("%s You have created at least one non-root user with sudo privileges\n", style.BulletItem) - fmt.Printf("%s You have tested SSH access with this non-root user\n", style.BulletItem) - fmt.Printf("%s You have a backup method to access this system if SSH fails\n", style.BulletItem) - - // Create menu options - menuOptions := []style.MenuOption{ - {Number: 1, Title: "Disable root SSH access", Description: "Modify SSH config to prevent root login"}, - } - - // Create menu - menu := style.NewMenu("Select an option", menuOptions) - menu.SetExitOption(style.MenuOption{ - Number: 0, - Title: "Return to main menu", - Description: "Keep root SSH access enabled", - }) - - // Display menu - menu.Print() - - choice := ReadInput() - - switch choice { - case "1": - fmt.Println("\nDisabling root SSH access...") - err := ssh.DisableRootSSHAccess(cfg, osInfo) - if err == nil { - fmt.Printf("\n%s Root SSH access has been disabled\n", - style.Colored(style.Green, style.SymCheckMark)) - - // Restart SSH service - fmt.Println(style.Dimmed("Restarting SSH service...")) - if osInfo.OsType == "alpine" { - exec.Command("rc-service", "sshd", "restart").Run() - } else { - exec.Command("systemctl", "restart", "ssh").Run() - } - } else { - fmt.Printf("\n%s Failed to disable root SSH access: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - } - case "0": - fmt.Println("\nOperation cancelled. Root SSH access remains enabled.") - default: - fmt.Printf("\n%s Invalid option. No changes were made.\n", - style.Colored(style.Yellow, style.SymWarning)) - } - - fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) - ReadKey() -} - -// CheckRootLoginEnabled checks if SSH root login is enabled -func CheckRootLoginEnabled(osInfo *osdetect.OSInfo) bool { - var sshConfigPath string - if osInfo.OsType == "alpine" { - sshConfigPath = "/etc/ssh/sshd_config" - } else { - // For Debian/Ubuntu, check both main config and config.d - sshConfigPath = "/etc/ssh/sshd_config" - if _, err := os.Stat("/etc/ssh/sshd_config.d/manage.conf"); err == nil { - sshConfigPath = "/etc/ssh/sshd_config.d/manage.conf" - } - } - - file, err := os.Open(sshConfigPath) - if err != nil { - return true // Assume vulnerable if can't check - } - defer file.Close() - - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "PermitRootLogin") { - fields := strings.Fields(line) - if len(fields) >= 2 && fields[1] == "no" { - return false - } - return true - } - } - - return true // Default to vulnerable if not explicitly set -} \ No newline at end of file diff --git a/pkg/menu/run_all.go b/pkg/menu/run_all.go deleted file mode 100644 index 04e763f..0000000 --- a/pkg/menu/run_all.go +++ /dev/null @@ -1,446 +0,0 @@ -// pkg/menu/run_all.go - -package menu - -import ( - "fmt" - "strings" - - "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/dns" - "github.com/abbott/hardn/pkg/firewall" - "github.com/abbott/hardn/pkg/logging" - "github.com/abbott/hardn/pkg/osdetect" - "github.com/abbott/hardn/pkg/packages" - "github.com/abbott/hardn/pkg/security" - "github.com/abbott/hardn/pkg/ssh" - "github.com/abbott/hardn/pkg/style" - "github.com/abbott/hardn/pkg/updates" - "github.com/abbott/hardn/pkg/user" - "github.com/abbott/hardn/pkg/utils" -) - -// RunAllHardeningMenu handles running all hardening steps -func RunAllHardeningMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { - utils.PrintHeader() - fmt.Println(style.Bolded("Run All Hardening Steps", style.Blue)) - - // Create a formatter for status - formatter := style.NewStatusFormatter([]string{"Dry-Run Mode", "Username", "SSH Port"}, 2) - - // Display current configuration status - fmt.Println() - fmt.Println(style.Bolded("Current Configuration:", style.Blue)) - - // Show dry-run status - if cfg.DryRun { - fmt.Println(formatter.FormatSuccess("Dry-Run Mode", "Enabled", "No actual changes will be made")) - } else { - fmt.Println(formatter.FormatWarning("Dry-Run Mode", "Disabled", "System will be modified!")) - } - - // Show username (or warn if not set) - if cfg.Username != "" { - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "Username", cfg.Username, style.Cyan, "", "light")) - } else { - fmt.Println(formatter.FormatWarning("Username", "Not set", "User creation will be skipped")) - } - - // Show SSH port - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "SSH Port", fmt.Sprintf("%d", cfg.SshPort), - style.Cyan, "", "light")) - - // Show enabled features - fmt.Println() - fmt.Println(style.Bolded("Enabled Features:", style.Blue)) - - featuresTable := []struct { - name string - enabled bool - desc string - }{ - {"AppArmor", cfg.EnableAppArmor, "Application control system"}, - {"Lynis", cfg.EnableLynis, "Security audit tool"}, - {"Unattended Upgrades", cfg.EnableUnattendedUpgrades, "Automatic security updates"}, - {"UFW SSH Policy", cfg.EnableUfwSshPolicy, "Firewall rules for SSH"}, - {"DNS Configuration", cfg.ConfigureDns, "DNS settings"}, - {"Root SSH Disable", cfg.DisableRoot, "Disable root SSH access"}, - } - - featuresFormatter := style.NewStatusFormatter([]string{"Feature"}, 2) - for _, feature := range featuresTable { - if feature.enabled { - fmt.Println(featuresFormatter.FormatSuccess("Feature: "+feature.name, "Enabled", feature.desc)) - } else { - fmt.Println(featuresFormatter.FormatLine(style.SymInfo, style.Yellow, "Feature: "+feature.name, - "Disabled", style.Yellow, feature.desc, "light")) - } - } - - // Security warning - fmt.Println() - fmt.Println(style.Bolded("SECURITY WARNING:", style.Red)) - fmt.Println(style.Colored(style.Yellow, "This will run ALL hardening steps on your system.")) - if !cfg.DryRun { - fmt.Println(style.Colored(style.Red, "Your system will be modified. This cannot be undone.")) - } else { - fmt.Println(style.Colored(style.Green, "Dry-run mode is enabled. Changes will only be simulated.")) - } - - // Create menu options - menuOptions := []style.MenuOption{ - {Number: 1, Title: "Run all hardening steps", Description: "Execute all configured hardening measures"}, - } - - // Add dry-run toggle option - if cfg.DryRun { - menuOptions = append(menuOptions, style.MenuOption{ - Number: 2, - Title: "Disable dry-run mode and run", - Description: "Apply real changes to the system", - }) - } else { - menuOptions = append(menuOptions, style.MenuOption{ - Number: 2, - Title: "Enable dry-run mode and run", - Description: "Simulate changes without applying them", - }) - } - - // Create menu - menu := style.NewMenu("Select an option", menuOptions) - menu.SetExitOption(style.MenuOption{ - Number: 0, - Title: "Return to main menu", - Description: "Cancel operation", - }) - - // Display menu - menu.Print() - - choice := ReadInput() - - switch choice { - case "1": - // Run with current settings - runAllHardening(cfg, osInfo) - case "2": - // Toggle dry-run mode and run - cfg.DryRun = !cfg.DryRun - if cfg.DryRun { - fmt.Printf("\n%s Dry-run mode has been %s\n", - style.Colored(style.Green, style.SymCheckMark), - style.Bolded("enabled", style.Green)) - } else { - fmt.Printf("\n%s Dry-run mode has been %s\n", - style.Colored(style.Yellow, style.SymWarning), - style.Bolded("disabled", style.Yellow)) - fmt.Println(style.Bolded("CAUTION: ", style.Red) + - style.Bolded("Your system will be modified!", style.Yellow)) - } - - // Confirm before proceeding with actual changes - if !cfg.DryRun { - fmt.Print("\nType 'yes' to confirm you want to apply real changes: ") - confirm := ReadInput() - if strings.ToLower(confirm) != "yes" { - fmt.Printf("\n%s Operation cancelled. No changes were made.\n", - style.Colored(style.Yellow, style.SymInfo)) - fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) - ReadKey() - return - } - } - - // Save config changes - configFile := "hardn.yml" // Default config file - if err := config.SaveConfig(cfg, configFile); err != nil { - logging.LogError("Failed to save configuration: %v", err) - } - - // Run with new dry-run setting - runAllHardening(cfg, osInfo) - case "0": - fmt.Println("\nOperation cancelled. No changes were made.") - return - default: - fmt.Printf("\n%s Invalid option. Please try again.\n", - style.Colored(style.Red, style.SymCrossMark)) - fmt.Printf("\n%s Press any key to continue...", style.BulletItem) - ReadKey() - RunAllHardeningMenu(cfg, osInfo) - return - } - - fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) - ReadKey() -} - -// runAllHardening executes all hardening steps -func runAllHardening(cfg *config.Config, osInfo *osdetect.OSInfo) { - utils.PrintLogo() - fmt.Println(style.Bolded("Executing All Hardening Steps", style.Blue)) - logging.LogInfo("Running complete system hardening...") - - // Track progress with step counting - totalSteps := 8 // Base steps, may increase based on enabled features - if cfg.EnableAppArmor { - totalSteps++ - } - if cfg.EnableLynis { - totalSteps++ - } - if cfg.EnableUnattendedUpgrades { - totalSteps++ - } - currentStep := 0 - - // Function to show progress - showProgress := func(stepName string) { - currentStep++ - fmt.Printf("\n%s [%d/%d] %s\n", - style.Colored(style.Cyan, style.SymArrowRight), - currentStep, - totalSteps, - style.Bolded(stepName, style.Cyan)) - } - - // 1. Setup hushlogin - showProgress("Setup basic configuration") - if err := utils.SetupHushlogin(cfg); err != nil { - fmt.Printf("%s Failed to setup hushlogin: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - } else if !cfg.DryRun { - fmt.Printf("%s Hushlogin configured\n", - style.Colored(style.Green, style.SymCheckMark)) - } - - // 2. Update package repositories - showProgress("Update package repositories") - if err := packages.WriteSources(cfg, osInfo); err != nil { - fmt.Printf("%s Failed to configure package sources: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - } else if !cfg.DryRun { - fmt.Printf("%s Package sources updated\n", - style.Colored(style.Green, style.SymCheckMark)) - } - - // 3. Handle Proxmox repositories if needed - if osInfo.OsType != "alpine" && osInfo.IsProxmox { - if err := packages.WriteProxmoxRepos(cfg, osInfo); err != nil { - fmt.Printf("%s Failed to configure Proxmox repositories: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - } else if !cfg.DryRun { - fmt.Printf("%s Proxmox repositories configured\n", - style.Colored(style.Green, style.SymCheckMark)) - } - } - - // 4. Install packages - showProgress("Install system packages") - installSystemPackages(cfg, osInfo) - - // 5. Create user - showProgress("Configure user account") - if cfg.Username != "" { - if err := user.CreateUser(cfg.Username, cfg, osInfo); err != nil { - fmt.Printf("%s Failed to create user: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - } else if !cfg.DryRun { - fmt.Printf("%s User '%s' configured\n", - style.Colored(style.Green, style.SymCheckMark), - cfg.Username) - } - } else { - fmt.Printf("%s No username specified, skipping user creation\n", - style.Colored(style.Yellow, style.SymWarning)) - } - - // 6. Configure SSH - showProgress("Configure SSH") - if err := ssh.WriteSSHConfig(cfg, osInfo); err != nil { - fmt.Printf("%s Failed to configure SSH: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - } else if !cfg.DryRun { - fmt.Printf("%s SSH configured\n", - style.Colored(style.Green, style.SymCheckMark)) - } - - // 7. Disable root SSH access if requested - showProgress("Configure root SSH access") - if cfg.DisableRoot { - if err := ssh.DisableRootSSHAccess(cfg, osInfo); err != nil { - fmt.Printf("%s Failed to disable root SSH access: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - } else if !cfg.DryRun { - fmt.Printf("%s Root SSH access disabled\n", - style.Colored(style.Green, style.SymCheckMark)) - } - } else { - fmt.Printf("%s Root SSH access will remain enabled (not configured to disable)\n", - style.Colored(style.Yellow, style.SymInfo)) - } - - // 8. Configure UFW - showProgress("Configure firewall") - if cfg.EnableUfwSshPolicy { - if err := firewall.ConfigureUFW(cfg, osInfo); err != nil { - fmt.Printf("%s Failed to configure firewall: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - } else if !cfg.DryRun { - fmt.Printf("%s Firewall configured\n", - style.Colored(style.Green, style.SymCheckMark)) - } - } else { - fmt.Printf("%s Firewall configuration skipped (not enabled in config)\n", - style.Colored(style.Yellow, style.SymInfo)) - } - - // 9. Configure DNS - showProgress("Configure DNS") - if cfg.ConfigureDns { - if err := dns.ConfigureDNS(cfg, osInfo); err != nil { - fmt.Printf("%s Failed to configure DNS: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - } else if !cfg.DryRun { - fmt.Printf("%s DNS configured\n", - style.Colored(style.Green, style.SymCheckMark)) - } - } else { - fmt.Printf("%s DNS configuration skipped (not enabled in config)\n", - style.Colored(style.Yellow, style.SymInfo)) - } - - // 10. Setup AppArmor if enabled - if cfg.EnableAppArmor { - showProgress("Configure AppArmor") - if err := security.SetupAppArmor(cfg, osInfo); err != nil { - fmt.Printf("%s Failed to configure AppArmor: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - } else if !cfg.DryRun { - fmt.Printf("%s AppArmor configured\n", - style.Colored(style.Green, style.SymCheckMark)) - } - } - - // 11. Setup Lynis if enabled - if cfg.EnableLynis { - showProgress("Install Lynis security audit") - if err := security.SetupLynis(cfg, osInfo); err != nil { - fmt.Printf("%s Failed to configure Lynis: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - } else if !cfg.DryRun { - fmt.Printf("%s Lynis installed and audit completed\n", - style.Colored(style.Green, style.SymCheckMark)) - } - } - - // 12. Setup unattended upgrades if enabled - if cfg.EnableUnattendedUpgrades { - showProgress("Configure automatic updates") - if err := updates.SetupUnattendedUpgrades(cfg, osInfo); err != nil { - fmt.Printf("%s Failed to configure unattended upgrades: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - } else if !cfg.DryRun { - fmt.Printf("%s Automatic updates configured\n", - style.Colored(style.Green, style.SymCheckMark)) - } - } - - // Final status - fmt.Println() - if cfg.DryRun { - fmt.Printf("%s System hardening %s (DRY-RUN)\n", - style.Colored(style.Green, style.SymCheckMark), - style.Bolded("simulation completed", style.Green)) - fmt.Println(style.Dimmed("No actual changes were made to your system.")) - } else { - fmt.Printf("%s System hardening %s\n", - style.Colored(style.Green, style.SymCheckMark), - style.Bolded("completed successfully", style.Green)) - } - - logging.LogSuccess("System hardening completed") - fmt.Printf("\n%s Check the log file at %s for details\n", - style.Colored(style.Cyan, style.SymInfo), - style.Colored(style.Cyan, cfg.LogFile)) -} - -// Helper function to install system packages -func installSystemPackages(cfg *config.Config, osInfo *osdetect.OSInfo) { - if osInfo.OsType == "alpine" { - // Install Alpine packages - if len(cfg.AlpineCorePackages) > 0 { - fmt.Printf("%s Installing Alpine core packages...\n", style.BulletItem) - if !cfg.DryRun { - packages.InstallPackages(cfg.AlpineCorePackages, osInfo, cfg) - } - } - - // Check subnet to determine which package sets to install - isDmz, _ := utils.CheckSubnet(cfg.DmzSubnet) - if isDmz { - if len(cfg.AlpineDmzPackages) > 0 { - fmt.Printf("%s Installing Alpine DMZ packages...\n", style.BulletItem) - if !cfg.DryRun { - packages.InstallPackages(cfg.AlpineDmzPackages, osInfo, cfg) - } - } - } else { - if len(cfg.AlpineDmzPackages) > 0 { - fmt.Printf("%s Installing Alpine DMZ packages...\n", style.BulletItem) - if !cfg.DryRun { - packages.InstallPackages(cfg.AlpineDmzPackages, osInfo, cfg) - } - } - - if len(cfg.AlpineLabPackages) > 0 { - fmt.Printf("%s Installing Alpine LAB packages...\n", style.BulletItem) - if !cfg.DryRun { - packages.InstallPackages(cfg.AlpineLabPackages, osInfo, cfg) - } - } - } - - // Install Python packages if defined - if len(cfg.AlpinePythonPackages) > 0 { - fmt.Printf("%s Installing Alpine Python packages...\n", style.BulletItem) - if !cfg.DryRun { - packages.InstallPackages(cfg.AlpinePythonPackages, osInfo, cfg) - } - } - } else { - // Install core Linux packages first - if len(cfg.LinuxCorePackages) > 0 { - fmt.Printf("%s Installing Linux core packages...\n", style.BulletItem) - if !cfg.DryRun { - packages.InstallPackages(cfg.LinuxCorePackages, osInfo, cfg) - } - } - - // Check subnet to determine which package sets to install - isDmz, _ := utils.CheckSubnet(cfg.DmzSubnet) - if isDmz { - if len(cfg.LinuxDmzPackages) > 0 { - fmt.Printf("%s Installing Debian DMZ packages...\n", style.BulletItem) - if !cfg.DryRun { - packages.InstallPackages(cfg.LinuxDmzPackages, osInfo, cfg) - } - } - } else { - // Install both - if len(cfg.LinuxDmzPackages) > 0 { - fmt.Printf("%s Installing Debian DMZ packages...\n", style.BulletItem) - if !cfg.DryRun { - packages.InstallPackages(cfg.LinuxDmzPackages, osInfo, cfg) - } - } - if len(cfg.LinuxLabPackages) > 0 { - fmt.Printf("%s Installing Debian Lab packages...\n", style.BulletItem) - if !cfg.DryRun { - packages.InstallPackages(cfg.LinuxLabPackages, osInfo, cfg) - } - } - } - } -} \ No newline at end of file diff --git a/pkg/menu/run_all_menu.go b/pkg/menu/run_all_menu.go new file mode 100644 index 0000000..a6efc0f --- /dev/null +++ b/pkg/menu/run_all_menu.go @@ -0,0 +1,417 @@ +// pkg/menu/run_all_menu.go +package menu + +import ( + "fmt" + "strings" + + "github.com/abbott/hardn/pkg/application" + "github.com/abbott/hardn/pkg/config" + "github.com/abbott/hardn/pkg/domain/model" + "github.com/abbott/hardn/pkg/osdetect" + "github.com/abbott/hardn/pkg/style" + "github.com/abbott/hardn/pkg/utils" +) + +// RunAllMenu handles the "Run All Hardening" functionality through the new architecture +type RunAllMenu struct { + menuManager *application.MenuManager + config *config.Config + osInfo *osdetect.OSInfo +} + +// NewRunAllMenu creates a new RunAllMenu +func NewRunAllMenu( + menuManager *application.MenuManager, + config *config.Config, + osInfo *osdetect.OSInfo, +) *RunAllMenu { + return &RunAllMenu{ + menuManager: menuManager, + config: config, + osInfo: osInfo, + } +} + +// Show displays the Run All menu and handles user input +func (m *RunAllMenu) Show() { + utils.PrintHeader() + fmt.Println(style.Bolded("Run All Hardening Steps", style.Blue)) + + // Create a formatter for status + formatter := style.NewStatusFormatter([]string{"Dry-Run Mode", "Username", "SSH Port"}, 2) + + // Display current configuration status + fmt.Println() + fmt.Println(style.Bolded("Current Configuration:", style.Blue)) + + // Show dry-run status + if m.config.DryRun { + fmt.Println(formatter.FormatSuccess("Dry-Run Mode", "Enabled", "No actual changes will be made")) + } else { + fmt.Println(formatter.FormatWarning("Dry-Run Mode", "Disabled", "System will be modified!")) + } + + // Show username (or warn if not set) + if m.config.Username != "" { + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "Username", m.config.Username, style.Cyan, "", "light")) + } else { + fmt.Println(formatter.FormatWarning("Username", "Not set", "User creation will be skipped")) + } + + // Show SSH port + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "SSH Port", fmt.Sprintf("%d", m.config.SshPort), + style.Cyan, "", "light")) + + // Show enabled features + fmt.Println() + fmt.Println(style.Bolded("Enabled Features:", style.Blue)) + + featuresTable := []struct { + name string + enabled bool + desc string + }{ + {"AppArmor", m.config.EnableAppArmor, "Application control system"}, + {"Lynis", m.config.EnableLynis, "Security audit tool"}, + {"Unattended Upgrades", m.config.EnableUnattendedUpgrades, "Automatic security updates"}, + {"UFW SSH Policy", m.config.EnableUfwSshPolicy, "Firewall rules for SSH"}, + {"DNS Configuration", m.config.ConfigureDns, "DNS settings"}, + {"Root SSH Disable", m.config.DisableRoot, "Disable root SSH access"}, + } + + featuresFormatter := style.NewStatusFormatter([]string{"Feature"}, 2) + for _, feature := range featuresTable { + if feature.enabled { + fmt.Println(featuresFormatter.FormatSuccess("Feature: "+feature.name, "Enabled", feature.desc)) + } else { + fmt.Println(featuresFormatter.FormatLine(style.SymInfo, style.Yellow, "Feature: "+feature.name, + "Disabled", style.Yellow, feature.desc, "light")) + } + } + + // Security warning + fmt.Println() + fmt.Println(style.Bolded("SECURITY WARNING:", style.Red)) + fmt.Println(style.Colored(style.Yellow, "This will run ALL hardening steps on your system.")) + if !m.config.DryRun { + fmt.Println(style.Colored(style.Red, "Your system will be modified. This cannot be undone.")) + } else { + fmt.Println(style.Colored(style.Green, "Dry-run mode is enabled. Changes will only be simulated.")) + } + + // Create menu options + menuOptions := []style.MenuOption{ + {Number: 1, Title: "Run all hardening steps", Description: "Execute all configured hardening measures"}, + } + + // Add dry-run toggle option + if m.config.DryRun { + menuOptions = append(menuOptions, style.MenuOption{ + Number: 2, + Title: "Disable dry-run mode and run", + Description: "Apply real changes to the system", + }) + } else { + menuOptions = append(menuOptions, style.MenuOption{ + Number: 2, + Title: "Enable dry-run mode and run", + Description: "Simulate changes without applying them", + }) + } + + // Create menu + menu := style.NewMenu("Select an option", menuOptions) + menu.SetExitOption(style.MenuOption{ + Number: 0, + Title: "Return to main menu", + Description: "Cancel operation", + }) + + // Display menu + menu.Print() + + choice := ReadMenuInput() + + // Handle 'q' as a special exit case + if choice == "q" { + return + } + + switch choice { + case "1": + // Run with current settings + m.runAllHardening() + case "2": + // Toggle dry-run mode and run + m.config.DryRun = !m.config.DryRun + if m.config.DryRun { + fmt.Printf("\n%s Dry-run mode has been %s\n", + style.Colored(style.Green, style.SymCheckMark), + style.Bolded("enabled", style.Green)) + } else { + fmt.Printf("\n%s Dry-run mode has been %s\n", + style.Colored(style.Yellow, style.SymWarning), + style.Bolded("disabled", style.Yellow)) + fmt.Println(style.Bolded("CAUTION: ", style.Red) + + style.Bolded("Your system will be modified!", style.Yellow)) + } + + // Confirm before proceeding with actual changes + if !m.config.DryRun { + fmt.Print("\nType 'yes' to confirm you want to apply real changes: ") + confirm := ReadInput() + if strings.ToLower(confirm) != "yes" { + fmt.Printf("\n%s Operation cancelled. No changes were made.\n", + style.Colored(style.Yellow, style.SymInfo)) + fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) + ReadKey() + return + } + } + + // Save config changes + configFile := "hardn.yml" // Default config file + if err := config.SaveConfig(m.config, configFile); err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } + + // Run with new dry-run setting + m.runAllHardening() + case "0": + fmt.Println("\nOperation cancelled. No changes were made.") + return + default: + fmt.Printf("\n%s Invalid option. Please try again.\n", + style.Colored(style.Red, style.SymCrossMark)) + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) + ReadKey() + m.Show() + return + } + + fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) + ReadKey() +} + +// runAllHardening uses the MenuManager to execute all hardening steps +func (m *RunAllMenu) runAllHardening() { + utils.PrintLogo() + fmt.Println(style.Bolded("Executing All Hardening Steps", style.Blue)) + + // Build a comprehensive HardeningConfig from current configuration + hardening := model.HardeningConfig{ + CreateUser: m.config.Username != "", + Username: m.config.Username, + SudoNoPassword: m.config.SudoNoPassword, + SshKeys: m.config.SshKeys, + SshPort: m.config.SshPort, + SshListenAddresses: []string{m.config.SshListenAddress}, + SshAllowedUsers: m.config.SshAllowedUsers, + EnableFirewall: m.config.EnableUfwSshPolicy, + AllowedPorts: m.config.UfwAllowedPorts, + ConfigureDns: m.config.ConfigureDns, + Nameservers: m.config.Nameservers, + EnableAppArmor: m.config.EnableAppArmor, + EnableLynis: m.config.EnableLynis, + EnableUnattendedUpgrades: m.config.EnableUnattendedUpgrades, + } + + // Track progress with step counting + totalSteps := calculateTotalSteps(&hardening) + currentStep := 0 + + // Function to show progress + showProgress := func(stepName string) { + currentStep++ + fmt.Printf("\n%s [%d/%d] %s\n", + style.Colored(style.Cyan, style.SymArrowRight), + currentStep, + totalSteps, + style.Bolded(stepName, style.Cyan)) + } + + // Begin hardening steps + showProgress("Preparing system hardening") + + if m.config.DryRun { + // In dry-run mode, show what would happen + // updateRepositories := true + // installPackages := true + useUvPackageManager := m.config.UseUvPackageManager + dryRunHardening(&hardening, showProgress, m.osInfo.IsProxmox, useUvPackageManager) + } else { + // Execute the hardening through the MenuManager + err := m.menuManager.HardenSystem(&hardening) + + if err != nil { + fmt.Printf("\n%s System hardening failed: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + return + } + + // Show steps completed when not in dry-run mode + if hardening.CreateUser { + showProgress("User account configured") + } + + showProgress("SSH configuration completed") + + if hardening.EnableFirewall { + showProgress("Firewall configured") + } + + if hardening.ConfigureDns { + showProgress("DNS settings applied") + } + + if hardening.EnableAppArmor { + showProgress("AppArmor configured") + } + + if hardening.EnableLynis { + showProgress("Lynis security audit completed") + } + + if hardening.EnableUnattendedUpgrades { + showProgress("Automatic updates configured") + } + } + + // Final status + fmt.Println() + if m.config.DryRun { + fmt.Printf("%s System hardening %s (DRY-RUN)\n", + style.Colored(style.Green, style.SymCheckMark), + style.Bolded("simulation completed", style.Green)) + fmt.Println(style.Dimmed("No actual changes were made to your system.")) + } else { + fmt.Printf("%s System hardening %s\n", + style.Colored(style.Green, style.SymCheckMark), + style.Bolded("completed successfully", style.Green)) + } + + fmt.Printf("\n%s Check the log file at %s for details\n", + style.Colored(style.Cyan, style.SymInfo), + style.Colored(style.Cyan, m.config.LogFile)) +} + +// calculateTotalSteps determines the total number of hardening steps +func calculateTotalSteps(config *model.HardeningConfig) int { + // Start with base steps (always performed) + totalSteps := 7 // Preparation, repositories, packages, Python packages, SSH config, completion + + // Add optional steps + if config.CreateUser { + totalSteps++ + } + + if config.EnableFirewall { + totalSteps++ + } + + if config.ConfigureDns { + totalSteps++ + } + + if config.EnableAppArmor { + totalSteps++ + } + + if config.EnableLynis { + totalSteps++ + } + + if config.EnableUnattendedUpgrades { + totalSteps++ + } + + return totalSteps +} + +// dryRunHardening simulates the hardening process without making changes +func dryRunHardening(config *model.HardeningConfig, showProgress func(string), isProxmox bool, useUvPackageManager bool) { + // Simulate user creation + if config.CreateUser { + showProgress("Simulating user account creation") + fmt.Printf("%s Would create user '%s' with sudo %s\n", + style.BulletItem, + config.Username, + map[bool]string{true: "without password", false: "with password"}[config.SudoNoPassword]) + + if len(config.SshKeys) > 0 { + fmt.Printf("%s Would configure %d SSH keys\n", + style.BulletItem, + len(config.SshKeys)) + } + } + + // Simulate package repository update + showProgress("Simulating package repository update") + fmt.Printf("%s Would update package sources for system\n", style.BulletItem) + + if isProxmox { + fmt.Printf("%s Would configure Proxmox-specific repositories\n", style.BulletItem) + } + + // Simulate package installation + showProgress("Simulating package installation") + fmt.Printf("%s Would install core system packages\n", style.BulletItem) + + // Check if DMZ subnet is detected (this is a simulation) + fmt.Printf("%s Would determine network environment (DMZ vs. Lab)\n", style.BulletItem) + fmt.Printf("%s Would install appropriate packages for environment\n", style.BulletItem) + + // Simulate Python package installation + showProgress("Simulating Python package installation") + packageManager := "pip" + if useUvPackageManager { + packageManager = "UV" + } + fmt.Printf("%s Would install Python packages with %s\n", + style.BulletItem, + packageManager) + + // Simulate SSH configuration + showProgress("Simulating SSH configuration") + fmt.Printf("%s Would configure SSH on port %d\n", + style.BulletItem, + config.SshPort) + + // Simulate firewall configuration + if config.EnableFirewall { + showProgress("Simulating firewall configuration") + fmt.Printf("%s Would configure firewall to allow SSH on port %d\n", + style.BulletItem, + config.SshPort) + } + + // Simulate DNS configuration + if config.ConfigureDns { + showProgress("Simulating DNS configuration") + if len(config.Nameservers) > 0 { + fmt.Printf("%s Would configure nameservers: %s\n", + style.BulletItem, + strings.Join(config.Nameservers, ", ")) + } + } + + // Simulate AppArmor setup + if config.EnableAppArmor { + showProgress("Simulating AppArmor configuration") + fmt.Printf("%s Would install and activate AppArmor\n", style.BulletItem) + } + + // Simulate Lynis installation + if config.EnableLynis { + showProgress("Simulating Lynis security audit") + fmt.Printf("%s Would install and run Lynis security audit\n", style.BulletItem) + } + + // Simulate unattended upgrades setup + if config.EnableUnattendedUpgrades { + showProgress("Simulating automatic updates configuration") + fmt.Printf("%s Would configure unattended security updates\n", style.BulletItem) + } +} diff --git a/pkg/menu/sources.go b/pkg/menu/sources_menu.go similarity index 66% rename from pkg/menu/sources.go rename to pkg/menu/sources_menu.go index 7858192..810efba 100644 --- a/pkg/menu/sources.go +++ b/pkg/menu/sources_menu.go @@ -1,4 +1,4 @@ -// pkg/menu/sources.go +// pkg/menu/sources_menu.go package menu @@ -7,110 +7,133 @@ import ( "os" "strings" + "github.com/abbott/hardn/pkg/application" "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/logging" "github.com/abbott/hardn/pkg/osdetect" - "github.com/abbott/hardn/pkg/packages" "github.com/abbott/hardn/pkg/style" "github.com/abbott/hardn/pkg/utils" + "golang.org/x/text/cases" + "golang.org/x/text/language" ) -// UpdateSourcesMenu handles configuration of package sources -func UpdateSourcesMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { +// SourcesMenu handles configuration of package sources +type SourcesMenu struct { + menuManager *application.MenuManager + config *config.Config + osInfo *osdetect.OSInfo +} + +// NewSourcesMenu creates a new SourcesMenu +func NewSourcesMenu( + menuManager *application.MenuManager, + config *config.Config, + osInfo *osdetect.OSInfo, +) *SourcesMenu { + return &SourcesMenu{ + menuManager: menuManager, + config: config, + osInfo: osInfo, + } +} + +// Show displays the sources menu and handles user input +func (m *SourcesMenu) Show() { utils.PrintHeader() fmt.Println(style.Bolded("Package Sources Configuration", style.Blue)) // Display current OS info fmt.Println() fmt.Println(style.Bolded("System Information:", style.Blue)) - + // Create formatter for status display formatter := style.NewStatusFormatter([]string{"OS Type", "Version", "Codename", "Proxmox"}, 2) - + // Show OS type - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "OS Type", - strings.Title(osInfo.OsType), style.Cyan, "", "light")) - + osName := cases.Title(language.English).String(m.osInfo.OsType) + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "OS Type", + osName, style.Cyan, "", "light")) + // Show OS version - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "Version", - osInfo.OsVersion, style.Cyan, "", "light")) - + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "Version", + m.osInfo.OsVersion, style.Cyan, "", "light")) + // Show OS codename (if not Alpine) - if osInfo.OsType != "alpine" { - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "Codename", - strings.Title(osInfo.OsCodename), style.Cyan, "", "light")) + if m.osInfo.OsType != "alpine" { + osCodename := cases.Title(language.English).String(m.osInfo.OsCodename) + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "Codename", + osCodename, style.Cyan, "", "light")) } - + // Show Proxmox status proxmoxStatus := "No" - if osInfo.IsProxmox { + if m.osInfo.IsProxmox { proxmoxStatus = "Yes" } - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "Proxmox", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "Proxmox", proxmoxStatus, style.Cyan, "", "light")) - + // Display current source configuration fmt.Println() fmt.Println(style.Bolded("Current Source Configuration:", style.Blue)) - - if osInfo.OsType == "alpine" { + + if m.osInfo.OsType == "alpine" { // Show Alpine repository status - showAlpineRepositories(cfg) + m.showAlpineRepositories() } else { // Show Debian/Ubuntu repository status - showDebianRepositories(cfg, osInfo) + m.showDebianRepositories() } // Create menu options based on OS type var menuOptions []style.MenuOption - - if osInfo.OsType == "alpine" { + + if m.osInfo.OsType == "alpine" { // Alpine specific options menuOptions = append(menuOptions, style.MenuOption{ - Number: 1, - Title: "Update Alpine repositories", + Number: 1, + Title: "Update Alpine repositories", Description: "Configure main and community repositories", }) - + // Testing repository toggle - if cfg.AlpineTestingRepo { + if m.config.AlpineTestingRepo { menuOptions = append(menuOptions, style.MenuOption{ - Number: 2, - Title: "Disable testing repository", + Number: 2, + Title: "Disable testing repository", Description: "Remove edge/testing repository", }) } else { menuOptions = append(menuOptions, style.MenuOption{ - Number: 2, - Title: "Enable testing repository", + Number: 2, + Title: "Enable testing repository", Description: "Add edge/testing repository (not recommended for production)", }) } } else { // Debian/Ubuntu options menuOptions = append(menuOptions, style.MenuOption{ - Number: 1, - Title: "Update package sources", + Number: 1, + Title: "Update package sources", Description: "Configure main system repositories", }) - + // Proxmox specific options - if osInfo.IsProxmox { + if m.osInfo.IsProxmox { menuOptions = append(menuOptions, style.MenuOption{ - Number: 2, - Title: "Configure Proxmox repositories", + Number: 2, + Title: "Configure Proxmox repositories", Description: "Set up Proxmox-specific repositories", }) } - + // Add option to edit sources menuOptions = append(menuOptions, style.MenuOption{ - Number: 3, - Title: "Edit repositories", + Number: 3, + Title: "Edit repositories", Description: "Modify repository configuration", }) } - + // Create menu menu := style.NewMenu("Select an option", menuOptions) menu.SetExitOption(style.MenuOption{ @@ -118,124 +141,129 @@ func UpdateSourcesMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { Title: "Return to main menu", Description: "", }) - + // Display menu menu.Print() - - choice := ReadInput() - + + choice := ReadMenuInput() + + // Handle 'q' as a special exit case + if choice == "q" { + return + } + switch choice { case "1": // Update main repositories fmt.Println("\nUpdating package sources...") - - if cfg.DryRun { - fmt.Printf("%s [DRY-RUN] Would update package sources for %s\n", - style.BulletItem, osInfo.OsType) + + if m.config.DryRun { + fmt.Printf("%s [DRY-RUN] Would update package sources for %s\n", + style.BulletItem, m.osInfo.OsType) } else { - if err := packages.WriteSources(cfg, osInfo); err != nil { - fmt.Printf("\n%s Failed to update package sources: %v\n", + // Use application layer to update sources + if err := m.menuManager.UpdatePackageSources(); err != nil { + fmt.Printf("\n%s Failed to update package sources: %v\n", style.Colored(style.Red, style.SymCrossMark), err) - logging.LogError("Failed to update package sources: %v", err) } else { - fmt.Printf("\n%s Package sources updated successfully\n", + fmt.Printf("\n%s Package sources updated successfully\n", style.Colored(style.Green, style.SymCheckMark)) } } - + case "2": - if osInfo.OsType == "alpine" { + if m.osInfo.OsType == "alpine" { // Toggle Alpine testing repository - cfg.AlpineTestingRepo = !cfg.AlpineTestingRepo - - if cfg.AlpineTestingRepo { - fmt.Printf("\n%s Alpine testing repository %s\n", + m.config.AlpineTestingRepo = !m.config.AlpineTestingRepo + + if m.config.AlpineTestingRepo { + fmt.Printf("\n%s Alpine testing repository %s\n", style.Colored(style.Yellow, style.SymWarning), style.Bolded("enabled", style.Green)) - fmt.Printf("%s WARNING: Testing repositories may contain unstable packages\n", + fmt.Printf("%s WARNING: Testing repositories may contain unstable packages\n", style.Colored(style.Yellow, style.SymWarning)) } else { - fmt.Printf("\n%s Alpine testing repository %s\n", + fmt.Printf("\n%s Alpine testing repository %s\n", style.Colored(style.Green, style.SymCheckMark), style.Bolded("disabled", style.Green)) } - + // Save config changes - saveSourcesConfig(cfg) - + m.saveSourcesConfig() + // Apply the change - if !cfg.DryRun { - if err := packages.WriteSources(cfg, osInfo); err != nil { - fmt.Printf("\n%s Failed to update package sources: %v\n", + if !m.config.DryRun { + // Use application layer to update sources + if err := m.menuManager.UpdatePackageSources(); err != nil { + fmt.Printf("\n%s Failed to update package sources: %v\n", style.Colored(style.Red, style.SymCrossMark), err) - logging.LogError("Failed to update package sources: %v", err) } else { - fmt.Printf("%s Package sources updated successfully\n", + fmt.Printf("%s Package sources updated successfully\n", style.Colored(style.Green, style.SymCheckMark)) } } - } else if osInfo.IsProxmox { + } else if m.osInfo.IsProxmox { // Configure Proxmox repositories fmt.Println("\nConfiguring Proxmox repositories...") - - if cfg.DryRun { + + if m.config.DryRun { fmt.Printf("%s [DRY-RUN] Would configure Proxmox repositories\n", style.BulletItem) } else { - if err := packages.WriteProxmoxRepos(cfg, osInfo); err != nil { - fmt.Printf("\n%s Failed to configure Proxmox repositories: %v\n", + // Use application layer to update Proxmox sources + if err := m.menuManager.UpdateProxmoxSources(); err != nil { + fmt.Printf("\n%s Failed to configure Proxmox repositories: %v\n", style.Colored(style.Red, style.SymCrossMark), err) - logging.LogError("Failed to configure Proxmox repositories: %v", err) } else { - fmt.Printf("\n%s Proxmox repositories configured successfully\n", + fmt.Printf("\n%s Proxmox repositories configured successfully\n", style.Colored(style.Green, style.SymCheckMark)) fmt.Printf("%s Created /etc/apt/sources.list.d/ceph.list\n", style.BulletItem) fmt.Printf("%s Created /etc/apt/sources.list.d/pve-enterprise.list\n", style.BulletItem) } } } else { - fmt.Printf("\n%s Invalid option for this OS type\n", + fmt.Printf("\n%s Invalid option for this OS type\n", style.Colored(style.Red, style.SymCrossMark)) } - + case "3": - if osInfo.OsType != "alpine" { + if m.osInfo.OsType != "alpine" { // Edit repositories submenu - editRepositoriesMenu(cfg, osInfo) - UpdateSourcesMenu(cfg, osInfo) + m.editRepositoriesMenu() + m.Show() return } else { - fmt.Printf("\n%s Invalid option for this OS type\n", + fmt.Printf("\n%s Invalid option for this OS type\n", style.Colored(style.Red, style.SymCrossMark)) } - + case "0": // Return to main menu return - + default: - fmt.Printf("\n%s Invalid option. Please try again.\n", + fmt.Printf("\n%s Invalid option. Please try again.\n", style.Colored(style.Red, style.SymCrossMark)) - + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - UpdateSourcesMenu(cfg, osInfo) + m.Show() return } - + fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) ReadKey() } // Helper function to show Alpine repositories -func showAlpineRepositories(cfg *config.Config) { +func (m *SourcesMenu) showAlpineRepositories() { // Check if repositories file exists reposFile := "/etc/apk/repositories" reposContent := "" - + if data, err := os.ReadFile(reposFile); err == nil { reposContent = string(data) } - + // Display repositories if reposContent != "" { lines := strings.Split(reposContent, "\n") @@ -244,10 +272,10 @@ func showAlpineRepositories(cfg *config.Config) { if line == "" || strings.HasPrefix(line, "#") { continue } - + // Color differently for testing repo if strings.Contains(line, "edge/testing") { - fmt.Printf("%s %s\n", + fmt.Printf("%s %s\n", style.Colored(style.Yellow, style.SymWarning), style.Colored(style.Yellow, line)) } else { @@ -255,31 +283,31 @@ func showAlpineRepositories(cfg *config.Config) { } } } else { - fmt.Printf("%s Could not read %s\n", + fmt.Printf("%s Could not read %s\n", style.Colored(style.Yellow, style.SymWarning), reposFile) } - + // Show testing repository flag fmt.Println() - if cfg.AlpineTestingRepo { - fmt.Printf("%s Testing repository: %s\n", + if m.config.AlpineTestingRepo { + fmt.Printf("%s Testing repository: %s\n", style.BulletItem, style.Colored(style.Yellow, "Enabled")) } else { - fmt.Printf("%s Testing repository: %s\n", + fmt.Printf("%s Testing repository: %s\n", style.BulletItem, style.Colored(style.Green, "Disabled")) } } // Helper function to show Debian/Ubuntu repositories -func showDebianRepositories(cfg *config.Config, osInfo *osdetect.OSInfo) { +func (m *SourcesMenu) showDebianRepositories() { // Check if sources file exists sourcesFile := "/etc/apt/sources.list" sourcesContent := "" - + if data, err := os.ReadFile(sourcesFile); err == nil { sourcesContent = string(data) } - + // Show main sources if sourcesContent != "" { fmt.Printf("%s %s:\n", style.BulletItem, style.Bolded("Main sources", style.Cyan)) @@ -289,19 +317,19 @@ func showDebianRepositories(cfg *config.Config, osInfo *osdetect.OSInfo) { if line == "" || strings.HasPrefix(line, "#") { continue } - + fmt.Printf(" %s\n", line) } } else { - fmt.Printf("%s Could not read %s\n", + fmt.Printf("%s Could not read %s\n", style.Colored(style.Yellow, style.SymWarning), sourcesFile) } - + // Check Proxmox repositories if relevant - if osInfo.IsProxmox { + if m.osInfo.IsProxmox { fmt.Println() fmt.Printf("%s %s:\n", style.BulletItem, style.Bolded("Proxmox repositories", style.Cyan)) - + // Check Ceph repo cephFile := "/etc/apt/sources.list.d/ceph.list" if data, err := os.ReadFile(cephFile); err == nil { @@ -312,14 +340,14 @@ func showDebianRepositories(cfg *config.Config, osInfo *osdetect.OSInfo) { if line == "" || strings.HasPrefix(line, "#") { continue } - + fmt.Printf(" %s\n", line) } } else { - fmt.Printf(" %s Ceph repository not configured\n", + fmt.Printf(" %s Ceph repository not configured\n", style.Colored(style.Yellow, style.SymWarning)) } - + // Check Enterprise repo pveFile := "/etc/apt/sources.list.d/pve-enterprise.list" if data, err := os.ReadFile(pveFile); err == nil { @@ -330,62 +358,62 @@ func showDebianRepositories(cfg *config.Config, osInfo *osdetect.OSInfo) { if line == "" || strings.HasPrefix(line, "#") { continue } - + fmt.Printf(" %s\n", line) } } else { - fmt.Printf(" %s Enterprise repository not configured\n", + fmt.Printf(" %s Enterprise repository not configured\n", style.Colored(style.Yellow, style.SymWarning)) } } - + // Show configured repositories fmt.Println() fmt.Printf("%s %s:\n", style.BulletItem, style.Bolded("Configured repositories", style.Cyan)) - - if len(cfg.DebianRepos) > 0 { - for _, repo := range cfg.DebianRepos { + + if len(m.config.DebianRepos) > 0 { + for _, repo := range m.config.DebianRepos { // Replace CODENAME placeholder with actual codename - displayRepo := strings.ReplaceAll(repo, "CODENAME", osInfo.OsCodename) + displayRepo := strings.ReplaceAll(repo, "CODENAME", m.osInfo.OsCodename) fmt.Printf(" %s\n", displayRepo) } } else { - fmt.Printf(" %s No repositories configured\n", + fmt.Printf(" %s No repositories configured\n", style.Colored(style.Yellow, style.SymWarning)) } - + // Show Proxmox configured repositories if relevant - if osInfo.IsProxmox { + if m.osInfo.IsProxmox { fmt.Println() - fmt.Printf("%s %s:\n", + fmt.Printf("%s %s:\n", style.BulletItem, style.Bolded("Configured Proxmox repositories", style.Cyan)) - + // Show Proxmox source repos - if len(cfg.ProxmoxSrcRepos) > 0 { + if len(m.config.ProxmoxSrcRepos) > 0 { fmt.Printf(" %s Source repositories:\n", style.BulletItem) - for _, repo := range cfg.ProxmoxSrcRepos { + for _, repo := range m.config.ProxmoxSrcRepos { // Replace CODENAME placeholder with actual codename - displayRepo := strings.ReplaceAll(repo, "CODENAME", osInfo.OsCodename) + displayRepo := strings.ReplaceAll(repo, "CODENAME", m.osInfo.OsCodename) fmt.Printf(" %s\n", displayRepo) } } - + // Show Ceph repos - if len(cfg.ProxmoxCephRepo) > 0 { + if len(m.config.ProxmoxCephRepo) > 0 { fmt.Printf(" %s Ceph repositories:\n", style.BulletItem) - for _, repo := range cfg.ProxmoxCephRepo { + for _, repo := range m.config.ProxmoxCephRepo { // Replace CODENAME placeholder with actual codename - displayRepo := strings.ReplaceAll(repo, "CODENAME", osInfo.OsCodename) + displayRepo := strings.ReplaceAll(repo, "CODENAME", m.osInfo.OsCodename) fmt.Printf(" %s\n", displayRepo) } } - + // Show Enterprise repos - if len(cfg.ProxmoxEnterpriseRepo) > 0 { + if len(m.config.ProxmoxEnterpriseRepo) > 0 { fmt.Printf(" %s Enterprise repositories:\n", style.BulletItem) - for _, repo := range cfg.ProxmoxEnterpriseRepo { + for _, repo := range m.config.ProxmoxEnterpriseRepo { // Replace CODENAME placeholder with actual codename - displayRepo := strings.ReplaceAll(repo, "CODENAME", osInfo.OsCodename) + displayRepo := strings.ReplaceAll(repo, "CODENAME", m.osInfo.OsCodename) fmt.Printf(" %s\n", displayRepo) } } @@ -393,37 +421,37 @@ func showDebianRepositories(cfg *config.Config, osInfo *osdetect.OSInfo) { } // Helper function to edit repositories -func editRepositoriesMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { +func (m *SourcesMenu) editRepositoriesMenu() { utils.PrintHeader() fmt.Println(style.Bolded("Edit Repositories", style.Blue)) - + // Create menu options var menuOptions []style.MenuOption - + // Basic options for all debian-based systems menuOptions = append(menuOptions, style.MenuOption{ - Number: 1, - Title: "Add repository", + Number: 1, + Title: "Add repository", Description: "Add a new repository to configuration", }) - - if len(cfg.DebianRepos) > 0 { + + if len(m.config.DebianRepos) > 0 { menuOptions = append(menuOptions, style.MenuOption{ - Number: 2, - Title: "Remove repository", + Number: 2, + Title: "Remove repository", Description: "Remove a repository from configuration", }) } - + // Proxmox specific options - if osInfo.IsProxmox { + if m.osInfo.IsProxmox { menuOptions = append(menuOptions, style.MenuOption{ - Number: 3, - Title: "Edit Proxmox repositories", + Number: 3, + Title: "Edit Proxmox repositories", Description: "Modify Proxmox-specific repositories", }) } - + // Create menu menu := style.NewMenu("Select an option", menuOptions) menu.SetExitOption(style.MenuOption{ @@ -431,139 +459,146 @@ func editRepositoriesMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { Title: "Return to sources menu", Description: "", }) - + // Display menu menu.Print() - - choice := ReadInput() - + + choice := ReadMenuInput() + + // Handle 'q' as a special exit case + if choice == "q" { + return + } + switch choice { case "1": // Add repository - fmt.Printf("\n%s Enter repository (e.g., 'deb http://deb.debian.org/debian CODENAME main'):\n", + fmt.Printf("\n%s Enter repository (e.g., 'deb http://deb.debian.org/debian CODENAME main'):\n", style.BulletItem) fmt.Printf("%s Use CODENAME as placeholder for the OS codename\n", style.BulletItem) fmt.Printf("> ") newRepo := ReadInput() - + if newRepo == "" { - fmt.Printf("\n%s Repository cannot be empty\n", + fmt.Printf("\n%s Repository cannot be empty\n", style.Colored(style.Red, style.SymCrossMark)) } else { // Check for duplicate isDuplicate := false - for _, repo := range cfg.DebianRepos { + for _, repo := range m.config.DebianRepos { if repo == newRepo { isDuplicate = true break } } - + if isDuplicate { - fmt.Printf("\n%s Repository already exists in configuration\n", + fmt.Printf("\n%s Repository already exists in configuration\n", style.Colored(style.Yellow, style.SymWarning)) } else { // Add new repository - cfg.DebianRepos = append(cfg.DebianRepos, newRepo) - + m.config.DebianRepos = append(m.config.DebianRepos, newRepo) + // Save config - saveSourcesConfig(cfg) - - fmt.Printf("\n%s Repository added to configuration\n", + m.saveSourcesConfig() + + fmt.Printf("\n%s Repository added to configuration\n", style.Colored(style.Green, style.SymCheckMark)) } } - + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - editRepositoriesMenu(cfg, osInfo) - + m.editRepositoriesMenu() + case "2": // Remove repository - if len(cfg.DebianRepos) == 0 { - fmt.Printf("\n%s No repositories to remove\n", + if len(m.config.DebianRepos) == 0 { + fmt.Printf("\n%s No repositories to remove\n", style.Colored(style.Yellow, style.SymWarning)) } else { fmt.Println() - for i, repo := range cfg.DebianRepos { + for i, repo := range m.config.DebianRepos { // Replace CODENAME placeholder with actual codename for display - displayRepo := strings.ReplaceAll(repo, "CODENAME", osInfo.OsCodename) + displayRepo := strings.ReplaceAll(repo, "CODENAME", m.osInfo.OsCodename) fmt.Printf("%s %d: %s\n", style.BulletItem, i+1, displayRepo) } - - fmt.Printf("\n%s Enter repository number to remove (1-%d): ", - style.BulletItem, len(cfg.DebianRepos)) + + fmt.Printf("\n%s Enter repository number to remove (1-%d): ", + style.BulletItem, len(m.config.DebianRepos)) numStr := ReadInput() - + // Parse number num := 0 - fmt.Sscanf(numStr, "%d", &num) - - if num < 1 || num > len(cfg.DebianRepos) { - fmt.Printf("\n%s Invalid repository number\n", + n, err := fmt.Sscanf(numStr, "%d", &num) + if err != nil || n != 1 { + fmt.Printf("\n%s Invalid repository number: not a valid number\n", + style.Colored(style.Red, style.SymCrossMark)) + } else if num < 1 || num > len(m.config.DebianRepos) { + fmt.Printf("\n%s Invalid repository number: out of range\n", style.Colored(style.Red, style.SymCrossMark)) } else { // Remove repository (adjust for 0-based index) - removedRepo := cfg.DebianRepos[num-1] - cfg.DebianRepos = append(cfg.DebianRepos[:num-1], cfg.DebianRepos[num:]...) - + removedRepo := m.config.DebianRepos[num-1] + m.config.DebianRepos = append(m.config.DebianRepos[:num-1], m.config.DebianRepos[num:]...) + // Save config - saveSourcesConfig(cfg) - + m.saveSourcesConfig() + // Replace CODENAME placeholder with actual codename for display - displayRepo := strings.ReplaceAll(removedRepo, "CODENAME", osInfo.OsCodename) - fmt.Printf("\n%s Repository removed from configuration:\n", + displayRepo := strings.ReplaceAll(removedRepo, "CODENAME", m.osInfo.OsCodename) + fmt.Printf("\n%s Repository removed from configuration:\n", style.Colored(style.Green, style.SymCheckMark)) fmt.Printf("%s %s\n", style.BulletItem, displayRepo) } } - + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - editRepositoriesMenu(cfg, osInfo) - + m.editRepositoriesMenu() + case "3": // Edit Proxmox repositories (only for Proxmox) - if osInfo.IsProxmox { - editProxmoxRepositoriesMenu(cfg, osInfo) - editRepositoriesMenu(cfg, osInfo) + if m.osInfo.IsProxmox { + m.editProxmoxRepositoriesMenu() + m.editRepositoriesMenu() return } else { - fmt.Printf("\n%s Invalid option for this OS type\n", + fmt.Printf("\n%s Invalid option for this OS type\n", style.Colored(style.Red, style.SymCrossMark)) - + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - editRepositoriesMenu(cfg, osInfo) + m.editRepositoriesMenu() } - + case "0": // Return to sources menu return - + default: - fmt.Printf("\n%s Invalid option. Please try again.\n", + fmt.Printf("\n%s Invalid option. Please try again.\n", style.Colored(style.Red, style.SymCrossMark)) - + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - editRepositoriesMenu(cfg, osInfo) + m.editRepositoriesMenu() return } } // Helper function to edit Proxmox repositories -func editProxmoxRepositoriesMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { +func (m *SourcesMenu) editProxmoxRepositoriesMenu() { utils.PrintHeader() fmt.Println(style.Bolded("Edit Proxmox Repositories", style.Blue)) - + // Create menu options menuOptions := []style.MenuOption{ {Number: 1, Title: "Edit source repositories", Description: "Modify main Proxmox repositories"}, {Number: 2, Title: "Edit Ceph repositories", Description: "Modify Proxmox Ceph repositories"}, {Number: 3, Title: "Edit Enterprise repositories", Description: "Modify Proxmox Enterprise repositories"}, } - + // Create menu menu := style.NewMenu("Select an option", menuOptions) menu.SetExitOption(style.MenuOption{ @@ -571,90 +606,95 @@ func editProxmoxRepositoriesMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { Title: "Return to edit repositories menu", Description: "", }) - + // Display menu menu.Print() - - choice := ReadInput() - + + choice := ReadMenuInput() + + // Handle 'q' as a special exit case + if choice == "q" { + return + } + switch choice { case "1": // Edit source repositories - editProxmoxRepoList(cfg, osInfo, "source", - "Proxmox Source Repositories", &cfg.ProxmoxSrcRepos) - editProxmoxRepositoriesMenu(cfg, osInfo) + m.editProxmoxRepoList("source", + "Proxmox Source Repositories", &m.config.ProxmoxSrcRepos) + m.editProxmoxRepositoriesMenu() return - + case "2": // Edit Ceph repositories - editProxmoxRepoList(cfg, osInfo, "ceph", - "Proxmox Ceph Repositories", &cfg.ProxmoxCephRepo) - editProxmoxRepositoriesMenu(cfg, osInfo) + m.editProxmoxRepoList("ceph", + "Proxmox Ceph Repositories", &m.config.ProxmoxCephRepo) + m.editProxmoxRepositoriesMenu() return - + case "3": // Edit Enterprise repositories - editProxmoxRepoList(cfg, osInfo, "enterprise", - "Proxmox Enterprise Repositories", &cfg.ProxmoxEnterpriseRepo) - editProxmoxRepositoriesMenu(cfg, osInfo) + m.editProxmoxRepoList("enterprise", + "Proxmox Enterprise Repositories", &m.config.ProxmoxEnterpriseRepo) + m.editProxmoxRepositoriesMenu() return - + case "0": // Return to edit repositories menu return - + default: - fmt.Printf("\n%s Invalid option. Please try again.\n", + fmt.Printf("\n%s Invalid option. Please try again.\n", style.Colored(style.Red, style.SymCrossMark)) - + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - editProxmoxRepositoriesMenu(cfg, osInfo) + m.editProxmoxRepositoriesMenu() return } } // Helper function to edit a Proxmox repository list -func editProxmoxRepoList(cfg *config.Config, osInfo *osdetect.OSInfo, +func (m *SourcesMenu) editProxmoxRepoList( repoType, title string, repoList *[]string) { - + utils.PrintHeader() fmt.Println(style.Bolded(title, style.Blue)) - + // Display current repositories fmt.Println() fmt.Println(style.Bolded("Current Repositories:", style.Blue)) - + if len(*repoList) == 0 { fmt.Printf("%s No repositories configured\n", style.BulletItem) } else { for i, repo := range *repoList { // Replace CODENAME placeholder with actual codename for display - displayRepo := strings.ReplaceAll(repo, "CODENAME", osInfo.OsCodename) + displayRepo := strings.ReplaceAll(repo, "CODENAME", m.osInfo.OsCodename) fmt.Printf("%s %d: %s\n", style.BulletItem, i+1, displayRepo) } } - + // Create menu options menuOptions := []style.MenuOption{ {Number: 1, Title: "Add repository", Description: "Add a new repository to configuration"}, } - + if len(*repoList) > 0 { menuOptions = append(menuOptions, style.MenuOption{ - Number: 2, - Title: "Remove repository", + Number: 2, + Title: "Remove repository", Description: "Remove a repository from configuration", }) } - + // Add options to use default repositories menuOptions = append(menuOptions, style.MenuOption{ - Number: 3, - Title: "Use default repositories", + Number: 3, + Title: "Use default repositories", Description: "Reset to recommended repositories", }) - + // Create menu menu := style.NewMenu("Select an option", menuOptions) menu.SetExitOption(style.MenuOption{ @@ -662,12 +702,17 @@ func editProxmoxRepoList(cfg *config.Config, osInfo *osdetect.OSInfo, Title: "Return to previous menu", Description: "", }) - + // Display menu menu.Print() - - choice := ReadInput() - + + choice := ReadMenuInput() + + // Handle 'q' as a special exit case + if choice == "q" { + return + } + switch choice { case "1": // Add repository @@ -675,9 +720,9 @@ func editProxmoxRepoList(cfg *config.Config, osInfo *osdetect.OSInfo, fmt.Printf("%s Use CODENAME as placeholder for the OS codename\n", style.BulletItem) fmt.Printf("> ") newRepo := ReadInput() - + if newRepo == "" { - fmt.Printf("\n%s Repository cannot be empty\n", + fmt.Printf("\n%s Repository cannot be empty\n", style.Colored(style.Red, style.SymCrossMark)) } else { // Check for duplicate @@ -688,78 +733,80 @@ func editProxmoxRepoList(cfg *config.Config, osInfo *osdetect.OSInfo, break } } - + if isDuplicate { - fmt.Printf("\n%s Repository already exists in configuration\n", + fmt.Printf("\n%s Repository already exists in configuration\n", style.Colored(style.Yellow, style.SymWarning)) } else { // Add new repository *repoList = append(*repoList, newRepo) - + // Save config - saveSourcesConfig(cfg) - - fmt.Printf("\n%s Repository added to configuration\n", + m.saveSourcesConfig() + + fmt.Printf("\n%s Repository added to configuration\n", style.Colored(style.Green, style.SymCheckMark)) } } - + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - editProxmoxRepoList(cfg, osInfo, repoType, title, repoList) + m.editProxmoxRepoList(repoType, title, repoList) return - + case "2": // Remove repository if len(*repoList) == 0 { - fmt.Printf("\n%s No repositories to remove\n", + fmt.Printf("\n%s No repositories to remove\n", style.Colored(style.Yellow, style.SymWarning)) } else { fmt.Println() for i, repo := range *repoList { // Replace CODENAME placeholder with actual codename for display - displayRepo := strings.ReplaceAll(repo, "CODENAME", osInfo.OsCodename) + displayRepo := strings.ReplaceAll(repo, "CODENAME", m.osInfo.OsCodename) fmt.Printf("%s %d: %s\n", style.BulletItem, i+1, displayRepo) } - - fmt.Printf("\n%s Enter repository number to remove (1-%d): ", + + fmt.Printf("\n%s Enter repository number to remove (1-%d): ", style.BulletItem, len(*repoList)) numStr := ReadInput() - + // Parse number num := 0 - fmt.Sscanf(numStr, "%d", &num) - - if num < 1 || num > len(*repoList) { - fmt.Printf("\n%s Invalid repository number\n", + n, err := fmt.Sscanf(numStr, "%d", &num) + if err != nil || n != 1 { + fmt.Printf("\n%s Invalid repository number: not a valid number\n", + style.Colored(style.Red, style.SymCrossMark)) + } else if num < 1 || num > len(*repoList) { + fmt.Printf("\n%s Invalid repository number: out of range\n", style.Colored(style.Red, style.SymCrossMark)) } else { // Remove repository (adjust for 0-based index) removedRepo := (*repoList)[num-1] *repoList = append((*repoList)[:num-1], (*repoList)[num:]...) - + // Save config - saveSourcesConfig(cfg) - + m.saveSourcesConfig() + // Replace CODENAME placeholder with actual codename for display - displayRepo := strings.ReplaceAll(removedRepo, "CODENAME", osInfo.OsCodename) - fmt.Printf("\n%s Repository removed from configuration:\n", + displayRepo := strings.ReplaceAll(removedRepo, "CODENAME", m.osInfo.OsCodename) + fmt.Printf("\n%s Repository removed from configuration:\n", style.Colored(style.Green, style.SymCheckMark)) fmt.Printf("%s %s\n", style.BulletItem, displayRepo) } } - + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - editProxmoxRepoList(cfg, osInfo, repoType, title, repoList) + m.editProxmoxRepoList(repoType, title, repoList) return - + case "3": // Use default repositories - fmt.Printf("\n%s Reset to default repositories? This will overwrite current configuration. (y/n): ", + fmt.Printf("\n%s Reset to default repositories? This will overwrite current configuration. (y/n): ", style.Colored(style.Yellow, style.SymWarning)) confirm := ReadInput() - + if strings.ToLower(confirm) == "y" || strings.ToLower(confirm) == "yes" { // Set default repositories based on type switch repoType { @@ -780,43 +827,42 @@ func editProxmoxRepoList(cfg *config.Config, osInfo *osdetect.OSInfo, "#deb https://enterprise.proxmox.com/debian/pve CODENAME pve-enterprise", } } - + // Save config - saveSourcesConfig(cfg) - - fmt.Printf("\n%s Repositories reset to defaults\n", + m.saveSourcesConfig() + + fmt.Printf("\n%s Repositories reset to defaults\n", style.Colored(style.Green, style.SymCheckMark)) } else { fmt.Println("\nOperation cancelled.") } - + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - editProxmoxRepoList(cfg, osInfo, repoType, title, repoList) + m.editProxmoxRepoList(repoType, title, repoList) return - + case "0": // Return to previous menu return - + default: - fmt.Printf("\n%s Invalid option. Please try again.\n", + fmt.Printf("\n%s Invalid option. Please try again.\n", style.Colored(style.Red, style.SymCrossMark)) - + fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - editProxmoxRepoList(cfg, osInfo, repoType, title, repoList) + m.editProxmoxRepoList(repoType, title, repoList) return } } // Helper function to save sources configuration -func saveSourcesConfig(cfg *config.Config) { +func (m *SourcesMenu) saveSourcesConfig() { // Save config changes configFile := "hardn.yml" // Default config file - if err := config.SaveConfig(cfg, configFile); err != nil { - logging.LogError("Failed to save configuration: %v", err) - fmt.Printf("\n%s Failed to save configuration: %v\n", + if err := config.SaveConfig(m.config, configFile); err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", style.Colored(style.Red, style.SymCrossMark), err) } -} \ No newline at end of file +} diff --git a/pkg/menu/user.go b/pkg/menu/user_menu.go similarity index 63% rename from pkg/menu/user.go rename to pkg/menu/user_menu.go index c392ce8..2f24f97 100644 --- a/pkg/menu/user.go +++ b/pkg/menu/user_menu.go @@ -1,5 +1,4 @@ -// pkg/menu/user.go - +// pkg/menu/user_menu.go package menu import ( @@ -7,62 +6,80 @@ import ( osuser "os/user" "strings" + "github.com/abbott/hardn/pkg/application" "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/logging" "github.com/abbott/hardn/pkg/osdetect" - "github.com/abbott/hardn/pkg/ssh" "github.com/abbott/hardn/pkg/style" "github.com/abbott/hardn/pkg/utils" - "github.com/abbott/hardn/pkg/user" ) -// UserCreationMenu handles creating a non-root user with sudo access -func UserCreationMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { +// UserMenu handles user-related operations through the menu system +type UserMenu struct { + menuManager *application.MenuManager + config *config.Config + osInfo *osdetect.OSInfo +} + +// NewUserMenu creates a new UserMenu +func NewUserMenu( + menuManager *application.MenuManager, + config *config.Config, + osInfo *osdetect.OSInfo, +) *UserMenu { + return &UserMenu{ + menuManager: menuManager, + config: config, + osInfo: osInfo, + } +} + +// Show displays the user menu and handles input +func (m *UserMenu) Show() { utils.PrintHeader() fmt.Println(style.Bolded("User Creation", style.Blue)) // Display current user settings fmt.Println() fmt.Println(style.Bolded("Current User Configuration:", style.Blue)) - + // Format user status formatter := style.NewStatusFormatter([]string{"Username", "Sudo Access", "SSH Keys"}, 2) - + // Username status - if cfg.Username != "" { - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "Username", - cfg.Username, style.Cyan, "", "light")) + if m.config.Username != "" { + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "Username", + m.config.Username, style.Cyan, "", "light")) } else { fmt.Println(formatter.FormatWarning("Username", "Not set", "Please provide a username")) } - + // Sudo access status sudoStatus := "No password required" - if !cfg.SudoNoPassword { + if !m.config.SudoNoPassword { sudoStatus = "Password required" } - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "Sudo Access", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "Sudo Access", sudoStatus, style.Cyan, "", "light")) - + // SSH key status - keyCount := len(cfg.SshKeys) + keyCount := len(m.config.SshKeys) keyStatus := "None configured" if keyCount > 0 { keyStatus = fmt.Sprintf("%d key(s) configured", keyCount) } - fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "SSH Keys", + fmt.Println(formatter.FormatLine(style.SymInfo, style.Cyan, "SSH Keys", keyStatus, style.Cyan, "", "light")) - + // Check if user already exists var userExists bool - var username string = cfg.Username - + var username string = m.config.Username + if username != "" { _, err := osuser.Lookup(username) userExists = (err == nil) - + if userExists { - fmt.Printf("\n%s User '%s' already exists on the system\n", + fmt.Printf("\n%s User '%s' already exists on the system\n", style.Colored(style.Yellow, style.SymInfo), username) } } @@ -73,55 +90,55 @@ func UserCreationMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { // Add or change username option if username == "" { menuOptions = append(menuOptions, style.MenuOption{ - Number: 1, - Title: "Enter username", + Number: 1, + Title: "Enter username", Description: "Specify username to create", }) } else { menuOptions = append(menuOptions, style.MenuOption{ - Number: 1, - Title: "Change username", + Number: 1, + Title: "Change username", Description: fmt.Sprintf("Current: %s", username), }) } - + // Toggle sudo password option - if cfg.SudoNoPassword { + if m.config.SudoNoPassword { menuOptions = append(menuOptions, style.MenuOption{ - Number: 2, - Title: "Require sudo password", + Number: 2, + Title: "Require sudo password", Description: "Change sudo to require password", }) } else { menuOptions = append(menuOptions, style.MenuOption{ - Number: 2, - Title: "Allow sudo without password", + Number: 2, + Title: "Allow sudo without password", Description: "Change sudo to not require password", }) } - + // Manage SSH keys option menuOptions = append(menuOptions, style.MenuOption{ - Number: 3, - Title: "Manage SSH keys", + Number: 3, + Title: "Manage SSH keys", Description: "Add or remove SSH public keys", }) - + // Create user option (only if username is set and user doesn't exist) if username != "" && !userExists { menuOptions = append(menuOptions, style.MenuOption{ - Number: 4, - Title: "Create user", + Number: 4, + Title: "Create user", Description: fmt.Sprintf("Create user '%s' with current settings", username), }) } else if username != "" && userExists { menuOptions = append(menuOptions, style.MenuOption{ - Number: 4, - Title: "Update user", + Number: 4, + Title: "Update user", Description: fmt.Sprintf("Update SSH configuration for user '%s'", username), }) } - + // Create menu menu := style.NewMenu("Select an option", menuOptions) menu.SetExitOption(style.MenuOption{ @@ -129,12 +146,17 @@ func UserCreationMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { Title: "Return to main menu", Description: "", }) - + // Display menu menu.Print() - - choice := ReadInput() - + + choice := ReadMenuInput() + + // Handle 'q' as a special exit case + if choice == "q" { + return + } + switch choice { case "1": // Set or change username @@ -144,194 +166,189 @@ func UserCreationMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { fmt.Printf("\n%s Current username: %s\n", style.BulletItem, username) fmt.Printf("%s Enter new username (leave empty to keep current): ", style.BulletItem) } - + newUsername := ReadInput() if newUsername != "" { - cfg.Username = newUsername - + m.config.Username = newUsername + // Check if new user exists _, err := osuser.Lookup(newUsername) if err == nil { - fmt.Printf("\n%s User '%s' already exists on the system\n", + fmt.Printf("\n%s User '%s' already exists on the system\n", style.Colored(style.Yellow, style.SymInfo), newUsername) } - - fmt.Printf("\n%s Username set to: %s\n", + + fmt.Printf("\n%s Username set to: %s\n", style.Colored(style.Green, style.SymCheckMark), newUsername) - + // Save config changes - saveUserConfig(cfg) + err = config.SaveConfig(m.config, "hardn.yml") + if err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } } else if username != "" { fmt.Printf("\n%s Username unchanged: %s\n", style.BulletItem, username) } - + // Return to this menu after changing username fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - UserCreationMenu(cfg, osInfo) - + m.Show() + case "2": // Toggle sudo password requirement - cfg.SudoNoPassword = !cfg.SudoNoPassword - - if cfg.SudoNoPassword { - fmt.Printf("\n%s Sudo will %s for user '%s'\n", + m.config.SudoNoPassword = !m.config.SudoNoPassword + + if m.config.SudoNoPassword { + fmt.Printf("\n%s Sudo will %s for user '%s'\n", style.Colored(style.Green, style.SymCheckMark), style.Bolded("NOT require a password", style.Green), - cfg.Username) + m.config.Username) } else { - fmt.Printf("\n%s Sudo will %s for user '%s'\n", + fmt.Printf("\n%s Sudo will %s for user '%s'\n", style.Colored(style.Green, style.SymCheckMark), style.Bolded("require a password", style.Green), - cfg.Username) + m.config.Username) } - + // Save config changes - saveUserConfig(cfg) - + err := config.SaveConfig(m.config, "hardn.yml") + if err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } + // Return to this menu after toggling sudo fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - UserCreationMenu(cfg, osInfo) - + m.Show() + case "3": // Manage SSH keys - manageSshKeysMenu(cfg, osInfo) - UserCreationMenu(cfg, osInfo) - + m.manageSshKeys() + m.Show() + case "4": // Create or update user if username == "" { - fmt.Printf("\n%s No username provided. Please enter a username first.\n", + fmt.Printf("\n%s No username provided. Please enter a username first.\n", style.Colored(style.Red, style.SymCrossMark)) - + // Return to this menu fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - UserCreationMenu(cfg, osInfo) + m.Show() return } - + // Confirm keys are configured - if len(cfg.SshKeys) == 0 { - fmt.Printf("\n%s Warning: No SSH keys configured. User will not have SSH access.\n", + if len(m.config.SshKeys) == 0 { + fmt.Printf("\n%s Warning: No SSH keys configured. User will not have SSH access.\n", style.Colored(style.Yellow, style.SymWarning)) fmt.Printf("%s Would you like to continue anyway? (y/n): ", style.BulletItem) - + confirm := ReadInput() if !strings.EqualFold(confirm, "y") && !strings.EqualFold(confirm, "yes") { - fmt.Printf("\n%s Operation cancelled. Please add SSH keys first.\n", + fmt.Printf("\n%s Operation cancelled. Please add SSH keys first.\n", style.Colored(style.Yellow, style.SymInfo)) - + // Return to this menu fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - UserCreationMenu(cfg, osInfo) + m.Show() return } } - + // Determine action based on whether user exists action := "Creating" if userExists { action = "Updating" } - - // Create or update user + + // Create or update user using menuManager fmt.Printf("\n%s %s user '%s'...\n", style.BulletItem, action, username) - - if !userExists { - err := user.CreateUser(username, cfg, osInfo) - if err != nil { - fmt.Printf("\n%s Failed to create user: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - logging.LogError("Failed to create user: %v", err) - } else if !cfg.DryRun { - fmt.Printf("\n%s User '%s' created successfully\n", - style.Colored(style.Green, style.SymCheckMark), username) - } - } - - // Configure SSH - fmt.Printf("\n%s Configuring SSH for user '%s'...\n", style.BulletItem, username) - err := ssh.WriteSSHConfig(cfg, osInfo) + + err := m.menuManager.CreateUser(username, true, m.config.SudoNoPassword, m.config.SshKeys) if err != nil { - fmt.Printf("\n%s Failed to configure SSH: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - logging.LogError("Failed to configure SSH: %v", err) - } else if !cfg.DryRun { - fmt.Printf("\n%s SSH configured successfully\n", - style.Colored(style.Green, style.SymCheckMark)) + fmt.Printf("\n%s Failed to %s user: %v\n", + style.Colored(style.Red, style.SymCrossMark), strings.ToLower(action), err) + } else if !m.config.DryRun { + fmt.Printf("\n%s User '%s' %s successfully\n", + style.Colored(style.Green, style.SymCheckMark), + username, + strings.ToLower(action)+"d") } - + case "0": // Return to main menu return - + default: - fmt.Printf("\n%s Invalid option. Please try again.\n", + fmt.Printf("\n%s Invalid option. Please try again.\n", style.Colored(style.Red, style.SymCrossMark)) - + // Return to this menu fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - UserCreationMenu(cfg, osInfo) + m.Show() } - + fmt.Printf("\n%s Press any key to return to the main menu...", style.BulletItem) ReadKey() } -// Helper function to manage SSH keys -func manageSshKeysMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { +// manageSshKeys handles SSH key management +func (m *UserMenu) manageSshKeys() { utils.PrintHeader() fmt.Println(style.Bolded("Manage SSH Keys", style.Blue)) - + // Display current keys fmt.Println() fmt.Println(style.Bolded("Current SSH Keys:", style.Blue)) - - if len(cfg.SshKeys) == 0 { + + if len(m.config.SshKeys) == 0 { fmt.Printf("%s No SSH keys configured\n", style.BulletItem) } else { - for i, key := range cfg.SshKeys { + for i, key := range m.config.SshKeys { // Try to extract comment from key (usually contains email or identifier) keyParts := strings.Fields(key) keyInfo := "" if len(keyParts) >= 3 { keyInfo = keyParts[2] } - + // Truncate the key for display truncatedKey := key if len(key) > 30 { truncatedKey = key[:15] + "..." + key[len(key)-15:] } - - fmt.Printf("%s Key %d: %s", style.BulletItem, i+1, + + fmt.Printf("%s Key %d: %s", style.BulletItem, i+1, style.Colored(style.Cyan, truncatedKey)) - + if keyInfo != "" { fmt.Printf(" (%s)", keyInfo) } fmt.Println() } } - + // Create menu options menuOptions := []style.MenuOption{ {Number: 1, Title: "Add SSH key", Description: "Add a new SSH public key"}, } - + // Only add remove option if keys exist - if len(cfg.SshKeys) > 0 { + if len(m.config.SshKeys) > 0 { menuOptions = append(menuOptions, style.MenuOption{ - Number: 2, - Title: "Remove SSH key", + Number: 2, + Title: "Remove SSH key", Description: "Remove an existing SSH public key", }) } - + // Create menu menu := style.NewMenu("Select an option", menuOptions) menu.SetExitOption(style.MenuOption{ @@ -339,109 +356,128 @@ func manageSshKeysMenu(cfg *config.Config, osInfo *osdetect.OSInfo) { Title: "Return to user menu", Description: "", }) - + // Display menu menu.Print() - - choice := ReadInput() - + + choice := ReadMenuInput() + + // Handle 'q' as a special exit case + if choice == "q" { + return + } + switch choice { case "1": // Add SSH key fmt.Printf("\n%s Paste SSH public key (e.g., ssh-ed25519 AAAAC3NzaC1lZDI1...): \n", style.BulletItem) newKey := ReadInput() - + if newKey != "" { // Validate key format if !strings.HasPrefix(newKey, "ssh-") && !strings.HasPrefix(newKey, "ecdsa-") { - fmt.Printf("\n%s Invalid SSH key format. Key should start with 'ssh-' or 'ecdsa-'\n", + fmt.Printf("\n%s Invalid SSH key format. Key should start with 'ssh-' or 'ecdsa-'\n", style.Colored(style.Red, style.SymCrossMark)) } else { // Add key - cfg.SshKeys = append(cfg.SshKeys, newKey) - fmt.Printf("\n%s SSH key added successfully\n", + m.config.SshKeys = append(m.config.SshKeys, newKey) + fmt.Printf("\n%s SSH key added successfully\n", style.Colored(style.Green, style.SymCheckMark)) - + // Save config changes - saveUserConfig(cfg) + err := config.SaveConfig(m.config, "hardn.yml") + if err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } + + // If user already exists, add key to user + if m.config.Username != "" { + _, err := osuser.Lookup(m.config.Username) + if err == nil { + err = m.menuManager.AddSSHKey(m.config.Username, newKey) + if err != nil { + fmt.Printf("\n%s Failed to add SSH key to user: %v\n", + style.Colored(style.Yellow, style.SymWarning), err) + } else if !m.config.DryRun { + fmt.Printf("%s Key added to user '%s'\n", + style.BulletItem, m.config.Username) + } + } + } } } - + // Return to SSH keys menu fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - manageSshKeysMenu(cfg, osInfo) - + m.manageSshKeys() + case "2": // Only process if keys exist - if len(cfg.SshKeys) == 0 { - fmt.Printf("\n%s No keys to remove\n", + if len(m.config.SshKeys) == 0 { + fmt.Printf("\n%s No keys to remove\n", style.Colored(style.Yellow, style.SymWarning)) - + // Return to SSH keys menu fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - manageSshKeysMenu(cfg, osInfo) + m.manageSshKeys() return } - + // Remove SSH key - fmt.Printf("\n%s Enter key number to remove (1-%d): ", style.BulletItem, len(cfg.SshKeys)) + fmt.Printf("\n%s Enter key number to remove (1-%d): ", style.BulletItem, len(m.config.SshKeys)) keyNumStr := ReadInput() - keyNum := 0 - + // Parse key number - fmt.Sscanf(keyNumStr, "%d", &keyNum) - - if keyNum < 1 || keyNum > len(cfg.SshKeys) { - fmt.Printf("\n%s Invalid key number. Please enter a number between 1 and %d\n", - style.Colored(style.Red, style.SymCrossMark), len(cfg.SshKeys)) + keyNum := 0 + n, err := fmt.Sscanf(keyNumStr, "%d", &keyNum) + if err != nil || n != 1 { + fmt.Printf("\n%s Invalid key number: not a valid number\n", + style.Colored(style.Red, style.SymCrossMark)) + } else if keyNum < 1 || keyNum > len(m.config.SshKeys) { + fmt.Printf("\n%s Invalid key number. Please enter a number between 1 and %d\n", + style.Colored(style.Red, style.SymCrossMark), len(m.config.SshKeys)) } else { // Remove key (adjusting for 0-based indexing) - removedKey := cfg.SshKeys[keyNum-1] - cfg.SshKeys = append(cfg.SshKeys[:keyNum-1], cfg.SshKeys[keyNum:]...) - - fmt.Printf("\n%s SSH key %d removed successfully\n", + removedKey := m.config.SshKeys[keyNum-1] + m.config.SshKeys = append(m.config.SshKeys[:keyNum-1], m.config.SshKeys[keyNum:]...) + + fmt.Printf("\n%s SSH key %d removed successfully\n", style.Colored(style.Green, style.SymCheckMark), keyNum) - + // Show truncated key that was removed if len(removedKey) > 30 { removedKey = removedKey[:15] + "..." + removedKey[len(removedKey)-15:] } - fmt.Printf("%s Removed: %s\n", style.BulletItem, + fmt.Printf("%s Removed: %s\n", style.BulletItem, style.Colored(style.Yellow, removedKey)) - + // Save config changes - saveUserConfig(cfg) + err := config.SaveConfig(m.config, "hardn.yml") + if err != nil { + fmt.Printf("\n%s Failed to save configuration: %v\n", + style.Colored(style.Red, style.SymCrossMark), err) + } } - + // Return to SSH keys menu fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - manageSshKeysMenu(cfg, osInfo) - + m.manageSshKeys() + case "0": // Return to user menu return - + default: - fmt.Printf("\n%s Invalid option. Please try again.\n", + fmt.Printf("\n%s Invalid option. Please try again.\n", style.Colored(style.Red, style.SymCrossMark)) - + // Return to SSH keys menu fmt.Printf("\n%s Press any key to continue...", style.BulletItem) ReadKey() - manageSshKeysMenu(cfg, osInfo) + m.manageSshKeys() } } - -// Helper function to save user configuration -func saveUserConfig(cfg *config.Config) { - // Save config changes - configFile := "hardn.yml" // Default config file - if err := config.SaveConfig(cfg, configFile); err != nil { - logging.LogError("Failed to save configuration: %v", err) - fmt.Printf("\n%s Failed to save configuration: %v\n", - style.Colored(style.Red, style.SymCrossMark), err) - } -} \ No newline at end of file diff --git a/pkg/osdetect/osdetect.go b/pkg/osdetect/osdetect.go index ceded5a..14f0d63 100644 --- a/pkg/osdetect/osdetect.go +++ b/pkg/osdetect/osdetect.go @@ -77,4 +77,4 @@ func DetectOS() (*OSInfo, error) { } return osInfo, nil -} \ No newline at end of file +} diff --git a/pkg/packages/packages.go b/pkg/packages/packages.go deleted file mode 100644 index 0972975..0000000 --- a/pkg/packages/packages.go +++ /dev/null @@ -1,482 +0,0 @@ -package packages - -import ( - "fmt" - "os" - "os/exec" - "strings" - - "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/logging" - "github.com/abbott/hardn/pkg/osdetect" - "github.com/abbott/hardn/pkg/utils" -) - -// IsPackageInstalled checks if a package is installed -func IsPackageInstalled(packageName string) bool { - var cmd *exec.Cmd - - // Check for dpkg (Debian/Ubuntu) first - if _, err := exec.LookPath("dpkg"); err == nil { - cmd = exec.Command("dpkg", "-l", packageName) - output, err := cmd.CombinedOutput() - if err == nil && strings.Contains(string(output), packageName) { - return true - } - } - - // Check for apk (Alpine) - if _, err := exec.LookPath("apk"); err == nil { - cmd = exec.Command("apk", "info", "-e", packageName) - if err := cmd.Run(); err == nil { - return true - } - } - - return false -} - -// WriteSources writes the appropriate repository sources based on OS type -func WriteSources(cfg *config.Config, osInfo *osdetect.OSInfo) error { - if cfg.DryRun { - if osInfo.OsType == "alpine" { - logging.LogInfo("[DRY-RUN] Configure Alpine repositories in /etc/apk/repositories:") - logging.LogInfo("[DRY-RUN] - Add main repository: https://dl-cdn.alpinelinux.org/alpine/v%s/main", osInfo.OsVersion[:strings.LastIndex(osInfo.OsVersion, ".")]) - logging.LogInfo("[DRY-RUN] - Add community repository: https://dl-cdn.alpinelinux.org/alpine/v%s/community", osInfo.OsVersion[:strings.LastIndex(osInfo.OsVersion, ".")]) - if cfg.AlpineTestingRepo { - logging.LogInfo("[DRY-RUN] - Add testing repository: https://dl-cdn.alpinelinux.org/alpine/edge/testing") - } - } else if osInfo.IsProxmox { - logging.LogInfo("[DRY-RUN] Configure Proxmox repositories in /etc/apt/sources.list:") - for _, repo := range cfg.ProxmoxSrcRepos { - logging.LogInfo("[DRY-RUN] - Add: %s", strings.ReplaceAll(repo, "CODENAME", osInfo.OsCodename)) - } - } else { - logging.LogInfo("[DRY-RUN] Configure %s repositories in /etc/apt/sources.list:", osInfo.OsType) - for _, repo := range cfg.DebianRepos { - logging.LogInfo("[DRY-RUN] - Add: %s", strings.ReplaceAll(repo, "CODENAME", osInfo.OsCodename)) - } - } - return nil - } - - if osInfo.OsType == "alpine" { - logging.LogInfo("Configuring Alpine repositories...") - - // Format Alpine version for repositories - versionPrefix := osInfo.OsVersion - if idx := strings.LastIndex(versionPrefix, "."); idx != -1 { - versionPrefix = versionPrefix[:idx] - } - - // Create Alpine repository file content - content := fmt.Sprintf(`# Main repositories -https://dl-cdn.alpinelinux.org/alpine/v%s/main -https://dl-cdn.alpinelinux.org/alpine/v%s/community - -# Security updates -https://dl-cdn.alpinelinux.org/alpine/v%s/main -https://dl-cdn.alpinelinux.org/alpine/v%s/community -`, versionPrefix, versionPrefix, versionPrefix, versionPrefix) - - // testing repo if enabled - if cfg.AlpineTestingRepo { - content += ` -# Testing repository (use with caution) -https://dl-cdn.alpinelinux.org/alpine/edge/testing -` - logging.LogInfo("Alpine testing repository enabled") - } - - // Write the file - if err := os.WriteFile("/etc/apk/repositories", []byte(content), 0644); err != nil { - return fmt.Errorf("failed to write Alpine repositories for version %s: %w", versionPrefix, err) - } - - // Update package index - cmd := exec.Command("apk", "update") - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to update Alpine package index for version %s: %w", versionPrefix, err) - } - - logging.LogSuccess("Alpine repositories configured successfully") - } else if osInfo.IsProxmox { - logging.LogInfo("Writing Proxmox sources list to /etc/apt/sources.list") - - // Prepare content by replacing CODENAME placeholder - var content strings.Builder - for _, repo := range cfg.ProxmoxSrcRepos { - content.WriteString(strings.ReplaceAll(repo, "CODENAME", osInfo.OsCodename)) - content.WriteString("\n") - } - - // Backup original file - utils.BackupFile("/etc/apt/sources.list", cfg) - - // Write the file - if err := os.WriteFile("/etc/apt/sources.list", []byte(content.String()), 0644); err != nil { - return fmt.Errorf("failed to write Proxmox sources list for %s: %w", osInfo.OsCodename, err) - } - - logging.LogSuccess("Proxmox repositories configured successfully") - } else { - logging.LogInfo("Writing %s sources list to /etc/apt/sources.list", osInfo.OsCodename) - - // Prepare content by replacing CODENAME placeholder - var content strings.Builder - for _, repo := range cfg.DebianRepos { - content.WriteString(strings.ReplaceAll(repo, "CODENAME", osInfo.OsCodename)) - content.WriteString("\n") - } - - // Backup original file - utils.BackupFile("/etc/apt/sources.list", cfg) - - // Write the file - if err := os.WriteFile("/etc/apt/sources.list", []byte(content.String()), 0644); err != nil { - return fmt.Errorf("failed to write Debian/Ubuntu sources list for %s: %w", osInfo.OsCodename, err) - } - - logging.LogSuccess("Debian/Ubuntu repositories configured successfully") - } - - return nil -} - -// WriteProxmoxRepos writes Proxmox-specific repositories -func WriteProxmoxRepos(cfg *config.Config, osInfo *osdetect.OSInfo) error { - if !osInfo.IsProxmox { - return nil - } - - if cfg.DryRun { - logging.LogInfo("[DRY-RUN] Write Proxmox Ceph repository to /etc/apt/sources.list.d/ceph.list") - logging.LogInfo("[DRY-RUN] Write Proxmox Enterprise repository to /etc/apt/sources.list.d/pve-enterprise.list") - return nil - } - - logging.LogInfo("Writing Proxmox Ceph repository to /etc/apt/sources.list.d/ceph.list") - - // Prepare content for Ceph repository - var cephContent strings.Builder - for _, repo := range cfg.ProxmoxCephRepo { - cephContent.WriteString(strings.ReplaceAll(repo, "CODENAME", osInfo.OsCodename)) - cephContent.WriteString("\n") - } - - // Backup original files - utils.BackupFile("/etc/apt/sources.list.d/ceph.list", cfg) - - // Write Ceph repository - if err := os.MkdirAll("/etc/apt/sources.list.d", 0755); err != nil { - return fmt.Errorf("failed to create sources.list.d directory for Proxmox: %w", err) - } - - if err := os.WriteFile("/etc/apt/sources.list.d/ceph.list", []byte(cephContent.String()), 0644); err != nil { - return fmt.Errorf("failed to write Proxmox Ceph repository for %s: %w", osInfo.OsCodename, err) - } - - // Prepare content for Enterprise repository - logging.LogInfo("Writing Proxmox Enterprise repository to /etc/apt/sources.list.d/pve-enterprise.list") - - var enterpriseContent strings.Builder - for _, repo := range cfg.ProxmoxEnterpriseRepo { - enterpriseContent.WriteString(strings.ReplaceAll(repo, "CODENAME", osInfo.OsCodename)) - enterpriseContent.WriteString("\n") - } - - // Backup original file - utils.BackupFile("/etc/apt/sources.list.d/pve-enterprise.list", cfg) - - // Write Enterprise repository - if err := os.WriteFile("/etc/apt/sources.list.d/pve-enterprise.list", []byte(enterpriseContent.String()), 0644); err != nil { - return fmt.Errorf("failed to write Proxmox Enterprise repository for %s: %w", osInfo.OsCodename, err) - } - - logging.LogSuccess("Proxmox-specific repositories configured") - return nil -} - -// HoldProxmoxPackages holds Proxmox packages to prevent removal -func HoldProxmoxPackages(osInfo *osdetect.OSInfo, patterns []string) error { - if !osInfo.IsProxmox { - return nil - } - - logging.LogInfo("Holding Proxmox packages to prevent removal...") - - for _, pattern := range patterns { - // Get packages matching the pattern - cmd := exec.Command("dpkg-query", "-W", "-f=${binary:Package}\n") - output, err := cmd.Output() - if err != nil { - return fmt.Errorf("failed to query packages with pattern %s: %w", pattern, err) - } - - // Mark packages as held - for _, pkg := range strings.Split(string(output), "\n") { - if pkg == "" { - continue - } - - if strings.HasPrefix(pkg, pattern) { - holdCmd := exec.Command("apt-mark", "hold", pkg) - if err := holdCmd.Run(); err != nil { - logging.LogError("Failed to hold package %s: %v", pkg, err) - } - } - } - } - - logging.LogSuccess("Proxmox packages protected") - return nil -} - -// UnholdProxmoxPackages releases Proxmox packages after script completion -func UnholdProxmoxPackages(osInfo *osdetect.OSInfo, patterns []string) error { - if !osInfo.IsProxmox { - return nil - } - - logging.LogInfo("Unholding Proxmox packages...") - - for _, pattern := range patterns { - // Get packages matching the pattern - cmd := exec.Command("dpkg-query", "-W", "-f=${binary:Package}\n") - output, err := cmd.Output() - if err != nil { - return fmt.Errorf("failed to query packages with pattern %s for unhold: %w", pattern, err) - } - - // Mark packages as unhold - for _, pkg := range strings.Split(string(output), "\n") { - if pkg == "" { - continue - } - - if strings.HasPrefix(pkg, pattern) { - unholdCmd := exec.Command("apt-mark", "unhold", pkg) - if err := unholdCmd.Run(); err != nil { - logging.LogError("Failed to unhold package %s: %v", pkg, err) - } - } - } - } - - logging.LogSuccess("Proxmox packages released") - return nil -} - -// InstallPackages installs a list of packages based on OS type -func InstallPackages(packages []string, osInfo *osdetect.OSInfo, cfg *config.Config) error { - if len(packages) == 0 { - return nil - } - - // Check for dry-run mode - if cfg.DryRun { - logging.LogInfo("[DRY-RUN] Install: %s", strings.Join(packages, ", ")) - return nil - } - - packagesList := strings.Join(packages, ", ") - logging.LogInfo("Installing %s packages: %s", osInfo.OsType, packagesList) - - if osInfo.OsType == "alpine" { - cmd := exec.Command("apk", append([]string{"add", "--no-cache"}, packages...)...) - output, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("failed to install Alpine packages [%s]: %w\n%s", packagesList, err, output) - } - } else { - // Hold Proxmox packages if necessary - if osInfo.IsProxmox { - HoldProxmoxPackages(osInfo, []string{"proxmox", "pve"}) - } - - // Update package lists - updateCmd := exec.Command("apt-get", "update") - updateOutput, err := updateCmd.CombinedOutput() - if err != nil { - return fmt.Errorf("failed to update package lists for %s: %w\n%s", packagesList, err, updateOutput) - } - - // Install locales first for Debian/Ubuntu - localesCmd := exec.Command("apt-get", "install", "--yes", "locales") - localesOutput, err := localesCmd.CombinedOutput() - if err != nil { - logging.LogError("Failed to install locales: %v\n%s", err, localesOutput) - } else { - logging.LogInstall("locales") - } - - // Configure locales - sedCmd := exec.Command("sed", "-i", "/en_US.UTF-8/s/^# //g", "/etc/locale.gen") - if err := sedCmd.Run(); err != nil { - logging.LogError("Failed to configure locales: %v", err) - } - - localeGenCmd := exec.Command("locale-gen") - if err := localeGenCmd.Run(); err != nil { - logging.LogError("Failed to generate locales: %v", err) - } - - // Install packages - installCmd := exec.Command("apt-get", append([]string{"install", "--yes"}, packages...)...) - installOutput, err := installCmd.CombinedOutput() - if err != nil { - return fmt.Errorf("failed to install Debian/Ubuntu packages [%s]: %w\n%s", packagesList, err, installOutput) - } - - // Clean up - autoremoveCmd := exec.Command("apt-get", "autoremove", "--yes") - if err := autoremoveCmd.Run(); err != nil { - logging.LogError("Failed to autoremove packages after installing %s: %v", packagesList, err) - } - - cleanCmd := exec.Command("apt-get", "clean") - if err := cleanCmd.Run(); err != nil { - logging.LogError("Failed to clean package cache after installing %s: %v", packagesList, err) - } - - rmCmd := exec.Command("rm", "-rf", "/var/lib/apt/lists/*") - if err := rmCmd.Run(); err != nil { - logging.LogError("Failed to remove apt lists after installing %s: %v", packagesList, err) - } - - // Unhold Proxmox packages - if osInfo.IsProxmox { - UnholdProxmoxPackages(osInfo, []string{"proxmox", "pve"}) - } - } - - logging.LogInstall(packagesList) - logging.LogSuccess("Linux packages installed successfully!") - return nil -} - -// InstallPythonPackages installs Python packages with potential UV support -func InstallPythonPackages(cfg *config.Config, osInfo *osdetect.OSInfo) error { - if cfg.DryRun { - if osInfo.OsType == "alpine" { - logging.LogInfo("[DRY-RUN] Install Alpine Python packages: %s", strings.Join(cfg.AlpinePythonPackages, ", ")) - } else { - pyList := cfg.PythonPackages - if os.Getenv("WSL") == "" { - pyList = append(pyList, cfg.NonWslPythonPackages...) - } - logging.LogInfo("[DRY-RUN] Install Python packages: %s", strings.Join(pyList, ", ")) - - if cfg.UseUvPackageManager { - logging.LogInfo("[DRY-RUN] Use UV package manager for Python package installation") - if len(cfg.PythonPipPackages) > 0 { - logging.LogInfo("[DRY-RUN] Install Python pip packages with UV: %s", strings.Join(cfg.PythonPipPackages, ", ")) - } - } else { - logging.LogInfo("[DRY-RUN] Use standard pip for Python package installation") - if len(cfg.PythonPipPackages) > 0 { - logging.LogInfo("[DRY-RUN] Install Python pip packages with pip: %s", strings.Join(cfg.PythonPipPackages, ", ")) - } - } - } - return nil - } - - if osInfo.OsType == "alpine" { - // Use Alpine's package manager for Python packages - if len(cfg.AlpinePythonPackages) > 0 { - logging.LogInfo("Installing Alpine Python packages...") - return InstallPackages(cfg.AlpinePythonPackages, osInfo, cfg) - } else { - logging.LogInfo("No Alpine Python packages defined in config") - } - } else { - // For Debian/Ubuntu systems - pyList := cfg.PythonPackages - if os.Getenv("WSL") == "" { - pyList = append(pyList, cfg.NonWslPythonPackages...) - } - - pythonPackagesList := strings.Join(pyList, ", ") - - // Install system packages first - cmd := exec.Command("apt-get", "update") - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to update package lists for Python installation: %w", err) - } - - cmd = exec.Command("apt-get", append([]string{"install", "--yes"}, pyList...)...) - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to install Python system packages [%s]: %w", pythonPackagesList, err) - } - - // If UV package manager is enabled, install and use it for Python packages - if cfg.UseUvPackageManager { - logging.LogInfo("UV package manager enabled for Python - installing if needed") - - // Check if UV is installed - _, err := exec.LookPath("uv") - if err != nil { - logging.LogInfo("Installing UV Python package manager...") - - // Check if pip3 is installed - _, err := exec.LookPath("pip3") - if err != nil { - logging.LogInfo("Installing pip3 first...") - pip3Cmd := exec.Command("apt-get", "install", "-y", "python3-pip") - if err := pip3Cmd.Run(); err != nil { - return fmt.Errorf("failed to install pip3 for UV installation: %w", err) - } - } - - // Install UV - uvCmd := exec.Command("pip3", "install", "uv") - if err := uvCmd.Run(); err != nil { - logging.LogError("Failed to install UV package manager, will use pip instead") - cfg.UseUvPackageManager = false - } else { - logging.LogInstall("UV package manager") - } - } else { - logging.LogInfo("UV package manager already installed") - } - - // Check if there are Python pip packages to install - if len(cfg.PythonPipPackages) > 0 { - pipPackagesList := strings.Join(cfg.PythonPipPackages, ", ") - logging.LogInfo("Installing Python pip packages with UV: %s", pipPackagesList) - uvPipCmd := exec.Command("uv", append([]string{"pip", "install"}, cfg.PythonPipPackages...)...) - if err := uvPipCmd.Run(); err != nil { - return fmt.Errorf("failed to install Python pip packages with UV [%s]: %w", pipPackagesList, err) - } - logging.LogInstall("Python pip packages via UV: %s", pipPackagesList) - } - } else { - // Use standard pip if UV is not enabled - if len(cfg.PythonPipPackages) > 0 { - pipPackagesList := strings.Join(cfg.PythonPipPackages, ", ") - logging.LogInfo("Installing Python pip packages with pip: %s", pipPackagesList) - - // Check if pip3 is installed - _, err := exec.LookPath("pip3") - if err != nil { - logging.LogInfo("Installing pip3 first...") - pip3Cmd := exec.Command("apt-get", "install", "-y", "python3-pip") - if err := pip3Cmd.Run(); err != nil { - return fmt.Errorf("failed to install pip3 for package installation: %w", err) - } - } - - // Install pip packages - pipCmd := exec.Command("pip3", append([]string{"install"}, cfg.PythonPipPackages...)...) - if err := pipCmd.Run(); err != nil { - return fmt.Errorf("failed to install Python pip packages with pip [%s]: %w", pipPackagesList, err) - } - logging.LogInstall("Python pip packages: %s", pipPackagesList) - } - } - } - - logging.LogSuccess("Python packages installation completed") - return nil -} \ No newline at end of file diff --git a/pkg/port/secondary/backup_repository.go b/pkg/port/secondary/backup_repository.go new file mode 100644 index 0000000..1c42be8 --- /dev/null +++ b/pkg/port/secondary/backup_repository.go @@ -0,0 +1,31 @@ +package secondary + +import ( + "time" + + "github.com/abbott/hardn/pkg/domain/model" +) + +// BackupRepository defines the interface for backup operations +type BackupRepository interface { + // BackupFile backs up a file with a timestamp + BackupFile(filePath string) error + + // ListBackups returns a list of all backups for a specific file + ListBackups(filePath string) ([]model.BackupFile, error) + + // RestoreBackup restores a file from backup + RestoreBackup(backupPath, originalPath string) error + + // CleanupOldBackups removes backups older than specified date + CleanupOldBackups(before time.Time) error + + // VerifyBackupDirectory ensures the backup directory exists and is writable + VerifyBackupDirectory() error + + // GetBackupConfig retrieves the current backup configuration + GetBackupConfig() (*model.BackupConfig, error) + + // SetBackupConfig updates the backup configuration + SetBackupConfig(config model.BackupConfig) error +} diff --git a/pkg/port/secondary/dns_repository.go b/pkg/port/secondary/dns_repository.go new file mode 100644 index 0000000..7ab2520 --- /dev/null +++ b/pkg/port/secondary/dns_repository.go @@ -0,0 +1,13 @@ +// pkg/port/secondary/dns_repository.go +package secondary + +import "github.com/abbott/hardn/pkg/domain/model" + +// DNSRepository defines the interface for DNS configuration operations +type DNSRepository interface { + // SaveDNSConfig persists the DNS configuration + SaveDNSConfig(config model.DNSConfig) error + + // GetDNSConfig retrieves the current DNS configuration + GetDNSConfig() (*model.DNSConfig, error) +} diff --git a/pkg/port/secondary/environment_repository.go b/pkg/port/secondary/environment_repository.go new file mode 100644 index 0000000..6882b52 --- /dev/null +++ b/pkg/port/secondary/environment_repository.go @@ -0,0 +1,16 @@ +// pkg/port/secondary/environment_repository.go +package secondary + +import "github.com/abbott/hardn/pkg/domain/model" + +// EnvironmentRepository defines the interface for environment configuration operations +type EnvironmentRepository interface { + // SetupSudoPreservation configures sudo to preserve the HARDN_CONFIG environment variable + SetupSudoPreservation(username string) error + + // IsSudoPreservationEnabled checks if the HARDN_CONFIG environment variable is preserved in sudo + IsSudoPreservationEnabled(username string) (bool, error) + + // GetEnvironmentVariables retrieves the current environment configuration + GetEnvironmentConfig() (*model.EnvironmentConfig, error) +} diff --git a/pkg/port/secondary/firewall_repository.go b/pkg/port/secondary/firewall_repository.go new file mode 100644 index 0000000..71a11e2 --- /dev/null +++ b/pkg/port/secondary/firewall_repository.go @@ -0,0 +1,32 @@ +// pkg/port/secondary/firewall_repository.go +package secondary + +import "github.com/abbott/hardn/pkg/domain/model" + +// FirewallRepository defines the interface for firewall configuration operations +type FirewallRepository interface { + + // GetFirewallStatus retrieves the current status of the firewall + GetFirewallStatus() (isInstalled bool, isEnabled bool, isConfigured bool, rules []string, err error) + + // SaveFirewallConfig persists the firewall configuration + SaveFirewallConfig(config model.FirewallConfig) error + + // GetFirewallConfig retrieves the current firewall configuration + GetFirewallConfig() (*model.FirewallConfig, error) + + // AddRule adds a firewall rule + AddRule(rule model.FirewallRule) error + + // RemoveRule removes a firewall rule + RemoveRule(rule model.FirewallRule) error + + // AddProfile adds a firewall application profile + AddProfile(profile model.FirewallProfile) error + + // EnableFirewall enables the firewall + EnableFirewall() error + + // DisableFirewall disables the firewall + DisableFirewall() error +} diff --git a/pkg/port/secondary/logs_repository.go b/pkg/port/secondary/logs_repository.go new file mode 100644 index 0000000..ff68af3 --- /dev/null +++ b/pkg/port/secondary/logs_repository.go @@ -0,0 +1,16 @@ +// pkg/port/secondary/logs_repository.go +package secondary + +import "github.com/abbott/hardn/pkg/domain/model" + +// LogsRepository defines the interface for log operations +type LogsRepository interface { + // GetLogs retrieves logs from the configured log file + GetLogs() ([]model.LogEntry, error) + + // GetLogConfig retrieves the current log configuration + GetLogConfig() (*model.LogsConfig, error) + + // PrintLogs prints the logs to the console + PrintLogs() error +} diff --git a/pkg/port/secondary/package_repository.go b/pkg/port/secondary/package_repository.go new file mode 100644 index 0000000..7d184f9 --- /dev/null +++ b/pkg/port/secondary/package_repository.go @@ -0,0 +1,22 @@ +// pkg/port/secondary/package_repository.go +package secondary + +import "github.com/abbott/hardn/pkg/domain/model" + +// PackageRepository defines the interface for package management operations +type PackageRepository interface { + // InstallPackages installs packages based on the request + InstallPackages(request model.PackageInstallRequest) error + + // UpdatePackageSources updates package repository sources + UpdatePackageSources(sources model.PackageSources) error + + // UpdateProxmoxSources updates Proxmox-specific package sources + UpdateProxmoxSources(sources model.PackageSources) error + + // IsPackageInstalled checks if a package is installed + IsPackageInstalled(packageName string) (bool, error) + + // GetPackageSources retrieves the current package sources configuration + GetPackageSources() (*model.PackageSources, error) +} diff --git a/pkg/port/secondary/ssh_repository.go b/pkg/port/secondary/ssh_repository.go new file mode 100644 index 0000000..acc15fb --- /dev/null +++ b/pkg/port/secondary/ssh_repository.go @@ -0,0 +1,19 @@ +// pkg/port/secondary/ssh_repository.go +package secondary + +import "github.com/abbott/hardn/pkg/domain/model" + +// SSHRepository defines the interface for SSH configuration operations +type SSHRepository interface { + // SaveSSHConfig persists the SSH configuration + SaveSSHConfig(config model.SSHConfig) error + + // GetSSHConfig retrieves the current SSH configuration + GetSSHConfig() (*model.SSHConfig, error) + + // DisableRootAccess disables SSH access for the root user + DisableRootAccess() error + + // AddAuthorizedKey adds an SSH public key to a user's authorized_keys + AddAuthorizedKey(username string, publicKey string) error +} diff --git a/pkg/port/secondary/user_repository.go b/pkg/port/secondary/user_repository.go new file mode 100644 index 0000000..1e4d5ca --- /dev/null +++ b/pkg/port/secondary/user_repository.go @@ -0,0 +1,12 @@ +package secondary + +import "github.com/abbott/hardn/pkg/domain/model" + +// UserRepository defines the interface for user persistence operations +type UserRepository interface { + CreateUser(user model.User) error + GetUser(username string) (*model.User, error) + AddSSHKey(username, publicKey string) error + ConfigureSudo(username string, noPassword bool) error + UserExists(username string) (bool, error) +} diff --git a/pkg/security/security.go b/pkg/security/security.go index 5c2f8f6..3022630 100644 --- a/pkg/security/security.go +++ b/pkg/security/security.go @@ -139,4 +139,4 @@ func SetupLynis(cfg *config.Config, osInfo *osdetect.OSInfo) error { logging.LogSuccess("Lynis installed and system audit completed") return nil -} \ No newline at end of file +} diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go deleted file mode 100644 index 8a4a85f..0000000 --- a/pkg/ssh/ssh.go +++ /dev/null @@ -1,300 +0,0 @@ -package ssh - -import ( - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - - "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/logging" - "github.com/abbott/hardn/pkg/osdetect" - "github.com/abbott/hardn/pkg/utils" -) - -// WriteSSHConfig writes the SSH server configuration based on OS type -func WriteSSHConfig(cfg *config.Config, osInfo *osdetect.OSInfo) error { - if cfg.DryRun { - logging.LogInfo("[DRY-RUN] Configure SSH server with the following settings:") - logging.LogInfo("[DRY-RUN] - Protocol: 2") - logging.LogInfo("[DRY-RUN] - Port: %d", cfg.SshPort) - logging.LogInfo("[DRY-RUN] - Listen Address: %s", cfg.SshListenAddress) - logging.LogInfo("[DRY-RUN] - Authentication Method: publickey") - logging.LogInfo("[DRY-RUN] - PermitRootLogin: %t", cfg.PermitRootLogin) - logging.LogInfo("[DRY-RUN] - Allowed Users: %s", strings.Join(cfg.SshAllowedUsers, ", ")) - logging.LogInfo("[DRY-RUN] - Password Authentication: no") - logging.LogInfo("[DRY-RUN] - AuthorizedKeysFile: .ssh/authorized_keys %s/authorized_keys", cfg.SshKeyPath) - - if osInfo.OsType == "alpine" { - logging.LogInfo("[DRY-RUN] - Write config to /etc/ssh/sshd_config") - logging.LogInfo("[DRY-RUN] - Restart sshd service using OpenRC") - } else { - logging.LogInfo("[DRY-RUN] - Configure systemd socket at /etc/systemd/system/ssh.socket.d/listen.conf") - logging.LogInfo("[DRY-RUN] - Write config to %s", cfg.SshConfigFile) - logging.LogInfo("[DRY-RUN] - Restart ssh service using systemd") - } - return nil - } - - logging.LogInfo("Configuring SSH...") - - // Format SSH listen address and port - sshListenAddress := cfg.SshListenAddress - if !strings.Contains(sshListenAddress, ":") { - sshListenAddress = fmt.Sprintf("%s:%d", sshListenAddress, cfg.SshPort) - } - - if osInfo.OsType == "alpine" { - // Alpine uses /etc/ssh/sshd_config directly - // Backup original config - utils.BackupFile("/etc/ssh/sshd_config", cfg) - - // Determine root login setting - permitRootLogin := "no" - if cfg.PermitRootLogin { - permitRootLogin = "yes" - } - - // Create new config - configContent := fmt.Sprintf(`Protocol 2 -StrictModes yes - -Port %d -ListenAddress %s - -AuthenticationMethods publickey -PubkeyAuthentication yes - -HostbasedAcceptedKeyTypes ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,ssh-ed25519 - -PermitRootLogin %s -AllowUsers %s - -PasswordAuthentication no -PermitEmptyPasswords no - -AuthorizedKeysFile .ssh/authorized_keys %s/authorized_keys -`, cfg.SshPort, sshListenAddress, permitRootLogin, strings.Join(cfg.SshAllowedUsers, " "), cfg.SshKeyPath) - - // Write the file - if err := os.WriteFile("/etc/ssh/sshd_config", []byte(configContent), 0644); err != nil { - return fmt.Errorf("failed to write Alpine SSH config for port %d: %w", cfg.SshPort, err) - } - - // Restart SSH using OpenRC - cmd := exec.Command("rc-service", "sshd", "restart") - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to restart Alpine SSH service for port %d: %w", cfg.SshPort, err) - } - - logging.LogSuccess("SSH configured for Alpine Linux") - } else { - // Debian/Ubuntu with systemd - utils.BackupFile("/etc/systemd/system/ssh.socket.d/listen.conf", cfg) - - // Create socket config directory - if err := os.MkdirAll("/etc/systemd/system/ssh.socket.d", 0755); err != nil { - return fmt.Errorf("failed to create SSH socket directory: %w", err) - } - - // Write socket config - socketConfig := fmt.Sprintf(`[Socket] -ListenStream= -ListenStream=%d -`, cfg.SshPort) - - if err := os.WriteFile("/etc/systemd/system/ssh.socket.d/listen.conf", []byte(socketConfig), 0644); err != nil { - logging.LogError("Failed to set ssh port listener for port %d.", cfg.SshPort) - return fmt.Errorf("failed to write SSH socket config for port %d: %w", cfg.SshPort, err) - } - - // Ensure config directory exists - if err := os.MkdirAll(filepath.Dir(cfg.SshConfigFile), 0755); err != nil { - return fmt.Errorf("failed to create SSH config directory %s: %w", filepath.Dir(cfg.SshConfigFile), err) - } - - // Determine root login setting - permitRootLogin := "no" - if cfg.PermitRootLogin { - permitRootLogin = "yes" - } - - // Set SSH config - utils.BackupFile(cfg.SshConfigFile, cfg) - - configContent := fmt.Sprintf(`### Reference -### https://cryptsus.com/blog/how-to-secure-your-ssh-server-with-public-key-elliptic-curve-ed25519-crypto.html - -Protocol 2 -StrictModes yes - -ListenAddress %s - -AuthenticationMethods publickey -PubkeyAuthentication yes - -HostbasedAcceptedKeyTypes ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,ssh-ed25519 -#PubkeyAcceptedKeyTypes sk-ecdsa-sha2-nistp256@openssh.com,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,sk-ssh-ed25519@openssh.com - -PermitRootLogin %s -AllowUsers %s - -# To disable tunneled clear text passwords, change to no here! -PasswordAuthentication no -PermitEmptyPasswords no - -#AuthorizedKeysFile /etc/ssh/authorized_keys -# mkdir custom SSH path (e.g., /home/$USERNAME/$SSH_KEY_PATH) -AuthorizedKeysFile .ssh/authorized_keys %s/authorized_keys - -### PVE ONLY: DO NOT DISABLE -#X11Forwarding yes -#AuthorizedKeysFile /etc/pve/priv/authorized_keys -`, sshListenAddress, permitRootLogin, strings.Join(cfg.SshAllowedUsers, " "), cfg.SshKeyPath) - - if err := os.WriteFile(cfg.SshConfigFile, []byte(configContent), 0644); err != nil { - logging.LogError("Failed to create %s", cfg.SshConfigFile) - return fmt.Errorf("failed to write SSH config to %s: %w", cfg.SshConfigFile, err) - } - - // Restart SSH - cmd := exec.Command("systemctl", "restart", "ssh") - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to restart SSH service on port %d: %w", cfg.SshPort, err) - } - - logging.LogSuccess("SSH configured for Debian/Ubuntu") - } - - return nil -} - -// DisableRootSSHAccess disables root SSH access -func DisableRootSSHAccess(cfg *config.Config, osInfo *osdetect.OSInfo) error { - if cfg.DryRun { - logging.LogInfo("[DRY-RUN] Disable root SSH access with the following changes:") - if osInfo.OsType == "alpine" { - logging.LogInfo("[DRY-RUN] - Modify /etc/ssh/sshd_config to set 'PermitRootLogin no'") - logging.LogInfo("[DRY-RUN] - Remove 'root' from AllowUsers directive") - logging.LogInfo("[DRY-RUN] - Restart sshd service using OpenRC") - } else { - logging.LogInfo("[DRY-RUN] - Modify %s to set 'PermitRootLogin no'", cfg.SshConfigFile) - logging.LogInfo("[DRY-RUN] - Remove 'root' from AllowUsers directive") - logging.LogInfo("[DRY-RUN] - Restart ssh service using systemd") - } - return nil - } - - if osInfo.OsType == "alpine" { - // For Alpine, modify the main sshd_config file - configFile := "/etc/ssh/sshd_config" - if _, err := os.Stat(configFile); os.IsNotExist(err) { - return fmt.Errorf("/etc/ssh/sshd_config not found: %w", err) - } - - utils.BackupFile(configFile, cfg) - - // Read the file - content, err := os.ReadFile(configFile) - if err != nil { - return fmt.Errorf("failed to read Alpine SSH config: %w", err) - } - - // Modify the content - lines := strings.Split(string(content), "\n") - for i, line := range lines { - // Change PermitRootLogin - if strings.HasPrefix(line, "PermitRootLogin yes") { - lines[i] = "PermitRootLogin no" - } - - // Remove 'root' from AllowUsers - if strings.HasPrefix(line, "AllowUsers") { - // Get the users - fields := strings.Fields(line) - if len(fields) > 1 { - // Remove 'root' - var newUsers []string - for _, user := range fields[1:] { - if user != "root" { - newUsers = append(newUsers, user) - } - } - - // Put back together - lines[i] = "AllowUsers " + strings.Join(newUsers, " ") - } - } - } - - // Write back the file - if err := os.WriteFile(configFile, []byte(strings.Join(lines, "\n")), 0644); err != nil { - return fmt.Errorf("failed to write updated Alpine SSH config: %w", err) - } - - // Restart SSH - cmd := exec.Command("rc-service", "sshd", "restart") - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to restart Alpine SSH service after disabling root login: %w", err) - } - - logging.LogSuccess("Root SSH access disabled in Alpine Linux") - } else { - // For Debian/Ubuntu - configFile := cfg.SshConfigFile - if _, err := os.Stat(configFile); os.IsNotExist(err) { - return fmt.Errorf("SSH config file %s not found: %w", configFile, err) - } - - utils.BackupFile(configFile, cfg) - - // Read the file - content, err := os.ReadFile(configFile) - if err != nil { - return fmt.Errorf("failed to read SSH config %s: %w", configFile, err) - } - - // Modify the content - lines := strings.Split(string(content), "\n") - for i, line := range lines { - // Change PermitRootLogin - if strings.HasPrefix(line, "PermitRootLogin yes") { - lines[i] = "PermitRootLogin no" - } - - // Remove 'root' from AllowUsers - if strings.HasPrefix(line, "AllowUsers") { - // Get the users - fields := strings.Fields(line) - if len(fields) > 1 { - // Remove 'root' - var newUsers []string - for _, user := range fields[1:] { - if user != "root" { - newUsers = append(newUsers, user) - } - } - - // Put back together - lines[i] = "AllowUsers " + strings.Join(newUsers, " ") - } - } - } - - // Write back the file - if err := os.WriteFile(configFile, []byte(strings.Join(lines, "\n")), 0644); err != nil { - return fmt.Errorf("failed to write updated SSH config to %s: %w", configFile, err) - } - - // Restart SSH - cmd := exec.Command("systemctl", "restart", "ssh") - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to restart SSH service after disabling root login: %w", err) - } - - logging.LogSuccess("Root SSH access disabled in Debian/Ubuntu") - } - - return nil -} \ No newline at end of file diff --git a/pkg/status/status.go b/pkg/status/status.go index f062c34..c16988e 100644 --- a/pkg/status/status.go +++ b/pkg/status/status.go @@ -91,7 +91,7 @@ func DisplaySecurityStatus(cfg *config.Config, status *SecurityStatus, formatter if !status.SecureUsers { fmt.Println(formatter.FormatWarning("Users", "Root user only", "create non-root user")) } else { - fmt.Println(formatter.FormatSuccess("Users", "Non-root user found", "sudo privileges")) + fmt.Println(formatter.FormatSuccess("Users", "Non-root user found", "sudo enabled")) } // Display SSH port status @@ -186,8 +186,8 @@ func checkRootLoginEnabled(osInfo *osdetect.OSInfo) bool { } else { // For Debian/Ubuntu, check both main config and config.d sshConfigPath = "/etc/ssh/sshd_config" - if _, err := os.Stat("/etc/ssh/sshd_config.d/manage.conf"); err == nil { - sshConfigPath = "/etc/ssh/sshd_config.d/manage.conf" + if _, err := os.Stat("/etc/ssh/sshd_config.d/hardn.conf"); err == nil { + sshConfigPath = "/etc/ssh/sshd_config.d/hardn.conf" } } @@ -377,8 +377,8 @@ func checkPasswordAuth(osInfo *osdetect.OSInfo) bool { } else { // For Debian/Ubuntu, check both main config and config.d sshConfigPath = "/etc/ssh/sshd_config" - if _, err := os.Stat("/etc/ssh/sshd_config.d/manage.conf"); err == nil { - sshConfigPath = "/etc/ssh/sshd_config.d/manage.conf" + if _, err := os.Stat("/etc/ssh/sshd_config.d/hardn.conf"); err == nil { + sshConfigPath = "/etc/ssh/sshd_config.d/hardn.conf" } } @@ -410,8 +410,8 @@ func CheckRootLoginEnabled(osInfo *osdetect.OSInfo) bool { } else { // For Debian/Ubuntu, check both main config and config.d sshConfigPath = "/etc/ssh/sshd_config" - if _, err := os.Stat("/etc/ssh/sshd_config.d/manage.conf"); err == nil { - sshConfigPath = "/etc/ssh/sshd_config.d/manage.conf" + if _, err := os.Stat("/etc/ssh/sshd_config.d/hardn.conf"); err == nil { + sshConfigPath = "/etc/ssh/sshd_config.d/hardn.conf" } } @@ -434,4 +434,4 @@ func CheckRootLoginEnabled(osInfo *osdetect.OSInfo) bool { } return true // Default to vulnerable if not explicitly set -} \ No newline at end of file +} diff --git a/pkg/style/style.go b/pkg/style/style.go index 03d01d4..1b3e7f9 100644 --- a/pkg/style/style.go +++ b/pkg/style/style.go @@ -14,16 +14,17 @@ type MenuOption struct { Number int Title string Description string + Style string } // Menu provides a formatted menu display type Menu struct { - title string - options []MenuOption - exitOption *MenuOption - prompt string - maxNumLen int - titleWidth int + title string + options []MenuOption + exitOption *MenuOption + prompt string + maxNumLen int + titleWidth int } const ( @@ -83,6 +84,7 @@ const ( Blink = "\033[5m" Reverse = "\033[7m" Hidden = "\033[8m" + Strike = "\033[9m" // Cursor control CursorOn = "\033[?25h" @@ -140,6 +142,13 @@ func Dimmed(text string, color ...string) string { return Dim + text + Reset } +func Striked(text string, color ...string) string { + if len(color) > 0 { + return Strike + color[0] + text + Reset + } + return Strike + text + Reset +} + // Apply italic style with an optional color func Italicized(text string, color ...string) string { if len(color) > 0 { @@ -222,14 +231,30 @@ func CenterText(text string, width int) string { return strings.Repeat(" ", leftPadding) + text + strings.Repeat(" ", rightPadding) } +// PadRight adds spaces to the right of text to reach the specified width +// Uses StripAnsi to correctly calculate visible text length for styled text func PadRight(text string, width int) string { - if len(text) >= width { + // Get the visible length by removing ANSI escape sequences + visibleLen := len(StripAnsi(text)) + + if visibleLen >= width { return text } - return text + strings.Repeat(" ", width-len(text)) + // Calculate the correct amount of padding based on visible length + padding := width - visibleLen + + return text + strings.Repeat(" ", padding) } +// func PadRight(text string, width int) string { +// if len(text) >= width { +// return text +// } + +// return text + strings.Repeat(" ", width-len(text)) +// } + var ansiRegex = regexp.MustCompile(`\x1b\[[0-9;]*[a-zA-Z]`) // StripAnsi removes ANSI escape codes from a string to get its true display length @@ -333,7 +358,6 @@ func (sf *StatusFormatter) Initialize() { sf.initialized = true } -// FormatLine formats a status line with proper alignment // FormatLine formats a status line with proper alignment func (sf *StatusFormatter) FormatLine(symbol string, symbolColor string, label string, status string, statusColor string, description string, statusWeight string) string { @@ -344,16 +368,16 @@ func (sf *StatusFormatter) FormatLine(symbol string, symbolColor string, // Calculate padding needed for label (strip ANSI codes for accuracy) labelText := StripAnsi(label) - + // Fix: Ensure padding size is never negative paddingSize := sf.maxLabelLen - len(labelText) if paddingSize < 0 { paddingSize = 0 // Prevent negative repeat count } - + // Always add at least one space padding between label and status - padding := strings.Repeat(" ", paddingSize + 1) - + padding := strings.Repeat(" ", paddingSize+1) + symbol = Colored(symbolColor, symbol) if statusWeight == "bold" { @@ -419,21 +443,21 @@ func PrintDivider(char string, length int, style ...string) { // NewMenu creates a new menu with the given title and options func NewMenu(title string, options []MenuOption) *Menu { // Calculate maximum number length and title width - maxNumLen := 1 // At least 1 digit + maxNumLen := 1 // At least 1 digit titleWidth := 20 // Minimum width - + for _, opt := range options { numLen := len(fmt.Sprintf("%d", opt.Number)) if numLen > maxNumLen { maxNumLen = numLen } - + titleLen := len(StripAnsi(opt.Title)) if titleLen > titleWidth { titleWidth = titleLen } } - + return &Menu{ title: title, options: options, @@ -446,13 +470,13 @@ func NewMenu(title string, options []MenuOption) *Menu { // SetExitOption sets a custom exit option (default is 0: Exit) func (m *Menu) SetExitOption(option MenuOption) { m.exitOption = &option - + // Update maxNumLen if necessary numLen := len(fmt.Sprintf("%d", option.Number)) if numLen > m.maxNumLen { m.maxNumLen = numLen } - + // Update titleWidth if necessary titleLen := len(StripAnsi(option.Title)) if titleLen > m.titleWidth { @@ -470,10 +494,10 @@ func (m *Menu) GetValidRange() string { if len(m.options) == 0 { return "0" } - + min := m.options[0].Number max := m.options[0].Number - + for _, opt := range m.options { if opt.Number < min { min = opt.Number @@ -482,25 +506,25 @@ func (m *Menu) GetValidRange() string { max = opt.Number } } - + // Include exit option in the range exitNum := 0 if m.exitOption != nil { exitNum = m.exitOption.Number } - + if exitNum < min { min = exitNum } - + if exitNum > max { max = exitNum } - + if min == max { return fmt.Sprintf("%d", min) } - + return fmt.Sprintf("%d-%d", min, max) } @@ -508,40 +532,56 @@ func (m *Menu) GetValidRange() string { func (m *Menu) FormatOption(opt MenuOption) string { // Format number with consistent padding numStr := fmt.Sprintf("%d)", opt.Number) - + // Add extra space for single-digit numbers to align with double-digit numbers if opt.Number < 10 { numStr = " " + numStr } - + numPadded := Bolded(numStr) - + // Add spacing after the number numPadded += " " - + + titlePadded := "" // Format title with consistent padding - titlePadded := PadRight(opt.Title, m.titleWidth + 4) // +4 for extra spacing - + if opt.Style == "" { + // opt.Title = Colored(opt.Style, opt.Title) + titlePadded += PadRight(opt.Title, m.titleWidth+4) + } else if opt.Style == "strike" { + // opt.Title = Bolded(opt.Title) + strikeTitle := Striked(opt.Title) + dimmedStrikeTitle := Dimmed(strikeTitle) + titlePadded += PadRight(dimmedStrikeTitle, m.titleWidth+4) + } + // Add description desc := Dimmed(opt.Description) - + return numPadded + titlePadded + desc } // Render returns the formatted menu as a string func (m *Menu) Render() string { var sb strings.Builder - + + // desc := Dimmed(opt.Description) + + // if m.title != "" { + // sb.WriteString("\n") + // sb.WriteString(Header(m.title)) + // } + // Title header sb.WriteString("\n") sb.WriteString(SubHeader(m.title)) - + // Options for _, opt := range m.options { sb.WriteString("\n") sb.WriteString(m.FormatOption(opt)) } - + // Exit option sb.WriteString("\n\n") if m.exitOption != nil { @@ -555,16 +595,16 @@ func (m *Menu) Render() string { } sb.WriteString(m.FormatOption(exit)) } - + // Prompt sb.WriteString("\n\n") sb.WriteString(BulletItem) sb.WriteString(fmt.Sprintf("%s [%s or q]: ", m.prompt, m.GetValidRange())) - + return sb.String() } // Print displays the menu on stdout func (m *Menu) Print() { fmt.Print(m.Render()) -} \ No newline at end of file +} diff --git a/pkg/testing/compare.go b/pkg/testing/compare.go new file mode 100644 index 0000000..07e86da --- /dev/null +++ b/pkg/testing/compare.go @@ -0,0 +1,151 @@ +// pkg/testing/compare.go +package testing + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + + "github.com/abbott/hardn/pkg/config" + "github.com/abbott/hardn/pkg/infrastructure" + "github.com/abbott/hardn/pkg/interfaces" + "github.com/abbott/hardn/pkg/osdetect" +) + +// ComparisonResult holds the result of a comparison test +type ComparisonResult struct { + Operation string + OutputsMatch bool + OldOutput interface{} + NewOutput interface{} + Errors []string +} + +// ComparisonTester runs comparison tests between old and new implementations +type ComparisonTester struct { + cfg *config.Config + osInfo *osdetect.OSInfo + mockProvider *interfaces.Provider + serviceFactory *infrastructure.ServiceFactory + tempDir string +} + +// NewComparisonTester creates a new ComparisonTester +func NewComparisonTester(cfg *config.Config, osInfo *osdetect.OSInfo) (*ComparisonTester, error) { + // Create temp directory for test files + tempDir, err := os.MkdirTemp("", "hardn-compare-*") + if err != nil { + return nil, fmt.Errorf("failed to create temp directory: %w", err) + } + + // Create mock provider + mockProvider := interfaces.NewProvider() + + // Create service factory + serviceFactory := infrastructure.NewServiceFactory(mockProvider, osInfo) + + return &ComparisonTester{ + cfg: cfg, + osInfo: osInfo, + mockProvider: mockProvider, + serviceFactory: serviceFactory, + tempDir: tempDir, + }, nil +} + +// Cleanup removes temporary files +func (c *ComparisonTester) Cleanup() { + os.RemoveAll(c.tempDir) +} + +// CompareSSHRootDisable compares old and new implementations of disabling root SSH access +func (c *ComparisonTester) CompareSSHRootDisable() ComparisonResult { + result := ComparisonResult{ + Operation: "DisableRootSSH", + } + + // Set up test config paths + oldConfigPath := filepath.Join(c.tempDir, "old_ssh_config") + newConfigPath := filepath.Join(c.tempDir, "new_ssh_config") + + // Create test config content + testConfig := "PermitRootLogin yes\nAllowUsers root user1" + + // Write test configs + if err := os.WriteFile(oldConfigPath, []byte(testConfig), 0644); err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("Failed to write old config: %v", err)) + return result + } + if err := os.WriteFile(newConfigPath, []byte(testConfig), 0644); err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("Failed to write new config: %v", err)) + return result + } + + // Run old implementation + // TODO: Call old implementation via the proper interfaces + + // Run new implementation + sshManager := c.serviceFactory.CreateSSHManager() + if err := sshManager.DisableRootAccess(); err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("New implementation error: %v", err)) + } + + // Read results + oldResult, err := os.ReadFile(oldConfigPath) + if err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("Failed to read old result: %v", err)) + return result + } + newResult, err := os.ReadFile(newConfigPath) + if err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("Failed to read new result: %v", err)) + return result + } + + // Compare results + result.OldOutput = string(oldResult) + result.NewOutput = string(newResult) + result.OutputsMatch = reflect.DeepEqual(oldResult, newResult) + + return result +} + +// RunAllComparisons runs all comparison tests +func (c *ComparisonTester) RunAllComparisons() []ComparisonResult { + var results []ComparisonResult + + // Run comparisons for different operations + results = append(results, c.CompareSSHRootDisable()) + // Add more comparisons as needed + + return results +} + +// PrintResults prints the comparison results +func (c *ComparisonTester) PrintResults(results []ComparisonResult) { + fmt.Println("Comparison Test Results") + fmt.Println("======================") + + for _, result := range results { + fmt.Printf("Operation: %s\n", result.Operation) + if result.OutputsMatch { + fmt.Println(" ✅ Outputs match") + } else { + fmt.Println(" ❌ Outputs differ:") + fmt.Println(" Old output:") + fmt.Println(result.OldOutput) + fmt.Println(" New output:") + fmt.Println(result.NewOutput) + } + + if len(result.Errors) > 0 { + fmt.Println(" Errors:") + for _, err := range result.Errors { + fmt.Printf(" - %s\n", err) + } + } + + fmt.Println() + } +} diff --git a/pkg/updates/updates.go b/pkg/updates/updates.go index cca3a6c..dcd4a09 100644 --- a/pkg/updates/updates.go +++ b/pkg/updates/updates.go @@ -126,4 +126,4 @@ func UpdateSystem(osInfo *osdetect.OSInfo) error { logging.LogSuccess("System packages updated successfully") return nil -} \ No newline at end of file +} diff --git a/pkg/user/user.go b/pkg/user/user.go deleted file mode 100644 index e98c419..0000000 --- a/pkg/user/user.go +++ /dev/null @@ -1,265 +0,0 @@ -package user - -import ( - "fmt" - "os" - "os/exec" - "os/user" - "path/filepath" - "strings" - - "github.com/abbott/hardn/pkg/config" - "github.com/abbott/hardn/pkg/logging" - "github.com/abbott/hardn/pkg/osdetect" - "github.com/abbott/hardn/pkg/utils" -) - -// CreateUser creates a new system user with SSH keys and sudo access -func CreateUser(username string, cfg *config.Config, osInfo *osdetect.OSInfo) error { - // Check if user already exists - _, err := user.Lookup(username) - if err == nil { - logging.LogInfo("User %s already exists. Skipping user creation.", username) - return nil - } - - logging.LogInfo("Creating user %s...", username) - - if cfg.DryRun { - logging.LogInfo("[DRY-RUN] Create user: %s", username) - logging.LogInfo("[DRY-RUN] Add user to sudo/wheel group") - logging.LogInfo("[DRY-RUN] Configure sudo with NOPASSWD: %t", cfg.SudoNoPassword) - logging.LogInfo("[DRY-RUN] Set up SSH keys in: %s", cfg.SshKeyPath) - return nil - } - - // Check if sudo is installed, install it if necessary - _, err = exec.LookPath("sudo") - if err != nil { - if osInfo.OsType == "alpine" { - cmd := exec.Command("apk", "add", "sudo") - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to install sudo on Alpine for user %s: %w", username, err) - } - logging.LogInstall("sudo") - } else { - cmd := exec.Command("apt-get", "update") - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to update package indexes for sudo installation: %w", err) - } - cmd = exec.Command("apt-get", "install", "-y", "sudo") - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to install sudo on Debian/Ubuntu: %w", err) - } - logging.LogInstall("sudo") - } - } - - if osInfo.OsType == "alpine" { - // Alpine user creation - cmd := exec.Command("adduser", "-D", "-g", "", username) - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to create user %s on Alpine: %w", username, err) - } - - // wheel group (sudo group for Alpine) - addGroupCmd := exec.Command("addgroup", username, "wheel") - if err := addGroupCmd.Run(); err != nil { - logging.LogError("Failed to add %s to wheel group: %v", username, err) - } else { - logging.LogInfo("Added %s to wheel group", username) - } - - // Configure sudo - sudoersDir := "/etc/sudoers.d" - if err := os.MkdirAll(sudoersDir, 0755); err != nil { - return fmt.Errorf("failed to create sudoers.d directory for user %s: %w", username, err) - } - - sudoersFile := filepath.Join(sudoersDir, username) - utils.BackupFile(sudoersFile, cfg) - - var sudoersContent string - if cfg.SudoNoPassword { - sudoersContent = fmt.Sprintf("%s ALL=(ALL) NOPASSWD: ALL\n", username) - } else { - sudoersContent = fmt.Sprintf("%s ALL=(ALL) ALL\n", username) - } - - if err := os.WriteFile(sudoersFile, []byte(sudoersContent), 0440); err != nil { - return fmt.Errorf("failed to write sudoers file for user %s: %w", username, err) - } - - // Extract the actual directory name from the SSH_KEY_PATH pattern - sshDir := strings.ReplaceAll(cfg.SshKeyPath, "%u", username) - userHomeDir := fmt.Sprintf("/home/%s", username) - sshDirPath := filepath.Join(userHomeDir, sshDir) - - // Create SSH key directory - if err := os.MkdirAll(sshDirPath, 0700); err != nil { - return fmt.Errorf("failed to create SSH directory %s for user %s: %w", sshDirPath, username, err) - } - - // SSH keys - authorizedKeysPath := filepath.Join(sshDirPath, "authorized_keys") - authorizedKeysContent := strings.Join(cfg.SshKeys, "\n") + "\n" - if err := os.WriteFile(authorizedKeysPath, []byte(authorizedKeysContent), 0600); err != nil { - return fmt.Errorf("failed to write authorized_keys for user %s: %w", username, err) - } - - // Set permissions - chownCmd := exec.Command("chown", "-R", fmt.Sprintf("%s:%s", username, username), sshDirPath) - if err := chownCmd.Run(); err != nil { - logging.LogError("Failed to set ownership for SSH directory for user %s: %v", username, err) - } - - // .hushlogin - hushLoginPath := filepath.Join(userHomeDir, ".hushlogin") - hushLoginFile, err := os.Create(hushLoginPath) - if err != nil { - logging.LogError("Failed to create .hushlogin file for user %s: %v", username, err) - } else { - hushLoginFile.Close() - chownHushCmd := exec.Command("chown", fmt.Sprintf("%s:%s", username, username), hushLoginPath) - if err := chownHushCmd.Run(); err != nil { - logging.LogError("Failed to set ownership for .hushlogin for user %s: %v", username, err) - } - } - } else { - // Debian/Ubuntu user creation - cmd := exec.Command("adduser", "--disabled-password", "--gecos", "", username) - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to create user %s on Debian/Ubuntu: %w", username, err) - } - - // sudo group - addGroupCmd := exec.Command("usermod", "-aG", "sudo", username) - if err := addGroupCmd.Run(); err != nil { - logging.LogError("Failed to add %s to sudo group: %v", username, err) - } else { - logging.LogInfo("Added %s to sudo group", username) - } - - // Configure sudo - sudoersDir := "/etc/sudoers.d" - if err := os.MkdirAll(sudoersDir, 0755); err != nil { - return fmt.Errorf("failed to create sudoers.d directory for user %s: %w", username, err) - } - - sudoersFile := filepath.Join(sudoersDir, username) - utils.BackupFile(sudoersFile, cfg) - - var sudoersContent string - if cfg.SudoNoPassword { - sudoersContent = fmt.Sprintf("%s ALL=(ALL) NOPASSWD: ALL\n", username) - } else { - sudoersContent = fmt.Sprintf("%s ALL=(ALL) ALL\n", username) - } - - if err := os.WriteFile(sudoersFile, []byte(sudoersContent), 0440); err != nil { - return fmt.Errorf("failed to write sudoers file for user %s: %w", username, err) - } - - // Extract the actual directory name from the SSH_KEY_PATH pattern - sshDir := strings.ReplaceAll(cfg.SshKeyPath, "%u", username) - - // Run commands as the new user to set up SSH - suCmd := exec.Command("su", "-", username, "-c", fmt.Sprintf("mkdir -p ~/%s && chmod 700 ~/%s", sshDir, sshDir)) - if err := suCmd.Run(); err != nil { - logging.LogError("Failed to create SSH directory for user %s: %v", username, err) - } - - // SSH keys - for _, key := range cfg.SshKeys { - suKeyCmd := exec.Command("su", "-", username, "-c", fmt.Sprintf("echo '%s' >> ~/%s/authorized_keys", key, sshDir)) - if err := suKeyCmd.Run(); err != nil { - logging.LogError("Failed to add SSH key for user %s: %v", username, err) - } - } - - // Set permissions for authorized_keys - suPermCmd := exec.Command("su", "-", username, "-c", fmt.Sprintf("chmod 600 ~/%s/authorized_keys", sshDir)) - if err := suPermCmd.Run(); err != nil { - logging.LogError("Failed to set permissions for authorized_keys for user %s: %v", username, err) - } - - // .hushlogin - suHushCmd := exec.Command("su", "-", username, "-c", "touch ~/.hushlogin") - if err := suHushCmd.Run(); err != nil { - logging.LogError("Failed to create .hushlogin file for user %s: %v", username, err) - } - } - - logging.LogSuccess("User %s created successfully", username) - return nil -} - -// DeleteUser deletes a user and their home directory -func DeleteUser(username string, osInfo *osdetect.OSInfo) error { - // Check if user exists - _, err := user.Lookup(username) - if err != nil { - return fmt.Errorf("user %s does not exist: %w", username, err) - } - - logging.LogInfo("Deleting user %s...", username) - - var cmd *exec.Cmd - if osInfo.OsType == "alpine" { - cmd = exec.Command("deluser", "--remove-home", username) - } else { - cmd = exec.Command("deluser", "--remove-home", username) - } - - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to delete user %s: %w", username, err) - } - - // Remove sudo configuration - sudoersFile := filepath.Join("/etc/sudoers.d", username) - if _, err := os.Stat(sudoersFile); err == nil { - if err := os.Remove(sudoersFile); err != nil { - logging.LogError("Failed to remove sudoers file for %s: %v", username, err) - } - } - - logging.LogSuccess("User %s deleted successfully", username) - return nil -} - -// ListUsers lists all non-system users -func ListUsers() ([]string, error) { - var users []string - - // Get all users from /etc/passwd - passwdFile, err := os.ReadFile("/etc/passwd") - if err != nil { - return nil, fmt.Errorf("failed to read /etc/passwd: %w", err) - } - - // Parse passwd file - lines := strings.Split(string(passwdFile), "\n") - for _, line := range lines { - if line == "" { - continue - } - - fields := strings.Split(line, ":") - if len(fields) < 7 { - continue - } - - username := fields[0] - uid := fields[2] - shell := fields[6] - - // Skip system users (UID < 1000) and users with nologin shell - uidInt := 0 - fmt.Sscanf(uid, "%d", &uidInt) - if uidInt >= 1000 && !strings.Contains(shell, "nologin") && !strings.Contains(shell, "false") { - users = append(users, username) - } - } - - return users, nil -} \ No newline at end of file diff --git a/pkg/utils/sudo.go b/pkg/utils/sudo.go deleted file mode 100644 index cc71725..0000000 --- a/pkg/utils/sudo.go +++ /dev/null @@ -1,94 +0,0 @@ -package utils - -import ( - "fmt" - "os" - "os/exec" - "os/user" - "path/filepath" - "strings" - - "github.com/abbott/hardn/pkg/logging" -) - -// SetupSudoEnvPreservation configures sudoers to preserve the HARDN_CONFIG environment variable -func SetupSudoEnvPreservation() error { - // Check if running as root - if os.Geteuid() != 0 { - return fmt.Errorf("this command must be run with sudo privileges") - } - - // Get current username (the real user, not root) - sudoUser := os.Getenv("SUDO_USER") - if sudoUser == "" { - // Fallback if SUDO_USER is not set - currentUser, err := user.Current() - if err != nil { - return fmt.Errorf("failed to determine current user: %w", err) - } - - // If we're still root and can't determine the real user, error out - if currentUser.Username == "root" { - return fmt.Errorf("cannot determine the real username; please run with sudo from a regular user account") - } - - sudoUser = currentUser.Username - } - - logging.LogInfo("Setting up sudo environment preservation for user: %s", sudoUser) - - // Ensure sudoers.d directory exists - sudoersDir := "/etc/sudoers.d" - if _, err := os.Stat(sudoersDir); os.IsNotExist(err) { - return fmt.Errorf("sudoers.d directory does not exist; your system may not support sudo drop-in configurations: %w", err) - } - - // Create/modify sudoers file for the user - sudoersFile := filepath.Join(sudoersDir, sudoUser) - - // Check if file already exists - var content string - if _, err := os.Stat(sudoersFile); err == nil { - // Read existing content - data, err := os.ReadFile(sudoersFile) - if err != nil { - return fmt.Errorf("failed to read existing sudoers file %s: %w", sudoersFile, err) - } - content = string(data) - - // Check if HARDN_CONFIG is already in the file - if strings.Contains(content, "env_keep += \"HARDN_CONFIG\"") { - logging.LogInfo("HARDN_CONFIG is already preserved in your sudoers configuration") - return nil - } - - // Append to existing content - content = strings.TrimSpace(content) + "\n" - } - - // env_keep directive - content += fmt.Sprintf("Defaults:%s env_keep += \"HARDN_CONFIG\"\n", sudoUser) - - // Create a temporary file for validation - tempFile := filepath.Join(os.TempDir(), "hardn_sudoers_temp") - if err := os.WriteFile(tempFile, []byte(content), 0440); err != nil { - return fmt.Errorf("failed to create temporary sudoers file at %s: %w", tempFile, err) - } - defer os.Remove(tempFile) - - // Validate the sudoers file - cmd := exec.Command("visudo", "-c", "-f", tempFile) - output, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("invalid sudoers configuration: %w\nOutput: %s", err, string(output)) - } - - // Write the validated content to the actual sudoers file - if err := os.WriteFile(sudoersFile, []byte(content), 0440); err != nil { - return fmt.Errorf("failed to write sudoers file %s: %w", sudoersFile, err) - } - - logging.LogSuccess("Successfully configured sudo to preserve HARDN_CONFIG environment variable for user: %s", sudoUser) - logging.LogInfo("You can now set HARDN_CONFIG environment variable and it will be preserved when using sudo") - return nil -} \ No newline at end of file diff --git a/pkg/utils/sudo_test.go b/pkg/utils/sudo_test.go deleted file mode 100644 index 22c3da8..0000000 --- a/pkg/utils/sudo_test.go +++ /dev/null @@ -1,32 +0,0 @@ -// pkg/utils/sudo_test.go - -package utils - -import ( - "os" - "testing" -) - -func TestCheckSudoEnvPreservation(t *testing.T) { - // This is a simple mock test since we can't actually create sudoers files in tests - // A real test would use a test fixture or mock the file system - - // Mock the username - origUser := os.Getenv("USER") - defer os.Setenv("USER", origUser) - - os.Setenv("USER", "testuser") - - // The real implementation would check for file existence and content - // but for testing we'll just return false since the file won't exist - // This is just a placeholder for the test structure - if checkSudoEnvPreservation() != false { - t.Error("Expected false when sudoers file doesn't exist") - } -} - -func checkSudoEnvPreservation() bool { - // This is a simplified version just for testing - // The real implementation is in the menu.go file - return false -} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 5a0e44e..380177e 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -2,14 +2,13 @@ package utils import ( "fmt" - "net" "os" "os/exec" "path/filepath" - "strings" "time" "github.com/abbott/hardn/pkg/config" + "github.com/abbott/hardn/pkg/interfaces" "github.com/abbott/hardn/pkg/logging" "github.com/abbott/hardn/pkg/style" ) @@ -94,64 +93,68 @@ func RunCommand(name string, args ...string) (string, error) { } // CheckSubnet checks if the specified subnet is present in the system's interfaces -func CheckSubnet(subnet string) (bool, error) { - interfaces, err := net.Interfaces() - if err != nil { - return false, fmt.Errorf("failed to get network interfaces: %w", err) - } - - for _, iface := range interfaces { - addrs, err := iface.Addrs() - if err != nil { - continue - } - - for _, addr := range addrs { - ipNet, ok := addr.(*net.IPNet) - if !ok { - continue - } - - ip := ipNet.IP.To4() - if ip == nil { - continue - } - - // Check if IP matches subnet - if strings.HasPrefix(ip.String(), subnet+".") { - logging.LogInfo("Target IP subnet %s.x found: %s", subnet, ip.String()) - return true, nil - } - } - } - - // Get all available IPs for logging - var availableIPs []string - for _, iface := range interfaces { - addrs, err := iface.Addrs() - if err != nil { - continue - } - - for _, addr := range addrs { - ipNet, ok := addr.(*net.IPNet) - if !ok { - continue - } - - ip := ipNet.IP.To4() - if ip == nil || strings.HasPrefix(ip.String(), "127.") { - continue - } - - availableIPs = append(availableIPs, ip.String()) - } - } - - logging.LogInfo("Target IP subnet %s.x not found. Available subnets: %s", subnet, strings.Join(availableIPs, ", ")) - return false, nil +func CheckSubnet(subnet string, networkOps interfaces.NetworkOperations) (bool, error) { + return networkOps.CheckSubnet(subnet) } +// func CheckSubnet(subnet string) (bool, error) { +// interfaces, err := net.Interfaces() +// if err != nil { +// return false, fmt.Errorf("failed to get network interfaces: %w", err) +// } + +// for _, iface := range interfaces { +// addrs, err := iface.Addrs() +// if err != nil { +// continue +// } + +// for _, addr := range addrs { +// ipNet, ok := addr.(*net.IPNet) +// if !ok { +// continue +// } + +// ip := ipNet.IP.To4() +// if ip == nil { +// continue +// } + +// // Check if IP matches subnet +// if strings.HasPrefix(ip.String(), subnet+".") { +// logging.LogInfo("Target IP subnet %s.x found: %s", subnet, ip.String()) +// return true, nil +// } +// } +// } + +// // Get all available IPs for logging +// var availableIPs []string +// for _, iface := range interfaces { +// addrs, err := iface.Addrs() +// if err != nil { +// continue +// } + +// for _, addr := range addrs { +// ipNet, ok := addr.(*net.IPNet) +// if !ok { +// continue +// } + +// ip := ipNet.IP.To4() +// if ip == nil || strings.HasPrefix(ip.String(), "127.") { +// continue +// } + +// availableIPs = append(availableIPs, ip.String()) +// } +// } + +// logging.LogInfo("Target IP subnet %s.x not found. Available subnets: %s", subnet, strings.Join(availableIPs, ", ")) +// return false, nil +// } + // SetupHushlogin creates a .hushlogin file in the home directory func SetupHushlogin(cfg *config.Config) error { if cfg.DryRun { @@ -188,4 +191,4 @@ func PrintLogs(logPath string) { fmt.Printf("\n# Contents of %s:\n\n", logPath) fmt.Println(string(data)) -} \ No newline at end of file +} diff --git a/pkg/version/checker.go b/pkg/version/checker.go new file mode 100644 index 0000000..1de2e46 --- /dev/null +++ b/pkg/version/checker.go @@ -0,0 +1,304 @@ +// Package version provides version checking and update notification functionality +package version + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "time" +) + +const ( + // GitHubAPIURL is the endpoint for checking the latest release + GitHubAPIURL = "https://api.github.com/repos/abbott/hardn/releases/latest" + + // CacheFileName is where we store the last check results + CacheFileName = ".hardn-version-cache.json" + + // CacheTTL defines how long the cache is valid (24 hours) + CacheTTL = 24 * time.Hour +) + +// GitHubRelease represents the JSON structure of a GitHub release +type GitHubRelease struct { + TagName string `json:"tag_name"` + Name string `json:"name"` + PublishedAt time.Time `json:"published_at"` + HTMLURL string `json:"html_url"` +} + +// VersionCache stores the cached check results +type VersionCache struct { + LastCheck time.Time `json:"last_check"` + LatestRelease GitHubRelease `json:"latest_release"` +} + +// CheckResult contains the result of a version check +type CheckResult struct { + CurrentVersion string + LatestVersion string + UpdateAvailable bool + ReleaseURL string + Error error +} + +// CheckForUpdates checks if a newer version is available on GitHub +func CheckForUpdates(currentVersion string, debug bool) CheckResult { + result := CheckResult{ + CurrentVersion: currentVersion, + } + + // Print debug info if enabled + if debug { + fmt.Println("DEBUG: Checking for updates...") + fmt.Println("DEBUG: Current version:", currentVersion) + } + + // Skip check if running without version info + if currentVersion == "" { + if debug { + fmt.Println("DEBUG: No version information provided. Skipping update check.") + } + return result + } + + // Use environment variable for cache control + if os.Getenv("HARDN_CLEAR_CACHE") != "" { + if debug { + fmt.Println("DEBUG: Clearing cache file") + } + os.Remove(getCacheFilePath()) + } + + // Try to load from cache first + cache, cacheValid := loadCache() + if cacheValid { + if debug { + fmt.Println("DEBUG: Using cached version information") + fmt.Println("DEBUG: Cached latest version:", cache.LatestRelease.TagName) + } + return compareVersions(currentVersion, cache.LatestRelease) + } + + if debug { + fmt.Println("DEBUG: No valid cache found. Fetching from GitHub API...") + } + + // Fetch from GitHub API with a short timeout + client := &http.Client{ + Timeout: 3 * time.Second, + } + + req, err := http.NewRequest("GET", GitHubAPIURL, nil) + if err != nil { + if debug { + fmt.Printf("DEBUG: Failed to create request: %v\n", err) + } + result.Error = fmt.Errorf("failed to create request: %w", err) + return result + } + + // Add User-Agent header to be a good API citizen + req.Header.Set("User-Agent", "hardn-version-checker") + + if debug { + fmt.Println("DEBUG: Sending request to GitHub API...") + } + + resp, err := client.Do(req) + if err != nil { + if debug { + fmt.Printf("DEBUG: Failed to check for updates: %v\n", err) + } + result.Error = fmt.Errorf("failed to check for updates: %w", err) + return result + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if debug { + fmt.Printf("DEBUG: GitHub API returned non-OK status: %s\n", resp.Status) + } + result.Error = fmt.Errorf("GitHub API returned non-OK status: %s", resp.Status) + return result + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + if debug { + fmt.Printf("DEBUG: Failed to read response: %v\n", err) + } + result.Error = fmt.Errorf("failed to read response: %w", err) + return result + } + + var release GitHubRelease + if err := json.Unmarshal(body, &release); err != nil { + if debug { + fmt.Printf("DEBUG: Failed to parse GitHub response: %v\n", err) + } + result.Error = fmt.Errorf("failed to parse GitHub response: %w", err) + return result + } + + if debug { + fmt.Println("DEBUG: Received latest version:", release.TagName) + fmt.Println("DEBUG: Saving to cache...") + } + + // Save to cache + saveCache(release) + + // Verify cache was written + if debug { + cacheFile := getCacheFilePath() + if _, err := os.Stat(cacheFile); err == nil { + fmt.Println("DEBUG: Successfully wrote cache file to:", cacheFile) + } else { + fmt.Printf("DEBUG: Failed to verify cache file: %v\n", err) + } + } + + return compareVersions(currentVersion, release) +} + +// compareVersions compares the current version with the latest release +func compareVersions(currentVersion string, release GitHubRelease) CheckResult { + result := CheckResult{ + CurrentVersion: currentVersion, + LatestVersion: strings.TrimPrefix(release.TagName, "v"), + ReleaseURL: release.HTMLURL, + } + + // Clean version strings (remove 'v' prefix if present) + current := strings.TrimPrefix(currentVersion, "v") + latest := strings.TrimPrefix(release.TagName, "v") + + // Handle pre-release suffixes + currentBase := current + latestBase := latest + + if strings.Contains(current, "-") { + parts := strings.SplitN(current, "-", 2) + currentBase = parts[0] + } + + if strings.Contains(latest, "-") { + parts := strings.SplitN(latest, "-", 2) + latestBase = parts[0] + } + + // Compare base versions + currentBaseParts := strings.Split(currentBase, ".") + latestBaseParts := strings.Split(latestBase, ".") + + // Ensure we have at least 3 components (major.minor.patch) + for len(currentBaseParts) < 3 { + currentBaseParts = append(currentBaseParts, "0") + } + + for len(latestBaseParts) < 3 { + latestBaseParts = append(latestBaseParts, "0") + } + + // Compare version components + for i := 0; i < 3; i++ { + currentNum, _ := strconv.Atoi(currentBaseParts[i]) + latestNum, _ := strconv.Atoi(latestBaseParts[i]) + + if latestNum > currentNum { + result.UpdateAvailable = true + return result + } else if currentNum > latestNum { + result.UpdateAvailable = false + return result + } + } + + // If base versions are equal, compare pre-release status + // A version without a pre-release suffix is considered newer than one with it + isCurrentPreRelease := strings.Contains(current, "-") + isLatestPreRelease := strings.Contains(latest, "-") + + if isCurrentPreRelease && !isLatestPreRelease { + result.UpdateAvailable = true + } else if !isCurrentPreRelease && isLatestPreRelease { + result.UpdateAvailable = false + } + + return result +} + +// loadCache tries to load the cached version check results +func loadCache() (VersionCache, bool) { + var cache VersionCache + + // Get cache file path + cacheFile := getCacheFilePath() + + // Try to read cache file + data, err := os.ReadFile(cacheFile) + if err != nil { + return cache, false + } + + // Parse JSON + if err := json.Unmarshal(data, &cache); err != nil { + return cache, false + } + + // Check if cache is still valid + if time.Since(cache.LastCheck) > CacheTTL { + return cache, false + } + + return cache, true +} + +// saveCache saves the version check results to cache +func saveCache(release GitHubRelease) { + cache := VersionCache{ + LastCheck: time.Now(), + LatestRelease: release, + } + + // Convert to JSON + data, err := json.MarshalIndent(cache, "", " ") + if err != nil { + fmt.Printf("Warning: Failed to marshal version cache: %v\n", err) + return + } + + // Get cache file path + cacheFile := getCacheFilePath() + + // Ensure directory exists + dir := filepath.Dir(cacheFile) + if err := os.MkdirAll(dir, 0755); err != nil { + fmt.Printf("Warning: Failed to create directory for version cache: %v\n", err) + return + } + + // Write to cache file + if err := os.WriteFile(cacheFile, data, 0644); err != nil { + // Log the error but don't fail the operation since this is just cache + fmt.Printf("Warning: Failed to write version cache: %v\n", err) + return + } +} + +// getCacheFilePath returns the path to the cache file +func getCacheFilePath() string { + // Check if HARDN_CACHE_PATH environment variable is set + if cachePath := os.Getenv("HARDN_CACHE_PATH"); cachePath != "" { + return cachePath + } + + // Use /tmp for easier access across users + return "/tmp/hardn-version-cache.json" +} diff --git a/pkg/version/service.go b/pkg/version/service.go new file mode 100644 index 0000000..5928bda --- /dev/null +++ b/pkg/version/service.go @@ -0,0 +1,97 @@ +package version + +import ( + "fmt" + "os" + "time" +) + +// UpdateOptions controls the behavior of the update checker +type UpdateOptions struct { + // Force an update to be available (for testing) + ForceUpdate bool + // Version to use if forcing an update + ForcedVersion string + // Show debug output + Debug bool + // Skip cache and force a fresh check + SkipCache bool + // Custom cache file location + CacheFilePath string + // Force immediate cache expiration + ClearCache bool +} + +// Service provides version checking functionality +type Service struct { + CurrentVersion string + BuildDate string + GitCommit string +} + +// NewService creates a version service instance +func NewService(currentVersion, buildDate, gitCommit string) *Service { + return &Service{ + CurrentVersion: currentVersion, + BuildDate: buildDate, + GitCommit: gitCommit, + } +} + +// CheckForUpdates checks if a newer version is available +func (s *Service) CheckForUpdates(options *UpdateOptions) CheckResult { + // Default options if nil + if options == nil { + options = &UpdateOptions{} + } + + // For testing purposes, we can force an update to be available + if options.ForceUpdate { + return CheckResult{ + CurrentVersion: s.CurrentVersion, + LatestVersion: options.ForcedVersion, + UpdateAvailable: true, + ReleaseURL: "https://github.com/abbott/hardn/releases/latest", + } + } + + // Set up environment variables for the underlying check function + if options.Debug { + os.Setenv("HARDN_DEBUG", "1") + defer os.Unsetenv("HARDN_DEBUG") + } + + if options.ClearCache { + os.Setenv("HARDN_CLEAR_CACHE", "1") + defer os.Unsetenv("HARDN_CLEAR_CACHE") + } + + if options.CacheFilePath != "" { + os.Setenv("HARDN_CACHE_PATH", options.CacheFilePath) + defer os.Unsetenv("HARDN_CACHE_PATH") + } + + // Perform the actual check + return CheckForUpdates(s.CurrentVersion, options.Debug) +} + +// PrintVersionInfo prints version information to stdout +func (s *Service) PrintVersionInfo() { + fmt.Println("hardn - Linux hardening tool") + fmt.Printf("Version: %s\n", s.CurrentVersion) + if s.BuildDate != "" { + fmt.Printf("Build Date: %s\n", s.BuildDate) + } + if s.GitCommit != "" { + fmt.Printf("Git Commit: %s\n", s.GitCommit) + } +} + +// GetCacheStatus returns information about the update cache +func (s *Service) GetCacheStatus() (bool, time.Time, error) { + cache, valid := loadCache() + if !valid { + return false, time.Time{}, fmt.Errorf("no valid cache found") + } + return true, cache.LastCheck, nil +} diff --git a/vendor/github.com/davecgh/go-spew/LICENSE b/vendor/github.com/davecgh/go-spew/LICENSE new file mode 100644 index 0000000..bc52e96 --- /dev/null +++ b/vendor/github.com/davecgh/go-spew/LICENSE @@ -0,0 +1,15 @@ +ISC License + +Copyright (c) 2012-2016 Dave Collins + +Permission to use, copy, modify, and/or distribute this software for any +purpose with or without fee is hereby granted, provided that the above +copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. diff --git a/vendor/github.com/davecgh/go-spew/spew/bypass.go b/vendor/github.com/davecgh/go-spew/spew/bypass.go new file mode 100644 index 0000000..7929947 --- /dev/null +++ b/vendor/github.com/davecgh/go-spew/spew/bypass.go @@ -0,0 +1,145 @@ +// Copyright (c) 2015-2016 Dave Collins +// +// Permission to use, copy, modify, and distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +// NOTE: Due to the following build constraints, this file will only be compiled +// when the code is not running on Google App Engine, compiled by GopherJS, and +// "-tags safe" is not added to the go build command line. The "disableunsafe" +// tag is deprecated and thus should not be used. +// Go versions prior to 1.4 are disabled because they use a different layout +// for interfaces which make the implementation of unsafeReflectValue more complex. +// +build !js,!appengine,!safe,!disableunsafe,go1.4 + +package spew + +import ( + "reflect" + "unsafe" +) + +const ( + // UnsafeDisabled is a build-time constant which specifies whether or + // not access to the unsafe package is available. + UnsafeDisabled = false + + // ptrSize is the size of a pointer on the current arch. + ptrSize = unsafe.Sizeof((*byte)(nil)) +) + +type flag uintptr + +var ( + // flagRO indicates whether the value field of a reflect.Value + // is read-only. + flagRO flag + + // flagAddr indicates whether the address of the reflect.Value's + // value may be taken. + flagAddr flag +) + +// flagKindMask holds the bits that make up the kind +// part of the flags field. In all the supported versions, +// it is in the lower 5 bits. +const flagKindMask = flag(0x1f) + +// Different versions of Go have used different +// bit layouts for the flags type. This table +// records the known combinations. +var okFlags = []struct { + ro, addr flag +}{{ + // From Go 1.4 to 1.5 + ro: 1 << 5, + addr: 1 << 7, +}, { + // Up to Go tip. + ro: 1<<5 | 1<<6, + addr: 1 << 8, +}} + +var flagValOffset = func() uintptr { + field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag") + if !ok { + panic("reflect.Value has no flag field") + } + return field.Offset +}() + +// flagField returns a pointer to the flag field of a reflect.Value. +func flagField(v *reflect.Value) *flag { + return (*flag)(unsafe.Pointer(uintptr(unsafe.Pointer(v)) + flagValOffset)) +} + +// unsafeReflectValue converts the passed reflect.Value into a one that bypasses +// the typical safety restrictions preventing access to unaddressable and +// unexported data. It works by digging the raw pointer to the underlying +// value out of the protected value and generating a new unprotected (unsafe) +// reflect.Value to it. +// +// This allows us to check for implementations of the Stringer and error +// interfaces to be used for pretty printing ordinarily unaddressable and +// inaccessible values such as unexported struct fields. +func unsafeReflectValue(v reflect.Value) reflect.Value { + if !v.IsValid() || (v.CanInterface() && v.CanAddr()) { + return v + } + flagFieldPtr := flagField(&v) + *flagFieldPtr &^= flagRO + *flagFieldPtr |= flagAddr + return v +} + +// Sanity checks against future reflect package changes +// to the type or semantics of the Value.flag field. +func init() { + field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag") + if !ok { + panic("reflect.Value has no flag field") + } + if field.Type.Kind() != reflect.TypeOf(flag(0)).Kind() { + panic("reflect.Value flag field has changed kind") + } + type t0 int + var t struct { + A t0 + // t0 will have flagEmbedRO set. + t0 + // a will have flagStickyRO set + a t0 + } + vA := reflect.ValueOf(t).FieldByName("A") + va := reflect.ValueOf(t).FieldByName("a") + vt0 := reflect.ValueOf(t).FieldByName("t0") + + // Infer flagRO from the difference between the flags + // for the (otherwise identical) fields in t. + flagPublic := *flagField(&vA) + flagWithRO := *flagField(&va) | *flagField(&vt0) + flagRO = flagPublic ^ flagWithRO + + // Infer flagAddr from the difference between a value + // taken from a pointer and not. + vPtrA := reflect.ValueOf(&t).Elem().FieldByName("A") + flagNoPtr := *flagField(&vA) + flagPtr := *flagField(&vPtrA) + flagAddr = flagNoPtr ^ flagPtr + + // Check that the inferred flags tally with one of the known versions. + for _, f := range okFlags { + if flagRO == f.ro && flagAddr == f.addr { + return + } + } + panic("reflect.Value read-only flag has changed semantics") +} diff --git a/vendor/github.com/davecgh/go-spew/spew/bypasssafe.go b/vendor/github.com/davecgh/go-spew/spew/bypasssafe.go new file mode 100644 index 0000000..205c28d --- /dev/null +++ b/vendor/github.com/davecgh/go-spew/spew/bypasssafe.go @@ -0,0 +1,38 @@ +// Copyright (c) 2015-2016 Dave Collins +// +// Permission to use, copy, modify, and distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +// NOTE: Due to the following build constraints, this file will only be compiled +// when the code is running on Google App Engine, compiled by GopherJS, or +// "-tags safe" is added to the go build command line. The "disableunsafe" +// tag is deprecated and thus should not be used. +// +build js appengine safe disableunsafe !go1.4 + +package spew + +import "reflect" + +const ( + // UnsafeDisabled is a build-time constant which specifies whether or + // not access to the unsafe package is available. + UnsafeDisabled = true +) + +// unsafeReflectValue typically converts the passed reflect.Value into a one +// that bypasses the typical safety restrictions preventing access to +// unaddressable and unexported data. However, doing this relies on access to +// the unsafe package. This is a stub version which simply returns the passed +// reflect.Value when the unsafe package is not available. +func unsafeReflectValue(v reflect.Value) reflect.Value { + return v +} diff --git a/vendor/github.com/davecgh/go-spew/spew/common.go b/vendor/github.com/davecgh/go-spew/spew/common.go new file mode 100644 index 0000000..1be8ce9 --- /dev/null +++ b/vendor/github.com/davecgh/go-spew/spew/common.go @@ -0,0 +1,341 @@ +/* + * Copyright (c) 2013-2016 Dave Collins + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +package spew + +import ( + "bytes" + "fmt" + "io" + "reflect" + "sort" + "strconv" +) + +// Some constants in the form of bytes to avoid string overhead. This mirrors +// the technique used in the fmt package. +var ( + panicBytes = []byte("(PANIC=") + plusBytes = []byte("+") + iBytes = []byte("i") + trueBytes = []byte("true") + falseBytes = []byte("false") + interfaceBytes = []byte("(interface {})") + commaNewlineBytes = []byte(",\n") + newlineBytes = []byte("\n") + openBraceBytes = []byte("{") + openBraceNewlineBytes = []byte("{\n") + closeBraceBytes = []byte("}") + asteriskBytes = []byte("*") + colonBytes = []byte(":") + colonSpaceBytes = []byte(": ") + openParenBytes = []byte("(") + closeParenBytes = []byte(")") + spaceBytes = []byte(" ") + pointerChainBytes = []byte("->") + nilAngleBytes = []byte("") + maxNewlineBytes = []byte("\n") + maxShortBytes = []byte("") + circularBytes = []byte("") + circularShortBytes = []byte("") + invalidAngleBytes = []byte("") + openBracketBytes = []byte("[") + closeBracketBytes = []byte("]") + percentBytes = []byte("%") + precisionBytes = []byte(".") + openAngleBytes = []byte("<") + closeAngleBytes = []byte(">") + openMapBytes = []byte("map[") + closeMapBytes = []byte("]") + lenEqualsBytes = []byte("len=") + capEqualsBytes = []byte("cap=") +) + +// hexDigits is used to map a decimal value to a hex digit. +var hexDigits = "0123456789abcdef" + +// catchPanic handles any panics that might occur during the handleMethods +// calls. +func catchPanic(w io.Writer, v reflect.Value) { + if err := recover(); err != nil { + w.Write(panicBytes) + fmt.Fprintf(w, "%v", err) + w.Write(closeParenBytes) + } +} + +// handleMethods attempts to call the Error and String methods on the underlying +// type the passed reflect.Value represents and outputes the result to Writer w. +// +// It handles panics in any called methods by catching and displaying the error +// as the formatted value. +func handleMethods(cs *ConfigState, w io.Writer, v reflect.Value) (handled bool) { + // We need an interface to check if the type implements the error or + // Stringer interface. However, the reflect package won't give us an + // interface on certain things like unexported struct fields in order + // to enforce visibility rules. We use unsafe, when it's available, + // to bypass these restrictions since this package does not mutate the + // values. + if !v.CanInterface() { + if UnsafeDisabled { + return false + } + + v = unsafeReflectValue(v) + } + + // Choose whether or not to do error and Stringer interface lookups against + // the base type or a pointer to the base type depending on settings. + // Technically calling one of these methods with a pointer receiver can + // mutate the value, however, types which choose to satisify an error or + // Stringer interface with a pointer receiver should not be mutating their + // state inside these interface methods. + if !cs.DisablePointerMethods && !UnsafeDisabled && !v.CanAddr() { + v = unsafeReflectValue(v) + } + if v.CanAddr() { + v = v.Addr() + } + + // Is it an error or Stringer? + switch iface := v.Interface().(type) { + case error: + defer catchPanic(w, v) + if cs.ContinueOnMethod { + w.Write(openParenBytes) + w.Write([]byte(iface.Error())) + w.Write(closeParenBytes) + w.Write(spaceBytes) + return false + } + + w.Write([]byte(iface.Error())) + return true + + case fmt.Stringer: + defer catchPanic(w, v) + if cs.ContinueOnMethod { + w.Write(openParenBytes) + w.Write([]byte(iface.String())) + w.Write(closeParenBytes) + w.Write(spaceBytes) + return false + } + w.Write([]byte(iface.String())) + return true + } + return false +} + +// printBool outputs a boolean value as true or false to Writer w. +func printBool(w io.Writer, val bool) { + if val { + w.Write(trueBytes) + } else { + w.Write(falseBytes) + } +} + +// printInt outputs a signed integer value to Writer w. +func printInt(w io.Writer, val int64, base int) { + w.Write([]byte(strconv.FormatInt(val, base))) +} + +// printUint outputs an unsigned integer value to Writer w. +func printUint(w io.Writer, val uint64, base int) { + w.Write([]byte(strconv.FormatUint(val, base))) +} + +// printFloat outputs a floating point value using the specified precision, +// which is expected to be 32 or 64bit, to Writer w. +func printFloat(w io.Writer, val float64, precision int) { + w.Write([]byte(strconv.FormatFloat(val, 'g', -1, precision))) +} + +// printComplex outputs a complex value using the specified float precision +// for the real and imaginary parts to Writer w. +func printComplex(w io.Writer, c complex128, floatPrecision int) { + r := real(c) + w.Write(openParenBytes) + w.Write([]byte(strconv.FormatFloat(r, 'g', -1, floatPrecision))) + i := imag(c) + if i >= 0 { + w.Write(plusBytes) + } + w.Write([]byte(strconv.FormatFloat(i, 'g', -1, floatPrecision))) + w.Write(iBytes) + w.Write(closeParenBytes) +} + +// printHexPtr outputs a uintptr formatted as hexadecimal with a leading '0x' +// prefix to Writer w. +func printHexPtr(w io.Writer, p uintptr) { + // Null pointer. + num := uint64(p) + if num == 0 { + w.Write(nilAngleBytes) + return + } + + // Max uint64 is 16 bytes in hex + 2 bytes for '0x' prefix + buf := make([]byte, 18) + + // It's simpler to construct the hex string right to left. + base := uint64(16) + i := len(buf) - 1 + for num >= base { + buf[i] = hexDigits[num%base] + num /= base + i-- + } + buf[i] = hexDigits[num] + + // Add '0x' prefix. + i-- + buf[i] = 'x' + i-- + buf[i] = '0' + + // Strip unused leading bytes. + buf = buf[i:] + w.Write(buf) +} + +// valuesSorter implements sort.Interface to allow a slice of reflect.Value +// elements to be sorted. +type valuesSorter struct { + values []reflect.Value + strings []string // either nil or same len and values + cs *ConfigState +} + +// newValuesSorter initializes a valuesSorter instance, which holds a set of +// surrogate keys on which the data should be sorted. It uses flags in +// ConfigState to decide if and how to populate those surrogate keys. +func newValuesSorter(values []reflect.Value, cs *ConfigState) sort.Interface { + vs := &valuesSorter{values: values, cs: cs} + if canSortSimply(vs.values[0].Kind()) { + return vs + } + if !cs.DisableMethods { + vs.strings = make([]string, len(values)) + for i := range vs.values { + b := bytes.Buffer{} + if !handleMethods(cs, &b, vs.values[i]) { + vs.strings = nil + break + } + vs.strings[i] = b.String() + } + } + if vs.strings == nil && cs.SpewKeys { + vs.strings = make([]string, len(values)) + for i := range vs.values { + vs.strings[i] = Sprintf("%#v", vs.values[i].Interface()) + } + } + return vs +} + +// canSortSimply tests whether a reflect.Kind is a primitive that can be sorted +// directly, or whether it should be considered for sorting by surrogate keys +// (if the ConfigState allows it). +func canSortSimply(kind reflect.Kind) bool { + // This switch parallels valueSortLess, except for the default case. + switch kind { + case reflect.Bool: + return true + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: + return true + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + return true + case reflect.Float32, reflect.Float64: + return true + case reflect.String: + return true + case reflect.Uintptr: + return true + case reflect.Array: + return true + } + return false +} + +// Len returns the number of values in the slice. It is part of the +// sort.Interface implementation. +func (s *valuesSorter) Len() int { + return len(s.values) +} + +// Swap swaps the values at the passed indices. It is part of the +// sort.Interface implementation. +func (s *valuesSorter) Swap(i, j int) { + s.values[i], s.values[j] = s.values[j], s.values[i] + if s.strings != nil { + s.strings[i], s.strings[j] = s.strings[j], s.strings[i] + } +} + +// valueSortLess returns whether the first value should sort before the second +// value. It is used by valueSorter.Less as part of the sort.Interface +// implementation. +func valueSortLess(a, b reflect.Value) bool { + switch a.Kind() { + case reflect.Bool: + return !a.Bool() && b.Bool() + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: + return a.Int() < b.Int() + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + return a.Uint() < b.Uint() + case reflect.Float32, reflect.Float64: + return a.Float() < b.Float() + case reflect.String: + return a.String() < b.String() + case reflect.Uintptr: + return a.Uint() < b.Uint() + case reflect.Array: + // Compare the contents of both arrays. + l := a.Len() + for i := 0; i < l; i++ { + av := a.Index(i) + bv := b.Index(i) + if av.Interface() == bv.Interface() { + continue + } + return valueSortLess(av, bv) + } + } + return a.String() < b.String() +} + +// Less returns whether the value at index i should sort before the +// value at index j. It is part of the sort.Interface implementation. +func (s *valuesSorter) Less(i, j int) bool { + if s.strings == nil { + return valueSortLess(s.values[i], s.values[j]) + } + return s.strings[i] < s.strings[j] +} + +// sortValues is a sort function that handles both native types and any type that +// can be converted to error or Stringer. Other inputs are sorted according to +// their Value.String() value to ensure display stability. +func sortValues(values []reflect.Value, cs *ConfigState) { + if len(values) == 0 { + return + } + sort.Sort(newValuesSorter(values, cs)) +} diff --git a/vendor/github.com/davecgh/go-spew/spew/config.go b/vendor/github.com/davecgh/go-spew/spew/config.go new file mode 100644 index 0000000..2e3d22f --- /dev/null +++ b/vendor/github.com/davecgh/go-spew/spew/config.go @@ -0,0 +1,306 @@ +/* + * Copyright (c) 2013-2016 Dave Collins + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +package spew + +import ( + "bytes" + "fmt" + "io" + "os" +) + +// ConfigState houses the configuration options used by spew to format and +// display values. There is a global instance, Config, that is used to control +// all top-level Formatter and Dump functionality. Each ConfigState instance +// provides methods equivalent to the top-level functions. +// +// The zero value for ConfigState provides no indentation. You would typically +// want to set it to a space or a tab. +// +// Alternatively, you can use NewDefaultConfig to get a ConfigState instance +// with default settings. See the documentation of NewDefaultConfig for default +// values. +type ConfigState struct { + // Indent specifies the string to use for each indentation level. The + // global config instance that all top-level functions use set this to a + // single space by default. If you would like more indentation, you might + // set this to a tab with "\t" or perhaps two spaces with " ". + Indent string + + // MaxDepth controls the maximum number of levels to descend into nested + // data structures. The default, 0, means there is no limit. + // + // NOTE: Circular data structures are properly detected, so it is not + // necessary to set this value unless you specifically want to limit deeply + // nested data structures. + MaxDepth int + + // DisableMethods specifies whether or not error and Stringer interfaces are + // invoked for types that implement them. + DisableMethods bool + + // DisablePointerMethods specifies whether or not to check for and invoke + // error and Stringer interfaces on types which only accept a pointer + // receiver when the current type is not a pointer. + // + // NOTE: This might be an unsafe action since calling one of these methods + // with a pointer receiver could technically mutate the value, however, + // in practice, types which choose to satisify an error or Stringer + // interface with a pointer receiver should not be mutating their state + // inside these interface methods. As a result, this option relies on + // access to the unsafe package, so it will not have any effect when + // running in environments without access to the unsafe package such as + // Google App Engine or with the "safe" build tag specified. + DisablePointerMethods bool + + // DisablePointerAddresses specifies whether to disable the printing of + // pointer addresses. This is useful when diffing data structures in tests. + DisablePointerAddresses bool + + // DisableCapacities specifies whether to disable the printing of capacities + // for arrays, slices, maps and channels. This is useful when diffing + // data structures in tests. + DisableCapacities bool + + // ContinueOnMethod specifies whether or not recursion should continue once + // a custom error or Stringer interface is invoked. The default, false, + // means it will print the results of invoking the custom error or Stringer + // interface and return immediately instead of continuing to recurse into + // the internals of the data type. + // + // NOTE: This flag does not have any effect if method invocation is disabled + // via the DisableMethods or DisablePointerMethods options. + ContinueOnMethod bool + + // SortKeys specifies map keys should be sorted before being printed. Use + // this to have a more deterministic, diffable output. Note that only + // native types (bool, int, uint, floats, uintptr and string) and types + // that support the error or Stringer interfaces (if methods are + // enabled) are supported, with other types sorted according to the + // reflect.Value.String() output which guarantees display stability. + SortKeys bool + + // SpewKeys specifies that, as a last resort attempt, map keys should + // be spewed to strings and sorted by those strings. This is only + // considered if SortKeys is true. + SpewKeys bool +} + +// Config is the active configuration of the top-level functions. +// The configuration can be changed by modifying the contents of spew.Config. +var Config = ConfigState{Indent: " "} + +// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were +// passed with a Formatter interface returned by c.NewFormatter. It returns +// the formatted string as a value that satisfies error. See NewFormatter +// for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Errorf(format, c.NewFormatter(a), c.NewFormatter(b)) +func (c *ConfigState) Errorf(format string, a ...interface{}) (err error) { + return fmt.Errorf(format, c.convertArgs(a)...) +} + +// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were +// passed with a Formatter interface returned by c.NewFormatter. It returns +// the number of bytes written and any write error encountered. See +// NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Fprint(w, c.NewFormatter(a), c.NewFormatter(b)) +func (c *ConfigState) Fprint(w io.Writer, a ...interface{}) (n int, err error) { + return fmt.Fprint(w, c.convertArgs(a)...) +} + +// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were +// passed with a Formatter interface returned by c.NewFormatter. It returns +// the number of bytes written and any write error encountered. See +// NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Fprintf(w, format, c.NewFormatter(a), c.NewFormatter(b)) +func (c *ConfigState) Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) { + return fmt.Fprintf(w, format, c.convertArgs(a)...) +} + +// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it +// passed with a Formatter interface returned by c.NewFormatter. See +// NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Fprintln(w, c.NewFormatter(a), c.NewFormatter(b)) +func (c *ConfigState) Fprintln(w io.Writer, a ...interface{}) (n int, err error) { + return fmt.Fprintln(w, c.convertArgs(a)...) +} + +// Print is a wrapper for fmt.Print that treats each argument as if it were +// passed with a Formatter interface returned by c.NewFormatter. It returns +// the number of bytes written and any write error encountered. See +// NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Print(c.NewFormatter(a), c.NewFormatter(b)) +func (c *ConfigState) Print(a ...interface{}) (n int, err error) { + return fmt.Print(c.convertArgs(a)...) +} + +// Printf is a wrapper for fmt.Printf that treats each argument as if it were +// passed with a Formatter interface returned by c.NewFormatter. It returns +// the number of bytes written and any write error encountered. See +// NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Printf(format, c.NewFormatter(a), c.NewFormatter(b)) +func (c *ConfigState) Printf(format string, a ...interface{}) (n int, err error) { + return fmt.Printf(format, c.convertArgs(a)...) +} + +// Println is a wrapper for fmt.Println that treats each argument as if it were +// passed with a Formatter interface returned by c.NewFormatter. It returns +// the number of bytes written and any write error encountered. See +// NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Println(c.NewFormatter(a), c.NewFormatter(b)) +func (c *ConfigState) Println(a ...interface{}) (n int, err error) { + return fmt.Println(c.convertArgs(a)...) +} + +// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were +// passed with a Formatter interface returned by c.NewFormatter. It returns +// the resulting string. See NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Sprint(c.NewFormatter(a), c.NewFormatter(b)) +func (c *ConfigState) Sprint(a ...interface{}) string { + return fmt.Sprint(c.convertArgs(a)...) +} + +// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were +// passed with a Formatter interface returned by c.NewFormatter. It returns +// the resulting string. See NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Sprintf(format, c.NewFormatter(a), c.NewFormatter(b)) +func (c *ConfigState) Sprintf(format string, a ...interface{}) string { + return fmt.Sprintf(format, c.convertArgs(a)...) +} + +// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it +// were passed with a Formatter interface returned by c.NewFormatter. It +// returns the resulting string. See NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Sprintln(c.NewFormatter(a), c.NewFormatter(b)) +func (c *ConfigState) Sprintln(a ...interface{}) string { + return fmt.Sprintln(c.convertArgs(a)...) +} + +/* +NewFormatter returns a custom formatter that satisfies the fmt.Formatter +interface. As a result, it integrates cleanly with standard fmt package +printing functions. The formatter is useful for inline printing of smaller data +types similar to the standard %v format specifier. + +The custom formatter only responds to the %v (most compact), %+v (adds pointer +addresses), %#v (adds types), and %#+v (adds types and pointer addresses) verb +combinations. Any other verbs such as %x and %q will be sent to the the +standard fmt package for formatting. In addition, the custom formatter ignores +the width and precision arguments (however they will still work on the format +specifiers not handled by the custom formatter). + +Typically this function shouldn't be called directly. It is much easier to make +use of the custom formatter by calling one of the convenience functions such as +c.Printf, c.Println, or c.Printf. +*/ +func (c *ConfigState) NewFormatter(v interface{}) fmt.Formatter { + return newFormatter(c, v) +} + +// Fdump formats and displays the passed arguments to io.Writer w. It formats +// exactly the same as Dump. +func (c *ConfigState) Fdump(w io.Writer, a ...interface{}) { + fdump(c, w, a...) +} + +/* +Dump displays the passed parameters to standard out with newlines, customizable +indentation, and additional debug information such as complete types and all +pointer addresses used to indirect to the final value. It provides the +following features over the built-in printing facilities provided by the fmt +package: + + * Pointers are dereferenced and followed + * Circular data structures are detected and handled properly + * Custom Stringer/error interfaces are optionally invoked, including + on unexported types + * Custom types which only implement the Stringer/error interfaces via + a pointer receiver are optionally invoked when passing non-pointer + variables + * Byte arrays and slices are dumped like the hexdump -C command which + includes offsets, byte values in hex, and ASCII output + +The configuration options are controlled by modifying the public members +of c. See ConfigState for options documentation. + +See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to +get the formatted result as a string. +*/ +func (c *ConfigState) Dump(a ...interface{}) { + fdump(c, os.Stdout, a...) +} + +// Sdump returns a string with the passed arguments formatted exactly the same +// as Dump. +func (c *ConfigState) Sdump(a ...interface{}) string { + var buf bytes.Buffer + fdump(c, &buf, a...) + return buf.String() +} + +// convertArgs accepts a slice of arguments and returns a slice of the same +// length with each argument converted to a spew Formatter interface using +// the ConfigState associated with s. +func (c *ConfigState) convertArgs(args []interface{}) (formatters []interface{}) { + formatters = make([]interface{}, len(args)) + for index, arg := range args { + formatters[index] = newFormatter(c, arg) + } + return formatters +} + +// NewDefaultConfig returns a ConfigState with the following default settings. +// +// Indent: " " +// MaxDepth: 0 +// DisableMethods: false +// DisablePointerMethods: false +// ContinueOnMethod: false +// SortKeys: false +func NewDefaultConfig() *ConfigState { + return &ConfigState{Indent: " "} +} diff --git a/vendor/github.com/davecgh/go-spew/spew/doc.go b/vendor/github.com/davecgh/go-spew/spew/doc.go new file mode 100644 index 0000000..aacaac6 --- /dev/null +++ b/vendor/github.com/davecgh/go-spew/spew/doc.go @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2013-2016 Dave Collins + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +/* +Package spew implements a deep pretty printer for Go data structures to aid in +debugging. + +A quick overview of the additional features spew provides over the built-in +printing facilities for Go data types are as follows: + + * Pointers are dereferenced and followed + * Circular data structures are detected and handled properly + * Custom Stringer/error interfaces are optionally invoked, including + on unexported types + * Custom types which only implement the Stringer/error interfaces via + a pointer receiver are optionally invoked when passing non-pointer + variables + * Byte arrays and slices are dumped like the hexdump -C command which + includes offsets, byte values in hex, and ASCII output (only when using + Dump style) + +There are two different approaches spew allows for dumping Go data structures: + + * Dump style which prints with newlines, customizable indentation, + and additional debug information such as types and all pointer addresses + used to indirect to the final value + * A custom Formatter interface that integrates cleanly with the standard fmt + package and replaces %v, %+v, %#v, and %#+v to provide inline printing + similar to the default %v while providing the additional functionality + outlined above and passing unsupported format verbs such as %x and %q + along to fmt + +Quick Start + +This section demonstrates how to quickly get started with spew. See the +sections below for further details on formatting and configuration options. + +To dump a variable with full newlines, indentation, type, and pointer +information use Dump, Fdump, or Sdump: + spew.Dump(myVar1, myVar2, ...) + spew.Fdump(someWriter, myVar1, myVar2, ...) + str := spew.Sdump(myVar1, myVar2, ...) + +Alternatively, if you would prefer to use format strings with a compacted inline +printing style, use the convenience wrappers Printf, Fprintf, etc with +%v (most compact), %+v (adds pointer addresses), %#v (adds types), or +%#+v (adds types and pointer addresses): + spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2) + spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4) + spew.Fprintf(someWriter, "myVar1: %v -- myVar2: %+v", myVar1, myVar2) + spew.Fprintf(someWriter, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4) + +Configuration Options + +Configuration of spew is handled by fields in the ConfigState type. For +convenience, all of the top-level functions use a global state available +via the spew.Config global. + +It is also possible to create a ConfigState instance that provides methods +equivalent to the top-level functions. This allows concurrent configuration +options. See the ConfigState documentation for more details. + +The following configuration options are available: + * Indent + String to use for each indentation level for Dump functions. + It is a single space by default. A popular alternative is "\t". + + * MaxDepth + Maximum number of levels to descend into nested data structures. + There is no limit by default. + + * DisableMethods + Disables invocation of error and Stringer interface methods. + Method invocation is enabled by default. + + * DisablePointerMethods + Disables invocation of error and Stringer interface methods on types + which only accept pointer receivers from non-pointer variables. + Pointer method invocation is enabled by default. + + * DisablePointerAddresses + DisablePointerAddresses specifies whether to disable the printing of + pointer addresses. This is useful when diffing data structures in tests. + + * DisableCapacities + DisableCapacities specifies whether to disable the printing of + capacities for arrays, slices, maps and channels. This is useful when + diffing data structures in tests. + + * ContinueOnMethod + Enables recursion into types after invoking error and Stringer interface + methods. Recursion after method invocation is disabled by default. + + * SortKeys + Specifies map keys should be sorted before being printed. Use + this to have a more deterministic, diffable output. Note that + only native types (bool, int, uint, floats, uintptr and string) + and types which implement error or Stringer interfaces are + supported with other types sorted according to the + reflect.Value.String() output which guarantees display + stability. Natural map order is used by default. + + * SpewKeys + Specifies that, as a last resort attempt, map keys should be + spewed to strings and sorted by those strings. This is only + considered if SortKeys is true. + +Dump Usage + +Simply call spew.Dump with a list of variables you want to dump: + + spew.Dump(myVar1, myVar2, ...) + +You may also call spew.Fdump if you would prefer to output to an arbitrary +io.Writer. For example, to dump to standard error: + + spew.Fdump(os.Stderr, myVar1, myVar2, ...) + +A third option is to call spew.Sdump to get the formatted output as a string: + + str := spew.Sdump(myVar1, myVar2, ...) + +Sample Dump Output + +See the Dump example for details on the setup of the types and variables being +shown here. + + (main.Foo) { + unexportedField: (*main.Bar)(0xf84002e210)({ + flag: (main.Flag) flagTwo, + data: (uintptr) + }), + ExportedField: (map[interface {}]interface {}) (len=1) { + (string) (len=3) "one": (bool) true + } + } + +Byte (and uint8) arrays and slices are displayed uniquely like the hexdump -C +command as shown. + ([]uint8) (len=32 cap=32) { + 00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 |............... | + 00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 |!"#$%&'()*+,-./0| + 00000020 31 32 |12| + } + +Custom Formatter + +Spew provides a custom formatter that implements the fmt.Formatter interface +so that it integrates cleanly with standard fmt package printing functions. The +formatter is useful for inline printing of smaller data types similar to the +standard %v format specifier. + +The custom formatter only responds to the %v (most compact), %+v (adds pointer +addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb +combinations. Any other verbs such as %x and %q will be sent to the the +standard fmt package for formatting. In addition, the custom formatter ignores +the width and precision arguments (however they will still work on the format +specifiers not handled by the custom formatter). + +Custom Formatter Usage + +The simplest way to make use of the spew custom formatter is to call one of the +convenience functions such as spew.Printf, spew.Println, or spew.Printf. The +functions have syntax you are most likely already familiar with: + + spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2) + spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4) + spew.Println(myVar, myVar2) + spew.Fprintf(os.Stderr, "myVar1: %v -- myVar2: %+v", myVar1, myVar2) + spew.Fprintf(os.Stderr, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4) + +See the Index for the full list convenience functions. + +Sample Formatter Output + +Double pointer to a uint8: + %v: <**>5 + %+v: <**>(0xf8400420d0->0xf8400420c8)5 + %#v: (**uint8)5 + %#+v: (**uint8)(0xf8400420d0->0xf8400420c8)5 + +Pointer to circular struct with a uint8 field and a pointer to itself: + %v: <*>{1 <*>} + %+v: <*>(0xf84003e260){ui8:1 c:<*>(0xf84003e260)} + %#v: (*main.circular){ui8:(uint8)1 c:(*main.circular)} + %#+v: (*main.circular)(0xf84003e260){ui8:(uint8)1 c:(*main.circular)(0xf84003e260)} + +See the Printf example for details on the setup of variables being shown +here. + +Errors + +Since it is possible for custom Stringer/error interfaces to panic, spew +detects them and handles them internally by printing the panic information +inline with the output. Since spew is intended to provide deep pretty printing +capabilities on structures, it intentionally does not return any errors. +*/ +package spew diff --git a/vendor/github.com/davecgh/go-spew/spew/dump.go b/vendor/github.com/davecgh/go-spew/spew/dump.go new file mode 100644 index 0000000..f78d89f --- /dev/null +++ b/vendor/github.com/davecgh/go-spew/spew/dump.go @@ -0,0 +1,509 @@ +/* + * Copyright (c) 2013-2016 Dave Collins + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +package spew + +import ( + "bytes" + "encoding/hex" + "fmt" + "io" + "os" + "reflect" + "regexp" + "strconv" + "strings" +) + +var ( + // uint8Type is a reflect.Type representing a uint8. It is used to + // convert cgo types to uint8 slices for hexdumping. + uint8Type = reflect.TypeOf(uint8(0)) + + // cCharRE is a regular expression that matches a cgo char. + // It is used to detect character arrays to hexdump them. + cCharRE = regexp.MustCompile(`^.*\._Ctype_char$`) + + // cUnsignedCharRE is a regular expression that matches a cgo unsigned + // char. It is used to detect unsigned character arrays to hexdump + // them. + cUnsignedCharRE = regexp.MustCompile(`^.*\._Ctype_unsignedchar$`) + + // cUint8tCharRE is a regular expression that matches a cgo uint8_t. + // It is used to detect uint8_t arrays to hexdump them. + cUint8tCharRE = regexp.MustCompile(`^.*\._Ctype_uint8_t$`) +) + +// dumpState contains information about the state of a dump operation. +type dumpState struct { + w io.Writer + depth int + pointers map[uintptr]int + ignoreNextType bool + ignoreNextIndent bool + cs *ConfigState +} + +// indent performs indentation according to the depth level and cs.Indent +// option. +func (d *dumpState) indent() { + if d.ignoreNextIndent { + d.ignoreNextIndent = false + return + } + d.w.Write(bytes.Repeat([]byte(d.cs.Indent), d.depth)) +} + +// unpackValue returns values inside of non-nil interfaces when possible. +// This is useful for data types like structs, arrays, slices, and maps which +// can contain varying types packed inside an interface. +func (d *dumpState) unpackValue(v reflect.Value) reflect.Value { + if v.Kind() == reflect.Interface && !v.IsNil() { + v = v.Elem() + } + return v +} + +// dumpPtr handles formatting of pointers by indirecting them as necessary. +func (d *dumpState) dumpPtr(v reflect.Value) { + // Remove pointers at or below the current depth from map used to detect + // circular refs. + for k, depth := range d.pointers { + if depth >= d.depth { + delete(d.pointers, k) + } + } + + // Keep list of all dereferenced pointers to show later. + pointerChain := make([]uintptr, 0) + + // Figure out how many levels of indirection there are by dereferencing + // pointers and unpacking interfaces down the chain while detecting circular + // references. + nilFound := false + cycleFound := false + indirects := 0 + ve := v + for ve.Kind() == reflect.Ptr { + if ve.IsNil() { + nilFound = true + break + } + indirects++ + addr := ve.Pointer() + pointerChain = append(pointerChain, addr) + if pd, ok := d.pointers[addr]; ok && pd < d.depth { + cycleFound = true + indirects-- + break + } + d.pointers[addr] = d.depth + + ve = ve.Elem() + if ve.Kind() == reflect.Interface { + if ve.IsNil() { + nilFound = true + break + } + ve = ve.Elem() + } + } + + // Display type information. + d.w.Write(openParenBytes) + d.w.Write(bytes.Repeat(asteriskBytes, indirects)) + d.w.Write([]byte(ve.Type().String())) + d.w.Write(closeParenBytes) + + // Display pointer information. + if !d.cs.DisablePointerAddresses && len(pointerChain) > 0 { + d.w.Write(openParenBytes) + for i, addr := range pointerChain { + if i > 0 { + d.w.Write(pointerChainBytes) + } + printHexPtr(d.w, addr) + } + d.w.Write(closeParenBytes) + } + + // Display dereferenced value. + d.w.Write(openParenBytes) + switch { + case nilFound: + d.w.Write(nilAngleBytes) + + case cycleFound: + d.w.Write(circularBytes) + + default: + d.ignoreNextType = true + d.dump(ve) + } + d.w.Write(closeParenBytes) +} + +// dumpSlice handles formatting of arrays and slices. Byte (uint8 under +// reflection) arrays and slices are dumped in hexdump -C fashion. +func (d *dumpState) dumpSlice(v reflect.Value) { + // Determine whether this type should be hex dumped or not. Also, + // for types which should be hexdumped, try to use the underlying data + // first, then fall back to trying to convert them to a uint8 slice. + var buf []uint8 + doConvert := false + doHexDump := false + numEntries := v.Len() + if numEntries > 0 { + vt := v.Index(0).Type() + vts := vt.String() + switch { + // C types that need to be converted. + case cCharRE.MatchString(vts): + fallthrough + case cUnsignedCharRE.MatchString(vts): + fallthrough + case cUint8tCharRE.MatchString(vts): + doConvert = true + + // Try to use existing uint8 slices and fall back to converting + // and copying if that fails. + case vt.Kind() == reflect.Uint8: + // We need an addressable interface to convert the type + // to a byte slice. However, the reflect package won't + // give us an interface on certain things like + // unexported struct fields in order to enforce + // visibility rules. We use unsafe, when available, to + // bypass these restrictions since this package does not + // mutate the values. + vs := v + if !vs.CanInterface() || !vs.CanAddr() { + vs = unsafeReflectValue(vs) + } + if !UnsafeDisabled { + vs = vs.Slice(0, numEntries) + + // Use the existing uint8 slice if it can be + // type asserted. + iface := vs.Interface() + if slice, ok := iface.([]uint8); ok { + buf = slice + doHexDump = true + break + } + } + + // The underlying data needs to be converted if it can't + // be type asserted to a uint8 slice. + doConvert = true + } + + // Copy and convert the underlying type if needed. + if doConvert && vt.ConvertibleTo(uint8Type) { + // Convert and copy each element into a uint8 byte + // slice. + buf = make([]uint8, numEntries) + for i := 0; i < numEntries; i++ { + vv := v.Index(i) + buf[i] = uint8(vv.Convert(uint8Type).Uint()) + } + doHexDump = true + } + } + + // Hexdump the entire slice as needed. + if doHexDump { + indent := strings.Repeat(d.cs.Indent, d.depth) + str := indent + hex.Dump(buf) + str = strings.Replace(str, "\n", "\n"+indent, -1) + str = strings.TrimRight(str, d.cs.Indent) + d.w.Write([]byte(str)) + return + } + + // Recursively call dump for each item. + for i := 0; i < numEntries; i++ { + d.dump(d.unpackValue(v.Index(i))) + if i < (numEntries - 1) { + d.w.Write(commaNewlineBytes) + } else { + d.w.Write(newlineBytes) + } + } +} + +// dump is the main workhorse for dumping a value. It uses the passed reflect +// value to figure out what kind of object we are dealing with and formats it +// appropriately. It is a recursive function, however circular data structures +// are detected and handled properly. +func (d *dumpState) dump(v reflect.Value) { + // Handle invalid reflect values immediately. + kind := v.Kind() + if kind == reflect.Invalid { + d.w.Write(invalidAngleBytes) + return + } + + // Handle pointers specially. + if kind == reflect.Ptr { + d.indent() + d.dumpPtr(v) + return + } + + // Print type information unless already handled elsewhere. + if !d.ignoreNextType { + d.indent() + d.w.Write(openParenBytes) + d.w.Write([]byte(v.Type().String())) + d.w.Write(closeParenBytes) + d.w.Write(spaceBytes) + } + d.ignoreNextType = false + + // Display length and capacity if the built-in len and cap functions + // work with the value's kind and the len/cap itself is non-zero. + valueLen, valueCap := 0, 0 + switch v.Kind() { + case reflect.Array, reflect.Slice, reflect.Chan: + valueLen, valueCap = v.Len(), v.Cap() + case reflect.Map, reflect.String: + valueLen = v.Len() + } + if valueLen != 0 || !d.cs.DisableCapacities && valueCap != 0 { + d.w.Write(openParenBytes) + if valueLen != 0 { + d.w.Write(lenEqualsBytes) + printInt(d.w, int64(valueLen), 10) + } + if !d.cs.DisableCapacities && valueCap != 0 { + if valueLen != 0 { + d.w.Write(spaceBytes) + } + d.w.Write(capEqualsBytes) + printInt(d.w, int64(valueCap), 10) + } + d.w.Write(closeParenBytes) + d.w.Write(spaceBytes) + } + + // Call Stringer/error interfaces if they exist and the handle methods flag + // is enabled + if !d.cs.DisableMethods { + if (kind != reflect.Invalid) && (kind != reflect.Interface) { + if handled := handleMethods(d.cs, d.w, v); handled { + return + } + } + } + + switch kind { + case reflect.Invalid: + // Do nothing. We should never get here since invalid has already + // been handled above. + + case reflect.Bool: + printBool(d.w, v.Bool()) + + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: + printInt(d.w, v.Int(), 10) + + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + printUint(d.w, v.Uint(), 10) + + case reflect.Float32: + printFloat(d.w, v.Float(), 32) + + case reflect.Float64: + printFloat(d.w, v.Float(), 64) + + case reflect.Complex64: + printComplex(d.w, v.Complex(), 32) + + case reflect.Complex128: + printComplex(d.w, v.Complex(), 64) + + case reflect.Slice: + if v.IsNil() { + d.w.Write(nilAngleBytes) + break + } + fallthrough + + case reflect.Array: + d.w.Write(openBraceNewlineBytes) + d.depth++ + if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) { + d.indent() + d.w.Write(maxNewlineBytes) + } else { + d.dumpSlice(v) + } + d.depth-- + d.indent() + d.w.Write(closeBraceBytes) + + case reflect.String: + d.w.Write([]byte(strconv.Quote(v.String()))) + + case reflect.Interface: + // The only time we should get here is for nil interfaces due to + // unpackValue calls. + if v.IsNil() { + d.w.Write(nilAngleBytes) + } + + case reflect.Ptr: + // Do nothing. We should never get here since pointers have already + // been handled above. + + case reflect.Map: + // nil maps should be indicated as different than empty maps + if v.IsNil() { + d.w.Write(nilAngleBytes) + break + } + + d.w.Write(openBraceNewlineBytes) + d.depth++ + if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) { + d.indent() + d.w.Write(maxNewlineBytes) + } else { + numEntries := v.Len() + keys := v.MapKeys() + if d.cs.SortKeys { + sortValues(keys, d.cs) + } + for i, key := range keys { + d.dump(d.unpackValue(key)) + d.w.Write(colonSpaceBytes) + d.ignoreNextIndent = true + d.dump(d.unpackValue(v.MapIndex(key))) + if i < (numEntries - 1) { + d.w.Write(commaNewlineBytes) + } else { + d.w.Write(newlineBytes) + } + } + } + d.depth-- + d.indent() + d.w.Write(closeBraceBytes) + + case reflect.Struct: + d.w.Write(openBraceNewlineBytes) + d.depth++ + if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) { + d.indent() + d.w.Write(maxNewlineBytes) + } else { + vt := v.Type() + numFields := v.NumField() + for i := 0; i < numFields; i++ { + d.indent() + vtf := vt.Field(i) + d.w.Write([]byte(vtf.Name)) + d.w.Write(colonSpaceBytes) + d.ignoreNextIndent = true + d.dump(d.unpackValue(v.Field(i))) + if i < (numFields - 1) { + d.w.Write(commaNewlineBytes) + } else { + d.w.Write(newlineBytes) + } + } + } + d.depth-- + d.indent() + d.w.Write(closeBraceBytes) + + case reflect.Uintptr: + printHexPtr(d.w, uintptr(v.Uint())) + + case reflect.UnsafePointer, reflect.Chan, reflect.Func: + printHexPtr(d.w, v.Pointer()) + + // There were not any other types at the time this code was written, but + // fall back to letting the default fmt package handle it in case any new + // types are added. + default: + if v.CanInterface() { + fmt.Fprintf(d.w, "%v", v.Interface()) + } else { + fmt.Fprintf(d.w, "%v", v.String()) + } + } +} + +// fdump is a helper function to consolidate the logic from the various public +// methods which take varying writers and config states. +func fdump(cs *ConfigState, w io.Writer, a ...interface{}) { + for _, arg := range a { + if arg == nil { + w.Write(interfaceBytes) + w.Write(spaceBytes) + w.Write(nilAngleBytes) + w.Write(newlineBytes) + continue + } + + d := dumpState{w: w, cs: cs} + d.pointers = make(map[uintptr]int) + d.dump(reflect.ValueOf(arg)) + d.w.Write(newlineBytes) + } +} + +// Fdump formats and displays the passed arguments to io.Writer w. It formats +// exactly the same as Dump. +func Fdump(w io.Writer, a ...interface{}) { + fdump(&Config, w, a...) +} + +// Sdump returns a string with the passed arguments formatted exactly the same +// as Dump. +func Sdump(a ...interface{}) string { + var buf bytes.Buffer + fdump(&Config, &buf, a...) + return buf.String() +} + +/* +Dump displays the passed parameters to standard out with newlines, customizable +indentation, and additional debug information such as complete types and all +pointer addresses used to indirect to the final value. It provides the +following features over the built-in printing facilities provided by the fmt +package: + + * Pointers are dereferenced and followed + * Circular data structures are detected and handled properly + * Custom Stringer/error interfaces are optionally invoked, including + on unexported types + * Custom types which only implement the Stringer/error interfaces via + a pointer receiver are optionally invoked when passing non-pointer + variables + * Byte arrays and slices are dumped like the hexdump -C command which + includes offsets, byte values in hex, and ASCII output + +The configuration options are controlled by an exported package global, +spew.Config. See ConfigState for options documentation. + +See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to +get the formatted result as a string. +*/ +func Dump(a ...interface{}) { + fdump(&Config, os.Stdout, a...) +} diff --git a/vendor/github.com/davecgh/go-spew/spew/format.go b/vendor/github.com/davecgh/go-spew/spew/format.go new file mode 100644 index 0000000..b04edb7 --- /dev/null +++ b/vendor/github.com/davecgh/go-spew/spew/format.go @@ -0,0 +1,419 @@ +/* + * Copyright (c) 2013-2016 Dave Collins + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +package spew + +import ( + "bytes" + "fmt" + "reflect" + "strconv" + "strings" +) + +// supportedFlags is a list of all the character flags supported by fmt package. +const supportedFlags = "0-+# " + +// formatState implements the fmt.Formatter interface and contains information +// about the state of a formatting operation. The NewFormatter function can +// be used to get a new Formatter which can be used directly as arguments +// in standard fmt package printing calls. +type formatState struct { + value interface{} + fs fmt.State + depth int + pointers map[uintptr]int + ignoreNextType bool + cs *ConfigState +} + +// buildDefaultFormat recreates the original format string without precision +// and width information to pass in to fmt.Sprintf in the case of an +// unrecognized type. Unless new types are added to the language, this +// function won't ever be called. +func (f *formatState) buildDefaultFormat() (format string) { + buf := bytes.NewBuffer(percentBytes) + + for _, flag := range supportedFlags { + if f.fs.Flag(int(flag)) { + buf.WriteRune(flag) + } + } + + buf.WriteRune('v') + + format = buf.String() + return format +} + +// constructOrigFormat recreates the original format string including precision +// and width information to pass along to the standard fmt package. This allows +// automatic deferral of all format strings this package doesn't support. +func (f *formatState) constructOrigFormat(verb rune) (format string) { + buf := bytes.NewBuffer(percentBytes) + + for _, flag := range supportedFlags { + if f.fs.Flag(int(flag)) { + buf.WriteRune(flag) + } + } + + if width, ok := f.fs.Width(); ok { + buf.WriteString(strconv.Itoa(width)) + } + + if precision, ok := f.fs.Precision(); ok { + buf.Write(precisionBytes) + buf.WriteString(strconv.Itoa(precision)) + } + + buf.WriteRune(verb) + + format = buf.String() + return format +} + +// unpackValue returns values inside of non-nil interfaces when possible and +// ensures that types for values which have been unpacked from an interface +// are displayed when the show types flag is also set. +// This is useful for data types like structs, arrays, slices, and maps which +// can contain varying types packed inside an interface. +func (f *formatState) unpackValue(v reflect.Value) reflect.Value { + if v.Kind() == reflect.Interface { + f.ignoreNextType = false + if !v.IsNil() { + v = v.Elem() + } + } + return v +} + +// formatPtr handles formatting of pointers by indirecting them as necessary. +func (f *formatState) formatPtr(v reflect.Value) { + // Display nil if top level pointer is nil. + showTypes := f.fs.Flag('#') + if v.IsNil() && (!showTypes || f.ignoreNextType) { + f.fs.Write(nilAngleBytes) + return + } + + // Remove pointers at or below the current depth from map used to detect + // circular refs. + for k, depth := range f.pointers { + if depth >= f.depth { + delete(f.pointers, k) + } + } + + // Keep list of all dereferenced pointers to possibly show later. + pointerChain := make([]uintptr, 0) + + // Figure out how many levels of indirection there are by derferencing + // pointers and unpacking interfaces down the chain while detecting circular + // references. + nilFound := false + cycleFound := false + indirects := 0 + ve := v + for ve.Kind() == reflect.Ptr { + if ve.IsNil() { + nilFound = true + break + } + indirects++ + addr := ve.Pointer() + pointerChain = append(pointerChain, addr) + if pd, ok := f.pointers[addr]; ok && pd < f.depth { + cycleFound = true + indirects-- + break + } + f.pointers[addr] = f.depth + + ve = ve.Elem() + if ve.Kind() == reflect.Interface { + if ve.IsNil() { + nilFound = true + break + } + ve = ve.Elem() + } + } + + // Display type or indirection level depending on flags. + if showTypes && !f.ignoreNextType { + f.fs.Write(openParenBytes) + f.fs.Write(bytes.Repeat(asteriskBytes, indirects)) + f.fs.Write([]byte(ve.Type().String())) + f.fs.Write(closeParenBytes) + } else { + if nilFound || cycleFound { + indirects += strings.Count(ve.Type().String(), "*") + } + f.fs.Write(openAngleBytes) + f.fs.Write([]byte(strings.Repeat("*", indirects))) + f.fs.Write(closeAngleBytes) + } + + // Display pointer information depending on flags. + if f.fs.Flag('+') && (len(pointerChain) > 0) { + f.fs.Write(openParenBytes) + for i, addr := range pointerChain { + if i > 0 { + f.fs.Write(pointerChainBytes) + } + printHexPtr(f.fs, addr) + } + f.fs.Write(closeParenBytes) + } + + // Display dereferenced value. + switch { + case nilFound: + f.fs.Write(nilAngleBytes) + + case cycleFound: + f.fs.Write(circularShortBytes) + + default: + f.ignoreNextType = true + f.format(ve) + } +} + +// format is the main workhorse for providing the Formatter interface. It +// uses the passed reflect value to figure out what kind of object we are +// dealing with and formats it appropriately. It is a recursive function, +// however circular data structures are detected and handled properly. +func (f *formatState) format(v reflect.Value) { + // Handle invalid reflect values immediately. + kind := v.Kind() + if kind == reflect.Invalid { + f.fs.Write(invalidAngleBytes) + return + } + + // Handle pointers specially. + if kind == reflect.Ptr { + f.formatPtr(v) + return + } + + // Print type information unless already handled elsewhere. + if !f.ignoreNextType && f.fs.Flag('#') { + f.fs.Write(openParenBytes) + f.fs.Write([]byte(v.Type().String())) + f.fs.Write(closeParenBytes) + } + f.ignoreNextType = false + + // Call Stringer/error interfaces if they exist and the handle methods + // flag is enabled. + if !f.cs.DisableMethods { + if (kind != reflect.Invalid) && (kind != reflect.Interface) { + if handled := handleMethods(f.cs, f.fs, v); handled { + return + } + } + } + + switch kind { + case reflect.Invalid: + // Do nothing. We should never get here since invalid has already + // been handled above. + + case reflect.Bool: + printBool(f.fs, v.Bool()) + + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: + printInt(f.fs, v.Int(), 10) + + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + printUint(f.fs, v.Uint(), 10) + + case reflect.Float32: + printFloat(f.fs, v.Float(), 32) + + case reflect.Float64: + printFloat(f.fs, v.Float(), 64) + + case reflect.Complex64: + printComplex(f.fs, v.Complex(), 32) + + case reflect.Complex128: + printComplex(f.fs, v.Complex(), 64) + + case reflect.Slice: + if v.IsNil() { + f.fs.Write(nilAngleBytes) + break + } + fallthrough + + case reflect.Array: + f.fs.Write(openBracketBytes) + f.depth++ + if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) { + f.fs.Write(maxShortBytes) + } else { + numEntries := v.Len() + for i := 0; i < numEntries; i++ { + if i > 0 { + f.fs.Write(spaceBytes) + } + f.ignoreNextType = true + f.format(f.unpackValue(v.Index(i))) + } + } + f.depth-- + f.fs.Write(closeBracketBytes) + + case reflect.String: + f.fs.Write([]byte(v.String())) + + case reflect.Interface: + // The only time we should get here is for nil interfaces due to + // unpackValue calls. + if v.IsNil() { + f.fs.Write(nilAngleBytes) + } + + case reflect.Ptr: + // Do nothing. We should never get here since pointers have already + // been handled above. + + case reflect.Map: + // nil maps should be indicated as different than empty maps + if v.IsNil() { + f.fs.Write(nilAngleBytes) + break + } + + f.fs.Write(openMapBytes) + f.depth++ + if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) { + f.fs.Write(maxShortBytes) + } else { + keys := v.MapKeys() + if f.cs.SortKeys { + sortValues(keys, f.cs) + } + for i, key := range keys { + if i > 0 { + f.fs.Write(spaceBytes) + } + f.ignoreNextType = true + f.format(f.unpackValue(key)) + f.fs.Write(colonBytes) + f.ignoreNextType = true + f.format(f.unpackValue(v.MapIndex(key))) + } + } + f.depth-- + f.fs.Write(closeMapBytes) + + case reflect.Struct: + numFields := v.NumField() + f.fs.Write(openBraceBytes) + f.depth++ + if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) { + f.fs.Write(maxShortBytes) + } else { + vt := v.Type() + for i := 0; i < numFields; i++ { + if i > 0 { + f.fs.Write(spaceBytes) + } + vtf := vt.Field(i) + if f.fs.Flag('+') || f.fs.Flag('#') { + f.fs.Write([]byte(vtf.Name)) + f.fs.Write(colonBytes) + } + f.format(f.unpackValue(v.Field(i))) + } + } + f.depth-- + f.fs.Write(closeBraceBytes) + + case reflect.Uintptr: + printHexPtr(f.fs, uintptr(v.Uint())) + + case reflect.UnsafePointer, reflect.Chan, reflect.Func: + printHexPtr(f.fs, v.Pointer()) + + // There were not any other types at the time this code was written, but + // fall back to letting the default fmt package handle it if any get added. + default: + format := f.buildDefaultFormat() + if v.CanInterface() { + fmt.Fprintf(f.fs, format, v.Interface()) + } else { + fmt.Fprintf(f.fs, format, v.String()) + } + } +} + +// Format satisfies the fmt.Formatter interface. See NewFormatter for usage +// details. +func (f *formatState) Format(fs fmt.State, verb rune) { + f.fs = fs + + // Use standard formatting for verbs that are not v. + if verb != 'v' { + format := f.constructOrigFormat(verb) + fmt.Fprintf(fs, format, f.value) + return + } + + if f.value == nil { + if fs.Flag('#') { + fs.Write(interfaceBytes) + } + fs.Write(nilAngleBytes) + return + } + + f.format(reflect.ValueOf(f.value)) +} + +// newFormatter is a helper function to consolidate the logic from the various +// public methods which take varying config states. +func newFormatter(cs *ConfigState, v interface{}) fmt.Formatter { + fs := &formatState{value: v, cs: cs} + fs.pointers = make(map[uintptr]int) + return fs +} + +/* +NewFormatter returns a custom formatter that satisfies the fmt.Formatter +interface. As a result, it integrates cleanly with standard fmt package +printing functions. The formatter is useful for inline printing of smaller data +types similar to the standard %v format specifier. + +The custom formatter only responds to the %v (most compact), %+v (adds pointer +addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb +combinations. Any other verbs such as %x and %q will be sent to the the +standard fmt package for formatting. In addition, the custom formatter ignores +the width and precision arguments (however they will still work on the format +specifiers not handled by the custom formatter). + +Typically this function shouldn't be called directly. It is much easier to make +use of the custom formatter by calling one of the convenience functions such as +Printf, Println, or Fprintf. +*/ +func NewFormatter(v interface{}) fmt.Formatter { + return newFormatter(&Config, v) +} diff --git a/vendor/github.com/davecgh/go-spew/spew/spew.go b/vendor/github.com/davecgh/go-spew/spew/spew.go new file mode 100644 index 0000000..32c0e33 --- /dev/null +++ b/vendor/github.com/davecgh/go-spew/spew/spew.go @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2013-2016 Dave Collins + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +package spew + +import ( + "fmt" + "io" +) + +// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were +// passed with a default Formatter interface returned by NewFormatter. It +// returns the formatted string as a value that satisfies error. See +// NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Errorf(format, spew.NewFormatter(a), spew.NewFormatter(b)) +func Errorf(format string, a ...interface{}) (err error) { + return fmt.Errorf(format, convertArgs(a)...) +} + +// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were +// passed with a default Formatter interface returned by NewFormatter. It +// returns the number of bytes written and any write error encountered. See +// NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Fprint(w, spew.NewFormatter(a), spew.NewFormatter(b)) +func Fprint(w io.Writer, a ...interface{}) (n int, err error) { + return fmt.Fprint(w, convertArgs(a)...) +} + +// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were +// passed with a default Formatter interface returned by NewFormatter. It +// returns the number of bytes written and any write error encountered. See +// NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Fprintf(w, format, spew.NewFormatter(a), spew.NewFormatter(b)) +func Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) { + return fmt.Fprintf(w, format, convertArgs(a)...) +} + +// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it +// passed with a default Formatter interface returned by NewFormatter. See +// NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Fprintln(w, spew.NewFormatter(a), spew.NewFormatter(b)) +func Fprintln(w io.Writer, a ...interface{}) (n int, err error) { + return fmt.Fprintln(w, convertArgs(a)...) +} + +// Print is a wrapper for fmt.Print that treats each argument as if it were +// passed with a default Formatter interface returned by NewFormatter. It +// returns the number of bytes written and any write error encountered. See +// NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Print(spew.NewFormatter(a), spew.NewFormatter(b)) +func Print(a ...interface{}) (n int, err error) { + return fmt.Print(convertArgs(a)...) +} + +// Printf is a wrapper for fmt.Printf that treats each argument as if it were +// passed with a default Formatter interface returned by NewFormatter. It +// returns the number of bytes written and any write error encountered. See +// NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Printf(format, spew.NewFormatter(a), spew.NewFormatter(b)) +func Printf(format string, a ...interface{}) (n int, err error) { + return fmt.Printf(format, convertArgs(a)...) +} + +// Println is a wrapper for fmt.Println that treats each argument as if it were +// passed with a default Formatter interface returned by NewFormatter. It +// returns the number of bytes written and any write error encountered. See +// NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Println(spew.NewFormatter(a), spew.NewFormatter(b)) +func Println(a ...interface{}) (n int, err error) { + return fmt.Println(convertArgs(a)...) +} + +// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were +// passed with a default Formatter interface returned by NewFormatter. It +// returns the resulting string. See NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Sprint(spew.NewFormatter(a), spew.NewFormatter(b)) +func Sprint(a ...interface{}) string { + return fmt.Sprint(convertArgs(a)...) +} + +// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were +// passed with a default Formatter interface returned by NewFormatter. It +// returns the resulting string. See NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Sprintf(format, spew.NewFormatter(a), spew.NewFormatter(b)) +func Sprintf(format string, a ...interface{}) string { + return fmt.Sprintf(format, convertArgs(a)...) +} + +// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it +// were passed with a default Formatter interface returned by NewFormatter. It +// returns the resulting string. See NewFormatter for formatting details. +// +// This function is shorthand for the following syntax: +// +// fmt.Sprintln(spew.NewFormatter(a), spew.NewFormatter(b)) +func Sprintln(a ...interface{}) string { + return fmt.Sprintln(convertArgs(a)...) +} + +// convertArgs accepts a slice of arguments and returns a slice of the same +// length with each argument converted to a default spew Formatter interface. +func convertArgs(args []interface{}) (formatters []interface{}) { + formatters = make([]interface{}, len(args)) + for index, arg := range args { + formatters[index] = NewFormatter(arg) + } + return formatters +} diff --git a/vendor/github.com/pmezard/go-difflib/LICENSE b/vendor/github.com/pmezard/go-difflib/LICENSE new file mode 100644 index 0000000..c67dad6 --- /dev/null +++ b/vendor/github.com/pmezard/go-difflib/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2013, Patrick Mezard +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + The names of its contributors may not be used to endorse or promote +products derived from this software without specific prior written +permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/pmezard/go-difflib/difflib/difflib.go b/vendor/github.com/pmezard/go-difflib/difflib/difflib.go new file mode 100644 index 0000000..003e99f --- /dev/null +++ b/vendor/github.com/pmezard/go-difflib/difflib/difflib.go @@ -0,0 +1,772 @@ +// Package difflib is a partial port of Python difflib module. +// +// It provides tools to compare sequences of strings and generate textual diffs. +// +// The following class and functions have been ported: +// +// - SequenceMatcher +// +// - unified_diff +// +// - context_diff +// +// Getting unified diffs was the main goal of the port. Keep in mind this code +// is mostly suitable to output text differences in a human friendly way, there +// are no guarantees generated diffs are consumable by patch(1). +package difflib + +import ( + "bufio" + "bytes" + "fmt" + "io" + "strings" +) + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +func calculateRatio(matches, length int) float64 { + if length > 0 { + return 2.0 * float64(matches) / float64(length) + } + return 1.0 +} + +type Match struct { + A int + B int + Size int +} + +type OpCode struct { + Tag byte + I1 int + I2 int + J1 int + J2 int +} + +// SequenceMatcher compares sequence of strings. The basic +// algorithm predates, and is a little fancier than, an algorithm +// published in the late 1980's by Ratcliff and Obershelp under the +// hyperbolic name "gestalt pattern matching". The basic idea is to find +// the longest contiguous matching subsequence that contains no "junk" +// elements (R-O doesn't address junk). The same idea is then applied +// recursively to the pieces of the sequences to the left and to the right +// of the matching subsequence. This does not yield minimal edit +// sequences, but does tend to yield matches that "look right" to people. +// +// SequenceMatcher tries to compute a "human-friendly diff" between two +// sequences. Unlike e.g. UNIX(tm) diff, the fundamental notion is the +// longest *contiguous* & junk-free matching subsequence. That's what +// catches peoples' eyes. The Windows(tm) windiff has another interesting +// notion, pairing up elements that appear uniquely in each sequence. +// That, and the method here, appear to yield more intuitive difference +// reports than does diff. This method appears to be the least vulnerable +// to synching up on blocks of "junk lines", though (like blank lines in +// ordinary text files, or maybe "

" lines in HTML files). That may be +// because this is the only method of the 3 that has a *concept* of +// "junk" . +// +// Timing: Basic R-O is cubic time worst case and quadratic time expected +// case. SequenceMatcher is quadratic time for the worst case and has +// expected-case behavior dependent in a complicated way on how many +// elements the sequences have in common; best case time is linear. +type SequenceMatcher struct { + a []string + b []string + b2j map[string][]int + IsJunk func(string) bool + autoJunk bool + bJunk map[string]struct{} + matchingBlocks []Match + fullBCount map[string]int + bPopular map[string]struct{} + opCodes []OpCode +} + +func NewMatcher(a, b []string) *SequenceMatcher { + m := SequenceMatcher{autoJunk: true} + m.SetSeqs(a, b) + return &m +} + +func NewMatcherWithJunk(a, b []string, autoJunk bool, + isJunk func(string) bool) *SequenceMatcher { + + m := SequenceMatcher{IsJunk: isJunk, autoJunk: autoJunk} + m.SetSeqs(a, b) + return &m +} + +// Set two sequences to be compared. +func (m *SequenceMatcher) SetSeqs(a, b []string) { + m.SetSeq1(a) + m.SetSeq2(b) +} + +// Set the first sequence to be compared. The second sequence to be compared is +// not changed. +// +// SequenceMatcher computes and caches detailed information about the second +// sequence, so if you want to compare one sequence S against many sequences, +// use .SetSeq2(s) once and call .SetSeq1(x) repeatedly for each of the other +// sequences. +// +// See also SetSeqs() and SetSeq2(). +func (m *SequenceMatcher) SetSeq1(a []string) { + if &a == &m.a { + return + } + m.a = a + m.matchingBlocks = nil + m.opCodes = nil +} + +// Set the second sequence to be compared. The first sequence to be compared is +// not changed. +func (m *SequenceMatcher) SetSeq2(b []string) { + if &b == &m.b { + return + } + m.b = b + m.matchingBlocks = nil + m.opCodes = nil + m.fullBCount = nil + m.chainB() +} + +func (m *SequenceMatcher) chainB() { + // Populate line -> index mapping + b2j := map[string][]int{} + for i, s := range m.b { + indices := b2j[s] + indices = append(indices, i) + b2j[s] = indices + } + + // Purge junk elements + m.bJunk = map[string]struct{}{} + if m.IsJunk != nil { + junk := m.bJunk + for s, _ := range b2j { + if m.IsJunk(s) { + junk[s] = struct{}{} + } + } + for s, _ := range junk { + delete(b2j, s) + } + } + + // Purge remaining popular elements + popular := map[string]struct{}{} + n := len(m.b) + if m.autoJunk && n >= 200 { + ntest := n/100 + 1 + for s, indices := range b2j { + if len(indices) > ntest { + popular[s] = struct{}{} + } + } + for s, _ := range popular { + delete(b2j, s) + } + } + m.bPopular = popular + m.b2j = b2j +} + +func (m *SequenceMatcher) isBJunk(s string) bool { + _, ok := m.bJunk[s] + return ok +} + +// Find longest matching block in a[alo:ahi] and b[blo:bhi]. +// +// If IsJunk is not defined: +// +// Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where +// alo <= i <= i+k <= ahi +// blo <= j <= j+k <= bhi +// and for all (i',j',k') meeting those conditions, +// k >= k' +// i <= i' +// and if i == i', j <= j' +// +// In other words, of all maximal matching blocks, return one that +// starts earliest in a, and of all those maximal matching blocks that +// start earliest in a, return the one that starts earliest in b. +// +// If IsJunk is defined, first the longest matching block is +// determined as above, but with the additional restriction that no +// junk element appears in the block. Then that block is extended as +// far as possible by matching (only) junk elements on both sides. So +// the resulting block never matches on junk except as identical junk +// happens to be adjacent to an "interesting" match. +// +// If no blocks match, return (alo, blo, 0). +func (m *SequenceMatcher) findLongestMatch(alo, ahi, blo, bhi int) Match { + // CAUTION: stripping common prefix or suffix would be incorrect. + // E.g., + // ab + // acab + // Longest matching block is "ab", but if common prefix is + // stripped, it's "a" (tied with "b"). UNIX(tm) diff does so + // strip, so ends up claiming that ab is changed to acab by + // inserting "ca" in the middle. That's minimal but unintuitive: + // "it's obvious" that someone inserted "ac" at the front. + // Windiff ends up at the same place as diff, but by pairing up + // the unique 'b's and then matching the first two 'a's. + besti, bestj, bestsize := alo, blo, 0 + + // find longest junk-free match + // during an iteration of the loop, j2len[j] = length of longest + // junk-free match ending with a[i-1] and b[j] + j2len := map[int]int{} + for i := alo; i != ahi; i++ { + // look at all instances of a[i] in b; note that because + // b2j has no junk keys, the loop is skipped if a[i] is junk + newj2len := map[int]int{} + for _, j := range m.b2j[m.a[i]] { + // a[i] matches b[j] + if j < blo { + continue + } + if j >= bhi { + break + } + k := j2len[j-1] + 1 + newj2len[j] = k + if k > bestsize { + besti, bestj, bestsize = i-k+1, j-k+1, k + } + } + j2len = newj2len + } + + // Extend the best by non-junk elements on each end. In particular, + // "popular" non-junk elements aren't in b2j, which greatly speeds + // the inner loop above, but also means "the best" match so far + // doesn't contain any junk *or* popular non-junk elements. + for besti > alo && bestj > blo && !m.isBJunk(m.b[bestj-1]) && + m.a[besti-1] == m.b[bestj-1] { + besti, bestj, bestsize = besti-1, bestj-1, bestsize+1 + } + for besti+bestsize < ahi && bestj+bestsize < bhi && + !m.isBJunk(m.b[bestj+bestsize]) && + m.a[besti+bestsize] == m.b[bestj+bestsize] { + bestsize += 1 + } + + // Now that we have a wholly interesting match (albeit possibly + // empty!), we may as well suck up the matching junk on each + // side of it too. Can't think of a good reason not to, and it + // saves post-processing the (possibly considerable) expense of + // figuring out what to do with it. In the case of an empty + // interesting match, this is clearly the right thing to do, + // because no other kind of match is possible in the regions. + for besti > alo && bestj > blo && m.isBJunk(m.b[bestj-1]) && + m.a[besti-1] == m.b[bestj-1] { + besti, bestj, bestsize = besti-1, bestj-1, bestsize+1 + } + for besti+bestsize < ahi && bestj+bestsize < bhi && + m.isBJunk(m.b[bestj+bestsize]) && + m.a[besti+bestsize] == m.b[bestj+bestsize] { + bestsize += 1 + } + + return Match{A: besti, B: bestj, Size: bestsize} +} + +// Return list of triples describing matching subsequences. +// +// Each triple is of the form (i, j, n), and means that +// a[i:i+n] == b[j:j+n]. The triples are monotonically increasing in +// i and in j. It's also guaranteed that if (i, j, n) and (i', j', n') are +// adjacent triples in the list, and the second is not the last triple in the +// list, then i+n != i' or j+n != j'. IOW, adjacent triples never describe +// adjacent equal blocks. +// +// The last triple is a dummy, (len(a), len(b), 0), and is the only +// triple with n==0. +func (m *SequenceMatcher) GetMatchingBlocks() []Match { + if m.matchingBlocks != nil { + return m.matchingBlocks + } + + var matchBlocks func(alo, ahi, blo, bhi int, matched []Match) []Match + matchBlocks = func(alo, ahi, blo, bhi int, matched []Match) []Match { + match := m.findLongestMatch(alo, ahi, blo, bhi) + i, j, k := match.A, match.B, match.Size + if match.Size > 0 { + if alo < i && blo < j { + matched = matchBlocks(alo, i, blo, j, matched) + } + matched = append(matched, match) + if i+k < ahi && j+k < bhi { + matched = matchBlocks(i+k, ahi, j+k, bhi, matched) + } + } + return matched + } + matched := matchBlocks(0, len(m.a), 0, len(m.b), nil) + + // It's possible that we have adjacent equal blocks in the + // matching_blocks list now. + nonAdjacent := []Match{} + i1, j1, k1 := 0, 0, 0 + for _, b := range matched { + // Is this block adjacent to i1, j1, k1? + i2, j2, k2 := b.A, b.B, b.Size + if i1+k1 == i2 && j1+k1 == j2 { + // Yes, so collapse them -- this just increases the length of + // the first block by the length of the second, and the first + // block so lengthened remains the block to compare against. + k1 += k2 + } else { + // Not adjacent. Remember the first block (k1==0 means it's + // the dummy we started with), and make the second block the + // new block to compare against. + if k1 > 0 { + nonAdjacent = append(nonAdjacent, Match{i1, j1, k1}) + } + i1, j1, k1 = i2, j2, k2 + } + } + if k1 > 0 { + nonAdjacent = append(nonAdjacent, Match{i1, j1, k1}) + } + + nonAdjacent = append(nonAdjacent, Match{len(m.a), len(m.b), 0}) + m.matchingBlocks = nonAdjacent + return m.matchingBlocks +} + +// Return list of 5-tuples describing how to turn a into b. +// +// Each tuple is of the form (tag, i1, i2, j1, j2). The first tuple +// has i1 == j1 == 0, and remaining tuples have i1 == the i2 from the +// tuple preceding it, and likewise for j1 == the previous j2. +// +// The tags are characters, with these meanings: +// +// 'r' (replace): a[i1:i2] should be replaced by b[j1:j2] +// +// 'd' (delete): a[i1:i2] should be deleted, j1==j2 in this case. +// +// 'i' (insert): b[j1:j2] should be inserted at a[i1:i1], i1==i2 in this case. +// +// 'e' (equal): a[i1:i2] == b[j1:j2] +func (m *SequenceMatcher) GetOpCodes() []OpCode { + if m.opCodes != nil { + return m.opCodes + } + i, j := 0, 0 + matching := m.GetMatchingBlocks() + opCodes := make([]OpCode, 0, len(matching)) + for _, m := range matching { + // invariant: we've pumped out correct diffs to change + // a[:i] into b[:j], and the next matching block is + // a[ai:ai+size] == b[bj:bj+size]. So we need to pump + // out a diff to change a[i:ai] into b[j:bj], pump out + // the matching block, and move (i,j) beyond the match + ai, bj, size := m.A, m.B, m.Size + tag := byte(0) + if i < ai && j < bj { + tag = 'r' + } else if i < ai { + tag = 'd' + } else if j < bj { + tag = 'i' + } + if tag > 0 { + opCodes = append(opCodes, OpCode{tag, i, ai, j, bj}) + } + i, j = ai+size, bj+size + // the list of matching blocks is terminated by a + // sentinel with size 0 + if size > 0 { + opCodes = append(opCodes, OpCode{'e', ai, i, bj, j}) + } + } + m.opCodes = opCodes + return m.opCodes +} + +// Isolate change clusters by eliminating ranges with no changes. +// +// Return a generator of groups with up to n lines of context. +// Each group is in the same format as returned by GetOpCodes(). +func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode { + if n < 0 { + n = 3 + } + codes := m.GetOpCodes() + if len(codes) == 0 { + codes = []OpCode{OpCode{'e', 0, 1, 0, 1}} + } + // Fixup leading and trailing groups if they show no changes. + if codes[0].Tag == 'e' { + c := codes[0] + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + codes[0] = OpCode{c.Tag, max(i1, i2-n), i2, max(j1, j2-n), j2} + } + if codes[len(codes)-1].Tag == 'e' { + c := codes[len(codes)-1] + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + codes[len(codes)-1] = OpCode{c.Tag, i1, min(i2, i1+n), j1, min(j2, j1+n)} + } + nn := n + n + groups := [][]OpCode{} + group := []OpCode{} + for _, c := range codes { + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + // End the current group and start a new one whenever + // there is a large range with no changes. + if c.Tag == 'e' && i2-i1 > nn { + group = append(group, OpCode{c.Tag, i1, min(i2, i1+n), + j1, min(j2, j1+n)}) + groups = append(groups, group) + group = []OpCode{} + i1, j1 = max(i1, i2-n), max(j1, j2-n) + } + group = append(group, OpCode{c.Tag, i1, i2, j1, j2}) + } + if len(group) > 0 && !(len(group) == 1 && group[0].Tag == 'e') { + groups = append(groups, group) + } + return groups +} + +// Return a measure of the sequences' similarity (float in [0,1]). +// +// Where T is the total number of elements in both sequences, and +// M is the number of matches, this is 2.0*M / T. +// Note that this is 1 if the sequences are identical, and 0 if +// they have nothing in common. +// +// .Ratio() is expensive to compute if you haven't already computed +// .GetMatchingBlocks() or .GetOpCodes(), in which case you may +// want to try .QuickRatio() or .RealQuickRation() first to get an +// upper bound. +func (m *SequenceMatcher) Ratio() float64 { + matches := 0 + for _, m := range m.GetMatchingBlocks() { + matches += m.Size + } + return calculateRatio(matches, len(m.a)+len(m.b)) +} + +// Return an upper bound on ratio() relatively quickly. +// +// This isn't defined beyond that it is an upper bound on .Ratio(), and +// is faster to compute. +func (m *SequenceMatcher) QuickRatio() float64 { + // viewing a and b as multisets, set matches to the cardinality + // of their intersection; this counts the number of matches + // without regard to order, so is clearly an upper bound + if m.fullBCount == nil { + m.fullBCount = map[string]int{} + for _, s := range m.b { + m.fullBCount[s] = m.fullBCount[s] + 1 + } + } + + // avail[x] is the number of times x appears in 'b' less the + // number of times we've seen it in 'a' so far ... kinda + avail := map[string]int{} + matches := 0 + for _, s := range m.a { + n, ok := avail[s] + if !ok { + n = m.fullBCount[s] + } + avail[s] = n - 1 + if n > 0 { + matches += 1 + } + } + return calculateRatio(matches, len(m.a)+len(m.b)) +} + +// Return an upper bound on ratio() very quickly. +// +// This isn't defined beyond that it is an upper bound on .Ratio(), and +// is faster to compute than either .Ratio() or .QuickRatio(). +func (m *SequenceMatcher) RealQuickRatio() float64 { + la, lb := len(m.a), len(m.b) + return calculateRatio(min(la, lb), la+lb) +} + +// Convert range to the "ed" format +func formatRangeUnified(start, stop int) string { + // Per the diff spec at http://www.unix.org/single_unix_specification/ + beginning := start + 1 // lines start numbering with one + length := stop - start + if length == 1 { + return fmt.Sprintf("%d", beginning) + } + if length == 0 { + beginning -= 1 // empty ranges begin at line just before the range + } + return fmt.Sprintf("%d,%d", beginning, length) +} + +// Unified diff parameters +type UnifiedDiff struct { + A []string // First sequence lines + FromFile string // First file name + FromDate string // First file time + B []string // Second sequence lines + ToFile string // Second file name + ToDate string // Second file time + Eol string // Headers end of line, defaults to LF + Context int // Number of context lines +} + +// Compare two sequences of lines; generate the delta as a unified diff. +// +// Unified diffs are a compact way of showing line changes and a few +// lines of context. The number of context lines is set by 'n' which +// defaults to three. +// +// By default, the diff control lines (those with ---, +++, or @@) are +// created with a trailing newline. This is helpful so that inputs +// created from file.readlines() result in diffs that are suitable for +// file.writelines() since both the inputs and outputs have trailing +// newlines. +// +// For inputs that do not have trailing newlines, set the lineterm +// argument to "" so that the output will be uniformly newline free. +// +// The unidiff format normally has a header for filenames and modification +// times. Any or all of these may be specified using strings for +// 'fromfile', 'tofile', 'fromfiledate', and 'tofiledate'. +// The modification times are normally expressed in the ISO 8601 format. +func WriteUnifiedDiff(writer io.Writer, diff UnifiedDiff) error { + buf := bufio.NewWriter(writer) + defer buf.Flush() + wf := func(format string, args ...interface{}) error { + _, err := buf.WriteString(fmt.Sprintf(format, args...)) + return err + } + ws := func(s string) error { + _, err := buf.WriteString(s) + return err + } + + if len(diff.Eol) == 0 { + diff.Eol = "\n" + } + + started := false + m := NewMatcher(diff.A, diff.B) + for _, g := range m.GetGroupedOpCodes(diff.Context) { + if !started { + started = true + fromDate := "" + if len(diff.FromDate) > 0 { + fromDate = "\t" + diff.FromDate + } + toDate := "" + if len(diff.ToDate) > 0 { + toDate = "\t" + diff.ToDate + } + if diff.FromFile != "" || diff.ToFile != "" { + err := wf("--- %s%s%s", diff.FromFile, fromDate, diff.Eol) + if err != nil { + return err + } + err = wf("+++ %s%s%s", diff.ToFile, toDate, diff.Eol) + if err != nil { + return err + } + } + } + first, last := g[0], g[len(g)-1] + range1 := formatRangeUnified(first.I1, last.I2) + range2 := formatRangeUnified(first.J1, last.J2) + if err := wf("@@ -%s +%s @@%s", range1, range2, diff.Eol); err != nil { + return err + } + for _, c := range g { + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + if c.Tag == 'e' { + for _, line := range diff.A[i1:i2] { + if err := ws(" " + line); err != nil { + return err + } + } + continue + } + if c.Tag == 'r' || c.Tag == 'd' { + for _, line := range diff.A[i1:i2] { + if err := ws("-" + line); err != nil { + return err + } + } + } + if c.Tag == 'r' || c.Tag == 'i' { + for _, line := range diff.B[j1:j2] { + if err := ws("+" + line); err != nil { + return err + } + } + } + } + } + return nil +} + +// Like WriteUnifiedDiff but returns the diff a string. +func GetUnifiedDiffString(diff UnifiedDiff) (string, error) { + w := &bytes.Buffer{} + err := WriteUnifiedDiff(w, diff) + return string(w.Bytes()), err +} + +// Convert range to the "ed" format. +func formatRangeContext(start, stop int) string { + // Per the diff spec at http://www.unix.org/single_unix_specification/ + beginning := start + 1 // lines start numbering with one + length := stop - start + if length == 0 { + beginning -= 1 // empty ranges begin at line just before the range + } + if length <= 1 { + return fmt.Sprintf("%d", beginning) + } + return fmt.Sprintf("%d,%d", beginning, beginning+length-1) +} + +type ContextDiff UnifiedDiff + +// Compare two sequences of lines; generate the delta as a context diff. +// +// Context diffs are a compact way of showing line changes and a few +// lines of context. The number of context lines is set by diff.Context +// which defaults to three. +// +// By default, the diff control lines (those with *** or ---) are +// created with a trailing newline. +// +// For inputs that do not have trailing newlines, set the diff.Eol +// argument to "" so that the output will be uniformly newline free. +// +// The context diff format normally has a header for filenames and +// modification times. Any or all of these may be specified using +// strings for diff.FromFile, diff.ToFile, diff.FromDate, diff.ToDate. +// The modification times are normally expressed in the ISO 8601 format. +// If not specified, the strings default to blanks. +func WriteContextDiff(writer io.Writer, diff ContextDiff) error { + buf := bufio.NewWriter(writer) + defer buf.Flush() + var diffErr error + wf := func(format string, args ...interface{}) { + _, err := buf.WriteString(fmt.Sprintf(format, args...)) + if diffErr == nil && err != nil { + diffErr = err + } + } + ws := func(s string) { + _, err := buf.WriteString(s) + if diffErr == nil && err != nil { + diffErr = err + } + } + + if len(diff.Eol) == 0 { + diff.Eol = "\n" + } + + prefix := map[byte]string{ + 'i': "+ ", + 'd': "- ", + 'r': "! ", + 'e': " ", + } + + started := false + m := NewMatcher(diff.A, diff.B) + for _, g := range m.GetGroupedOpCodes(diff.Context) { + if !started { + started = true + fromDate := "" + if len(diff.FromDate) > 0 { + fromDate = "\t" + diff.FromDate + } + toDate := "" + if len(diff.ToDate) > 0 { + toDate = "\t" + diff.ToDate + } + if diff.FromFile != "" || diff.ToFile != "" { + wf("*** %s%s%s", diff.FromFile, fromDate, diff.Eol) + wf("--- %s%s%s", diff.ToFile, toDate, diff.Eol) + } + } + + first, last := g[0], g[len(g)-1] + ws("***************" + diff.Eol) + + range1 := formatRangeContext(first.I1, last.I2) + wf("*** %s ****%s", range1, diff.Eol) + for _, c := range g { + if c.Tag == 'r' || c.Tag == 'd' { + for _, cc := range g { + if cc.Tag == 'i' { + continue + } + for _, line := range diff.A[cc.I1:cc.I2] { + ws(prefix[cc.Tag] + line) + } + } + break + } + } + + range2 := formatRangeContext(first.J1, last.J2) + wf("--- %s ----%s", range2, diff.Eol) + for _, c := range g { + if c.Tag == 'r' || c.Tag == 'i' { + for _, cc := range g { + if cc.Tag == 'd' { + continue + } + for _, line := range diff.B[cc.J1:cc.J2] { + ws(prefix[cc.Tag] + line) + } + } + break + } + } + } + return diffErr +} + +// Like WriteContextDiff but returns the diff a string. +func GetContextDiffString(diff ContextDiff) (string, error) { + w := &bytes.Buffer{} + err := WriteContextDiff(w, diff) + return string(w.Bytes()), err +} + +// Split a string on "\n" while preserving them. The output can be used +// as input for UnifiedDiff and ContextDiff structures. +func SplitLines(s string) []string { + lines := strings.SplitAfter(s, "\n") + lines[len(lines)-1] += "\n" + return lines +} diff --git a/vendor/github.com/stretchr/objx/LICENSE b/vendor/github.com/stretchr/objx/LICENSE new file mode 100644 index 0000000..44d4d9d --- /dev/null +++ b/vendor/github.com/stretchr/objx/LICENSE @@ -0,0 +1,22 @@ +The MIT License + +Copyright (c) 2014 Stretchr, Inc. +Copyright (c) 2017-2018 objx contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/stretchr/objx/README.md b/vendor/github.com/stretchr/objx/README.md new file mode 100644 index 0000000..78dc1f8 --- /dev/null +++ b/vendor/github.com/stretchr/objx/README.md @@ -0,0 +1,80 @@ +# Objx +[![Build Status](https://travis-ci.org/stretchr/objx.svg?branch=master)](https://travis-ci.org/stretchr/objx) +[![Go Report Card](https://goreportcard.com/badge/github.com/stretchr/objx)](https://goreportcard.com/report/github.com/stretchr/objx) +[![Maintainability](https://api.codeclimate.com/v1/badges/1d64bc6c8474c2074f2b/maintainability)](https://codeclimate.com/github/stretchr/objx/maintainability) +[![Test Coverage](https://api.codeclimate.com/v1/badges/1d64bc6c8474c2074f2b/test_coverage)](https://codeclimate.com/github/stretchr/objx/test_coverage) +[![Sourcegraph](https://sourcegraph.com/github.com/stretchr/objx/-/badge.svg)](https://sourcegraph.com/github.com/stretchr/objx) +[![GoDoc](https://pkg.go.dev/badge/github.com/stretchr/objx?utm_source=godoc)](https://pkg.go.dev/github.com/stretchr/objx) + +Objx - Go package for dealing with maps, slices, JSON and other data. + +Get started: + +- Install Objx with [one line of code](#installation), or [update it with another](#staying-up-to-date) +- Check out the API Documentation http://pkg.go.dev/github.com/stretchr/objx + +## Overview +Objx provides the `objx.Map` type, which is a `map[string]interface{}` that exposes a powerful `Get` method (among others) that allows you to easily and quickly get access to data within the map, without having to worry too much about type assertions, missing data, default values etc. + +### Pattern +Objx uses a predictable pattern to make access data from within `map[string]interface{}` easy. Call one of the `objx.` functions to create your `objx.Map` to get going: + + m, err := objx.FromJSON(json) + +NOTE: Any methods or functions with the `Must` prefix will panic if something goes wrong, the rest will be optimistic and try to figure things out without panicking. + +Use `Get` to access the value you're interested in. You can use dot and array +notation too: + + m.Get("places[0].latlng") + +Once you have sought the `Value` you're interested in, you can use the `Is*` methods to determine its type. + + if m.Get("code").IsStr() { // Your code... } + +Or you can just assume the type, and use one of the strong type methods to extract the real value: + + m.Get("code").Int() + +If there's no value there (or if it's the wrong type) then a default value will be returned, or you can be explicit about the default value. + + Get("code").Int(-1) + +If you're dealing with a slice of data as a value, Objx provides many useful methods for iterating, manipulating and selecting that data. You can find out more by exploring the index below. + +### Reading data +A simple example of how to use Objx: + + // Use MustFromJSON to make an objx.Map from some JSON + m := objx.MustFromJSON(`{"name": "Mat", "age": 30}`) + + // Get the details + name := m.Get("name").Str() + age := m.Get("age").Int() + + // Get their nickname (or use their name if they don't have one) + nickname := m.Get("nickname").Str(name) + +### Ranging +Since `objx.Map` is a `map[string]interface{}` you can treat it as such. For example, to `range` the data, do what you would expect: + + m := objx.MustFromJSON(json) + for key, value := range m { + // Your code... + } + +## Installation +To install Objx, use go get: + + go get github.com/stretchr/objx + +### Staying up to date +To update Objx to the latest version, run: + + go get -u github.com/stretchr/objx + +### Supported go versions +We currently support the three recent major Go versions. + +## Contributing +Please feel free to submit issues, fork the repository and send pull requests! diff --git a/vendor/github.com/stretchr/objx/accessors.go b/vendor/github.com/stretchr/objx/accessors.go new file mode 100644 index 0000000..72f1d1c --- /dev/null +++ b/vendor/github.com/stretchr/objx/accessors.go @@ -0,0 +1,197 @@ +package objx + +import ( + "reflect" + "regexp" + "strconv" + "strings" +) + +const ( + // PathSeparator is the character used to separate the elements + // of the keypath. + // + // For example, `location.address.city` + PathSeparator string = "." + + // arrayAccessRegexString is the regex used to extract the array number + // from the access path + arrayAccessRegexString = `^(.+)\[([0-9]+)\]$` + + // mapAccessRegexString is the regex used to extract the map key + // from the access path + mapAccessRegexString = `^([^\[]*)\[([^\]]+)\](.*)$` +) + +// arrayAccessRegex is the compiled arrayAccessRegexString +var arrayAccessRegex = regexp.MustCompile(arrayAccessRegexString) + +// mapAccessRegex is the compiled mapAccessRegexString +var mapAccessRegex = regexp.MustCompile(mapAccessRegexString) + +// Get gets the value using the specified selector and +// returns it inside a new Obj object. +// +// If it cannot find the value, Get will return a nil +// value inside an instance of Obj. +// +// Get can only operate directly on map[string]interface{} and []interface. +// +// # Example +// +// To access the title of the third chapter of the second book, do: +// +// o.Get("books[1].chapters[2].title") +func (m Map) Get(selector string) *Value { + rawObj := access(m, selector, nil, false) + return &Value{data: rawObj} +} + +// Set sets the value using the specified selector and +// returns the object on which Set was called. +// +// Set can only operate directly on map[string]interface{} and []interface +// +// # Example +// +// To set the title of the third chapter of the second book, do: +// +// o.Set("books[1].chapters[2].title","Time to Go") +func (m Map) Set(selector string, value interface{}) Map { + access(m, selector, value, true) + return m +} + +// getIndex returns the index, which is hold in s by two branches. +// It also returns s without the index part, e.g. name[1] will return (1, name). +// If no index is found, -1 is returned +func getIndex(s string) (int, string) { + arrayMatches := arrayAccessRegex.FindStringSubmatch(s) + if len(arrayMatches) > 0 { + // Get the key into the map + selector := arrayMatches[1] + // Get the index into the array at the key + // We know this can't fail because arrayMatches[2] is an int for sure + index, _ := strconv.Atoi(arrayMatches[2]) + return index, selector + } + return -1, s +} + +// getKey returns the key which is held in s by two brackets. +// It also returns the next selector. +func getKey(s string) (string, string) { + selSegs := strings.SplitN(s, PathSeparator, 2) + thisSel := selSegs[0] + nextSel := "" + + if len(selSegs) > 1 { + nextSel = selSegs[1] + } + + mapMatches := mapAccessRegex.FindStringSubmatch(s) + if len(mapMatches) > 0 { + if _, err := strconv.Atoi(mapMatches[2]); err != nil { + thisSel = mapMatches[1] + nextSel = "[" + mapMatches[2] + "]" + mapMatches[3] + + if thisSel == "" { + thisSel = mapMatches[2] + nextSel = mapMatches[3] + } + + if nextSel == "" { + selSegs = []string{"", ""} + } else if nextSel[0] == '.' { + nextSel = nextSel[1:] + } + } + } + + return thisSel, nextSel +} + +// access accesses the object using the selector and performs the +// appropriate action. +func access(current interface{}, selector string, value interface{}, isSet bool) interface{} { + thisSel, nextSel := getKey(selector) + + indexes := []int{} + for strings.Contains(thisSel, "[") { + prevSel := thisSel + index := -1 + index, thisSel = getIndex(thisSel) + indexes = append(indexes, index) + if prevSel == thisSel { + break + } + } + + if curMap, ok := current.(Map); ok { + current = map[string]interface{}(curMap) + } + // get the object in question + switch current.(type) { + case map[string]interface{}: + curMSI := current.(map[string]interface{}) + if nextSel == "" && isSet { + curMSI[thisSel] = value + return nil + } + + _, ok := curMSI[thisSel].(map[string]interface{}) + if !ok { + _, ok = curMSI[thisSel].(Map) + } + + if (curMSI[thisSel] == nil || !ok) && len(indexes) == 0 && isSet { + curMSI[thisSel] = map[string]interface{}{} + } + + current = curMSI[thisSel] + default: + current = nil + } + + // do we need to access the item of an array? + if len(indexes) > 0 { + num := len(indexes) + for num > 0 { + num-- + index := indexes[num] + indexes = indexes[:num] + if array, ok := interSlice(current); ok { + if index < len(array) { + current = array[index] + } else { + current = nil + break + } + } + } + } + + if nextSel != "" { + current = access(current, nextSel, value, isSet) + } + return current +} + +func interSlice(slice interface{}) ([]interface{}, bool) { + if array, ok := slice.([]interface{}); ok { + return array, ok + } + + s := reflect.ValueOf(slice) + if s.Kind() != reflect.Slice { + return nil, false + } + + ret := make([]interface{}, s.Len()) + + for i := 0; i < s.Len(); i++ { + ret[i] = s.Index(i).Interface() + } + + return ret, true +} diff --git a/vendor/github.com/stretchr/objx/conversions.go b/vendor/github.com/stretchr/objx/conversions.go new file mode 100644 index 0000000..01c63d7 --- /dev/null +++ b/vendor/github.com/stretchr/objx/conversions.go @@ -0,0 +1,280 @@ +package objx + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/url" + "strconv" +) + +// SignatureSeparator is the character that is used to +// separate the Base64 string from the security signature. +const SignatureSeparator = "_" + +// URLValuesSliceKeySuffix is the character that is used to +// specify a suffix for slices parsed by URLValues. +// If the suffix is set to "[i]", then the index of the slice +// is used in place of i +// Ex: Suffix "[]" would have the form a[]=b&a[]=c +// OR Suffix "[i]" would have the form a[0]=b&a[1]=c +// OR Suffix "" would have the form a=b&a=c +var urlValuesSliceKeySuffix = "[]" + +const ( + URLValuesSliceKeySuffixEmpty = "" + URLValuesSliceKeySuffixArray = "[]" + URLValuesSliceKeySuffixIndex = "[i]" +) + +// SetURLValuesSliceKeySuffix sets the character that is used to +// specify a suffix for slices parsed by URLValues. +// If the suffix is set to "[i]", then the index of the slice +// is used in place of i +// Ex: Suffix "[]" would have the form a[]=b&a[]=c +// OR Suffix "[i]" would have the form a[0]=b&a[1]=c +// OR Suffix "" would have the form a=b&a=c +func SetURLValuesSliceKeySuffix(s string) error { + if s == URLValuesSliceKeySuffixEmpty || s == URLValuesSliceKeySuffixArray || s == URLValuesSliceKeySuffixIndex { + urlValuesSliceKeySuffix = s + return nil + } + + return errors.New("objx: Invalid URLValuesSliceKeySuffix provided.") +} + +// JSON converts the contained object to a JSON string +// representation +func (m Map) JSON() (string, error) { + for k, v := range m { + m[k] = cleanUp(v) + } + + result, err := json.Marshal(m) + if err != nil { + err = errors.New("objx: JSON encode failed with: " + err.Error()) + } + return string(result), err +} + +func cleanUpInterfaceArray(in []interface{}) []interface{} { + result := make([]interface{}, len(in)) + for i, v := range in { + result[i] = cleanUp(v) + } + return result +} + +func cleanUpInterfaceMap(in map[interface{}]interface{}) Map { + result := Map{} + for k, v := range in { + result[fmt.Sprintf("%v", k)] = cleanUp(v) + } + return result +} + +func cleanUpStringMap(in map[string]interface{}) Map { + result := Map{} + for k, v := range in { + result[k] = cleanUp(v) + } + return result +} + +func cleanUpMSIArray(in []map[string]interface{}) []Map { + result := make([]Map, len(in)) + for i, v := range in { + result[i] = cleanUpStringMap(v) + } + return result +} + +func cleanUpMapArray(in []Map) []Map { + result := make([]Map, len(in)) + for i, v := range in { + result[i] = cleanUpStringMap(v) + } + return result +} + +func cleanUp(v interface{}) interface{} { + switch v := v.(type) { + case []interface{}: + return cleanUpInterfaceArray(v) + case []map[string]interface{}: + return cleanUpMSIArray(v) + case map[interface{}]interface{}: + return cleanUpInterfaceMap(v) + case Map: + return cleanUpStringMap(v) + case []Map: + return cleanUpMapArray(v) + default: + return v + } +} + +// MustJSON converts the contained object to a JSON string +// representation and panics if there is an error +func (m Map) MustJSON() string { + result, err := m.JSON() + if err != nil { + panic(err.Error()) + } + return result +} + +// Base64 converts the contained object to a Base64 string +// representation of the JSON string representation +func (m Map) Base64() (string, error) { + var buf bytes.Buffer + + jsonData, err := m.JSON() + if err != nil { + return "", err + } + + encoder := base64.NewEncoder(base64.StdEncoding, &buf) + _, _ = encoder.Write([]byte(jsonData)) + _ = encoder.Close() + + return buf.String(), nil +} + +// MustBase64 converts the contained object to a Base64 string +// representation of the JSON string representation and panics +// if there is an error +func (m Map) MustBase64() string { + result, err := m.Base64() + if err != nil { + panic(err.Error()) + } + return result +} + +// SignedBase64 converts the contained object to a Base64 string +// representation of the JSON string representation and signs it +// using the provided key. +func (m Map) SignedBase64(key string) (string, error) { + base64, err := m.Base64() + if err != nil { + return "", err + } + + sig := HashWithKey(base64, key) + return base64 + SignatureSeparator + sig, nil +} + +// MustSignedBase64 converts the contained object to a Base64 string +// representation of the JSON string representation and signs it +// using the provided key and panics if there is an error +func (m Map) MustSignedBase64(key string) string { + result, err := m.SignedBase64(key) + if err != nil { + panic(err.Error()) + } + return result +} + +/* + URL Query + ------------------------------------------------ +*/ + +// URLValues creates a url.Values object from an Obj. This +// function requires that the wrapped object be a map[string]interface{} +func (m Map) URLValues() url.Values { + vals := make(url.Values) + + m.parseURLValues(m, vals, "") + + return vals +} + +func (m Map) parseURLValues(queryMap Map, vals url.Values, key string) { + useSliceIndex := false + if urlValuesSliceKeySuffix == "[i]" { + useSliceIndex = true + } + + for k, v := range queryMap { + val := &Value{data: v} + switch { + case val.IsObjxMap(): + if key == "" { + m.parseURLValues(val.ObjxMap(), vals, k) + } else { + m.parseURLValues(val.ObjxMap(), vals, key+"["+k+"]") + } + case val.IsObjxMapSlice(): + sliceKey := k + if key != "" { + sliceKey = key + "[" + k + "]" + } + + if useSliceIndex { + for i, sv := range val.MustObjxMapSlice() { + sk := sliceKey + "[" + strconv.FormatInt(int64(i), 10) + "]" + m.parseURLValues(sv, vals, sk) + } + } else { + sliceKey = sliceKey + urlValuesSliceKeySuffix + for _, sv := range val.MustObjxMapSlice() { + m.parseURLValues(sv, vals, sliceKey) + } + } + case val.IsMSISlice(): + sliceKey := k + if key != "" { + sliceKey = key + "[" + k + "]" + } + + if useSliceIndex { + for i, sv := range val.MustMSISlice() { + sk := sliceKey + "[" + strconv.FormatInt(int64(i), 10) + "]" + m.parseURLValues(New(sv), vals, sk) + } + } else { + sliceKey = sliceKey + urlValuesSliceKeySuffix + for _, sv := range val.MustMSISlice() { + m.parseURLValues(New(sv), vals, sliceKey) + } + } + case val.IsStrSlice(), val.IsBoolSlice(), + val.IsFloat32Slice(), val.IsFloat64Slice(), + val.IsIntSlice(), val.IsInt8Slice(), val.IsInt16Slice(), val.IsInt32Slice(), val.IsInt64Slice(), + val.IsUintSlice(), val.IsUint8Slice(), val.IsUint16Slice(), val.IsUint32Slice(), val.IsUint64Slice(): + + sliceKey := k + if key != "" { + sliceKey = key + "[" + k + "]" + } + + if useSliceIndex { + for i, sv := range val.StringSlice() { + sk := sliceKey + "[" + strconv.FormatInt(int64(i), 10) + "]" + vals.Set(sk, sv) + } + } else { + sliceKey = sliceKey + urlValuesSliceKeySuffix + vals[sliceKey] = val.StringSlice() + } + + default: + if key == "" { + vals.Set(k, val.String()) + } else { + vals.Set(key+"["+k+"]", val.String()) + } + } + } +} + +// URLQuery gets an encoded URL query representing the given +// Obj. This function requires that the wrapped object be a +// map[string]interface{} +func (m Map) URLQuery() (string, error) { + return m.URLValues().Encode(), nil +} diff --git a/vendor/github.com/stretchr/objx/doc.go b/vendor/github.com/stretchr/objx/doc.go new file mode 100644 index 0000000..b170af7 --- /dev/null +++ b/vendor/github.com/stretchr/objx/doc.go @@ -0,0 +1,66 @@ +/* +Package objx provides utilities for dealing with maps, slices, JSON and other data. + +# Overview + +Objx provides the `objx.Map` type, which is a `map[string]interface{}` that exposes +a powerful `Get` method (among others) that allows you to easily and quickly get +access to data within the map, without having to worry too much about type assertions, +missing data, default values etc. + +# Pattern + +Objx uses a predictable pattern to make access data from within `map[string]interface{}` easy. +Call one of the `objx.` functions to create your `objx.Map` to get going: + + m, err := objx.FromJSON(json) + +NOTE: Any methods or functions with the `Must` prefix will panic if something goes wrong, +the rest will be optimistic and try to figure things out without panicking. + +Use `Get` to access the value you're interested in. You can use dot and array +notation too: + + m.Get("places[0].latlng") + +Once you have sought the `Value` you're interested in, you can use the `Is*` methods to determine its type. + + if m.Get("code").IsStr() { // Your code... } + +Or you can just assume the type, and use one of the strong type methods to extract the real value: + + m.Get("code").Int() + +If there's no value there (or if it's the wrong type) then a default value will be returned, +or you can be explicit about the default value. + + Get("code").Int(-1) + +If you're dealing with a slice of data as a value, Objx provides many useful methods for iterating, +manipulating and selecting that data. You can find out more by exploring the index below. + +# Reading data + +A simple example of how to use Objx: + + // Use MustFromJSON to make an objx.Map from some JSON + m := objx.MustFromJSON(`{"name": "Mat", "age": 30}`) + + // Get the details + name := m.Get("name").Str() + age := m.Get("age").Int() + + // Get their nickname (or use their name if they don't have one) + nickname := m.Get("nickname").Str(name) + +# Ranging + +Since `objx.Map` is a `map[string]interface{}` you can treat it as such. +For example, to `range` the data, do what you would expect: + + m := objx.MustFromJSON(json) + for key, value := range m { + // Your code... + } +*/ +package objx diff --git a/vendor/github.com/stretchr/objx/map.go b/vendor/github.com/stretchr/objx/map.go new file mode 100644 index 0000000..ab9f9ae --- /dev/null +++ b/vendor/github.com/stretchr/objx/map.go @@ -0,0 +1,214 @@ +package objx + +import ( + "encoding/base64" + "encoding/json" + "errors" + "io/ioutil" + "net/url" + "strings" +) + +// MSIConvertable is an interface that defines methods for converting your +// custom types to a map[string]interface{} representation. +type MSIConvertable interface { + // MSI gets a map[string]interface{} (msi) representing the + // object. + MSI() map[string]interface{} +} + +// Map provides extended functionality for working with +// untyped data, in particular map[string]interface (msi). +type Map map[string]interface{} + +// Value returns the internal value instance +func (m Map) Value() *Value { + return &Value{data: m} +} + +// Nil represents a nil Map. +var Nil = New(nil) + +// New creates a new Map containing the map[string]interface{} in the data argument. +// If the data argument is not a map[string]interface, New attempts to call the +// MSI() method on the MSIConvertable interface to create one. +func New(data interface{}) Map { + if _, ok := data.(map[string]interface{}); !ok { + if converter, ok := data.(MSIConvertable); ok { + data = converter.MSI() + } else { + return nil + } + } + return Map(data.(map[string]interface{})) +} + +// MSI creates a map[string]interface{} and puts it inside a new Map. +// +// The arguments follow a key, value pattern. +// +// Returns nil if any key argument is non-string or if there are an odd number of arguments. +// +// # Example +// +// To easily create Maps: +// +// m := objx.MSI("name", "Mat", "age", 29, "subobj", objx.MSI("active", true)) +// +// // creates an Map equivalent to +// m := objx.Map{"name": "Mat", "age": 29, "subobj": objx.Map{"active": true}} +func MSI(keyAndValuePairs ...interface{}) Map { + newMap := Map{} + keyAndValuePairsLen := len(keyAndValuePairs) + if keyAndValuePairsLen%2 != 0 { + return nil + } + for i := 0; i < keyAndValuePairsLen; i = i + 2 { + key := keyAndValuePairs[i] + value := keyAndValuePairs[i+1] + + // make sure the key is a string + keyString, keyStringOK := key.(string) + if !keyStringOK { + return nil + } + newMap[keyString] = value + } + return newMap +} + +// ****** Conversion Constructors + +// MustFromJSON creates a new Map containing the data specified in the +// jsonString. +// +// Panics if the JSON is invalid. +func MustFromJSON(jsonString string) Map { + o, err := FromJSON(jsonString) + if err != nil { + panic("objx: MustFromJSON failed with error: " + err.Error()) + } + return o +} + +// MustFromJSONSlice creates a new slice of Map containing the data specified in the +// jsonString. Works with jsons with a top level array +// +// Panics if the JSON is invalid. +func MustFromJSONSlice(jsonString string) []Map { + slice, err := FromJSONSlice(jsonString) + if err != nil { + panic("objx: MustFromJSONSlice failed with error: " + err.Error()) + } + return slice +} + +// FromJSON creates a new Map containing the data specified in the +// jsonString. +// +// Returns an error if the JSON is invalid. +func FromJSON(jsonString string) (Map, error) { + var m Map + err := json.Unmarshal([]byte(jsonString), &m) + if err != nil { + return Nil, err + } + return m, nil +} + +// FromJSONSlice creates a new slice of Map containing the data specified in the +// jsonString. Works with jsons with a top level array +// +// Returns an error if the JSON is invalid. +func FromJSONSlice(jsonString string) ([]Map, error) { + var slice []Map + err := json.Unmarshal([]byte(jsonString), &slice) + if err != nil { + return nil, err + } + return slice, nil +} + +// FromBase64 creates a new Obj containing the data specified +// in the Base64 string. +// +// The string is an encoded JSON string returned by Base64 +func FromBase64(base64String string) (Map, error) { + decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(base64String)) + decoded, err := ioutil.ReadAll(decoder) + if err != nil { + return nil, err + } + return FromJSON(string(decoded)) +} + +// MustFromBase64 creates a new Obj containing the data specified +// in the Base64 string and panics if there is an error. +// +// The string is an encoded JSON string returned by Base64 +func MustFromBase64(base64String string) Map { + result, err := FromBase64(base64String) + if err != nil { + panic("objx: MustFromBase64 failed with error: " + err.Error()) + } + return result +} + +// FromSignedBase64 creates a new Obj containing the data specified +// in the Base64 string. +// +// The string is an encoded JSON string returned by SignedBase64 +func FromSignedBase64(base64String, key string) (Map, error) { + parts := strings.Split(base64String, SignatureSeparator) + if len(parts) != 2 { + return nil, errors.New("objx: Signed base64 string is malformed") + } + + sig := HashWithKey(parts[0], key) + if parts[1] != sig { + return nil, errors.New("objx: Signature for base64 data does not match") + } + return FromBase64(parts[0]) +} + +// MustFromSignedBase64 creates a new Obj containing the data specified +// in the Base64 string and panics if there is an error. +// +// The string is an encoded JSON string returned by Base64 +func MustFromSignedBase64(base64String, key string) Map { + result, err := FromSignedBase64(base64String, key) + if err != nil { + panic("objx: MustFromSignedBase64 failed with error: " + err.Error()) + } + return result +} + +// FromURLQuery generates a new Obj by parsing the specified +// query. +// +// For queries with multiple values, the first value is selected. +func FromURLQuery(query string) (Map, error) { + vals, err := url.ParseQuery(query) + if err != nil { + return nil, err + } + m := Map{} + for k, vals := range vals { + m[k] = vals[0] + } + return m, nil +} + +// MustFromURLQuery generates a new Obj by parsing the specified +// query. +// +// For queries with multiple values, the first value is selected. +// +// Panics if it encounters an error +func MustFromURLQuery(query string) Map { + o, err := FromURLQuery(query) + if err != nil { + panic("objx: MustFromURLQuery failed with error: " + err.Error()) + } + return o +} diff --git a/vendor/github.com/stretchr/objx/mutations.go b/vendor/github.com/stretchr/objx/mutations.go new file mode 100644 index 0000000..c3400a3 --- /dev/null +++ b/vendor/github.com/stretchr/objx/mutations.go @@ -0,0 +1,77 @@ +package objx + +// Exclude returns a new Map with the keys in the specified []string +// excluded. +func (m Map) Exclude(exclude []string) Map { + excluded := make(Map) + for k, v := range m { + if !contains(exclude, k) { + excluded[k] = v + } + } + return excluded +} + +// Copy creates a shallow copy of the Obj. +func (m Map) Copy() Map { + copied := Map{} + for k, v := range m { + copied[k] = v + } + return copied +} + +// Merge blends the specified map with a copy of this map and returns the result. +// +// Keys that appear in both will be selected from the specified map. +// This method requires that the wrapped object be a map[string]interface{} +func (m Map) Merge(merge Map) Map { + return m.Copy().MergeHere(merge) +} + +// MergeHere blends the specified map with this map and returns the current map. +// +// Keys that appear in both will be selected from the specified map. The original map +// will be modified. This method requires that +// the wrapped object be a map[string]interface{} +func (m Map) MergeHere(merge Map) Map { + for k, v := range merge { + m[k] = v + } + return m +} + +// Transform builds a new Obj giving the transformer a chance +// to change the keys and values as it goes. This method requires that +// the wrapped object be a map[string]interface{} +func (m Map) Transform(transformer func(key string, value interface{}) (string, interface{})) Map { + newMap := Map{} + for k, v := range m { + modifiedKey, modifiedVal := transformer(k, v) + newMap[modifiedKey] = modifiedVal + } + return newMap +} + +// TransformKeys builds a new map using the specified key mapping. +// +// Unspecified keys will be unaltered. +// This method requires that the wrapped object be a map[string]interface{} +func (m Map) TransformKeys(mapping map[string]string) Map { + return m.Transform(func(key string, value interface{}) (string, interface{}) { + if newKey, ok := mapping[key]; ok { + return newKey, value + } + return key, value + }) +} + +// Checks if a string slice contains a string +func contains(s []string, e string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +} diff --git a/vendor/github.com/stretchr/objx/security.go b/vendor/github.com/stretchr/objx/security.go new file mode 100644 index 0000000..692be8e --- /dev/null +++ b/vendor/github.com/stretchr/objx/security.go @@ -0,0 +1,12 @@ +package objx + +import ( + "crypto/sha1" + "encoding/hex" +) + +// HashWithKey hashes the specified string using the security key +func HashWithKey(data, key string) string { + d := sha1.Sum([]byte(data + ":" + key)) + return hex.EncodeToString(d[:]) +} diff --git a/vendor/github.com/stretchr/objx/tests.go b/vendor/github.com/stretchr/objx/tests.go new file mode 100644 index 0000000..d9e0b47 --- /dev/null +++ b/vendor/github.com/stretchr/objx/tests.go @@ -0,0 +1,17 @@ +package objx + +// Has gets whether there is something at the specified selector +// or not. +// +// If m is nil, Has will always return false. +func (m Map) Has(selector string) bool { + if m == nil { + return false + } + return !m.Get(selector).IsNil() +} + +// IsNil gets whether the data is nil or not. +func (v *Value) IsNil() bool { + return v == nil || v.data == nil +} diff --git a/vendor/github.com/stretchr/objx/type_specific.go b/vendor/github.com/stretchr/objx/type_specific.go new file mode 100644 index 0000000..80f88d9 --- /dev/null +++ b/vendor/github.com/stretchr/objx/type_specific.go @@ -0,0 +1,346 @@ +package objx + +/* + MSI (map[string]interface{} and []map[string]interface{}) +*/ + +// MSI gets the value as a map[string]interface{}, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) MSI(optionalDefault ...map[string]interface{}) map[string]interface{} { + if s, ok := v.data.(map[string]interface{}); ok { + return s + } + if s, ok := v.data.(Map); ok { + return map[string]interface{}(s) + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustMSI gets the value as a map[string]interface{}. +// +// Panics if the object is not a map[string]interface{}. +func (v *Value) MustMSI() map[string]interface{} { + if s, ok := v.data.(Map); ok { + return map[string]interface{}(s) + } + return v.data.(map[string]interface{}) +} + +// MSISlice gets the value as a []map[string]interface{}, returns the optionalDefault +// value or nil if the value is not a []map[string]interface{}. +func (v *Value) MSISlice(optionalDefault ...[]map[string]interface{}) []map[string]interface{} { + if s, ok := v.data.([]map[string]interface{}); ok { + return s + } + + s := v.ObjxMapSlice() + if s == nil { + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil + } + + result := make([]map[string]interface{}, len(s)) + for i := range s { + result[i] = s[i].Value().MSI() + } + return result +} + +// MustMSISlice gets the value as a []map[string]interface{}. +// +// Panics if the object is not a []map[string]interface{}. +func (v *Value) MustMSISlice() []map[string]interface{} { + if s := v.MSISlice(); s != nil { + return s + } + + return v.data.([]map[string]interface{}) +} + +// IsMSI gets whether the object contained is a map[string]interface{} or not. +func (v *Value) IsMSI() bool { + _, ok := v.data.(map[string]interface{}) + if !ok { + _, ok = v.data.(Map) + } + return ok +} + +// IsMSISlice gets whether the object contained is a []map[string]interface{} or not. +func (v *Value) IsMSISlice() bool { + _, ok := v.data.([]map[string]interface{}) + if !ok { + _, ok = v.data.([]Map) + if !ok { + s, ok := v.data.([]interface{}) + if ok { + for i := range s { + switch s[i].(type) { + case Map: + case map[string]interface{}: + default: + return false + } + } + return true + } + } + } + return ok +} + +// EachMSI calls the specified callback for each object +// in the []map[string]interface{}. +// +// Panics if the object is the wrong type. +func (v *Value) EachMSI(callback func(int, map[string]interface{}) bool) *Value { + for index, val := range v.MustMSISlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereMSI uses the specified decider function to select items +// from the []map[string]interface{}. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereMSI(decider func(int, map[string]interface{}) bool) *Value { + var selected []map[string]interface{} + v.EachMSI(func(index int, val map[string]interface{}) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupMSI uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]map[string]interface{}. +func (v *Value) GroupMSI(grouper func(int, map[string]interface{}) string) *Value { + groups := make(map[string][]map[string]interface{}) + v.EachMSI(func(index int, val map[string]interface{}) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]map[string]interface{}, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceMSI uses the specified function to replace each map[string]interface{}s +// by iterating each item. The data in the returned result will be a +// []map[string]interface{} containing the replaced items. +func (v *Value) ReplaceMSI(replacer func(int, map[string]interface{}) map[string]interface{}) *Value { + arr := v.MustMSISlice() + replaced := make([]map[string]interface{}, len(arr)) + v.EachMSI(func(index int, val map[string]interface{}) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectMSI uses the specified collector function to collect a value +// for each of the map[string]interface{}s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectMSI(collector func(int, map[string]interface{}) interface{}) *Value { + arr := v.MustMSISlice() + collected := make([]interface{}, len(arr)) + v.EachMSI(func(index int, val map[string]interface{}) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + ObjxMap ((Map) and [](Map)) +*/ + +// ObjxMap gets the value as a (Map), returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) ObjxMap(optionalDefault ...(Map)) Map { + if s, ok := v.data.((Map)); ok { + return s + } + if s, ok := v.data.(map[string]interface{}); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return New(nil) +} + +// MustObjxMap gets the value as a (Map). +// +// Panics if the object is not a (Map). +func (v *Value) MustObjxMap() Map { + if s, ok := v.data.(map[string]interface{}); ok { + return s + } + return v.data.((Map)) +} + +// ObjxMapSlice gets the value as a [](Map), returns the optionalDefault +// value or nil if the value is not a [](Map). +func (v *Value) ObjxMapSlice(optionalDefault ...[](Map)) [](Map) { + if s, ok := v.data.([]Map); ok { + return s + } + + if s, ok := v.data.([]map[string]interface{}); ok { + result := make([]Map, len(s)) + for i := range s { + result[i] = s[i] + } + return result + } + + s, ok := v.data.([]interface{}) + if !ok { + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil + } + + result := make([]Map, len(s)) + for i := range s { + switch s[i].(type) { + case Map: + result[i] = s[i].(Map) + case map[string]interface{}: + result[i] = New(s[i]) + default: + return nil + } + } + return result +} + +// MustObjxMapSlice gets the value as a [](Map). +// +// Panics if the object is not a [](Map). +func (v *Value) MustObjxMapSlice() [](Map) { + if s := v.ObjxMapSlice(); s != nil { + return s + } + return v.data.([](Map)) +} + +// IsObjxMap gets whether the object contained is a (Map) or not. +func (v *Value) IsObjxMap() bool { + _, ok := v.data.((Map)) + if !ok { + _, ok = v.data.(map[string]interface{}) + } + return ok +} + +// IsObjxMapSlice gets whether the object contained is a [](Map) or not. +func (v *Value) IsObjxMapSlice() bool { + _, ok := v.data.([](Map)) + if !ok { + _, ok = v.data.([]map[string]interface{}) + if !ok { + s, ok := v.data.([]interface{}) + if ok { + for i := range s { + switch s[i].(type) { + case Map: + case map[string]interface{}: + default: + return false + } + } + return true + } + } + } + + return ok +} + +// EachObjxMap calls the specified callback for each object +// in the [](Map). +// +// Panics if the object is the wrong type. +func (v *Value) EachObjxMap(callback func(int, Map) bool) *Value { + for index, val := range v.MustObjxMapSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereObjxMap uses the specified decider function to select items +// from the [](Map). The object contained in the result will contain +// only the selected items. +func (v *Value) WhereObjxMap(decider func(int, Map) bool) *Value { + var selected [](Map) + v.EachObjxMap(func(index int, val Map) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupObjxMap uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][](Map). +func (v *Value) GroupObjxMap(grouper func(int, Map) string) *Value { + groups := make(map[string][](Map)) + v.EachObjxMap(func(index int, val Map) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([](Map), 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceObjxMap uses the specified function to replace each (Map)s +// by iterating each item. The data in the returned result will be a +// [](Map) containing the replaced items. +func (v *Value) ReplaceObjxMap(replacer func(int, Map) Map) *Value { + arr := v.MustObjxMapSlice() + replaced := make([](Map), len(arr)) + v.EachObjxMap(func(index int, val Map) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectObjxMap uses the specified collector function to collect a value +// for each of the (Map)s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectObjxMap(collector func(int, Map) interface{}) *Value { + arr := v.MustObjxMapSlice() + collected := make([]interface{}, len(arr)) + v.EachObjxMap(func(index int, val Map) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} diff --git a/vendor/github.com/stretchr/objx/type_specific_codegen.go b/vendor/github.com/stretchr/objx/type_specific_codegen.go new file mode 100644 index 0000000..4585045 --- /dev/null +++ b/vendor/github.com/stretchr/objx/type_specific_codegen.go @@ -0,0 +1,2261 @@ +package objx + +/* + Inter (interface{} and []interface{}) +*/ + +// Inter gets the value as a interface{}, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Inter(optionalDefault ...interface{}) interface{} { + if s, ok := v.data.(interface{}); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInter gets the value as a interface{}. +// +// Panics if the object is not a interface{}. +func (v *Value) MustInter() interface{} { + return v.data.(interface{}) +} + +// InterSlice gets the value as a []interface{}, returns the optionalDefault +// value or nil if the value is not a []interface{}. +func (v *Value) InterSlice(optionalDefault ...[]interface{}) []interface{} { + if s, ok := v.data.([]interface{}); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInterSlice gets the value as a []interface{}. +// +// Panics if the object is not a []interface{}. +func (v *Value) MustInterSlice() []interface{} { + return v.data.([]interface{}) +} + +// IsInter gets whether the object contained is a interface{} or not. +func (v *Value) IsInter() bool { + _, ok := v.data.(interface{}) + return ok +} + +// IsInterSlice gets whether the object contained is a []interface{} or not. +func (v *Value) IsInterSlice() bool { + _, ok := v.data.([]interface{}) + return ok +} + +// EachInter calls the specified callback for each object +// in the []interface{}. +// +// Panics if the object is the wrong type. +func (v *Value) EachInter(callback func(int, interface{}) bool) *Value { + for index, val := range v.MustInterSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInter uses the specified decider function to select items +// from the []interface{}. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInter(decider func(int, interface{}) bool) *Value { + var selected []interface{} + v.EachInter(func(index int, val interface{}) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInter uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]interface{}. +func (v *Value) GroupInter(grouper func(int, interface{}) string) *Value { + groups := make(map[string][]interface{}) + v.EachInter(func(index int, val interface{}) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]interface{}, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInter uses the specified function to replace each interface{}s +// by iterating each item. The data in the returned result will be a +// []interface{} containing the replaced items. +func (v *Value) ReplaceInter(replacer func(int, interface{}) interface{}) *Value { + arr := v.MustInterSlice() + replaced := make([]interface{}, len(arr)) + v.EachInter(func(index int, val interface{}) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInter uses the specified collector function to collect a value +// for each of the interface{}s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInter(collector func(int, interface{}) interface{}) *Value { + arr := v.MustInterSlice() + collected := make([]interface{}, len(arr)) + v.EachInter(func(index int, val interface{}) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Bool (bool and []bool) +*/ + +// Bool gets the value as a bool, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Bool(optionalDefault ...bool) bool { + if s, ok := v.data.(bool); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return false +} + +// MustBool gets the value as a bool. +// +// Panics if the object is not a bool. +func (v *Value) MustBool() bool { + return v.data.(bool) +} + +// BoolSlice gets the value as a []bool, returns the optionalDefault +// value or nil if the value is not a []bool. +func (v *Value) BoolSlice(optionalDefault ...[]bool) []bool { + if s, ok := v.data.([]bool); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustBoolSlice gets the value as a []bool. +// +// Panics if the object is not a []bool. +func (v *Value) MustBoolSlice() []bool { + return v.data.([]bool) +} + +// IsBool gets whether the object contained is a bool or not. +func (v *Value) IsBool() bool { + _, ok := v.data.(bool) + return ok +} + +// IsBoolSlice gets whether the object contained is a []bool or not. +func (v *Value) IsBoolSlice() bool { + _, ok := v.data.([]bool) + return ok +} + +// EachBool calls the specified callback for each object +// in the []bool. +// +// Panics if the object is the wrong type. +func (v *Value) EachBool(callback func(int, bool) bool) *Value { + for index, val := range v.MustBoolSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereBool uses the specified decider function to select items +// from the []bool. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereBool(decider func(int, bool) bool) *Value { + var selected []bool + v.EachBool(func(index int, val bool) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupBool uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]bool. +func (v *Value) GroupBool(grouper func(int, bool) string) *Value { + groups := make(map[string][]bool) + v.EachBool(func(index int, val bool) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]bool, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceBool uses the specified function to replace each bools +// by iterating each item. The data in the returned result will be a +// []bool containing the replaced items. +func (v *Value) ReplaceBool(replacer func(int, bool) bool) *Value { + arr := v.MustBoolSlice() + replaced := make([]bool, len(arr)) + v.EachBool(func(index int, val bool) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectBool uses the specified collector function to collect a value +// for each of the bools in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectBool(collector func(int, bool) interface{}) *Value { + arr := v.MustBoolSlice() + collected := make([]interface{}, len(arr)) + v.EachBool(func(index int, val bool) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Str (string and []string) +*/ + +// Str gets the value as a string, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Str(optionalDefault ...string) string { + if s, ok := v.data.(string); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return "" +} + +// MustStr gets the value as a string. +// +// Panics if the object is not a string. +func (v *Value) MustStr() string { + return v.data.(string) +} + +// StrSlice gets the value as a []string, returns the optionalDefault +// value or nil if the value is not a []string. +func (v *Value) StrSlice(optionalDefault ...[]string) []string { + if s, ok := v.data.([]string); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustStrSlice gets the value as a []string. +// +// Panics if the object is not a []string. +func (v *Value) MustStrSlice() []string { + return v.data.([]string) +} + +// IsStr gets whether the object contained is a string or not. +func (v *Value) IsStr() bool { + _, ok := v.data.(string) + return ok +} + +// IsStrSlice gets whether the object contained is a []string or not. +func (v *Value) IsStrSlice() bool { + _, ok := v.data.([]string) + return ok +} + +// EachStr calls the specified callback for each object +// in the []string. +// +// Panics if the object is the wrong type. +func (v *Value) EachStr(callback func(int, string) bool) *Value { + for index, val := range v.MustStrSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereStr uses the specified decider function to select items +// from the []string. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereStr(decider func(int, string) bool) *Value { + var selected []string + v.EachStr(func(index int, val string) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupStr uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]string. +func (v *Value) GroupStr(grouper func(int, string) string) *Value { + groups := make(map[string][]string) + v.EachStr(func(index int, val string) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]string, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceStr uses the specified function to replace each strings +// by iterating each item. The data in the returned result will be a +// []string containing the replaced items. +func (v *Value) ReplaceStr(replacer func(int, string) string) *Value { + arr := v.MustStrSlice() + replaced := make([]string, len(arr)) + v.EachStr(func(index int, val string) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectStr uses the specified collector function to collect a value +// for each of the strings in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectStr(collector func(int, string) interface{}) *Value { + arr := v.MustStrSlice() + collected := make([]interface{}, len(arr)) + v.EachStr(func(index int, val string) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Int (int and []int) +*/ + +// Int gets the value as a int, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Int(optionalDefault ...int) int { + if s, ok := v.data.(int); ok { + return s + } + if s, ok := v.data.(float64); ok { + if float64(int(s)) == s { + return int(s) + } + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustInt gets the value as a int. +// +// Panics if the object is not a int. +func (v *Value) MustInt() int { + if s, ok := v.data.(float64); ok { + if float64(int(s)) == s { + return int(s) + } + } + return v.data.(int) +} + +// IntSlice gets the value as a []int, returns the optionalDefault +// value or nil if the value is not a []int. +func (v *Value) IntSlice(optionalDefault ...[]int) []int { + if s, ok := v.data.([]int); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustIntSlice gets the value as a []int. +// +// Panics if the object is not a []int. +func (v *Value) MustIntSlice() []int { + return v.data.([]int) +} + +// IsInt gets whether the object contained is a int or not. +func (v *Value) IsInt() bool { + _, ok := v.data.(int) + return ok +} + +// IsIntSlice gets whether the object contained is a []int or not. +func (v *Value) IsIntSlice() bool { + _, ok := v.data.([]int) + return ok +} + +// EachInt calls the specified callback for each object +// in the []int. +// +// Panics if the object is the wrong type. +func (v *Value) EachInt(callback func(int, int) bool) *Value { + for index, val := range v.MustIntSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInt uses the specified decider function to select items +// from the []int. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInt(decider func(int, int) bool) *Value { + var selected []int + v.EachInt(func(index int, val int) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInt uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]int. +func (v *Value) GroupInt(grouper func(int, int) string) *Value { + groups := make(map[string][]int) + v.EachInt(func(index int, val int) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]int, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInt uses the specified function to replace each ints +// by iterating each item. The data in the returned result will be a +// []int containing the replaced items. +func (v *Value) ReplaceInt(replacer func(int, int) int) *Value { + arr := v.MustIntSlice() + replaced := make([]int, len(arr)) + v.EachInt(func(index int, val int) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInt uses the specified collector function to collect a value +// for each of the ints in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInt(collector func(int, int) interface{}) *Value { + arr := v.MustIntSlice() + collected := make([]interface{}, len(arr)) + v.EachInt(func(index int, val int) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Int8 (int8 and []int8) +*/ + +// Int8 gets the value as a int8, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Int8(optionalDefault ...int8) int8 { + if s, ok := v.data.(int8); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustInt8 gets the value as a int8. +// +// Panics if the object is not a int8. +func (v *Value) MustInt8() int8 { + return v.data.(int8) +} + +// Int8Slice gets the value as a []int8, returns the optionalDefault +// value or nil if the value is not a []int8. +func (v *Value) Int8Slice(optionalDefault ...[]int8) []int8 { + if s, ok := v.data.([]int8); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInt8Slice gets the value as a []int8. +// +// Panics if the object is not a []int8. +func (v *Value) MustInt8Slice() []int8 { + return v.data.([]int8) +} + +// IsInt8 gets whether the object contained is a int8 or not. +func (v *Value) IsInt8() bool { + _, ok := v.data.(int8) + return ok +} + +// IsInt8Slice gets whether the object contained is a []int8 or not. +func (v *Value) IsInt8Slice() bool { + _, ok := v.data.([]int8) + return ok +} + +// EachInt8 calls the specified callback for each object +// in the []int8. +// +// Panics if the object is the wrong type. +func (v *Value) EachInt8(callback func(int, int8) bool) *Value { + for index, val := range v.MustInt8Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInt8 uses the specified decider function to select items +// from the []int8. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInt8(decider func(int, int8) bool) *Value { + var selected []int8 + v.EachInt8(func(index int, val int8) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInt8 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]int8. +func (v *Value) GroupInt8(grouper func(int, int8) string) *Value { + groups := make(map[string][]int8) + v.EachInt8(func(index int, val int8) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]int8, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInt8 uses the specified function to replace each int8s +// by iterating each item. The data in the returned result will be a +// []int8 containing the replaced items. +func (v *Value) ReplaceInt8(replacer func(int, int8) int8) *Value { + arr := v.MustInt8Slice() + replaced := make([]int8, len(arr)) + v.EachInt8(func(index int, val int8) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInt8 uses the specified collector function to collect a value +// for each of the int8s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInt8(collector func(int, int8) interface{}) *Value { + arr := v.MustInt8Slice() + collected := make([]interface{}, len(arr)) + v.EachInt8(func(index int, val int8) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Int16 (int16 and []int16) +*/ + +// Int16 gets the value as a int16, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Int16(optionalDefault ...int16) int16 { + if s, ok := v.data.(int16); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustInt16 gets the value as a int16. +// +// Panics if the object is not a int16. +func (v *Value) MustInt16() int16 { + return v.data.(int16) +} + +// Int16Slice gets the value as a []int16, returns the optionalDefault +// value or nil if the value is not a []int16. +func (v *Value) Int16Slice(optionalDefault ...[]int16) []int16 { + if s, ok := v.data.([]int16); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInt16Slice gets the value as a []int16. +// +// Panics if the object is not a []int16. +func (v *Value) MustInt16Slice() []int16 { + return v.data.([]int16) +} + +// IsInt16 gets whether the object contained is a int16 or not. +func (v *Value) IsInt16() bool { + _, ok := v.data.(int16) + return ok +} + +// IsInt16Slice gets whether the object contained is a []int16 or not. +func (v *Value) IsInt16Slice() bool { + _, ok := v.data.([]int16) + return ok +} + +// EachInt16 calls the specified callback for each object +// in the []int16. +// +// Panics if the object is the wrong type. +func (v *Value) EachInt16(callback func(int, int16) bool) *Value { + for index, val := range v.MustInt16Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInt16 uses the specified decider function to select items +// from the []int16. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInt16(decider func(int, int16) bool) *Value { + var selected []int16 + v.EachInt16(func(index int, val int16) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInt16 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]int16. +func (v *Value) GroupInt16(grouper func(int, int16) string) *Value { + groups := make(map[string][]int16) + v.EachInt16(func(index int, val int16) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]int16, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInt16 uses the specified function to replace each int16s +// by iterating each item. The data in the returned result will be a +// []int16 containing the replaced items. +func (v *Value) ReplaceInt16(replacer func(int, int16) int16) *Value { + arr := v.MustInt16Slice() + replaced := make([]int16, len(arr)) + v.EachInt16(func(index int, val int16) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInt16 uses the specified collector function to collect a value +// for each of the int16s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInt16(collector func(int, int16) interface{}) *Value { + arr := v.MustInt16Slice() + collected := make([]interface{}, len(arr)) + v.EachInt16(func(index int, val int16) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Int32 (int32 and []int32) +*/ + +// Int32 gets the value as a int32, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Int32(optionalDefault ...int32) int32 { + if s, ok := v.data.(int32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustInt32 gets the value as a int32. +// +// Panics if the object is not a int32. +func (v *Value) MustInt32() int32 { + return v.data.(int32) +} + +// Int32Slice gets the value as a []int32, returns the optionalDefault +// value or nil if the value is not a []int32. +func (v *Value) Int32Slice(optionalDefault ...[]int32) []int32 { + if s, ok := v.data.([]int32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInt32Slice gets the value as a []int32. +// +// Panics if the object is not a []int32. +func (v *Value) MustInt32Slice() []int32 { + return v.data.([]int32) +} + +// IsInt32 gets whether the object contained is a int32 or not. +func (v *Value) IsInt32() bool { + _, ok := v.data.(int32) + return ok +} + +// IsInt32Slice gets whether the object contained is a []int32 or not. +func (v *Value) IsInt32Slice() bool { + _, ok := v.data.([]int32) + return ok +} + +// EachInt32 calls the specified callback for each object +// in the []int32. +// +// Panics if the object is the wrong type. +func (v *Value) EachInt32(callback func(int, int32) bool) *Value { + for index, val := range v.MustInt32Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInt32 uses the specified decider function to select items +// from the []int32. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInt32(decider func(int, int32) bool) *Value { + var selected []int32 + v.EachInt32(func(index int, val int32) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInt32 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]int32. +func (v *Value) GroupInt32(grouper func(int, int32) string) *Value { + groups := make(map[string][]int32) + v.EachInt32(func(index int, val int32) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]int32, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInt32 uses the specified function to replace each int32s +// by iterating each item. The data in the returned result will be a +// []int32 containing the replaced items. +func (v *Value) ReplaceInt32(replacer func(int, int32) int32) *Value { + arr := v.MustInt32Slice() + replaced := make([]int32, len(arr)) + v.EachInt32(func(index int, val int32) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInt32 uses the specified collector function to collect a value +// for each of the int32s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInt32(collector func(int, int32) interface{}) *Value { + arr := v.MustInt32Slice() + collected := make([]interface{}, len(arr)) + v.EachInt32(func(index int, val int32) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Int64 (int64 and []int64) +*/ + +// Int64 gets the value as a int64, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Int64(optionalDefault ...int64) int64 { + if s, ok := v.data.(int64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustInt64 gets the value as a int64. +// +// Panics if the object is not a int64. +func (v *Value) MustInt64() int64 { + return v.data.(int64) +} + +// Int64Slice gets the value as a []int64, returns the optionalDefault +// value or nil if the value is not a []int64. +func (v *Value) Int64Slice(optionalDefault ...[]int64) []int64 { + if s, ok := v.data.([]int64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInt64Slice gets the value as a []int64. +// +// Panics if the object is not a []int64. +func (v *Value) MustInt64Slice() []int64 { + return v.data.([]int64) +} + +// IsInt64 gets whether the object contained is a int64 or not. +func (v *Value) IsInt64() bool { + _, ok := v.data.(int64) + return ok +} + +// IsInt64Slice gets whether the object contained is a []int64 or not. +func (v *Value) IsInt64Slice() bool { + _, ok := v.data.([]int64) + return ok +} + +// EachInt64 calls the specified callback for each object +// in the []int64. +// +// Panics if the object is the wrong type. +func (v *Value) EachInt64(callback func(int, int64) bool) *Value { + for index, val := range v.MustInt64Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInt64 uses the specified decider function to select items +// from the []int64. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInt64(decider func(int, int64) bool) *Value { + var selected []int64 + v.EachInt64(func(index int, val int64) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInt64 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]int64. +func (v *Value) GroupInt64(grouper func(int, int64) string) *Value { + groups := make(map[string][]int64) + v.EachInt64(func(index int, val int64) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]int64, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInt64 uses the specified function to replace each int64s +// by iterating each item. The data in the returned result will be a +// []int64 containing the replaced items. +func (v *Value) ReplaceInt64(replacer func(int, int64) int64) *Value { + arr := v.MustInt64Slice() + replaced := make([]int64, len(arr)) + v.EachInt64(func(index int, val int64) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInt64 uses the specified collector function to collect a value +// for each of the int64s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInt64(collector func(int, int64) interface{}) *Value { + arr := v.MustInt64Slice() + collected := make([]interface{}, len(arr)) + v.EachInt64(func(index int, val int64) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uint (uint and []uint) +*/ + +// Uint gets the value as a uint, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uint(optionalDefault ...uint) uint { + if s, ok := v.data.(uint); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUint gets the value as a uint. +// +// Panics if the object is not a uint. +func (v *Value) MustUint() uint { + return v.data.(uint) +} + +// UintSlice gets the value as a []uint, returns the optionalDefault +// value or nil if the value is not a []uint. +func (v *Value) UintSlice(optionalDefault ...[]uint) []uint { + if s, ok := v.data.([]uint); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUintSlice gets the value as a []uint. +// +// Panics if the object is not a []uint. +func (v *Value) MustUintSlice() []uint { + return v.data.([]uint) +} + +// IsUint gets whether the object contained is a uint or not. +func (v *Value) IsUint() bool { + _, ok := v.data.(uint) + return ok +} + +// IsUintSlice gets whether the object contained is a []uint or not. +func (v *Value) IsUintSlice() bool { + _, ok := v.data.([]uint) + return ok +} + +// EachUint calls the specified callback for each object +// in the []uint. +// +// Panics if the object is the wrong type. +func (v *Value) EachUint(callback func(int, uint) bool) *Value { + for index, val := range v.MustUintSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUint uses the specified decider function to select items +// from the []uint. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUint(decider func(int, uint) bool) *Value { + var selected []uint + v.EachUint(func(index int, val uint) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUint uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uint. +func (v *Value) GroupUint(grouper func(int, uint) string) *Value { + groups := make(map[string][]uint) + v.EachUint(func(index int, val uint) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uint, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUint uses the specified function to replace each uints +// by iterating each item. The data in the returned result will be a +// []uint containing the replaced items. +func (v *Value) ReplaceUint(replacer func(int, uint) uint) *Value { + arr := v.MustUintSlice() + replaced := make([]uint, len(arr)) + v.EachUint(func(index int, val uint) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUint uses the specified collector function to collect a value +// for each of the uints in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUint(collector func(int, uint) interface{}) *Value { + arr := v.MustUintSlice() + collected := make([]interface{}, len(arr)) + v.EachUint(func(index int, val uint) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uint8 (uint8 and []uint8) +*/ + +// Uint8 gets the value as a uint8, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uint8(optionalDefault ...uint8) uint8 { + if s, ok := v.data.(uint8); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUint8 gets the value as a uint8. +// +// Panics if the object is not a uint8. +func (v *Value) MustUint8() uint8 { + return v.data.(uint8) +} + +// Uint8Slice gets the value as a []uint8, returns the optionalDefault +// value or nil if the value is not a []uint8. +func (v *Value) Uint8Slice(optionalDefault ...[]uint8) []uint8 { + if s, ok := v.data.([]uint8); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUint8Slice gets the value as a []uint8. +// +// Panics if the object is not a []uint8. +func (v *Value) MustUint8Slice() []uint8 { + return v.data.([]uint8) +} + +// IsUint8 gets whether the object contained is a uint8 or not. +func (v *Value) IsUint8() bool { + _, ok := v.data.(uint8) + return ok +} + +// IsUint8Slice gets whether the object contained is a []uint8 or not. +func (v *Value) IsUint8Slice() bool { + _, ok := v.data.([]uint8) + return ok +} + +// EachUint8 calls the specified callback for each object +// in the []uint8. +// +// Panics if the object is the wrong type. +func (v *Value) EachUint8(callback func(int, uint8) bool) *Value { + for index, val := range v.MustUint8Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUint8 uses the specified decider function to select items +// from the []uint8. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUint8(decider func(int, uint8) bool) *Value { + var selected []uint8 + v.EachUint8(func(index int, val uint8) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUint8 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uint8. +func (v *Value) GroupUint8(grouper func(int, uint8) string) *Value { + groups := make(map[string][]uint8) + v.EachUint8(func(index int, val uint8) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uint8, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUint8 uses the specified function to replace each uint8s +// by iterating each item. The data in the returned result will be a +// []uint8 containing the replaced items. +func (v *Value) ReplaceUint8(replacer func(int, uint8) uint8) *Value { + arr := v.MustUint8Slice() + replaced := make([]uint8, len(arr)) + v.EachUint8(func(index int, val uint8) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUint8 uses the specified collector function to collect a value +// for each of the uint8s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUint8(collector func(int, uint8) interface{}) *Value { + arr := v.MustUint8Slice() + collected := make([]interface{}, len(arr)) + v.EachUint8(func(index int, val uint8) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uint16 (uint16 and []uint16) +*/ + +// Uint16 gets the value as a uint16, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uint16(optionalDefault ...uint16) uint16 { + if s, ok := v.data.(uint16); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUint16 gets the value as a uint16. +// +// Panics if the object is not a uint16. +func (v *Value) MustUint16() uint16 { + return v.data.(uint16) +} + +// Uint16Slice gets the value as a []uint16, returns the optionalDefault +// value or nil if the value is not a []uint16. +func (v *Value) Uint16Slice(optionalDefault ...[]uint16) []uint16 { + if s, ok := v.data.([]uint16); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUint16Slice gets the value as a []uint16. +// +// Panics if the object is not a []uint16. +func (v *Value) MustUint16Slice() []uint16 { + return v.data.([]uint16) +} + +// IsUint16 gets whether the object contained is a uint16 or not. +func (v *Value) IsUint16() bool { + _, ok := v.data.(uint16) + return ok +} + +// IsUint16Slice gets whether the object contained is a []uint16 or not. +func (v *Value) IsUint16Slice() bool { + _, ok := v.data.([]uint16) + return ok +} + +// EachUint16 calls the specified callback for each object +// in the []uint16. +// +// Panics if the object is the wrong type. +func (v *Value) EachUint16(callback func(int, uint16) bool) *Value { + for index, val := range v.MustUint16Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUint16 uses the specified decider function to select items +// from the []uint16. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUint16(decider func(int, uint16) bool) *Value { + var selected []uint16 + v.EachUint16(func(index int, val uint16) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUint16 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uint16. +func (v *Value) GroupUint16(grouper func(int, uint16) string) *Value { + groups := make(map[string][]uint16) + v.EachUint16(func(index int, val uint16) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uint16, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUint16 uses the specified function to replace each uint16s +// by iterating each item. The data in the returned result will be a +// []uint16 containing the replaced items. +func (v *Value) ReplaceUint16(replacer func(int, uint16) uint16) *Value { + arr := v.MustUint16Slice() + replaced := make([]uint16, len(arr)) + v.EachUint16(func(index int, val uint16) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUint16 uses the specified collector function to collect a value +// for each of the uint16s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUint16(collector func(int, uint16) interface{}) *Value { + arr := v.MustUint16Slice() + collected := make([]interface{}, len(arr)) + v.EachUint16(func(index int, val uint16) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uint32 (uint32 and []uint32) +*/ + +// Uint32 gets the value as a uint32, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uint32(optionalDefault ...uint32) uint32 { + if s, ok := v.data.(uint32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUint32 gets the value as a uint32. +// +// Panics if the object is not a uint32. +func (v *Value) MustUint32() uint32 { + return v.data.(uint32) +} + +// Uint32Slice gets the value as a []uint32, returns the optionalDefault +// value or nil if the value is not a []uint32. +func (v *Value) Uint32Slice(optionalDefault ...[]uint32) []uint32 { + if s, ok := v.data.([]uint32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUint32Slice gets the value as a []uint32. +// +// Panics if the object is not a []uint32. +func (v *Value) MustUint32Slice() []uint32 { + return v.data.([]uint32) +} + +// IsUint32 gets whether the object contained is a uint32 or not. +func (v *Value) IsUint32() bool { + _, ok := v.data.(uint32) + return ok +} + +// IsUint32Slice gets whether the object contained is a []uint32 or not. +func (v *Value) IsUint32Slice() bool { + _, ok := v.data.([]uint32) + return ok +} + +// EachUint32 calls the specified callback for each object +// in the []uint32. +// +// Panics if the object is the wrong type. +func (v *Value) EachUint32(callback func(int, uint32) bool) *Value { + for index, val := range v.MustUint32Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUint32 uses the specified decider function to select items +// from the []uint32. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUint32(decider func(int, uint32) bool) *Value { + var selected []uint32 + v.EachUint32(func(index int, val uint32) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUint32 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uint32. +func (v *Value) GroupUint32(grouper func(int, uint32) string) *Value { + groups := make(map[string][]uint32) + v.EachUint32(func(index int, val uint32) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uint32, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUint32 uses the specified function to replace each uint32s +// by iterating each item. The data in the returned result will be a +// []uint32 containing the replaced items. +func (v *Value) ReplaceUint32(replacer func(int, uint32) uint32) *Value { + arr := v.MustUint32Slice() + replaced := make([]uint32, len(arr)) + v.EachUint32(func(index int, val uint32) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUint32 uses the specified collector function to collect a value +// for each of the uint32s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUint32(collector func(int, uint32) interface{}) *Value { + arr := v.MustUint32Slice() + collected := make([]interface{}, len(arr)) + v.EachUint32(func(index int, val uint32) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uint64 (uint64 and []uint64) +*/ + +// Uint64 gets the value as a uint64, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uint64(optionalDefault ...uint64) uint64 { + if s, ok := v.data.(uint64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUint64 gets the value as a uint64. +// +// Panics if the object is not a uint64. +func (v *Value) MustUint64() uint64 { + return v.data.(uint64) +} + +// Uint64Slice gets the value as a []uint64, returns the optionalDefault +// value or nil if the value is not a []uint64. +func (v *Value) Uint64Slice(optionalDefault ...[]uint64) []uint64 { + if s, ok := v.data.([]uint64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUint64Slice gets the value as a []uint64. +// +// Panics if the object is not a []uint64. +func (v *Value) MustUint64Slice() []uint64 { + return v.data.([]uint64) +} + +// IsUint64 gets whether the object contained is a uint64 or not. +func (v *Value) IsUint64() bool { + _, ok := v.data.(uint64) + return ok +} + +// IsUint64Slice gets whether the object contained is a []uint64 or not. +func (v *Value) IsUint64Slice() bool { + _, ok := v.data.([]uint64) + return ok +} + +// EachUint64 calls the specified callback for each object +// in the []uint64. +// +// Panics if the object is the wrong type. +func (v *Value) EachUint64(callback func(int, uint64) bool) *Value { + for index, val := range v.MustUint64Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUint64 uses the specified decider function to select items +// from the []uint64. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUint64(decider func(int, uint64) bool) *Value { + var selected []uint64 + v.EachUint64(func(index int, val uint64) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUint64 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uint64. +func (v *Value) GroupUint64(grouper func(int, uint64) string) *Value { + groups := make(map[string][]uint64) + v.EachUint64(func(index int, val uint64) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uint64, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUint64 uses the specified function to replace each uint64s +// by iterating each item. The data in the returned result will be a +// []uint64 containing the replaced items. +func (v *Value) ReplaceUint64(replacer func(int, uint64) uint64) *Value { + arr := v.MustUint64Slice() + replaced := make([]uint64, len(arr)) + v.EachUint64(func(index int, val uint64) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUint64 uses the specified collector function to collect a value +// for each of the uint64s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUint64(collector func(int, uint64) interface{}) *Value { + arr := v.MustUint64Slice() + collected := make([]interface{}, len(arr)) + v.EachUint64(func(index int, val uint64) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uintptr (uintptr and []uintptr) +*/ + +// Uintptr gets the value as a uintptr, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uintptr(optionalDefault ...uintptr) uintptr { + if s, ok := v.data.(uintptr); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUintptr gets the value as a uintptr. +// +// Panics if the object is not a uintptr. +func (v *Value) MustUintptr() uintptr { + return v.data.(uintptr) +} + +// UintptrSlice gets the value as a []uintptr, returns the optionalDefault +// value or nil if the value is not a []uintptr. +func (v *Value) UintptrSlice(optionalDefault ...[]uintptr) []uintptr { + if s, ok := v.data.([]uintptr); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUintptrSlice gets the value as a []uintptr. +// +// Panics if the object is not a []uintptr. +func (v *Value) MustUintptrSlice() []uintptr { + return v.data.([]uintptr) +} + +// IsUintptr gets whether the object contained is a uintptr or not. +func (v *Value) IsUintptr() bool { + _, ok := v.data.(uintptr) + return ok +} + +// IsUintptrSlice gets whether the object contained is a []uintptr or not. +func (v *Value) IsUintptrSlice() bool { + _, ok := v.data.([]uintptr) + return ok +} + +// EachUintptr calls the specified callback for each object +// in the []uintptr. +// +// Panics if the object is the wrong type. +func (v *Value) EachUintptr(callback func(int, uintptr) bool) *Value { + for index, val := range v.MustUintptrSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUintptr uses the specified decider function to select items +// from the []uintptr. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUintptr(decider func(int, uintptr) bool) *Value { + var selected []uintptr + v.EachUintptr(func(index int, val uintptr) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUintptr uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uintptr. +func (v *Value) GroupUintptr(grouper func(int, uintptr) string) *Value { + groups := make(map[string][]uintptr) + v.EachUintptr(func(index int, val uintptr) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uintptr, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUintptr uses the specified function to replace each uintptrs +// by iterating each item. The data in the returned result will be a +// []uintptr containing the replaced items. +func (v *Value) ReplaceUintptr(replacer func(int, uintptr) uintptr) *Value { + arr := v.MustUintptrSlice() + replaced := make([]uintptr, len(arr)) + v.EachUintptr(func(index int, val uintptr) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUintptr uses the specified collector function to collect a value +// for each of the uintptrs in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUintptr(collector func(int, uintptr) interface{}) *Value { + arr := v.MustUintptrSlice() + collected := make([]interface{}, len(arr)) + v.EachUintptr(func(index int, val uintptr) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Float32 (float32 and []float32) +*/ + +// Float32 gets the value as a float32, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Float32(optionalDefault ...float32) float32 { + if s, ok := v.data.(float32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustFloat32 gets the value as a float32. +// +// Panics if the object is not a float32. +func (v *Value) MustFloat32() float32 { + return v.data.(float32) +} + +// Float32Slice gets the value as a []float32, returns the optionalDefault +// value or nil if the value is not a []float32. +func (v *Value) Float32Slice(optionalDefault ...[]float32) []float32 { + if s, ok := v.data.([]float32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustFloat32Slice gets the value as a []float32. +// +// Panics if the object is not a []float32. +func (v *Value) MustFloat32Slice() []float32 { + return v.data.([]float32) +} + +// IsFloat32 gets whether the object contained is a float32 or not. +func (v *Value) IsFloat32() bool { + _, ok := v.data.(float32) + return ok +} + +// IsFloat32Slice gets whether the object contained is a []float32 or not. +func (v *Value) IsFloat32Slice() bool { + _, ok := v.data.([]float32) + return ok +} + +// EachFloat32 calls the specified callback for each object +// in the []float32. +// +// Panics if the object is the wrong type. +func (v *Value) EachFloat32(callback func(int, float32) bool) *Value { + for index, val := range v.MustFloat32Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereFloat32 uses the specified decider function to select items +// from the []float32. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereFloat32(decider func(int, float32) bool) *Value { + var selected []float32 + v.EachFloat32(func(index int, val float32) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupFloat32 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]float32. +func (v *Value) GroupFloat32(grouper func(int, float32) string) *Value { + groups := make(map[string][]float32) + v.EachFloat32(func(index int, val float32) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]float32, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceFloat32 uses the specified function to replace each float32s +// by iterating each item. The data in the returned result will be a +// []float32 containing the replaced items. +func (v *Value) ReplaceFloat32(replacer func(int, float32) float32) *Value { + arr := v.MustFloat32Slice() + replaced := make([]float32, len(arr)) + v.EachFloat32(func(index int, val float32) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectFloat32 uses the specified collector function to collect a value +// for each of the float32s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectFloat32(collector func(int, float32) interface{}) *Value { + arr := v.MustFloat32Slice() + collected := make([]interface{}, len(arr)) + v.EachFloat32(func(index int, val float32) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Float64 (float64 and []float64) +*/ + +// Float64 gets the value as a float64, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Float64(optionalDefault ...float64) float64 { + if s, ok := v.data.(float64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustFloat64 gets the value as a float64. +// +// Panics if the object is not a float64. +func (v *Value) MustFloat64() float64 { + return v.data.(float64) +} + +// Float64Slice gets the value as a []float64, returns the optionalDefault +// value or nil if the value is not a []float64. +func (v *Value) Float64Slice(optionalDefault ...[]float64) []float64 { + if s, ok := v.data.([]float64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustFloat64Slice gets the value as a []float64. +// +// Panics if the object is not a []float64. +func (v *Value) MustFloat64Slice() []float64 { + return v.data.([]float64) +} + +// IsFloat64 gets whether the object contained is a float64 or not. +func (v *Value) IsFloat64() bool { + _, ok := v.data.(float64) + return ok +} + +// IsFloat64Slice gets whether the object contained is a []float64 or not. +func (v *Value) IsFloat64Slice() bool { + _, ok := v.data.([]float64) + return ok +} + +// EachFloat64 calls the specified callback for each object +// in the []float64. +// +// Panics if the object is the wrong type. +func (v *Value) EachFloat64(callback func(int, float64) bool) *Value { + for index, val := range v.MustFloat64Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereFloat64 uses the specified decider function to select items +// from the []float64. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereFloat64(decider func(int, float64) bool) *Value { + var selected []float64 + v.EachFloat64(func(index int, val float64) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupFloat64 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]float64. +func (v *Value) GroupFloat64(grouper func(int, float64) string) *Value { + groups := make(map[string][]float64) + v.EachFloat64(func(index int, val float64) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]float64, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceFloat64 uses the specified function to replace each float64s +// by iterating each item. The data in the returned result will be a +// []float64 containing the replaced items. +func (v *Value) ReplaceFloat64(replacer func(int, float64) float64) *Value { + arr := v.MustFloat64Slice() + replaced := make([]float64, len(arr)) + v.EachFloat64(func(index int, val float64) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectFloat64 uses the specified collector function to collect a value +// for each of the float64s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectFloat64(collector func(int, float64) interface{}) *Value { + arr := v.MustFloat64Slice() + collected := make([]interface{}, len(arr)) + v.EachFloat64(func(index int, val float64) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Complex64 (complex64 and []complex64) +*/ + +// Complex64 gets the value as a complex64, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Complex64(optionalDefault ...complex64) complex64 { + if s, ok := v.data.(complex64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustComplex64 gets the value as a complex64. +// +// Panics if the object is not a complex64. +func (v *Value) MustComplex64() complex64 { + return v.data.(complex64) +} + +// Complex64Slice gets the value as a []complex64, returns the optionalDefault +// value or nil if the value is not a []complex64. +func (v *Value) Complex64Slice(optionalDefault ...[]complex64) []complex64 { + if s, ok := v.data.([]complex64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustComplex64Slice gets the value as a []complex64. +// +// Panics if the object is not a []complex64. +func (v *Value) MustComplex64Slice() []complex64 { + return v.data.([]complex64) +} + +// IsComplex64 gets whether the object contained is a complex64 or not. +func (v *Value) IsComplex64() bool { + _, ok := v.data.(complex64) + return ok +} + +// IsComplex64Slice gets whether the object contained is a []complex64 or not. +func (v *Value) IsComplex64Slice() bool { + _, ok := v.data.([]complex64) + return ok +} + +// EachComplex64 calls the specified callback for each object +// in the []complex64. +// +// Panics if the object is the wrong type. +func (v *Value) EachComplex64(callback func(int, complex64) bool) *Value { + for index, val := range v.MustComplex64Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereComplex64 uses the specified decider function to select items +// from the []complex64. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereComplex64(decider func(int, complex64) bool) *Value { + var selected []complex64 + v.EachComplex64(func(index int, val complex64) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupComplex64 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]complex64. +func (v *Value) GroupComplex64(grouper func(int, complex64) string) *Value { + groups := make(map[string][]complex64) + v.EachComplex64(func(index int, val complex64) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]complex64, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceComplex64 uses the specified function to replace each complex64s +// by iterating each item. The data in the returned result will be a +// []complex64 containing the replaced items. +func (v *Value) ReplaceComplex64(replacer func(int, complex64) complex64) *Value { + arr := v.MustComplex64Slice() + replaced := make([]complex64, len(arr)) + v.EachComplex64(func(index int, val complex64) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectComplex64 uses the specified collector function to collect a value +// for each of the complex64s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectComplex64(collector func(int, complex64) interface{}) *Value { + arr := v.MustComplex64Slice() + collected := make([]interface{}, len(arr)) + v.EachComplex64(func(index int, val complex64) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Complex128 (complex128 and []complex128) +*/ + +// Complex128 gets the value as a complex128, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Complex128(optionalDefault ...complex128) complex128 { + if s, ok := v.data.(complex128); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustComplex128 gets the value as a complex128. +// +// Panics if the object is not a complex128. +func (v *Value) MustComplex128() complex128 { + return v.data.(complex128) +} + +// Complex128Slice gets the value as a []complex128, returns the optionalDefault +// value or nil if the value is not a []complex128. +func (v *Value) Complex128Slice(optionalDefault ...[]complex128) []complex128 { + if s, ok := v.data.([]complex128); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustComplex128Slice gets the value as a []complex128. +// +// Panics if the object is not a []complex128. +func (v *Value) MustComplex128Slice() []complex128 { + return v.data.([]complex128) +} + +// IsComplex128 gets whether the object contained is a complex128 or not. +func (v *Value) IsComplex128() bool { + _, ok := v.data.(complex128) + return ok +} + +// IsComplex128Slice gets whether the object contained is a []complex128 or not. +func (v *Value) IsComplex128Slice() bool { + _, ok := v.data.([]complex128) + return ok +} + +// EachComplex128 calls the specified callback for each object +// in the []complex128. +// +// Panics if the object is the wrong type. +func (v *Value) EachComplex128(callback func(int, complex128) bool) *Value { + for index, val := range v.MustComplex128Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereComplex128 uses the specified decider function to select items +// from the []complex128. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereComplex128(decider func(int, complex128) bool) *Value { + var selected []complex128 + v.EachComplex128(func(index int, val complex128) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupComplex128 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]complex128. +func (v *Value) GroupComplex128(grouper func(int, complex128) string) *Value { + groups := make(map[string][]complex128) + v.EachComplex128(func(index int, val complex128) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]complex128, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceComplex128 uses the specified function to replace each complex128s +// by iterating each item. The data in the returned result will be a +// []complex128 containing the replaced items. +func (v *Value) ReplaceComplex128(replacer func(int, complex128) complex128) *Value { + arr := v.MustComplex128Slice() + replaced := make([]complex128, len(arr)) + v.EachComplex128(func(index int, val complex128) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectComplex128 uses the specified collector function to collect a value +// for each of the complex128s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectComplex128(collector func(int, complex128) interface{}) *Value { + arr := v.MustComplex128Slice() + collected := make([]interface{}, len(arr)) + v.EachComplex128(func(index int, val complex128) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} diff --git a/vendor/github.com/stretchr/objx/value.go b/vendor/github.com/stretchr/objx/value.go new file mode 100644 index 0000000..4e5f9b7 --- /dev/null +++ b/vendor/github.com/stretchr/objx/value.go @@ -0,0 +1,159 @@ +package objx + +import ( + "fmt" + "strconv" +) + +// Value provides methods for extracting interface{} data in various +// types. +type Value struct { + // data contains the raw data being managed by this Value + data interface{} +} + +// Data returns the raw data contained by this Value +func (v *Value) Data() interface{} { + return v.data +} + +// String returns the value always as a string +func (v *Value) String() string { + switch { + case v.IsNil(): + return "" + case v.IsStr(): + return v.Str() + case v.IsBool(): + return strconv.FormatBool(v.Bool()) + case v.IsFloat32(): + return strconv.FormatFloat(float64(v.Float32()), 'f', -1, 32) + case v.IsFloat64(): + return strconv.FormatFloat(v.Float64(), 'f', -1, 64) + case v.IsInt(): + return strconv.FormatInt(int64(v.Int()), 10) + case v.IsInt8(): + return strconv.FormatInt(int64(v.Int8()), 10) + case v.IsInt16(): + return strconv.FormatInt(int64(v.Int16()), 10) + case v.IsInt32(): + return strconv.FormatInt(int64(v.Int32()), 10) + case v.IsInt64(): + return strconv.FormatInt(v.Int64(), 10) + case v.IsUint(): + return strconv.FormatUint(uint64(v.Uint()), 10) + case v.IsUint8(): + return strconv.FormatUint(uint64(v.Uint8()), 10) + case v.IsUint16(): + return strconv.FormatUint(uint64(v.Uint16()), 10) + case v.IsUint32(): + return strconv.FormatUint(uint64(v.Uint32()), 10) + case v.IsUint64(): + return strconv.FormatUint(v.Uint64(), 10) + } + return fmt.Sprintf("%#v", v.Data()) +} + +// StringSlice returns the value always as a []string +func (v *Value) StringSlice(optionalDefault ...[]string) []string { + switch { + case v.IsStrSlice(): + return v.MustStrSlice() + case v.IsBoolSlice(): + slice := v.MustBoolSlice() + vals := make([]string, len(slice)) + for i, iv := range slice { + vals[i] = strconv.FormatBool(iv) + } + return vals + case v.IsFloat32Slice(): + slice := v.MustFloat32Slice() + vals := make([]string, len(slice)) + for i, iv := range slice { + vals[i] = strconv.FormatFloat(float64(iv), 'f', -1, 32) + } + return vals + case v.IsFloat64Slice(): + slice := v.MustFloat64Slice() + vals := make([]string, len(slice)) + for i, iv := range slice { + vals[i] = strconv.FormatFloat(iv, 'f', -1, 64) + } + return vals + case v.IsIntSlice(): + slice := v.MustIntSlice() + vals := make([]string, len(slice)) + for i, iv := range slice { + vals[i] = strconv.FormatInt(int64(iv), 10) + } + return vals + case v.IsInt8Slice(): + slice := v.MustInt8Slice() + vals := make([]string, len(slice)) + for i, iv := range slice { + vals[i] = strconv.FormatInt(int64(iv), 10) + } + return vals + case v.IsInt16Slice(): + slice := v.MustInt16Slice() + vals := make([]string, len(slice)) + for i, iv := range slice { + vals[i] = strconv.FormatInt(int64(iv), 10) + } + return vals + case v.IsInt32Slice(): + slice := v.MustInt32Slice() + vals := make([]string, len(slice)) + for i, iv := range slice { + vals[i] = strconv.FormatInt(int64(iv), 10) + } + return vals + case v.IsInt64Slice(): + slice := v.MustInt64Slice() + vals := make([]string, len(slice)) + for i, iv := range slice { + vals[i] = strconv.FormatInt(iv, 10) + } + return vals + case v.IsUintSlice(): + slice := v.MustUintSlice() + vals := make([]string, len(slice)) + for i, iv := range slice { + vals[i] = strconv.FormatUint(uint64(iv), 10) + } + return vals + case v.IsUint8Slice(): + slice := v.MustUint8Slice() + vals := make([]string, len(slice)) + for i, iv := range slice { + vals[i] = strconv.FormatUint(uint64(iv), 10) + } + return vals + case v.IsUint16Slice(): + slice := v.MustUint16Slice() + vals := make([]string, len(slice)) + for i, iv := range slice { + vals[i] = strconv.FormatUint(uint64(iv), 10) + } + return vals + case v.IsUint32Slice(): + slice := v.MustUint32Slice() + vals := make([]string, len(slice)) + for i, iv := range slice { + vals[i] = strconv.FormatUint(uint64(iv), 10) + } + return vals + case v.IsUint64Slice(): + slice := v.MustUint64Slice() + vals := make([]string, len(slice)) + for i, iv := range slice { + vals[i] = strconv.FormatUint(iv, 10) + } + return vals + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + + return []string{} +} diff --git a/vendor/github.com/stretchr/testify/LICENSE b/vendor/github.com/stretchr/testify/LICENSE new file mode 100644 index 0000000..4b0421c --- /dev/null +++ b/vendor/github.com/stretchr/testify/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2012-2020 Mat Ryer, Tyler Bunnell and contributors. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/stretchr/testify/assert/assertion_compare.go b/vendor/github.com/stretchr/testify/assert/assertion_compare.go new file mode 100644 index 0000000..7e19eba --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/assertion_compare.go @@ -0,0 +1,489 @@ +package assert + +import ( + "bytes" + "fmt" + "reflect" + "time" +) + +// Deprecated: CompareType has only ever been for internal use and has accidentally been published since v1.6.0. Do not use it. +type CompareType = compareResult + +type compareResult int + +const ( + compareLess compareResult = iota - 1 + compareEqual + compareGreater +) + +var ( + intType = reflect.TypeOf(int(1)) + int8Type = reflect.TypeOf(int8(1)) + int16Type = reflect.TypeOf(int16(1)) + int32Type = reflect.TypeOf(int32(1)) + int64Type = reflect.TypeOf(int64(1)) + + uintType = reflect.TypeOf(uint(1)) + uint8Type = reflect.TypeOf(uint8(1)) + uint16Type = reflect.TypeOf(uint16(1)) + uint32Type = reflect.TypeOf(uint32(1)) + uint64Type = reflect.TypeOf(uint64(1)) + + uintptrType = reflect.TypeOf(uintptr(1)) + + float32Type = reflect.TypeOf(float32(1)) + float64Type = reflect.TypeOf(float64(1)) + + stringType = reflect.TypeOf("") + + timeType = reflect.TypeOf(time.Time{}) + bytesType = reflect.TypeOf([]byte{}) +) + +func compare(obj1, obj2 interface{}, kind reflect.Kind) (compareResult, bool) { + obj1Value := reflect.ValueOf(obj1) + obj2Value := reflect.ValueOf(obj2) + + // throughout this switch we try and avoid calling .Convert() if possible, + // as this has a pretty big performance impact + switch kind { + case reflect.Int: + { + intobj1, ok := obj1.(int) + if !ok { + intobj1 = obj1Value.Convert(intType).Interface().(int) + } + intobj2, ok := obj2.(int) + if !ok { + intobj2 = obj2Value.Convert(intType).Interface().(int) + } + if intobj1 > intobj2 { + return compareGreater, true + } + if intobj1 == intobj2 { + return compareEqual, true + } + if intobj1 < intobj2 { + return compareLess, true + } + } + case reflect.Int8: + { + int8obj1, ok := obj1.(int8) + if !ok { + int8obj1 = obj1Value.Convert(int8Type).Interface().(int8) + } + int8obj2, ok := obj2.(int8) + if !ok { + int8obj2 = obj2Value.Convert(int8Type).Interface().(int8) + } + if int8obj1 > int8obj2 { + return compareGreater, true + } + if int8obj1 == int8obj2 { + return compareEqual, true + } + if int8obj1 < int8obj2 { + return compareLess, true + } + } + case reflect.Int16: + { + int16obj1, ok := obj1.(int16) + if !ok { + int16obj1 = obj1Value.Convert(int16Type).Interface().(int16) + } + int16obj2, ok := obj2.(int16) + if !ok { + int16obj2 = obj2Value.Convert(int16Type).Interface().(int16) + } + if int16obj1 > int16obj2 { + return compareGreater, true + } + if int16obj1 == int16obj2 { + return compareEqual, true + } + if int16obj1 < int16obj2 { + return compareLess, true + } + } + case reflect.Int32: + { + int32obj1, ok := obj1.(int32) + if !ok { + int32obj1 = obj1Value.Convert(int32Type).Interface().(int32) + } + int32obj2, ok := obj2.(int32) + if !ok { + int32obj2 = obj2Value.Convert(int32Type).Interface().(int32) + } + if int32obj1 > int32obj2 { + return compareGreater, true + } + if int32obj1 == int32obj2 { + return compareEqual, true + } + if int32obj1 < int32obj2 { + return compareLess, true + } + } + case reflect.Int64: + { + int64obj1, ok := obj1.(int64) + if !ok { + int64obj1 = obj1Value.Convert(int64Type).Interface().(int64) + } + int64obj2, ok := obj2.(int64) + if !ok { + int64obj2 = obj2Value.Convert(int64Type).Interface().(int64) + } + if int64obj1 > int64obj2 { + return compareGreater, true + } + if int64obj1 == int64obj2 { + return compareEqual, true + } + if int64obj1 < int64obj2 { + return compareLess, true + } + } + case reflect.Uint: + { + uintobj1, ok := obj1.(uint) + if !ok { + uintobj1 = obj1Value.Convert(uintType).Interface().(uint) + } + uintobj2, ok := obj2.(uint) + if !ok { + uintobj2 = obj2Value.Convert(uintType).Interface().(uint) + } + if uintobj1 > uintobj2 { + return compareGreater, true + } + if uintobj1 == uintobj2 { + return compareEqual, true + } + if uintobj1 < uintobj2 { + return compareLess, true + } + } + case reflect.Uint8: + { + uint8obj1, ok := obj1.(uint8) + if !ok { + uint8obj1 = obj1Value.Convert(uint8Type).Interface().(uint8) + } + uint8obj2, ok := obj2.(uint8) + if !ok { + uint8obj2 = obj2Value.Convert(uint8Type).Interface().(uint8) + } + if uint8obj1 > uint8obj2 { + return compareGreater, true + } + if uint8obj1 == uint8obj2 { + return compareEqual, true + } + if uint8obj1 < uint8obj2 { + return compareLess, true + } + } + case reflect.Uint16: + { + uint16obj1, ok := obj1.(uint16) + if !ok { + uint16obj1 = obj1Value.Convert(uint16Type).Interface().(uint16) + } + uint16obj2, ok := obj2.(uint16) + if !ok { + uint16obj2 = obj2Value.Convert(uint16Type).Interface().(uint16) + } + if uint16obj1 > uint16obj2 { + return compareGreater, true + } + if uint16obj1 == uint16obj2 { + return compareEqual, true + } + if uint16obj1 < uint16obj2 { + return compareLess, true + } + } + case reflect.Uint32: + { + uint32obj1, ok := obj1.(uint32) + if !ok { + uint32obj1 = obj1Value.Convert(uint32Type).Interface().(uint32) + } + uint32obj2, ok := obj2.(uint32) + if !ok { + uint32obj2 = obj2Value.Convert(uint32Type).Interface().(uint32) + } + if uint32obj1 > uint32obj2 { + return compareGreater, true + } + if uint32obj1 == uint32obj2 { + return compareEqual, true + } + if uint32obj1 < uint32obj2 { + return compareLess, true + } + } + case reflect.Uint64: + { + uint64obj1, ok := obj1.(uint64) + if !ok { + uint64obj1 = obj1Value.Convert(uint64Type).Interface().(uint64) + } + uint64obj2, ok := obj2.(uint64) + if !ok { + uint64obj2 = obj2Value.Convert(uint64Type).Interface().(uint64) + } + if uint64obj1 > uint64obj2 { + return compareGreater, true + } + if uint64obj1 == uint64obj2 { + return compareEqual, true + } + if uint64obj1 < uint64obj2 { + return compareLess, true + } + } + case reflect.Float32: + { + float32obj1, ok := obj1.(float32) + if !ok { + float32obj1 = obj1Value.Convert(float32Type).Interface().(float32) + } + float32obj2, ok := obj2.(float32) + if !ok { + float32obj2 = obj2Value.Convert(float32Type).Interface().(float32) + } + if float32obj1 > float32obj2 { + return compareGreater, true + } + if float32obj1 == float32obj2 { + return compareEqual, true + } + if float32obj1 < float32obj2 { + return compareLess, true + } + } + case reflect.Float64: + { + float64obj1, ok := obj1.(float64) + if !ok { + float64obj1 = obj1Value.Convert(float64Type).Interface().(float64) + } + float64obj2, ok := obj2.(float64) + if !ok { + float64obj2 = obj2Value.Convert(float64Type).Interface().(float64) + } + if float64obj1 > float64obj2 { + return compareGreater, true + } + if float64obj1 == float64obj2 { + return compareEqual, true + } + if float64obj1 < float64obj2 { + return compareLess, true + } + } + case reflect.String: + { + stringobj1, ok := obj1.(string) + if !ok { + stringobj1 = obj1Value.Convert(stringType).Interface().(string) + } + stringobj2, ok := obj2.(string) + if !ok { + stringobj2 = obj2Value.Convert(stringType).Interface().(string) + } + if stringobj1 > stringobj2 { + return compareGreater, true + } + if stringobj1 == stringobj2 { + return compareEqual, true + } + if stringobj1 < stringobj2 { + return compareLess, true + } + } + // Check for known struct types we can check for compare results. + case reflect.Struct: + { + // All structs enter here. We're not interested in most types. + if !obj1Value.CanConvert(timeType) { + break + } + + // time.Time can be compared! + timeObj1, ok := obj1.(time.Time) + if !ok { + timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time) + } + + timeObj2, ok := obj2.(time.Time) + if !ok { + timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time) + } + + if timeObj1.Before(timeObj2) { + return compareLess, true + } + if timeObj1.Equal(timeObj2) { + return compareEqual, true + } + return compareGreater, true + } + case reflect.Slice: + { + // We only care about the []byte type. + if !obj1Value.CanConvert(bytesType) { + break + } + + // []byte can be compared! + bytesObj1, ok := obj1.([]byte) + if !ok { + bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte) + + } + bytesObj2, ok := obj2.([]byte) + if !ok { + bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte) + } + + return compareResult(bytes.Compare(bytesObj1, bytesObj2)), true + } + case reflect.Uintptr: + { + uintptrObj1, ok := obj1.(uintptr) + if !ok { + uintptrObj1 = obj1Value.Convert(uintptrType).Interface().(uintptr) + } + uintptrObj2, ok := obj2.(uintptr) + if !ok { + uintptrObj2 = obj2Value.Convert(uintptrType).Interface().(uintptr) + } + if uintptrObj1 > uintptrObj2 { + return compareGreater, true + } + if uintptrObj1 == uintptrObj2 { + return compareEqual, true + } + if uintptrObj1 < uintptrObj2 { + return compareLess, true + } + } + } + + return compareEqual, false +} + +// Greater asserts that the first element is greater than the second +// +// assert.Greater(t, 2, 1) +// assert.Greater(t, float64(2), float64(1)) +// assert.Greater(t, "b", "a") +func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return compareTwoValues(t, e1, e2, []compareResult{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...) +} + +// GreaterOrEqual asserts that the first element is greater than or equal to the second +// +// assert.GreaterOrEqual(t, 2, 1) +// assert.GreaterOrEqual(t, 2, 2) +// assert.GreaterOrEqual(t, "b", "a") +// assert.GreaterOrEqual(t, "b", "b") +func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return compareTwoValues(t, e1, e2, []compareResult{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...) +} + +// Less asserts that the first element is less than the second +// +// assert.Less(t, 1, 2) +// assert.Less(t, float64(1), float64(2)) +// assert.Less(t, "a", "b") +func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return compareTwoValues(t, e1, e2, []compareResult{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...) +} + +// LessOrEqual asserts that the first element is less than or equal to the second +// +// assert.LessOrEqual(t, 1, 2) +// assert.LessOrEqual(t, 2, 2) +// assert.LessOrEqual(t, "a", "b") +// assert.LessOrEqual(t, "b", "b") +func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return compareTwoValues(t, e1, e2, []compareResult{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...) +} + +// Positive asserts that the specified element is positive +// +// assert.Positive(t, 1) +// assert.Positive(t, 1.23) +func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + zero := reflect.Zero(reflect.TypeOf(e)) + return compareTwoValues(t, e, zero.Interface(), []compareResult{compareGreater}, "\"%v\" is not positive", msgAndArgs...) +} + +// Negative asserts that the specified element is negative +// +// assert.Negative(t, -1) +// assert.Negative(t, -1.23) +func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + zero := reflect.Zero(reflect.TypeOf(e)) + return compareTwoValues(t, e, zero.Interface(), []compareResult{compareLess}, "\"%v\" is not negative", msgAndArgs...) +} + +func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []compareResult, failMessage string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + e1Kind := reflect.ValueOf(e1).Kind() + e2Kind := reflect.ValueOf(e2).Kind() + if e1Kind != e2Kind { + return Fail(t, "Elements should be the same type", msgAndArgs...) + } + + compareResult, isComparable := compare(e1, e2, e1Kind) + if !isComparable { + return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...) + } + + if !containsValue(allowedComparesResults, compareResult) { + return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...) + } + + return true +} + +func containsValue(values []compareResult, value compareResult) bool { + for _, v := range values { + if v == value { + return true + } + } + + return false +} diff --git a/vendor/github.com/stretchr/testify/assert/assertion_format.go b/vendor/github.com/stretchr/testify/assert/assertion_format.go new file mode 100644 index 0000000..1906341 --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/assertion_format.go @@ -0,0 +1,841 @@ +// Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT. + +package assert + +import ( + http "net/http" + url "net/url" + time "time" +) + +// Conditionf uses a Comparison to assert a complex condition. +func Conditionf(t TestingT, comp Comparison, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Condition(t, comp, append([]interface{}{msg}, args...)...) +} + +// Containsf asserts that the specified string, list(array, slice...) or map contains the +// specified substring or element. +// +// assert.Containsf(t, "Hello World", "World", "error message %s", "formatted") +// assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted") +// assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted") +func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Contains(t, s, contains, append([]interface{}{msg}, args...)...) +} + +// DirExistsf checks whether a directory exists in the given path. It also fails +// if the path is a file rather a directory or there is an error checking whether it exists. +func DirExistsf(t TestingT, path string, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return DirExists(t, path, append([]interface{}{msg}, args...)...) +} + +// ElementsMatchf asserts that the specified listA(array, slice...) is equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should match. +// +// assert.ElementsMatchf(t, [1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted") +func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return ElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...) +} + +// Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either +// a slice or a channel with len == 0. +// +// assert.Emptyf(t, obj, "error message %s", "formatted") +func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Empty(t, object, append([]interface{}{msg}, args...)...) +} + +// Equalf asserts that two objects are equal. +// +// assert.Equalf(t, 123, 123, "error message %s", "formatted") +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). Function equality +// cannot be determined and will always fail. +func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Equal(t, expected, actual, append([]interface{}{msg}, args...)...) +} + +// EqualErrorf asserts that a function returned an error (i.e. not `nil`) +// and that it is equal to the provided error. +// +// actualObj, err := SomeFunction() +// assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted") +func EqualErrorf(t TestingT, theError error, errString string, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return EqualError(t, theError, errString, append([]interface{}{msg}, args...)...) +} + +// EqualExportedValuesf asserts that the types of two objects are equal and their public +// fields are also equal. This is useful for comparing structs that have private fields +// that could potentially differ. +// +// type S struct { +// Exported int +// notExported int +// } +// assert.EqualExportedValuesf(t, S{1, 2}, S{1, 3}, "error message %s", "formatted") => true +// assert.EqualExportedValuesf(t, S{1, 2}, S{2, 3}, "error message %s", "formatted") => false +func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...) +} + +// EqualValuesf asserts that two objects are equal or convertible to the larger +// type and equal. +// +// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") +func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return EqualValues(t, expected, actual, append([]interface{}{msg}, args...)...) +} + +// Errorf asserts that a function returned an error (i.e. not `nil`). +// +// actualObj, err := SomeFunction() +// if assert.Errorf(t, err, "error message %s", "formatted") { +// assert.Equal(t, expectedErrorf, err) +// } +func Errorf(t TestingT, err error, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Error(t, err, append([]interface{}{msg}, args...)...) +} + +// ErrorAsf asserts that at least one of the errors in err's chain matches target, and if so, sets target to that error value. +// This is a wrapper for errors.As. +func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return ErrorAs(t, err, target, append([]interface{}{msg}, args...)...) +} + +// ErrorContainsf asserts that a function returned an error (i.e. not `nil`) +// and that the error contains the specified substring. +// +// actualObj, err := SomeFunction() +// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted") +func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return ErrorContains(t, theError, contains, append([]interface{}{msg}, args...)...) +} + +// ErrorIsf asserts that at least one of the errors in err's chain matches target. +// This is a wrapper for errors.Is. +func ErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return ErrorIs(t, err, target, append([]interface{}{msg}, args...)...) +} + +// Eventuallyf asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. +// +// assert.Eventuallyf(t, func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") +func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Eventually(t, condition, waitFor, tick, append([]interface{}{msg}, args...)...) +} + +// EventuallyWithTf asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. In contrast to Eventually, +// it supplies a CollectT to the condition function, so that the condition +// function can use the CollectT to call other assertions. +// The condition is considered "met" if no errors are raised in a tick. +// The supplied CollectT collects all errors from one tick (if there are any). +// If the condition is not met before waitFor, the collected errors of +// the last tick are copied to t. +// +// externalValue := false +// go func() { +// time.Sleep(8*time.Second) +// externalValue = true +// }() +// assert.EventuallyWithTf(t, func(c *assert.CollectT, "error message %s", "formatted") { +// // add assertions as needed; any assertion failure will fail the current tick +// assert.True(c, externalValue, "expected 'externalValue' to be true") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") +func EventuallyWithTf(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return EventuallyWithT(t, condition, waitFor, tick, append([]interface{}{msg}, args...)...) +} + +// Exactlyf asserts that two objects are equal in value and type. +// +// assert.Exactlyf(t, int32(123), int64(123), "error message %s", "formatted") +func Exactlyf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Exactly(t, expected, actual, append([]interface{}{msg}, args...)...) +} + +// Failf reports a failure through +func Failf(t TestingT, failureMessage string, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Fail(t, failureMessage, append([]interface{}{msg}, args...)...) +} + +// FailNowf fails test +func FailNowf(t TestingT, failureMessage string, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return FailNow(t, failureMessage, append([]interface{}{msg}, args...)...) +} + +// Falsef asserts that the specified value is false. +// +// assert.Falsef(t, myBool, "error message %s", "formatted") +func Falsef(t TestingT, value bool, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return False(t, value, append([]interface{}{msg}, args...)...) +} + +// FileExistsf checks whether a file exists in the given path. It also fails if +// the path points to a directory or there is an error when trying to check the file. +func FileExistsf(t TestingT, path string, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return FileExists(t, path, append([]interface{}{msg}, args...)...) +} + +// Greaterf asserts that the first element is greater than the second +// +// assert.Greaterf(t, 2, 1, "error message %s", "formatted") +// assert.Greaterf(t, float64(2), float64(1), "error message %s", "formatted") +// assert.Greaterf(t, "b", "a", "error message %s", "formatted") +func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Greater(t, e1, e2, append([]interface{}{msg}, args...)...) +} + +// GreaterOrEqualf asserts that the first element is greater than or equal to the second +// +// assert.GreaterOrEqualf(t, 2, 1, "error message %s", "formatted") +// assert.GreaterOrEqualf(t, 2, 2, "error message %s", "formatted") +// assert.GreaterOrEqualf(t, "b", "a", "error message %s", "formatted") +// assert.GreaterOrEqualf(t, "b", "b", "error message %s", "formatted") +func GreaterOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return GreaterOrEqual(t, e1, e2, append([]interface{}{msg}, args...)...) +} + +// HTTPBodyContainsf asserts that a specified handler returns a +// body that contains a string. +// +// assert.HTTPBodyContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPBodyContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return HTTPBodyContains(t, handler, method, url, values, str, append([]interface{}{msg}, args...)...) +} + +// HTTPBodyNotContainsf asserts that a specified handler returns a +// body that does not contain a string. +// +// assert.HTTPBodyNotContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return HTTPBodyNotContains(t, handler, method, url, values, str, append([]interface{}{msg}, args...)...) +} + +// HTTPErrorf asserts that a specified handler returns an error status code. +// +// assert.HTTPErrorf(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return HTTPError(t, handler, method, url, values, append([]interface{}{msg}, args...)...) +} + +// HTTPRedirectf asserts that a specified handler returns a redirect status code. +// +// assert.HTTPRedirectf(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return HTTPRedirect(t, handler, method, url, values, append([]interface{}{msg}, args...)...) +} + +// HTTPStatusCodef asserts that a specified handler returns a specified status code. +// +// assert.HTTPStatusCodef(t, myHandler, "GET", "/notImplemented", nil, 501, "error message %s", "formatted") +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPStatusCodef(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, statuscode int, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return HTTPStatusCode(t, handler, method, url, values, statuscode, append([]interface{}{msg}, args...)...) +} + +// HTTPSuccessf asserts that a specified handler returns a success status code. +// +// assert.HTTPSuccessf(t, myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted") +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return HTTPSuccess(t, handler, method, url, values, append([]interface{}{msg}, args...)...) +} + +// Implementsf asserts that an object is implemented by the specified interface. +// +// assert.Implementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted") +func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Implements(t, interfaceObject, object, append([]interface{}{msg}, args...)...) +} + +// InDeltaf asserts that the two numerals are within delta of each other. +// +// assert.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted") +func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return InDelta(t, expected, actual, delta, append([]interface{}{msg}, args...)...) +} + +// InDeltaMapValuesf is the same as InDelta, but it compares all values between two maps. Both maps must have exactly the same keys. +func InDeltaMapValuesf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return InDeltaMapValues(t, expected, actual, delta, append([]interface{}{msg}, args...)...) +} + +// InDeltaSlicef is the same as InDelta, except it compares two slices. +func InDeltaSlicef(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return InDeltaSlice(t, expected, actual, delta, append([]interface{}{msg}, args...)...) +} + +// InEpsilonf asserts that expected and actual have a relative error less than epsilon +func InEpsilonf(t TestingT, expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return InEpsilon(t, expected, actual, epsilon, append([]interface{}{msg}, args...)...) +} + +// InEpsilonSlicef is the same as InEpsilon, except it compares each value from two slices. +func InEpsilonSlicef(t TestingT, expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return InEpsilonSlice(t, expected, actual, epsilon, append([]interface{}{msg}, args...)...) +} + +// IsDecreasingf asserts that the collection is decreasing +// +// assert.IsDecreasingf(t, []int{2, 1, 0}, "error message %s", "formatted") +// assert.IsDecreasingf(t, []float{2, 1}, "error message %s", "formatted") +// assert.IsDecreasingf(t, []string{"b", "a"}, "error message %s", "formatted") +func IsDecreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return IsDecreasing(t, object, append([]interface{}{msg}, args...)...) +} + +// IsIncreasingf asserts that the collection is increasing +// +// assert.IsIncreasingf(t, []int{1, 2, 3}, "error message %s", "formatted") +// assert.IsIncreasingf(t, []float{1, 2}, "error message %s", "formatted") +// assert.IsIncreasingf(t, []string{"a", "b"}, "error message %s", "formatted") +func IsIncreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return IsIncreasing(t, object, append([]interface{}{msg}, args...)...) +} + +// IsNonDecreasingf asserts that the collection is not decreasing +// +// assert.IsNonDecreasingf(t, []int{1, 1, 2}, "error message %s", "formatted") +// assert.IsNonDecreasingf(t, []float{1, 2}, "error message %s", "formatted") +// assert.IsNonDecreasingf(t, []string{"a", "b"}, "error message %s", "formatted") +func IsNonDecreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return IsNonDecreasing(t, object, append([]interface{}{msg}, args...)...) +} + +// IsNonIncreasingf asserts that the collection is not increasing +// +// assert.IsNonIncreasingf(t, []int{2, 1, 1}, "error message %s", "formatted") +// assert.IsNonIncreasingf(t, []float{2, 1}, "error message %s", "formatted") +// assert.IsNonIncreasingf(t, []string{"b", "a"}, "error message %s", "formatted") +func IsNonIncreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return IsNonIncreasing(t, object, append([]interface{}{msg}, args...)...) +} + +// IsTypef asserts that the specified objects are of the same type. +func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return IsType(t, expectedType, object, append([]interface{}{msg}, args...)...) +} + +// JSONEqf asserts that two JSON strings are equivalent. +// +// assert.JSONEqf(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") +func JSONEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return JSONEq(t, expected, actual, append([]interface{}{msg}, args...)...) +} + +// Lenf asserts that the specified object has specific length. +// Lenf also fails if the object has a type that len() not accept. +// +// assert.Lenf(t, mySlice, 3, "error message %s", "formatted") +func Lenf(t TestingT, object interface{}, length int, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Len(t, object, length, append([]interface{}{msg}, args...)...) +} + +// Lessf asserts that the first element is less than the second +// +// assert.Lessf(t, 1, 2, "error message %s", "formatted") +// assert.Lessf(t, float64(1), float64(2), "error message %s", "formatted") +// assert.Lessf(t, "a", "b", "error message %s", "formatted") +func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Less(t, e1, e2, append([]interface{}{msg}, args...)...) +} + +// LessOrEqualf asserts that the first element is less than or equal to the second +// +// assert.LessOrEqualf(t, 1, 2, "error message %s", "formatted") +// assert.LessOrEqualf(t, 2, 2, "error message %s", "formatted") +// assert.LessOrEqualf(t, "a", "b", "error message %s", "formatted") +// assert.LessOrEqualf(t, "b", "b", "error message %s", "formatted") +func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return LessOrEqual(t, e1, e2, append([]interface{}{msg}, args...)...) +} + +// Negativef asserts that the specified element is negative +// +// assert.Negativef(t, -1, "error message %s", "formatted") +// assert.Negativef(t, -1.23, "error message %s", "formatted") +func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Negative(t, e, append([]interface{}{msg}, args...)...) +} + +// Neverf asserts that the given condition doesn't satisfy in waitFor time, +// periodically checking the target function each tick. +// +// assert.Neverf(t, func() bool { return false; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") +func Neverf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Never(t, condition, waitFor, tick, append([]interface{}{msg}, args...)...) +} + +// Nilf asserts that the specified object is nil. +// +// assert.Nilf(t, err, "error message %s", "formatted") +func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Nil(t, object, append([]interface{}{msg}, args...)...) +} + +// NoDirExistsf checks whether a directory does not exist in the given path. +// It fails if the path points to an existing _directory_ only. +func NoDirExistsf(t TestingT, path string, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NoDirExists(t, path, append([]interface{}{msg}, args...)...) +} + +// NoErrorf asserts that a function returned no error (i.e. `nil`). +// +// actualObj, err := SomeFunction() +// if assert.NoErrorf(t, err, "error message %s", "formatted") { +// assert.Equal(t, expectedObj, actualObj) +// } +func NoErrorf(t TestingT, err error, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NoError(t, err, append([]interface{}{msg}, args...)...) +} + +// NoFileExistsf checks whether a file does not exist in a given path. It fails +// if the path points to an existing _file_ only. +func NoFileExistsf(t TestingT, path string, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NoFileExists(t, path, append([]interface{}{msg}, args...)...) +} + +// NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the +// specified substring or element. +// +// assert.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted") +// assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted") +// assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted") +func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotContains(t, s, contains, append([]interface{}{msg}, args...)...) +} + +// NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// assert.NotElementsMatchf(t, [1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false +// +// assert.NotElementsMatchf(t, [1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true +// +// assert.NotElementsMatchf(t, [1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true +func NotElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...) +} + +// NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either +// a slice or a channel with len == 0. +// +// if assert.NotEmptyf(t, obj, "error message %s", "formatted") { +// assert.Equal(t, "two", obj[1]) +// } +func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotEmpty(t, object, append([]interface{}{msg}, args...)...) +} + +// NotEqualf asserts that the specified values are NOT equal. +// +// assert.NotEqualf(t, obj1, obj2, "error message %s", "formatted") +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). +func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotEqual(t, expected, actual, append([]interface{}{msg}, args...)...) +} + +// NotEqualValuesf asserts that two objects are not equal even when converted to the same type +// +// assert.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted") +func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotEqualValues(t, expected, actual, append([]interface{}{msg}, args...)...) +} + +// NotErrorAsf asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func NotErrorAsf(t TestingT, err error, target interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotErrorAs(t, err, target, append([]interface{}{msg}, args...)...) +} + +// NotErrorIsf asserts that none of the errors in err's chain matches target. +// This is a wrapper for errors.Is. +func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotErrorIs(t, err, target, append([]interface{}{msg}, args...)...) +} + +// NotImplementsf asserts that an object does not implement the specified interface. +// +// assert.NotImplementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted") +func NotImplementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotImplements(t, interfaceObject, object, append([]interface{}{msg}, args...)...) +} + +// NotNilf asserts that the specified object is not nil. +// +// assert.NotNilf(t, err, "error message %s", "formatted") +func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotNil(t, object, append([]interface{}{msg}, args...)...) +} + +// NotPanicsf asserts that the code inside the specified PanicTestFunc does NOT panic. +// +// assert.NotPanicsf(t, func(){ RemainCalm() }, "error message %s", "formatted") +func NotPanicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotPanics(t, f, append([]interface{}{msg}, args...)...) +} + +// NotRegexpf asserts that a specified regexp does not match a string. +// +// assert.NotRegexpf(t, regexp.MustCompile("starts"), "it's starting", "error message %s", "formatted") +// assert.NotRegexpf(t, "^start", "it's not starting", "error message %s", "formatted") +func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotRegexp(t, rx, str, append([]interface{}{msg}, args...)...) +} + +// NotSamef asserts that two pointers do not reference the same object. +// +// assert.NotSamef(t, ptr1, ptr2, "error message %s", "formatted") +// +// Both arguments must be pointer variables. Pointer variable sameness is +// determined based on the equality of both type and value. +func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotSame(t, expected, actual, append([]interface{}{msg}, args...)...) +} + +// NotSubsetf asserts that the specified list(array, slice...) or map does NOT +// contain all elements given in the specified subset list(array, slice...) or +// map. +// +// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted") +// assert.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted") +func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotSubset(t, list, subset, append([]interface{}{msg}, args...)...) +} + +// NotZerof asserts that i is not the zero value for its type. +func NotZerof(t TestingT, i interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotZero(t, i, append([]interface{}{msg}, args...)...) +} + +// Panicsf asserts that the code inside the specified PanicTestFunc panics. +// +// assert.Panicsf(t, func(){ GoCrazy() }, "error message %s", "formatted") +func Panicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Panics(t, f, append([]interface{}{msg}, args...)...) +} + +// PanicsWithErrorf asserts that the code inside the specified PanicTestFunc +// panics, and that the recovered panic value is an error that satisfies the +// EqualError comparison. +// +// assert.PanicsWithErrorf(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") +func PanicsWithErrorf(t TestingT, errString string, f PanicTestFunc, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return PanicsWithError(t, errString, f, append([]interface{}{msg}, args...)...) +} + +// PanicsWithValuef asserts that the code inside the specified PanicTestFunc panics, and that +// the recovered panic value equals the expected panic value. +// +// assert.PanicsWithValuef(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") +func PanicsWithValuef(t TestingT, expected interface{}, f PanicTestFunc, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return PanicsWithValue(t, expected, f, append([]interface{}{msg}, args...)...) +} + +// Positivef asserts that the specified element is positive +// +// assert.Positivef(t, 1, "error message %s", "formatted") +// assert.Positivef(t, 1.23, "error message %s", "formatted") +func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Positive(t, e, append([]interface{}{msg}, args...)...) +} + +// Regexpf asserts that a specified regexp matches a string. +// +// assert.Regexpf(t, regexp.MustCompile("start"), "it's starting", "error message %s", "formatted") +// assert.Regexpf(t, "start...$", "it's not starting", "error message %s", "formatted") +func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Regexp(t, rx, str, append([]interface{}{msg}, args...)...) +} + +// Samef asserts that two pointers reference the same object. +// +// assert.Samef(t, ptr1, ptr2, "error message %s", "formatted") +// +// Both arguments must be pointer variables. Pointer variable sameness is +// determined based on the equality of both type and value. +func Samef(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Same(t, expected, actual, append([]interface{}{msg}, args...)...) +} + +// Subsetf asserts that the specified list(array, slice...) or map contains all +// elements given in the specified subset list(array, slice...) or map. +// +// assert.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted") +// assert.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted") +func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Subset(t, list, subset, append([]interface{}{msg}, args...)...) +} + +// Truef asserts that the specified value is true. +// +// assert.Truef(t, myBool, "error message %s", "formatted") +func Truef(t TestingT, value bool, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return True(t, value, append([]interface{}{msg}, args...)...) +} + +// WithinDurationf asserts that the two times are within duration delta of each other. +// +// assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") +func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return WithinDuration(t, expected, actual, delta, append([]interface{}{msg}, args...)...) +} + +// WithinRangef asserts that a time is within a time range (inclusive). +// +// assert.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted") +func WithinRangef(t TestingT, actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return WithinRange(t, actual, start, end, append([]interface{}{msg}, args...)...) +} + +// YAMLEqf asserts that two YAML strings are equivalent. +func YAMLEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return YAMLEq(t, expected, actual, append([]interface{}{msg}, args...)...) +} + +// Zerof asserts that i is the zero value for its type. +func Zerof(t TestingT, i interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Zero(t, i, append([]interface{}{msg}, args...)...) +} diff --git a/vendor/github.com/stretchr/testify/assert/assertion_forward.go b/vendor/github.com/stretchr/testify/assert/assertion_forward.go new file mode 100644 index 0000000..2162908 --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/assertion_forward.go @@ -0,0 +1,1673 @@ +// Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT. + +package assert + +import ( + http "net/http" + url "net/url" + time "time" +) + +// Condition uses a Comparison to assert a complex condition. +func (a *Assertions) Condition(comp Comparison, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Condition(a.t, comp, msgAndArgs...) +} + +// Conditionf uses a Comparison to assert a complex condition. +func (a *Assertions) Conditionf(comp Comparison, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Conditionf(a.t, comp, msg, args...) +} + +// Contains asserts that the specified string, list(array, slice...) or map contains the +// specified substring or element. +// +// a.Contains("Hello World", "World") +// a.Contains(["Hello", "World"], "World") +// a.Contains({"Hello": "World"}, "Hello") +func (a *Assertions) Contains(s interface{}, contains interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Contains(a.t, s, contains, msgAndArgs...) +} + +// Containsf asserts that the specified string, list(array, slice...) or map contains the +// specified substring or element. +// +// a.Containsf("Hello World", "World", "error message %s", "formatted") +// a.Containsf(["Hello", "World"], "World", "error message %s", "formatted") +// a.Containsf({"Hello": "World"}, "Hello", "error message %s", "formatted") +func (a *Assertions) Containsf(s interface{}, contains interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Containsf(a.t, s, contains, msg, args...) +} + +// DirExists checks whether a directory exists in the given path. It also fails +// if the path is a file rather a directory or there is an error checking whether it exists. +func (a *Assertions) DirExists(path string, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return DirExists(a.t, path, msgAndArgs...) +} + +// DirExistsf checks whether a directory exists in the given path. It also fails +// if the path is a file rather a directory or there is an error checking whether it exists. +func (a *Assertions) DirExistsf(path string, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return DirExistsf(a.t, path, msg, args...) +} + +// ElementsMatch asserts that the specified listA(array, slice...) is equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should match. +// +// a.ElementsMatch([1, 3, 2, 3], [1, 3, 3, 2]) +func (a *Assertions) ElementsMatch(listA interface{}, listB interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return ElementsMatch(a.t, listA, listB, msgAndArgs...) +} + +// ElementsMatchf asserts that the specified listA(array, slice...) is equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should match. +// +// a.ElementsMatchf([1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted") +func (a *Assertions) ElementsMatchf(listA interface{}, listB interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return ElementsMatchf(a.t, listA, listB, msg, args...) +} + +// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either +// a slice or a channel with len == 0. +// +// a.Empty(obj) +func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Empty(a.t, object, msgAndArgs...) +} + +// Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either +// a slice or a channel with len == 0. +// +// a.Emptyf(obj, "error message %s", "formatted") +func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Emptyf(a.t, object, msg, args...) +} + +// Equal asserts that two objects are equal. +// +// a.Equal(123, 123) +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). Function equality +// cannot be determined and will always fail. +func (a *Assertions) Equal(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Equal(a.t, expected, actual, msgAndArgs...) +} + +// EqualError asserts that a function returned an error (i.e. not `nil`) +// and that it is equal to the provided error. +// +// actualObj, err := SomeFunction() +// a.EqualError(err, expectedErrorString) +func (a *Assertions) EqualError(theError error, errString string, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return EqualError(a.t, theError, errString, msgAndArgs...) +} + +// EqualErrorf asserts that a function returned an error (i.e. not `nil`) +// and that it is equal to the provided error. +// +// actualObj, err := SomeFunction() +// a.EqualErrorf(err, expectedErrorString, "error message %s", "formatted") +func (a *Assertions) EqualErrorf(theError error, errString string, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return EqualErrorf(a.t, theError, errString, msg, args...) +} + +// EqualExportedValues asserts that the types of two objects are equal and their public +// fields are also equal. This is useful for comparing structs that have private fields +// that could potentially differ. +// +// type S struct { +// Exported int +// notExported int +// } +// a.EqualExportedValues(S{1, 2}, S{1, 3}) => true +// a.EqualExportedValues(S{1, 2}, S{2, 3}) => false +func (a *Assertions) EqualExportedValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return EqualExportedValues(a.t, expected, actual, msgAndArgs...) +} + +// EqualExportedValuesf asserts that the types of two objects are equal and their public +// fields are also equal. This is useful for comparing structs that have private fields +// that could potentially differ. +// +// type S struct { +// Exported int +// notExported int +// } +// a.EqualExportedValuesf(S{1, 2}, S{1, 3}, "error message %s", "formatted") => true +// a.EqualExportedValuesf(S{1, 2}, S{2, 3}, "error message %s", "formatted") => false +func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return EqualExportedValuesf(a.t, expected, actual, msg, args...) +} + +// EqualValues asserts that two objects are equal or convertible to the larger +// type and equal. +// +// a.EqualValues(uint32(123), int32(123)) +func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return EqualValues(a.t, expected, actual, msgAndArgs...) +} + +// EqualValuesf asserts that two objects are equal or convertible to the larger +// type and equal. +// +// a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted") +func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return EqualValuesf(a.t, expected, actual, msg, args...) +} + +// Equalf asserts that two objects are equal. +// +// a.Equalf(123, 123, "error message %s", "formatted") +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). Function equality +// cannot be determined and will always fail. +func (a *Assertions) Equalf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Equalf(a.t, expected, actual, msg, args...) +} + +// Error asserts that a function returned an error (i.e. not `nil`). +// +// actualObj, err := SomeFunction() +// if a.Error(err) { +// assert.Equal(t, expectedError, err) +// } +func (a *Assertions) Error(err error, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Error(a.t, err, msgAndArgs...) +} + +// ErrorAs asserts that at least one of the errors in err's chain matches target, and if so, sets target to that error value. +// This is a wrapper for errors.As. +func (a *Assertions) ErrorAs(err error, target interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return ErrorAs(a.t, err, target, msgAndArgs...) +} + +// ErrorAsf asserts that at least one of the errors in err's chain matches target, and if so, sets target to that error value. +// This is a wrapper for errors.As. +func (a *Assertions) ErrorAsf(err error, target interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return ErrorAsf(a.t, err, target, msg, args...) +} + +// ErrorContains asserts that a function returned an error (i.e. not `nil`) +// and that the error contains the specified substring. +// +// actualObj, err := SomeFunction() +// a.ErrorContains(err, expectedErrorSubString) +func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return ErrorContains(a.t, theError, contains, msgAndArgs...) +} + +// ErrorContainsf asserts that a function returned an error (i.e. not `nil`) +// and that the error contains the specified substring. +// +// actualObj, err := SomeFunction() +// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted") +func (a *Assertions) ErrorContainsf(theError error, contains string, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return ErrorContainsf(a.t, theError, contains, msg, args...) +} + +// ErrorIs asserts that at least one of the errors in err's chain matches target. +// This is a wrapper for errors.Is. +func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return ErrorIs(a.t, err, target, msgAndArgs...) +} + +// ErrorIsf asserts that at least one of the errors in err's chain matches target. +// This is a wrapper for errors.Is. +func (a *Assertions) ErrorIsf(err error, target error, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return ErrorIsf(a.t, err, target, msg, args...) +} + +// Errorf asserts that a function returned an error (i.e. not `nil`). +// +// actualObj, err := SomeFunction() +// if a.Errorf(err, "error message %s", "formatted") { +// assert.Equal(t, expectedErrorf, err) +// } +func (a *Assertions) Errorf(err error, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Errorf(a.t, err, msg, args...) +} + +// Eventually asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. +// +// a.Eventually(func() bool { return true; }, time.Second, 10*time.Millisecond) +func (a *Assertions) Eventually(condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Eventually(a.t, condition, waitFor, tick, msgAndArgs...) +} + +// EventuallyWithT asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. In contrast to Eventually, +// it supplies a CollectT to the condition function, so that the condition +// function can use the CollectT to call other assertions. +// The condition is considered "met" if no errors are raised in a tick. +// The supplied CollectT collects all errors from one tick (if there are any). +// If the condition is not met before waitFor, the collected errors of +// the last tick are copied to t. +// +// externalValue := false +// go func() { +// time.Sleep(8*time.Second) +// externalValue = true +// }() +// a.EventuallyWithT(func(c *assert.CollectT) { +// // add assertions as needed; any assertion failure will fail the current tick +// assert.True(c, externalValue, "expected 'externalValue' to be true") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") +func (a *Assertions) EventuallyWithT(condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return EventuallyWithT(a.t, condition, waitFor, tick, msgAndArgs...) +} + +// EventuallyWithTf asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. In contrast to Eventually, +// it supplies a CollectT to the condition function, so that the condition +// function can use the CollectT to call other assertions. +// The condition is considered "met" if no errors are raised in a tick. +// The supplied CollectT collects all errors from one tick (if there are any). +// If the condition is not met before waitFor, the collected errors of +// the last tick are copied to t. +// +// externalValue := false +// go func() { +// time.Sleep(8*time.Second) +// externalValue = true +// }() +// a.EventuallyWithTf(func(c *assert.CollectT, "error message %s", "formatted") { +// // add assertions as needed; any assertion failure will fail the current tick +// assert.True(c, externalValue, "expected 'externalValue' to be true") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") +func (a *Assertions) EventuallyWithTf(condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return EventuallyWithTf(a.t, condition, waitFor, tick, msg, args...) +} + +// Eventuallyf asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. +// +// a.Eventuallyf(func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") +func (a *Assertions) Eventuallyf(condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Eventuallyf(a.t, condition, waitFor, tick, msg, args...) +} + +// Exactly asserts that two objects are equal in value and type. +// +// a.Exactly(int32(123), int64(123)) +func (a *Assertions) Exactly(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Exactly(a.t, expected, actual, msgAndArgs...) +} + +// Exactlyf asserts that two objects are equal in value and type. +// +// a.Exactlyf(int32(123), int64(123), "error message %s", "formatted") +func (a *Assertions) Exactlyf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Exactlyf(a.t, expected, actual, msg, args...) +} + +// Fail reports a failure through +func (a *Assertions) Fail(failureMessage string, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Fail(a.t, failureMessage, msgAndArgs...) +} + +// FailNow fails test +func (a *Assertions) FailNow(failureMessage string, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return FailNow(a.t, failureMessage, msgAndArgs...) +} + +// FailNowf fails test +func (a *Assertions) FailNowf(failureMessage string, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return FailNowf(a.t, failureMessage, msg, args...) +} + +// Failf reports a failure through +func (a *Assertions) Failf(failureMessage string, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Failf(a.t, failureMessage, msg, args...) +} + +// False asserts that the specified value is false. +// +// a.False(myBool) +func (a *Assertions) False(value bool, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return False(a.t, value, msgAndArgs...) +} + +// Falsef asserts that the specified value is false. +// +// a.Falsef(myBool, "error message %s", "formatted") +func (a *Assertions) Falsef(value bool, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Falsef(a.t, value, msg, args...) +} + +// FileExists checks whether a file exists in the given path. It also fails if +// the path points to a directory or there is an error when trying to check the file. +func (a *Assertions) FileExists(path string, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return FileExists(a.t, path, msgAndArgs...) +} + +// FileExistsf checks whether a file exists in the given path. It also fails if +// the path points to a directory or there is an error when trying to check the file. +func (a *Assertions) FileExistsf(path string, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return FileExistsf(a.t, path, msg, args...) +} + +// Greater asserts that the first element is greater than the second +// +// a.Greater(2, 1) +// a.Greater(float64(2), float64(1)) +// a.Greater("b", "a") +func (a *Assertions) Greater(e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Greater(a.t, e1, e2, msgAndArgs...) +} + +// GreaterOrEqual asserts that the first element is greater than or equal to the second +// +// a.GreaterOrEqual(2, 1) +// a.GreaterOrEqual(2, 2) +// a.GreaterOrEqual("b", "a") +// a.GreaterOrEqual("b", "b") +func (a *Assertions) GreaterOrEqual(e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return GreaterOrEqual(a.t, e1, e2, msgAndArgs...) +} + +// GreaterOrEqualf asserts that the first element is greater than or equal to the second +// +// a.GreaterOrEqualf(2, 1, "error message %s", "formatted") +// a.GreaterOrEqualf(2, 2, "error message %s", "formatted") +// a.GreaterOrEqualf("b", "a", "error message %s", "formatted") +// a.GreaterOrEqualf("b", "b", "error message %s", "formatted") +func (a *Assertions) GreaterOrEqualf(e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return GreaterOrEqualf(a.t, e1, e2, msg, args...) +} + +// Greaterf asserts that the first element is greater than the second +// +// a.Greaterf(2, 1, "error message %s", "formatted") +// a.Greaterf(float64(2), float64(1), "error message %s", "formatted") +// a.Greaterf("b", "a", "error message %s", "formatted") +func (a *Assertions) Greaterf(e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Greaterf(a.t, e1, e2, msg, args...) +} + +// HTTPBodyContains asserts that a specified handler returns a +// body that contains a string. +// +// a.HTTPBodyContains(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPBodyContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return HTTPBodyContains(a.t, handler, method, url, values, str, msgAndArgs...) +} + +// HTTPBodyContainsf asserts that a specified handler returns a +// body that contains a string. +// +// a.HTTPBodyContainsf(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPBodyContainsf(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return HTTPBodyContainsf(a.t, handler, method, url, values, str, msg, args...) +} + +// HTTPBodyNotContains asserts that a specified handler returns a +// body that does not contain a string. +// +// a.HTTPBodyNotContains(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPBodyNotContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return HTTPBodyNotContains(a.t, handler, method, url, values, str, msgAndArgs...) +} + +// HTTPBodyNotContainsf asserts that a specified handler returns a +// body that does not contain a string. +// +// a.HTTPBodyNotContainsf(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPBodyNotContainsf(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return HTTPBodyNotContainsf(a.t, handler, method, url, values, str, msg, args...) +} + +// HTTPError asserts that a specified handler returns an error status code. +// +// a.HTTPError(myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPError(handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return HTTPError(a.t, handler, method, url, values, msgAndArgs...) +} + +// HTTPErrorf asserts that a specified handler returns an error status code. +// +// a.HTTPErrorf(myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPErrorf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return HTTPErrorf(a.t, handler, method, url, values, msg, args...) +} + +// HTTPRedirect asserts that a specified handler returns a redirect status code. +// +// a.HTTPRedirect(myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPRedirect(handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return HTTPRedirect(a.t, handler, method, url, values, msgAndArgs...) +} + +// HTTPRedirectf asserts that a specified handler returns a redirect status code. +// +// a.HTTPRedirectf(myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPRedirectf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return HTTPRedirectf(a.t, handler, method, url, values, msg, args...) +} + +// HTTPStatusCode asserts that a specified handler returns a specified status code. +// +// a.HTTPStatusCode(myHandler, "GET", "/notImplemented", nil, 501) +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPStatusCode(handler http.HandlerFunc, method string, url string, values url.Values, statuscode int, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return HTTPStatusCode(a.t, handler, method, url, values, statuscode, msgAndArgs...) +} + +// HTTPStatusCodef asserts that a specified handler returns a specified status code. +// +// a.HTTPStatusCodef(myHandler, "GET", "/notImplemented", nil, 501, "error message %s", "formatted") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPStatusCodef(handler http.HandlerFunc, method string, url string, values url.Values, statuscode int, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return HTTPStatusCodef(a.t, handler, method, url, values, statuscode, msg, args...) +} + +// HTTPSuccess asserts that a specified handler returns a success status code. +// +// a.HTTPSuccess(myHandler, "POST", "http://www.google.com", nil) +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPSuccess(handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return HTTPSuccess(a.t, handler, method, url, values, msgAndArgs...) +} + +// HTTPSuccessf asserts that a specified handler returns a success status code. +// +// a.HTTPSuccessf(myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPSuccessf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return HTTPSuccessf(a.t, handler, method, url, values, msg, args...) +} + +// Implements asserts that an object is implemented by the specified interface. +// +// a.Implements((*MyInterface)(nil), new(MyObject)) +func (a *Assertions) Implements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Implements(a.t, interfaceObject, object, msgAndArgs...) +} + +// Implementsf asserts that an object is implemented by the specified interface. +// +// a.Implementsf((*MyInterface)(nil), new(MyObject), "error message %s", "formatted") +func (a *Assertions) Implementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Implementsf(a.t, interfaceObject, object, msg, args...) +} + +// InDelta asserts that the two numerals are within delta of each other. +// +// a.InDelta(math.Pi, 22/7.0, 0.01) +func (a *Assertions) InDelta(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return InDelta(a.t, expected, actual, delta, msgAndArgs...) +} + +// InDeltaMapValues is the same as InDelta, but it compares all values between two maps. Both maps must have exactly the same keys. +func (a *Assertions) InDeltaMapValues(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return InDeltaMapValues(a.t, expected, actual, delta, msgAndArgs...) +} + +// InDeltaMapValuesf is the same as InDelta, but it compares all values between two maps. Both maps must have exactly the same keys. +func (a *Assertions) InDeltaMapValuesf(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return InDeltaMapValuesf(a.t, expected, actual, delta, msg, args...) +} + +// InDeltaSlice is the same as InDelta, except it compares two slices. +func (a *Assertions) InDeltaSlice(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return InDeltaSlice(a.t, expected, actual, delta, msgAndArgs...) +} + +// InDeltaSlicef is the same as InDelta, except it compares two slices. +func (a *Assertions) InDeltaSlicef(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return InDeltaSlicef(a.t, expected, actual, delta, msg, args...) +} + +// InDeltaf asserts that the two numerals are within delta of each other. +// +// a.InDeltaf(math.Pi, 22/7.0, 0.01, "error message %s", "formatted") +func (a *Assertions) InDeltaf(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return InDeltaf(a.t, expected, actual, delta, msg, args...) +} + +// InEpsilon asserts that expected and actual have a relative error less than epsilon +func (a *Assertions) InEpsilon(expected interface{}, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return InEpsilon(a.t, expected, actual, epsilon, msgAndArgs...) +} + +// InEpsilonSlice is the same as InEpsilon, except it compares each value from two slices. +func (a *Assertions) InEpsilonSlice(expected interface{}, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return InEpsilonSlice(a.t, expected, actual, epsilon, msgAndArgs...) +} + +// InEpsilonSlicef is the same as InEpsilon, except it compares each value from two slices. +func (a *Assertions) InEpsilonSlicef(expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return InEpsilonSlicef(a.t, expected, actual, epsilon, msg, args...) +} + +// InEpsilonf asserts that expected and actual have a relative error less than epsilon +func (a *Assertions) InEpsilonf(expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return InEpsilonf(a.t, expected, actual, epsilon, msg, args...) +} + +// IsDecreasing asserts that the collection is decreasing +// +// a.IsDecreasing([]int{2, 1, 0}) +// a.IsDecreasing([]float{2, 1}) +// a.IsDecreasing([]string{"b", "a"}) +func (a *Assertions) IsDecreasing(object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return IsDecreasing(a.t, object, msgAndArgs...) +} + +// IsDecreasingf asserts that the collection is decreasing +// +// a.IsDecreasingf([]int{2, 1, 0}, "error message %s", "formatted") +// a.IsDecreasingf([]float{2, 1}, "error message %s", "formatted") +// a.IsDecreasingf([]string{"b", "a"}, "error message %s", "formatted") +func (a *Assertions) IsDecreasingf(object interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return IsDecreasingf(a.t, object, msg, args...) +} + +// IsIncreasing asserts that the collection is increasing +// +// a.IsIncreasing([]int{1, 2, 3}) +// a.IsIncreasing([]float{1, 2}) +// a.IsIncreasing([]string{"a", "b"}) +func (a *Assertions) IsIncreasing(object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return IsIncreasing(a.t, object, msgAndArgs...) +} + +// IsIncreasingf asserts that the collection is increasing +// +// a.IsIncreasingf([]int{1, 2, 3}, "error message %s", "formatted") +// a.IsIncreasingf([]float{1, 2}, "error message %s", "formatted") +// a.IsIncreasingf([]string{"a", "b"}, "error message %s", "formatted") +func (a *Assertions) IsIncreasingf(object interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return IsIncreasingf(a.t, object, msg, args...) +} + +// IsNonDecreasing asserts that the collection is not decreasing +// +// a.IsNonDecreasing([]int{1, 1, 2}) +// a.IsNonDecreasing([]float{1, 2}) +// a.IsNonDecreasing([]string{"a", "b"}) +func (a *Assertions) IsNonDecreasing(object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return IsNonDecreasing(a.t, object, msgAndArgs...) +} + +// IsNonDecreasingf asserts that the collection is not decreasing +// +// a.IsNonDecreasingf([]int{1, 1, 2}, "error message %s", "formatted") +// a.IsNonDecreasingf([]float{1, 2}, "error message %s", "formatted") +// a.IsNonDecreasingf([]string{"a", "b"}, "error message %s", "formatted") +func (a *Assertions) IsNonDecreasingf(object interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return IsNonDecreasingf(a.t, object, msg, args...) +} + +// IsNonIncreasing asserts that the collection is not increasing +// +// a.IsNonIncreasing([]int{2, 1, 1}) +// a.IsNonIncreasing([]float{2, 1}) +// a.IsNonIncreasing([]string{"b", "a"}) +func (a *Assertions) IsNonIncreasing(object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return IsNonIncreasing(a.t, object, msgAndArgs...) +} + +// IsNonIncreasingf asserts that the collection is not increasing +// +// a.IsNonIncreasingf([]int{2, 1, 1}, "error message %s", "formatted") +// a.IsNonIncreasingf([]float{2, 1}, "error message %s", "formatted") +// a.IsNonIncreasingf([]string{"b", "a"}, "error message %s", "formatted") +func (a *Assertions) IsNonIncreasingf(object interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return IsNonIncreasingf(a.t, object, msg, args...) +} + +// IsType asserts that the specified objects are of the same type. +func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return IsType(a.t, expectedType, object, msgAndArgs...) +} + +// IsTypef asserts that the specified objects are of the same type. +func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return IsTypef(a.t, expectedType, object, msg, args...) +} + +// JSONEq asserts that two JSON strings are equivalent. +// +// a.JSONEq(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) +func (a *Assertions) JSONEq(expected string, actual string, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return JSONEq(a.t, expected, actual, msgAndArgs...) +} + +// JSONEqf asserts that two JSON strings are equivalent. +// +// a.JSONEqf(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") +func (a *Assertions) JSONEqf(expected string, actual string, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return JSONEqf(a.t, expected, actual, msg, args...) +} + +// Len asserts that the specified object has specific length. +// Len also fails if the object has a type that len() not accept. +// +// a.Len(mySlice, 3) +func (a *Assertions) Len(object interface{}, length int, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Len(a.t, object, length, msgAndArgs...) +} + +// Lenf asserts that the specified object has specific length. +// Lenf also fails if the object has a type that len() not accept. +// +// a.Lenf(mySlice, 3, "error message %s", "formatted") +func (a *Assertions) Lenf(object interface{}, length int, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Lenf(a.t, object, length, msg, args...) +} + +// Less asserts that the first element is less than the second +// +// a.Less(1, 2) +// a.Less(float64(1), float64(2)) +// a.Less("a", "b") +func (a *Assertions) Less(e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Less(a.t, e1, e2, msgAndArgs...) +} + +// LessOrEqual asserts that the first element is less than or equal to the second +// +// a.LessOrEqual(1, 2) +// a.LessOrEqual(2, 2) +// a.LessOrEqual("a", "b") +// a.LessOrEqual("b", "b") +func (a *Assertions) LessOrEqual(e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return LessOrEqual(a.t, e1, e2, msgAndArgs...) +} + +// LessOrEqualf asserts that the first element is less than or equal to the second +// +// a.LessOrEqualf(1, 2, "error message %s", "formatted") +// a.LessOrEqualf(2, 2, "error message %s", "formatted") +// a.LessOrEqualf("a", "b", "error message %s", "formatted") +// a.LessOrEqualf("b", "b", "error message %s", "formatted") +func (a *Assertions) LessOrEqualf(e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return LessOrEqualf(a.t, e1, e2, msg, args...) +} + +// Lessf asserts that the first element is less than the second +// +// a.Lessf(1, 2, "error message %s", "formatted") +// a.Lessf(float64(1), float64(2), "error message %s", "formatted") +// a.Lessf("a", "b", "error message %s", "formatted") +func (a *Assertions) Lessf(e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Lessf(a.t, e1, e2, msg, args...) +} + +// Negative asserts that the specified element is negative +// +// a.Negative(-1) +// a.Negative(-1.23) +func (a *Assertions) Negative(e interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Negative(a.t, e, msgAndArgs...) +} + +// Negativef asserts that the specified element is negative +// +// a.Negativef(-1, "error message %s", "formatted") +// a.Negativef(-1.23, "error message %s", "formatted") +func (a *Assertions) Negativef(e interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Negativef(a.t, e, msg, args...) +} + +// Never asserts that the given condition doesn't satisfy in waitFor time, +// periodically checking the target function each tick. +// +// a.Never(func() bool { return false; }, time.Second, 10*time.Millisecond) +func (a *Assertions) Never(condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Never(a.t, condition, waitFor, tick, msgAndArgs...) +} + +// Neverf asserts that the given condition doesn't satisfy in waitFor time, +// periodically checking the target function each tick. +// +// a.Neverf(func() bool { return false; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") +func (a *Assertions) Neverf(condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Neverf(a.t, condition, waitFor, tick, msg, args...) +} + +// Nil asserts that the specified object is nil. +// +// a.Nil(err) +func (a *Assertions) Nil(object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Nil(a.t, object, msgAndArgs...) +} + +// Nilf asserts that the specified object is nil. +// +// a.Nilf(err, "error message %s", "formatted") +func (a *Assertions) Nilf(object interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Nilf(a.t, object, msg, args...) +} + +// NoDirExists checks whether a directory does not exist in the given path. +// It fails if the path points to an existing _directory_ only. +func (a *Assertions) NoDirExists(path string, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NoDirExists(a.t, path, msgAndArgs...) +} + +// NoDirExistsf checks whether a directory does not exist in the given path. +// It fails if the path points to an existing _directory_ only. +func (a *Assertions) NoDirExistsf(path string, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NoDirExistsf(a.t, path, msg, args...) +} + +// NoError asserts that a function returned no error (i.e. `nil`). +// +// actualObj, err := SomeFunction() +// if a.NoError(err) { +// assert.Equal(t, expectedObj, actualObj) +// } +func (a *Assertions) NoError(err error, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NoError(a.t, err, msgAndArgs...) +} + +// NoErrorf asserts that a function returned no error (i.e. `nil`). +// +// actualObj, err := SomeFunction() +// if a.NoErrorf(err, "error message %s", "formatted") { +// assert.Equal(t, expectedObj, actualObj) +// } +func (a *Assertions) NoErrorf(err error, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NoErrorf(a.t, err, msg, args...) +} + +// NoFileExists checks whether a file does not exist in a given path. It fails +// if the path points to an existing _file_ only. +func (a *Assertions) NoFileExists(path string, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NoFileExists(a.t, path, msgAndArgs...) +} + +// NoFileExistsf checks whether a file does not exist in a given path. It fails +// if the path points to an existing _file_ only. +func (a *Assertions) NoFileExistsf(path string, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NoFileExistsf(a.t, path, msg, args...) +} + +// NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the +// specified substring or element. +// +// a.NotContains("Hello World", "Earth") +// a.NotContains(["Hello", "World"], "Earth") +// a.NotContains({"Hello": "World"}, "Earth") +func (a *Assertions) NotContains(s interface{}, contains interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotContains(a.t, s, contains, msgAndArgs...) +} + +// NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the +// specified substring or element. +// +// a.NotContainsf("Hello World", "Earth", "error message %s", "formatted") +// a.NotContainsf(["Hello", "World"], "Earth", "error message %s", "formatted") +// a.NotContainsf({"Hello": "World"}, "Earth", "error message %s", "formatted") +func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotContainsf(a.t, s, contains, msg, args...) +} + +// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// a.NotElementsMatch([1, 1, 2, 3], [1, 1, 2, 3]) -> false +// +// a.NotElementsMatch([1, 1, 2, 3], [1, 2, 3]) -> true +// +// a.NotElementsMatch([1, 2, 3], [1, 2, 4]) -> true +func (a *Assertions) NotElementsMatch(listA interface{}, listB interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotElementsMatch(a.t, listA, listB, msgAndArgs...) +} + +// NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// a.NotElementsMatchf([1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false +// +// a.NotElementsMatchf([1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true +// +// a.NotElementsMatchf([1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true +func (a *Assertions) NotElementsMatchf(listA interface{}, listB interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotElementsMatchf(a.t, listA, listB, msg, args...) +} + +// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either +// a slice or a channel with len == 0. +// +// if a.NotEmpty(obj) { +// assert.Equal(t, "two", obj[1]) +// } +func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotEmpty(a.t, object, msgAndArgs...) +} + +// NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either +// a slice or a channel with len == 0. +// +// if a.NotEmptyf(obj, "error message %s", "formatted") { +// assert.Equal(t, "two", obj[1]) +// } +func (a *Assertions) NotEmptyf(object interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotEmptyf(a.t, object, msg, args...) +} + +// NotEqual asserts that the specified values are NOT equal. +// +// a.NotEqual(obj1, obj2) +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). +func (a *Assertions) NotEqual(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotEqual(a.t, expected, actual, msgAndArgs...) +} + +// NotEqualValues asserts that two objects are not equal even when converted to the same type +// +// a.NotEqualValues(obj1, obj2) +func (a *Assertions) NotEqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotEqualValues(a.t, expected, actual, msgAndArgs...) +} + +// NotEqualValuesf asserts that two objects are not equal even when converted to the same type +// +// a.NotEqualValuesf(obj1, obj2, "error message %s", "formatted") +func (a *Assertions) NotEqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotEqualValuesf(a.t, expected, actual, msg, args...) +} + +// NotEqualf asserts that the specified values are NOT equal. +// +// a.NotEqualf(obj1, obj2, "error message %s", "formatted") +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). +func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotEqualf(a.t, expected, actual, msg, args...) +} + +// NotErrorAs asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func (a *Assertions) NotErrorAs(err error, target interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotErrorAs(a.t, err, target, msgAndArgs...) +} + +// NotErrorAsf asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func (a *Assertions) NotErrorAsf(err error, target interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotErrorAsf(a.t, err, target, msg, args...) +} + +// NotErrorIs asserts that none of the errors in err's chain matches target. +// This is a wrapper for errors.Is. +func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotErrorIs(a.t, err, target, msgAndArgs...) +} + +// NotErrorIsf asserts that none of the errors in err's chain matches target. +// This is a wrapper for errors.Is. +func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotErrorIsf(a.t, err, target, msg, args...) +} + +// NotImplements asserts that an object does not implement the specified interface. +// +// a.NotImplements((*MyInterface)(nil), new(MyObject)) +func (a *Assertions) NotImplements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotImplements(a.t, interfaceObject, object, msgAndArgs...) +} + +// NotImplementsf asserts that an object does not implement the specified interface. +// +// a.NotImplementsf((*MyInterface)(nil), new(MyObject), "error message %s", "formatted") +func (a *Assertions) NotImplementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotImplementsf(a.t, interfaceObject, object, msg, args...) +} + +// NotNil asserts that the specified object is not nil. +// +// a.NotNil(err) +func (a *Assertions) NotNil(object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotNil(a.t, object, msgAndArgs...) +} + +// NotNilf asserts that the specified object is not nil. +// +// a.NotNilf(err, "error message %s", "formatted") +func (a *Assertions) NotNilf(object interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotNilf(a.t, object, msg, args...) +} + +// NotPanics asserts that the code inside the specified PanicTestFunc does NOT panic. +// +// a.NotPanics(func(){ RemainCalm() }) +func (a *Assertions) NotPanics(f PanicTestFunc, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotPanics(a.t, f, msgAndArgs...) +} + +// NotPanicsf asserts that the code inside the specified PanicTestFunc does NOT panic. +// +// a.NotPanicsf(func(){ RemainCalm() }, "error message %s", "formatted") +func (a *Assertions) NotPanicsf(f PanicTestFunc, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotPanicsf(a.t, f, msg, args...) +} + +// NotRegexp asserts that a specified regexp does not match a string. +// +// a.NotRegexp(regexp.MustCompile("starts"), "it's starting") +// a.NotRegexp("^start", "it's not starting") +func (a *Assertions) NotRegexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotRegexp(a.t, rx, str, msgAndArgs...) +} + +// NotRegexpf asserts that a specified regexp does not match a string. +// +// a.NotRegexpf(regexp.MustCompile("starts"), "it's starting", "error message %s", "formatted") +// a.NotRegexpf("^start", "it's not starting", "error message %s", "formatted") +func (a *Assertions) NotRegexpf(rx interface{}, str interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotRegexpf(a.t, rx, str, msg, args...) +} + +// NotSame asserts that two pointers do not reference the same object. +// +// a.NotSame(ptr1, ptr2) +// +// Both arguments must be pointer variables. Pointer variable sameness is +// determined based on the equality of both type and value. +func (a *Assertions) NotSame(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotSame(a.t, expected, actual, msgAndArgs...) +} + +// NotSamef asserts that two pointers do not reference the same object. +// +// a.NotSamef(ptr1, ptr2, "error message %s", "formatted") +// +// Both arguments must be pointer variables. Pointer variable sameness is +// determined based on the equality of both type and value. +func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotSamef(a.t, expected, actual, msg, args...) +} + +// NotSubset asserts that the specified list(array, slice...) or map does NOT +// contain all elements given in the specified subset list(array, slice...) or +// map. +// +// a.NotSubset([1, 3, 4], [1, 2]) +// a.NotSubset({"x": 1, "y": 2}, {"z": 3}) +func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotSubset(a.t, list, subset, msgAndArgs...) +} + +// NotSubsetf asserts that the specified list(array, slice...) or map does NOT +// contain all elements given in the specified subset list(array, slice...) or +// map. +// +// a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted") +// a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted") +func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotSubsetf(a.t, list, subset, msg, args...) +} + +// NotZero asserts that i is not the zero value for its type. +func (a *Assertions) NotZero(i interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotZero(a.t, i, msgAndArgs...) +} + +// NotZerof asserts that i is not the zero value for its type. +func (a *Assertions) NotZerof(i interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotZerof(a.t, i, msg, args...) +} + +// Panics asserts that the code inside the specified PanicTestFunc panics. +// +// a.Panics(func(){ GoCrazy() }) +func (a *Assertions) Panics(f PanicTestFunc, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Panics(a.t, f, msgAndArgs...) +} + +// PanicsWithError asserts that the code inside the specified PanicTestFunc +// panics, and that the recovered panic value is an error that satisfies the +// EqualError comparison. +// +// a.PanicsWithError("crazy error", func(){ GoCrazy() }) +func (a *Assertions) PanicsWithError(errString string, f PanicTestFunc, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return PanicsWithError(a.t, errString, f, msgAndArgs...) +} + +// PanicsWithErrorf asserts that the code inside the specified PanicTestFunc +// panics, and that the recovered panic value is an error that satisfies the +// EqualError comparison. +// +// a.PanicsWithErrorf("crazy error", func(){ GoCrazy() }, "error message %s", "formatted") +func (a *Assertions) PanicsWithErrorf(errString string, f PanicTestFunc, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return PanicsWithErrorf(a.t, errString, f, msg, args...) +} + +// PanicsWithValue asserts that the code inside the specified PanicTestFunc panics, and that +// the recovered panic value equals the expected panic value. +// +// a.PanicsWithValue("crazy error", func(){ GoCrazy() }) +func (a *Assertions) PanicsWithValue(expected interface{}, f PanicTestFunc, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return PanicsWithValue(a.t, expected, f, msgAndArgs...) +} + +// PanicsWithValuef asserts that the code inside the specified PanicTestFunc panics, and that +// the recovered panic value equals the expected panic value. +// +// a.PanicsWithValuef("crazy error", func(){ GoCrazy() }, "error message %s", "formatted") +func (a *Assertions) PanicsWithValuef(expected interface{}, f PanicTestFunc, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return PanicsWithValuef(a.t, expected, f, msg, args...) +} + +// Panicsf asserts that the code inside the specified PanicTestFunc panics. +// +// a.Panicsf(func(){ GoCrazy() }, "error message %s", "formatted") +func (a *Assertions) Panicsf(f PanicTestFunc, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Panicsf(a.t, f, msg, args...) +} + +// Positive asserts that the specified element is positive +// +// a.Positive(1) +// a.Positive(1.23) +func (a *Assertions) Positive(e interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Positive(a.t, e, msgAndArgs...) +} + +// Positivef asserts that the specified element is positive +// +// a.Positivef(1, "error message %s", "formatted") +// a.Positivef(1.23, "error message %s", "formatted") +func (a *Assertions) Positivef(e interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Positivef(a.t, e, msg, args...) +} + +// Regexp asserts that a specified regexp matches a string. +// +// a.Regexp(regexp.MustCompile("start"), "it's starting") +// a.Regexp("start...$", "it's not starting") +func (a *Assertions) Regexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Regexp(a.t, rx, str, msgAndArgs...) +} + +// Regexpf asserts that a specified regexp matches a string. +// +// a.Regexpf(regexp.MustCompile("start"), "it's starting", "error message %s", "formatted") +// a.Regexpf("start...$", "it's not starting", "error message %s", "formatted") +func (a *Assertions) Regexpf(rx interface{}, str interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Regexpf(a.t, rx, str, msg, args...) +} + +// Same asserts that two pointers reference the same object. +// +// a.Same(ptr1, ptr2) +// +// Both arguments must be pointer variables. Pointer variable sameness is +// determined based on the equality of both type and value. +func (a *Assertions) Same(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Same(a.t, expected, actual, msgAndArgs...) +} + +// Samef asserts that two pointers reference the same object. +// +// a.Samef(ptr1, ptr2, "error message %s", "formatted") +// +// Both arguments must be pointer variables. Pointer variable sameness is +// determined based on the equality of both type and value. +func (a *Assertions) Samef(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Samef(a.t, expected, actual, msg, args...) +} + +// Subset asserts that the specified list(array, slice...) or map contains all +// elements given in the specified subset list(array, slice...) or map. +// +// a.Subset([1, 2, 3], [1, 2]) +// a.Subset({"x": 1, "y": 2}, {"x": 1}) +func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Subset(a.t, list, subset, msgAndArgs...) +} + +// Subsetf asserts that the specified list(array, slice...) or map contains all +// elements given in the specified subset list(array, slice...) or map. +// +// a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted") +// a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted") +func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Subsetf(a.t, list, subset, msg, args...) +} + +// True asserts that the specified value is true. +// +// a.True(myBool) +func (a *Assertions) True(value bool, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return True(a.t, value, msgAndArgs...) +} + +// Truef asserts that the specified value is true. +// +// a.Truef(myBool, "error message %s", "formatted") +func (a *Assertions) Truef(value bool, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Truef(a.t, value, msg, args...) +} + +// WithinDuration asserts that the two times are within duration delta of each other. +// +// a.WithinDuration(time.Now(), time.Now(), 10*time.Second) +func (a *Assertions) WithinDuration(expected time.Time, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return WithinDuration(a.t, expected, actual, delta, msgAndArgs...) +} + +// WithinDurationf asserts that the two times are within duration delta of each other. +// +// a.WithinDurationf(time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") +func (a *Assertions) WithinDurationf(expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return WithinDurationf(a.t, expected, actual, delta, msg, args...) +} + +// WithinRange asserts that a time is within a time range (inclusive). +// +// a.WithinRange(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) +func (a *Assertions) WithinRange(actual time.Time, start time.Time, end time.Time, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return WithinRange(a.t, actual, start, end, msgAndArgs...) +} + +// WithinRangef asserts that a time is within a time range (inclusive). +// +// a.WithinRangef(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted") +func (a *Assertions) WithinRangef(actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return WithinRangef(a.t, actual, start, end, msg, args...) +} + +// YAMLEq asserts that two YAML strings are equivalent. +func (a *Assertions) YAMLEq(expected string, actual string, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return YAMLEq(a.t, expected, actual, msgAndArgs...) +} + +// YAMLEqf asserts that two YAML strings are equivalent. +func (a *Assertions) YAMLEqf(expected string, actual string, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return YAMLEqf(a.t, expected, actual, msg, args...) +} + +// Zero asserts that i is the zero value for its type. +func (a *Assertions) Zero(i interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Zero(a.t, i, msgAndArgs...) +} + +// Zerof asserts that i is the zero value for its type. +func (a *Assertions) Zerof(i interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return Zerof(a.t, i, msg, args...) +} diff --git a/vendor/github.com/stretchr/testify/assert/assertion_order.go b/vendor/github.com/stretchr/testify/assert/assertion_order.go new file mode 100644 index 0000000..1d2f718 --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/assertion_order.go @@ -0,0 +1,81 @@ +package assert + +import ( + "fmt" + "reflect" +) + +// isOrdered checks that collection contains orderable elements. +func isOrdered(t TestingT, object interface{}, allowedComparesResults []compareResult, failMessage string, msgAndArgs ...interface{}) bool { + objKind := reflect.TypeOf(object).Kind() + if objKind != reflect.Slice && objKind != reflect.Array { + return false + } + + objValue := reflect.ValueOf(object) + objLen := objValue.Len() + + if objLen <= 1 { + return true + } + + value := objValue.Index(0) + valueInterface := value.Interface() + firstValueKind := value.Kind() + + for i := 1; i < objLen; i++ { + prevValue := value + prevValueInterface := valueInterface + + value = objValue.Index(i) + valueInterface = value.Interface() + + compareResult, isComparable := compare(prevValueInterface, valueInterface, firstValueKind) + + if !isComparable { + return Fail(t, fmt.Sprintf("Can not compare type \"%s\" and \"%s\"", reflect.TypeOf(value), reflect.TypeOf(prevValue)), msgAndArgs...) + } + + if !containsValue(allowedComparesResults, compareResult) { + return Fail(t, fmt.Sprintf(failMessage, prevValue, value), msgAndArgs...) + } + } + + return true +} + +// IsIncreasing asserts that the collection is increasing +// +// assert.IsIncreasing(t, []int{1, 2, 3}) +// assert.IsIncreasing(t, []float{1, 2}) +// assert.IsIncreasing(t, []string{"a", "b"}) +func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + return isOrdered(t, object, []compareResult{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...) +} + +// IsNonIncreasing asserts that the collection is not increasing +// +// assert.IsNonIncreasing(t, []int{2, 1, 1}) +// assert.IsNonIncreasing(t, []float{2, 1}) +// assert.IsNonIncreasing(t, []string{"b", "a"}) +func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + return isOrdered(t, object, []compareResult{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...) +} + +// IsDecreasing asserts that the collection is decreasing +// +// assert.IsDecreasing(t, []int{2, 1, 0}) +// assert.IsDecreasing(t, []float{2, 1}) +// assert.IsDecreasing(t, []string{"b", "a"}) +func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + return isOrdered(t, object, []compareResult{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...) +} + +// IsNonDecreasing asserts that the collection is not decreasing +// +// assert.IsNonDecreasing(t, []int{1, 1, 2}) +// assert.IsNonDecreasing(t, []float{1, 2}) +// assert.IsNonDecreasing(t, []string{"a", "b"}) +func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + return isOrdered(t, object, []compareResult{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...) +} diff --git a/vendor/github.com/stretchr/testify/assert/assertions.go b/vendor/github.com/stretchr/testify/assert/assertions.go new file mode 100644 index 0000000..4e91332 --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/assertions.go @@ -0,0 +1,2184 @@ +package assert + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "math" + "os" + "reflect" + "regexp" + "runtime" + "runtime/debug" + "strings" + "time" + "unicode" + "unicode/utf8" + + "github.com/davecgh/go-spew/spew" + "github.com/pmezard/go-difflib/difflib" + + // Wrapper around gopkg.in/yaml.v3 + "github.com/stretchr/testify/assert/yaml" +) + +//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_format.go.tmpl" + +// TestingT is an interface wrapper around *testing.T +type TestingT interface { + Errorf(format string, args ...interface{}) +} + +// ComparisonAssertionFunc is a common function prototype when comparing two values. Can be useful +// for table driven tests. +type ComparisonAssertionFunc func(TestingT, interface{}, interface{}, ...interface{}) bool + +// ValueAssertionFunc is a common function prototype when validating a single value. Can be useful +// for table driven tests. +type ValueAssertionFunc func(TestingT, interface{}, ...interface{}) bool + +// BoolAssertionFunc is a common function prototype when validating a bool value. Can be useful +// for table driven tests. +type BoolAssertionFunc func(TestingT, bool, ...interface{}) bool + +// ErrorAssertionFunc is a common function prototype when validating an error value. Can be useful +// for table driven tests. +type ErrorAssertionFunc func(TestingT, error, ...interface{}) bool + +// PanicAssertionFunc is a common function prototype when validating a panic value. Can be useful +// for table driven tests. +type PanicAssertionFunc = func(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool + +// Comparison is a custom function that returns true on success and false on failure +type Comparison func() (success bool) + +/* + Helper functions +*/ + +// ObjectsAreEqual determines if two objects are considered equal. +// +// This function does no assertion of any kind. +func ObjectsAreEqual(expected, actual interface{}) bool { + if expected == nil || actual == nil { + return expected == actual + } + + exp, ok := expected.([]byte) + if !ok { + return reflect.DeepEqual(expected, actual) + } + + act, ok := actual.([]byte) + if !ok { + return false + } + if exp == nil || act == nil { + return exp == nil && act == nil + } + return bytes.Equal(exp, act) +} + +// copyExportedFields iterates downward through nested data structures and creates a copy +// that only contains the exported struct fields. +func copyExportedFields(expected interface{}) interface{} { + if isNil(expected) { + return expected + } + + expectedType := reflect.TypeOf(expected) + expectedKind := expectedType.Kind() + expectedValue := reflect.ValueOf(expected) + + switch expectedKind { + case reflect.Struct: + result := reflect.New(expectedType).Elem() + for i := 0; i < expectedType.NumField(); i++ { + field := expectedType.Field(i) + isExported := field.IsExported() + if isExported { + fieldValue := expectedValue.Field(i) + if isNil(fieldValue) || isNil(fieldValue.Interface()) { + continue + } + newValue := copyExportedFields(fieldValue.Interface()) + result.Field(i).Set(reflect.ValueOf(newValue)) + } + } + return result.Interface() + + case reflect.Ptr: + result := reflect.New(expectedType.Elem()) + unexportedRemoved := copyExportedFields(expectedValue.Elem().Interface()) + result.Elem().Set(reflect.ValueOf(unexportedRemoved)) + return result.Interface() + + case reflect.Array, reflect.Slice: + var result reflect.Value + if expectedKind == reflect.Array { + result = reflect.New(reflect.ArrayOf(expectedValue.Len(), expectedType.Elem())).Elem() + } else { + result = reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len()) + } + for i := 0; i < expectedValue.Len(); i++ { + index := expectedValue.Index(i) + if isNil(index) { + continue + } + unexportedRemoved := copyExportedFields(index.Interface()) + result.Index(i).Set(reflect.ValueOf(unexportedRemoved)) + } + return result.Interface() + + case reflect.Map: + result := reflect.MakeMap(expectedType) + for _, k := range expectedValue.MapKeys() { + index := expectedValue.MapIndex(k) + unexportedRemoved := copyExportedFields(index.Interface()) + result.SetMapIndex(k, reflect.ValueOf(unexportedRemoved)) + } + return result.Interface() + + default: + return expected + } +} + +// ObjectsExportedFieldsAreEqual determines if the exported (public) fields of two objects are +// considered equal. This comparison of only exported fields is applied recursively to nested data +// structures. +// +// This function does no assertion of any kind. +// +// Deprecated: Use [EqualExportedValues] instead. +func ObjectsExportedFieldsAreEqual(expected, actual interface{}) bool { + expectedCleaned := copyExportedFields(expected) + actualCleaned := copyExportedFields(actual) + return ObjectsAreEqualValues(expectedCleaned, actualCleaned) +} + +// ObjectsAreEqualValues gets whether two objects are equal, or if their +// values are equal. +func ObjectsAreEqualValues(expected, actual interface{}) bool { + if ObjectsAreEqual(expected, actual) { + return true + } + + expectedValue := reflect.ValueOf(expected) + actualValue := reflect.ValueOf(actual) + if !expectedValue.IsValid() || !actualValue.IsValid() { + return false + } + + expectedType := expectedValue.Type() + actualType := actualValue.Type() + if !expectedType.ConvertibleTo(actualType) { + return false + } + + if !isNumericType(expectedType) || !isNumericType(actualType) { + // Attempt comparison after type conversion + return reflect.DeepEqual( + expectedValue.Convert(actualType).Interface(), actual, + ) + } + + // If BOTH values are numeric, there are chances of false positives due + // to overflow or underflow. So, we need to make sure to always convert + // the smaller type to a larger type before comparing. + if expectedType.Size() >= actualType.Size() { + return actualValue.Convert(expectedType).Interface() == expected + } + + return expectedValue.Convert(actualType).Interface() == actual +} + +// isNumericType returns true if the type is one of: +// int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, +// float32, float64, complex64, complex128 +func isNumericType(t reflect.Type) bool { + return t.Kind() >= reflect.Int && t.Kind() <= reflect.Complex128 +} + +/* CallerInfo is necessary because the assert functions use the testing object +internally, causing it to print the file:line of the assert method, rather than where +the problem actually occurred in calling code.*/ + +// CallerInfo returns an array of strings containing the file and line number +// of each stack frame leading from the current test to the assert call that +// failed. +func CallerInfo() []string { + + var pc uintptr + var ok bool + var file string + var line int + var name string + + callers := []string{} + for i := 0; ; i++ { + pc, file, line, ok = runtime.Caller(i) + if !ok { + // The breaks below failed to terminate the loop, and we ran off the + // end of the call stack. + break + } + + // This is a huge edge case, but it will panic if this is the case, see #180 + if file == "" { + break + } + + f := runtime.FuncForPC(pc) + if f == nil { + break + } + name = f.Name() + + // testing.tRunner is the standard library function that calls + // tests. Subtests are called directly by tRunner, without going through + // the Test/Benchmark/Example function that contains the t.Run calls, so + // with subtests we should break when we hit tRunner, without adding it + // to the list of callers. + if name == "testing.tRunner" { + break + } + + parts := strings.Split(file, "/") + if len(parts) > 1 { + filename := parts[len(parts)-1] + dir := parts[len(parts)-2] + if (dir != "assert" && dir != "mock" && dir != "require") || filename == "mock_test.go" { + callers = append(callers, fmt.Sprintf("%s:%d", file, line)) + } + } + + // Drop the package + segments := strings.Split(name, ".") + name = segments[len(segments)-1] + if isTest(name, "Test") || + isTest(name, "Benchmark") || + isTest(name, "Example") { + break + } + } + + return callers +} + +// Stolen from the `go test` tool. +// isTest tells whether name looks like a test (or benchmark, according to prefix). +// It is a Test (say) if there is a character after Test that is not a lower-case letter. +// We don't want TesticularCancer. +func isTest(name, prefix string) bool { + if !strings.HasPrefix(name, prefix) { + return false + } + if len(name) == len(prefix) { // "Test" is ok + return true + } + r, _ := utf8.DecodeRuneInString(name[len(prefix):]) + return !unicode.IsLower(r) +} + +func messageFromMsgAndArgs(msgAndArgs ...interface{}) string { + if len(msgAndArgs) == 0 || msgAndArgs == nil { + return "" + } + if len(msgAndArgs) == 1 { + msg := msgAndArgs[0] + if msgAsStr, ok := msg.(string); ok { + return msgAsStr + } + return fmt.Sprintf("%+v", msg) + } + if len(msgAndArgs) > 1 { + return fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...) + } + return "" +} + +// Aligns the provided message so that all lines after the first line start at the same location as the first line. +// Assumes that the first line starts at the correct location (after carriage return, tab, label, spacer and tab). +// The longestLabelLen parameter specifies the length of the longest label in the output (required because this is the +// basis on which the alignment occurs). +func indentMessageLines(message string, longestLabelLen int) string { + outBuf := new(bytes.Buffer) + + for i, scanner := 0, bufio.NewScanner(strings.NewReader(message)); scanner.Scan(); i++ { + // no need to align first line because it starts at the correct location (after the label) + if i != 0 { + // append alignLen+1 spaces to align with "{{longestLabel}}:" before adding tab + outBuf.WriteString("\n\t" + strings.Repeat(" ", longestLabelLen+1) + "\t") + } + outBuf.WriteString(scanner.Text()) + } + + return outBuf.String() +} + +type failNower interface { + FailNow() +} + +// FailNow fails test +func FailNow(t TestingT, failureMessage string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + Fail(t, failureMessage, msgAndArgs...) + + // We cannot extend TestingT with FailNow() and + // maintain backwards compatibility, so we fallback + // to panicking when FailNow is not available in + // TestingT. + // See issue #263 + + if t, ok := t.(failNower); ok { + t.FailNow() + } else { + panic("test failed and t is missing `FailNow()`") + } + return false +} + +// Fail reports a failure through +func Fail(t TestingT, failureMessage string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + content := []labeledContent{ + {"Error Trace", strings.Join(CallerInfo(), "\n\t\t\t")}, + {"Error", failureMessage}, + } + + // Add test name if the Go version supports it + if n, ok := t.(interface { + Name() string + }); ok { + content = append(content, labeledContent{"Test", n.Name()}) + } + + message := messageFromMsgAndArgs(msgAndArgs...) + if len(message) > 0 { + content = append(content, labeledContent{"Messages", message}) + } + + t.Errorf("\n%s", ""+labeledOutput(content...)) + + return false +} + +type labeledContent struct { + label string + content string +} + +// labeledOutput returns a string consisting of the provided labeledContent. Each labeled output is appended in the following manner: +// +// \t{{label}}:{{align_spaces}}\t{{content}}\n +// +// The initial carriage return is required to undo/erase any padding added by testing.T.Errorf. The "\t{{label}}:" is for the label. +// If a label is shorter than the longest label provided, padding spaces are added to make all the labels match in length. Once this +// alignment is achieved, "\t{{content}}\n" is added for the output. +// +// If the content of the labeledOutput contains line breaks, the subsequent lines are aligned so that they start at the same location as the first line. +func labeledOutput(content ...labeledContent) string { + longestLabel := 0 + for _, v := range content { + if len(v.label) > longestLabel { + longestLabel = len(v.label) + } + } + var output string + for _, v := range content { + output += "\t" + v.label + ":" + strings.Repeat(" ", longestLabel-len(v.label)) + "\t" + indentMessageLines(v.content, longestLabel) + "\n" + } + return output +} + +// Implements asserts that an object is implemented by the specified interface. +// +// assert.Implements(t, (*MyInterface)(nil), new(MyObject)) +func Implements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + interfaceType := reflect.TypeOf(interfaceObject).Elem() + + if object == nil { + return Fail(t, fmt.Sprintf("Cannot check if nil implements %v", interfaceType), msgAndArgs...) + } + if !reflect.TypeOf(object).Implements(interfaceType) { + return Fail(t, fmt.Sprintf("%T must implement %v", object, interfaceType), msgAndArgs...) + } + + return true +} + +// NotImplements asserts that an object does not implement the specified interface. +// +// assert.NotImplements(t, (*MyInterface)(nil), new(MyObject)) +func NotImplements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + interfaceType := reflect.TypeOf(interfaceObject).Elem() + + if object == nil { + return Fail(t, fmt.Sprintf("Cannot check if nil does not implement %v", interfaceType), msgAndArgs...) + } + if reflect.TypeOf(object).Implements(interfaceType) { + return Fail(t, fmt.Sprintf("%T implements %v", object, interfaceType), msgAndArgs...) + } + + return true +} + +// IsType asserts that the specified objects are of the same type. +func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + if !ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) { + return Fail(t, fmt.Sprintf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object)), msgAndArgs...) + } + + return true +} + +// Equal asserts that two objects are equal. +// +// assert.Equal(t, 123, 123) +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). Function equality +// cannot be determined and will always fail. +func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if err := validateEqualArgs(expected, actual); err != nil { + return Fail(t, fmt.Sprintf("Invalid operation: %#v == %#v (%s)", + expected, actual, err), msgAndArgs...) + } + + if !ObjectsAreEqual(expected, actual) { + diff := diff(expected, actual) + expected, actual = formatUnequalValues(expected, actual) + return Fail(t, fmt.Sprintf("Not equal: \n"+ + "expected: %s\n"+ + "actual : %s%s", expected, actual, diff), msgAndArgs...) + } + + return true + +} + +// validateEqualArgs checks whether provided arguments can be safely used in the +// Equal/NotEqual functions. +func validateEqualArgs(expected, actual interface{}) error { + if expected == nil && actual == nil { + return nil + } + + if isFunction(expected) || isFunction(actual) { + return errors.New("cannot take func type as argument") + } + return nil +} + +// Same asserts that two pointers reference the same object. +// +// assert.Same(t, ptr1, ptr2) +// +// Both arguments must be pointer variables. Pointer variable sameness is +// determined based on the equality of both type and value. +func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + same, ok := samePointers(expected, actual) + if !ok { + return Fail(t, "Both arguments must be pointers", msgAndArgs...) + } + + if !same { + // both are pointers but not the same type & pointing to the same address + return Fail(t, fmt.Sprintf("Not same: \n"+ + "expected: %p %#v\n"+ + "actual : %p %#v", expected, expected, actual, actual), msgAndArgs...) + } + + return true +} + +// NotSame asserts that two pointers do not reference the same object. +// +// assert.NotSame(t, ptr1, ptr2) +// +// Both arguments must be pointer variables. Pointer variable sameness is +// determined based on the equality of both type and value. +func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + same, ok := samePointers(expected, actual) + if !ok { + //fails when the arguments are not pointers + return !(Fail(t, "Both arguments must be pointers", msgAndArgs...)) + } + + if same { + return Fail(t, fmt.Sprintf( + "Expected and actual point to the same object: %p %#v", + expected, expected), msgAndArgs...) + } + return true +} + +// samePointers checks if two generic interface objects are pointers of the same +// type pointing to the same object. It returns two values: same indicating if +// they are the same type and point to the same object, and ok indicating that +// both inputs are pointers. +func samePointers(first, second interface{}) (same bool, ok bool) { + firstPtr, secondPtr := reflect.ValueOf(first), reflect.ValueOf(second) + if firstPtr.Kind() != reflect.Ptr || secondPtr.Kind() != reflect.Ptr { + return false, false //not both are pointers + } + + firstType, secondType := reflect.TypeOf(first), reflect.TypeOf(second) + if firstType != secondType { + return false, true // both are pointers, but of different types + } + + // compare pointer addresses + return first == second, true +} + +// formatUnequalValues takes two values of arbitrary types and returns string +// representations appropriate to be presented to the user. +// +// If the values are not of like type, the returned strings will be prefixed +// with the type name, and the value will be enclosed in parentheses similar +// to a type conversion in the Go grammar. +func formatUnequalValues(expected, actual interface{}) (e string, a string) { + if reflect.TypeOf(expected) != reflect.TypeOf(actual) { + return fmt.Sprintf("%T(%s)", expected, truncatingFormat(expected)), + fmt.Sprintf("%T(%s)", actual, truncatingFormat(actual)) + } + switch expected.(type) { + case time.Duration: + return fmt.Sprintf("%v", expected), fmt.Sprintf("%v", actual) + } + return truncatingFormat(expected), truncatingFormat(actual) +} + +// truncatingFormat formats the data and truncates it if it's too long. +// +// This helps keep formatted error messages lines from exceeding the +// bufio.MaxScanTokenSize max line length that the go testing framework imposes. +func truncatingFormat(data interface{}) string { + value := fmt.Sprintf("%#v", data) + max := bufio.MaxScanTokenSize - 100 // Give us some space the type info too if needed. + if len(value) > max { + value = value[0:max] + "<... truncated>" + } + return value +} + +// EqualValues asserts that two objects are equal or convertible to the larger +// type and equal. +// +// assert.EqualValues(t, uint32(123), int32(123)) +func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + if !ObjectsAreEqualValues(expected, actual) { + diff := diff(expected, actual) + expected, actual = formatUnequalValues(expected, actual) + return Fail(t, fmt.Sprintf("Not equal: \n"+ + "expected: %s\n"+ + "actual : %s%s", expected, actual, diff), msgAndArgs...) + } + + return true + +} + +// EqualExportedValues asserts that the types of two objects are equal and their public +// fields are also equal. This is useful for comparing structs that have private fields +// that could potentially differ. +// +// type S struct { +// Exported int +// notExported int +// } +// assert.EqualExportedValues(t, S{1, 2}, S{1, 3}) => true +// assert.EqualExportedValues(t, S{1, 2}, S{2, 3}) => false +func EqualExportedValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + aType := reflect.TypeOf(expected) + bType := reflect.TypeOf(actual) + + if aType != bType { + return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...) + } + + expected = copyExportedFields(expected) + actual = copyExportedFields(actual) + + if !ObjectsAreEqualValues(expected, actual) { + diff := diff(expected, actual) + expected, actual = formatUnequalValues(expected, actual) + return Fail(t, fmt.Sprintf("Not equal (comparing only exported fields): \n"+ + "expected: %s\n"+ + "actual : %s%s", expected, actual, diff), msgAndArgs...) + } + + return true +} + +// Exactly asserts that two objects are equal in value and type. +// +// assert.Exactly(t, int32(123), int64(123)) +func Exactly(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + aType := reflect.TypeOf(expected) + bType := reflect.TypeOf(actual) + + if aType != bType { + return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...) + } + + return Equal(t, expected, actual, msgAndArgs...) + +} + +// NotNil asserts that the specified object is not nil. +// +// assert.NotNil(t, err) +func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + if !isNil(object) { + return true + } + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Fail(t, "Expected value not to be nil.", msgAndArgs...) +} + +// isNil checks if a specified object is nil or not, without Failing. +func isNil(object interface{}) bool { + if object == nil { + return true + } + + value := reflect.ValueOf(object) + switch value.Kind() { + case + reflect.Chan, reflect.Func, + reflect.Interface, reflect.Map, + reflect.Ptr, reflect.Slice, reflect.UnsafePointer: + + return value.IsNil() + } + + return false +} + +// Nil asserts that the specified object is nil. +// +// assert.Nil(t, err) +func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + if isNil(object) { + return true + } + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Fail(t, fmt.Sprintf("Expected nil, but got: %#v", object), msgAndArgs...) +} + +// isEmpty gets whether the specified object is considered empty or not. +func isEmpty(object interface{}) bool { + + // get nil case out of the way + if object == nil { + return true + } + + objValue := reflect.ValueOf(object) + + switch objValue.Kind() { + // collection types are empty when they have no element + case reflect.Chan, reflect.Map, reflect.Slice: + return objValue.Len() == 0 + // pointers are empty if nil or if the value they point to is empty + case reflect.Ptr: + if objValue.IsNil() { + return true + } + deref := objValue.Elem().Interface() + return isEmpty(deref) + // for all other types, compare against the zero value + // array types are empty when they match their zero-initialized state + default: + zero := reflect.Zero(objValue.Type()) + return reflect.DeepEqual(object, zero.Interface()) + } +} + +// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either +// a slice or a channel with len == 0. +// +// assert.Empty(t, obj) +func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + pass := isEmpty(object) + if !pass { + if h, ok := t.(tHelper); ok { + h.Helper() + } + Fail(t, fmt.Sprintf("Should be empty, but was %v", object), msgAndArgs...) + } + + return pass + +} + +// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either +// a slice or a channel with len == 0. +// +// if assert.NotEmpty(t, obj) { +// assert.Equal(t, "two", obj[1]) +// } +func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + pass := !isEmpty(object) + if !pass { + if h, ok := t.(tHelper); ok { + h.Helper() + } + Fail(t, fmt.Sprintf("Should NOT be empty, but was %v", object), msgAndArgs...) + } + + return pass + +} + +// getLen tries to get the length of an object. +// It returns (0, false) if impossible. +func getLen(x interface{}) (length int, ok bool) { + v := reflect.ValueOf(x) + defer func() { + ok = recover() == nil + }() + return v.Len(), true +} + +// Len asserts that the specified object has specific length. +// Len also fails if the object has a type that len() not accept. +// +// assert.Len(t, mySlice, 3) +func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + l, ok := getLen(object) + if !ok { + return Fail(t, fmt.Sprintf("\"%v\" could not be applied builtin len()", object), msgAndArgs...) + } + + if l != length { + return Fail(t, fmt.Sprintf("\"%v\" should have %d item(s), but has %d", object, length, l), msgAndArgs...) + } + return true +} + +// True asserts that the specified value is true. +// +// assert.True(t, myBool) +func True(t TestingT, value bool, msgAndArgs ...interface{}) bool { + if !value { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Fail(t, "Should be true", msgAndArgs...) + } + + return true + +} + +// False asserts that the specified value is false. +// +// assert.False(t, myBool) +func False(t TestingT, value bool, msgAndArgs ...interface{}) bool { + if value { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Fail(t, "Should be false", msgAndArgs...) + } + + return true + +} + +// NotEqual asserts that the specified values are NOT equal. +// +// assert.NotEqual(t, obj1, obj2) +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). +func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if err := validateEqualArgs(expected, actual); err != nil { + return Fail(t, fmt.Sprintf("Invalid operation: %#v != %#v (%s)", + expected, actual, err), msgAndArgs...) + } + + if ObjectsAreEqual(expected, actual) { + return Fail(t, fmt.Sprintf("Should not be: %#v\n", actual), msgAndArgs...) + } + + return true + +} + +// NotEqualValues asserts that two objects are not equal even when converted to the same type +// +// assert.NotEqualValues(t, obj1, obj2) +func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + if ObjectsAreEqualValues(expected, actual) { + return Fail(t, fmt.Sprintf("Should not be: %#v\n", actual), msgAndArgs...) + } + + return true +} + +// containsElement try loop over the list check if the list includes the element. +// return (false, false) if impossible. +// return (true, false) if element was not found. +// return (true, true) if element was found. +func containsElement(list interface{}, element interface{}) (ok, found bool) { + + listValue := reflect.ValueOf(list) + listType := reflect.TypeOf(list) + if listType == nil { + return false, false + } + listKind := listType.Kind() + defer func() { + if e := recover(); e != nil { + ok = false + found = false + } + }() + + if listKind == reflect.String { + elementValue := reflect.ValueOf(element) + return true, strings.Contains(listValue.String(), elementValue.String()) + } + + if listKind == reflect.Map { + mapKeys := listValue.MapKeys() + for i := 0; i < len(mapKeys); i++ { + if ObjectsAreEqual(mapKeys[i].Interface(), element) { + return true, true + } + } + return true, false + } + + for i := 0; i < listValue.Len(); i++ { + if ObjectsAreEqual(listValue.Index(i).Interface(), element) { + return true, true + } + } + return true, false + +} + +// Contains asserts that the specified string, list(array, slice...) or map contains the +// specified substring or element. +// +// assert.Contains(t, "Hello World", "World") +// assert.Contains(t, ["Hello", "World"], "World") +// assert.Contains(t, {"Hello": "World"}, "Hello") +func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + ok, found := containsElement(s, contains) + if !ok { + return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...) + } + if !found { + return Fail(t, fmt.Sprintf("%#v does not contain %#v", s, contains), msgAndArgs...) + } + + return true + +} + +// NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the +// specified substring or element. +// +// assert.NotContains(t, "Hello World", "Earth") +// assert.NotContains(t, ["Hello", "World"], "Earth") +// assert.NotContains(t, {"Hello": "World"}, "Earth") +func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + ok, found := containsElement(s, contains) + if !ok { + return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...) + } + if found { + return Fail(t, fmt.Sprintf("%#v should not contain %#v", s, contains), msgAndArgs...) + } + + return true + +} + +// Subset asserts that the specified list(array, slice...) or map contains all +// elements given in the specified subset list(array, slice...) or map. +// +// assert.Subset(t, [1, 2, 3], [1, 2]) +// assert.Subset(t, {"x": 1, "y": 2}, {"x": 1}) +func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if subset == nil { + return true // we consider nil to be equal to the nil set + } + + listKind := reflect.TypeOf(list).Kind() + if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map { + return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...) + } + + subsetKind := reflect.TypeOf(subset).Kind() + if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map { + return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...) + } + + if subsetKind == reflect.Map && listKind == reflect.Map { + subsetMap := reflect.ValueOf(subset) + actualMap := reflect.ValueOf(list) + + for _, k := range subsetMap.MapKeys() { + ev := subsetMap.MapIndex(k) + av := actualMap.MapIndex(k) + + if !av.IsValid() { + return Fail(t, fmt.Sprintf("%#v does not contain %#v", list, subset), msgAndArgs...) + } + if !ObjectsAreEqual(ev.Interface(), av.Interface()) { + return Fail(t, fmt.Sprintf("%#v does not contain %#v", list, subset), msgAndArgs...) + } + } + + return true + } + + subsetList := reflect.ValueOf(subset) + for i := 0; i < subsetList.Len(); i++ { + element := subsetList.Index(i).Interface() + ok, found := containsElement(list, element) + if !ok { + return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", list), msgAndArgs...) + } + if !found { + return Fail(t, fmt.Sprintf("%#v does not contain %#v", list, element), msgAndArgs...) + } + } + + return true +} + +// NotSubset asserts that the specified list(array, slice...) or map does NOT +// contain all elements given in the specified subset list(array, slice...) or +// map. +// +// assert.NotSubset(t, [1, 3, 4], [1, 2]) +// assert.NotSubset(t, {"x": 1, "y": 2}, {"z": 3}) +func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if subset == nil { + return Fail(t, "nil is the empty set which is a subset of every set", msgAndArgs...) + } + + listKind := reflect.TypeOf(list).Kind() + if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map { + return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...) + } + + subsetKind := reflect.TypeOf(subset).Kind() + if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map { + return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...) + } + + if subsetKind == reflect.Map && listKind == reflect.Map { + subsetMap := reflect.ValueOf(subset) + actualMap := reflect.ValueOf(list) + + for _, k := range subsetMap.MapKeys() { + ev := subsetMap.MapIndex(k) + av := actualMap.MapIndex(k) + + if !av.IsValid() { + return true + } + if !ObjectsAreEqual(ev.Interface(), av.Interface()) { + return true + } + } + + return Fail(t, fmt.Sprintf("%q is a subset of %q", subset, list), msgAndArgs...) + } + + subsetList := reflect.ValueOf(subset) + for i := 0; i < subsetList.Len(); i++ { + element := subsetList.Index(i).Interface() + ok, found := containsElement(list, element) + if !ok { + return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...) + } + if !found { + return true + } + } + + return Fail(t, fmt.Sprintf("%q is a subset of %q", subset, list), msgAndArgs...) +} + +// ElementsMatch asserts that the specified listA(array, slice...) is equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should match. +// +// assert.ElementsMatch(t, [1, 3, 2, 3], [1, 3, 3, 2]) +func ElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if isEmpty(listA) && isEmpty(listB) { + return true + } + + if !isList(t, listA, msgAndArgs...) || !isList(t, listB, msgAndArgs...) { + return false + } + + extraA, extraB := diffLists(listA, listB) + + if len(extraA) == 0 && len(extraB) == 0 { + return true + } + + return Fail(t, formatListDiff(listA, listB, extraA, extraB), msgAndArgs...) +} + +// isList checks that the provided value is array or slice. +func isList(t TestingT, list interface{}, msgAndArgs ...interface{}) (ok bool) { + kind := reflect.TypeOf(list).Kind() + if kind != reflect.Array && kind != reflect.Slice { + return Fail(t, fmt.Sprintf("%q has an unsupported type %s, expecting array or slice", list, kind), + msgAndArgs...) + } + return true +} + +// diffLists diffs two arrays/slices and returns slices of elements that are only in A and only in B. +// If some element is present multiple times, each instance is counted separately (e.g. if something is 2x in A and +// 5x in B, it will be 0x in extraA and 3x in extraB). The order of items in both lists is ignored. +func diffLists(listA, listB interface{}) (extraA, extraB []interface{}) { + aValue := reflect.ValueOf(listA) + bValue := reflect.ValueOf(listB) + + aLen := aValue.Len() + bLen := bValue.Len() + + // Mark indexes in bValue that we already used + visited := make([]bool, bLen) + for i := 0; i < aLen; i++ { + element := aValue.Index(i).Interface() + found := false + for j := 0; j < bLen; j++ { + if visited[j] { + continue + } + if ObjectsAreEqual(bValue.Index(j).Interface(), element) { + visited[j] = true + found = true + break + } + } + if !found { + extraA = append(extraA, element) + } + } + + for j := 0; j < bLen; j++ { + if visited[j] { + continue + } + extraB = append(extraB, bValue.Index(j).Interface()) + } + + return +} + +func formatListDiff(listA, listB interface{}, extraA, extraB []interface{}) string { + var msg bytes.Buffer + + msg.WriteString("elements differ") + if len(extraA) > 0 { + msg.WriteString("\n\nextra elements in list A:\n") + msg.WriteString(spewConfig.Sdump(extraA)) + } + if len(extraB) > 0 { + msg.WriteString("\n\nextra elements in list B:\n") + msg.WriteString(spewConfig.Sdump(extraB)) + } + msg.WriteString("\n\nlistA:\n") + msg.WriteString(spewConfig.Sdump(listA)) + msg.WriteString("\n\nlistB:\n") + msg.WriteString(spewConfig.Sdump(listB)) + + return msg.String() +} + +// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// assert.NotElementsMatch(t, [1, 1, 2, 3], [1, 1, 2, 3]) -> false +// +// assert.NotElementsMatch(t, [1, 1, 2, 3], [1, 2, 3]) -> true +// +// assert.NotElementsMatch(t, [1, 2, 3], [1, 2, 4]) -> true +func NotElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if isEmpty(listA) && isEmpty(listB) { + return Fail(t, "listA and listB contain the same elements", msgAndArgs) + } + + if !isList(t, listA, msgAndArgs...) { + return Fail(t, "listA is not a list type", msgAndArgs...) + } + if !isList(t, listB, msgAndArgs...) { + return Fail(t, "listB is not a list type", msgAndArgs...) + } + + extraA, extraB := diffLists(listA, listB) + if len(extraA) == 0 && len(extraB) == 0 { + return Fail(t, "listA and listB contain the same elements", msgAndArgs) + } + + return true +} + +// Condition uses a Comparison to assert a complex condition. +func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + result := comp() + if !result { + Fail(t, "Condition failed!", msgAndArgs...) + } + return result +} + +// PanicTestFunc defines a func that should be passed to the assert.Panics and assert.NotPanics +// methods, and represents a simple func that takes no arguments, and returns nothing. +type PanicTestFunc func() + +// didPanic returns true if the function passed to it panics. Otherwise, it returns false. +func didPanic(f PanicTestFunc) (didPanic bool, message interface{}, stack string) { + didPanic = true + + defer func() { + message = recover() + if didPanic { + stack = string(debug.Stack()) + } + }() + + // call the target function + f() + didPanic = false + + return +} + +// Panics asserts that the code inside the specified PanicTestFunc panics. +// +// assert.Panics(t, func(){ GoCrazy() }) +func Panics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + if funcDidPanic, panicValue, _ := didPanic(f); !funcDidPanic { + return Fail(t, fmt.Sprintf("func %#v should panic\n\tPanic value:\t%#v", f, panicValue), msgAndArgs...) + } + + return true +} + +// PanicsWithValue asserts that the code inside the specified PanicTestFunc panics, and that +// the recovered panic value equals the expected panic value. +// +// assert.PanicsWithValue(t, "crazy error", func(){ GoCrazy() }) +func PanicsWithValue(t TestingT, expected interface{}, f PanicTestFunc, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + funcDidPanic, panicValue, panickedStack := didPanic(f) + if !funcDidPanic { + return Fail(t, fmt.Sprintf("func %#v should panic\n\tPanic value:\t%#v", f, panicValue), msgAndArgs...) + } + if panicValue != expected { + return Fail(t, fmt.Sprintf("func %#v should panic with value:\t%#v\n\tPanic value:\t%#v\n\tPanic stack:\t%s", f, expected, panicValue, panickedStack), msgAndArgs...) + } + + return true +} + +// PanicsWithError asserts that the code inside the specified PanicTestFunc +// panics, and that the recovered panic value is an error that satisfies the +// EqualError comparison. +// +// assert.PanicsWithError(t, "crazy error", func(){ GoCrazy() }) +func PanicsWithError(t TestingT, errString string, f PanicTestFunc, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + funcDidPanic, panicValue, panickedStack := didPanic(f) + if !funcDidPanic { + return Fail(t, fmt.Sprintf("func %#v should panic\n\tPanic value:\t%#v", f, panicValue), msgAndArgs...) + } + panicErr, ok := panicValue.(error) + if !ok || panicErr.Error() != errString { + return Fail(t, fmt.Sprintf("func %#v should panic with error message:\t%#v\n\tPanic value:\t%#v\n\tPanic stack:\t%s", f, errString, panicValue, panickedStack), msgAndArgs...) + } + + return true +} + +// NotPanics asserts that the code inside the specified PanicTestFunc does NOT panic. +// +// assert.NotPanics(t, func(){ RemainCalm() }) +func NotPanics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + if funcDidPanic, panicValue, panickedStack := didPanic(f); funcDidPanic { + return Fail(t, fmt.Sprintf("func %#v should not panic\n\tPanic value:\t%v\n\tPanic stack:\t%s", f, panicValue, panickedStack), msgAndArgs...) + } + + return true +} + +// WithinDuration asserts that the two times are within duration delta of each other. +// +// assert.WithinDuration(t, time.Now(), time.Now(), 10*time.Second) +func WithinDuration(t TestingT, expected, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + dt := expected.Sub(actual) + if dt < -delta || dt > delta { + return Fail(t, fmt.Sprintf("Max difference between %v and %v allowed is %v, but difference was %v", expected, actual, delta, dt), msgAndArgs...) + } + + return true +} + +// WithinRange asserts that a time is within a time range (inclusive). +// +// assert.WithinRange(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) +func WithinRange(t TestingT, actual, start, end time.Time, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + if end.Before(start) { + return Fail(t, "Start should be before end", msgAndArgs...) + } + + if actual.Before(start) { + return Fail(t, fmt.Sprintf("Time %v expected to be in time range %v to %v, but is before the range", actual, start, end), msgAndArgs...) + } else if actual.After(end) { + return Fail(t, fmt.Sprintf("Time %v expected to be in time range %v to %v, but is after the range", actual, start, end), msgAndArgs...) + } + + return true +} + +func toFloat(x interface{}) (float64, bool) { + var xf float64 + xok := true + + switch xn := x.(type) { + case uint: + xf = float64(xn) + case uint8: + xf = float64(xn) + case uint16: + xf = float64(xn) + case uint32: + xf = float64(xn) + case uint64: + xf = float64(xn) + case int: + xf = float64(xn) + case int8: + xf = float64(xn) + case int16: + xf = float64(xn) + case int32: + xf = float64(xn) + case int64: + xf = float64(xn) + case float32: + xf = float64(xn) + case float64: + xf = xn + case time.Duration: + xf = float64(xn) + default: + xok = false + } + + return xf, xok +} + +// InDelta asserts that the two numerals are within delta of each other. +// +// assert.InDelta(t, math.Pi, 22/7.0, 0.01) +func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + af, aok := toFloat(expected) + bf, bok := toFloat(actual) + + if !aok || !bok { + return Fail(t, "Parameters must be numerical", msgAndArgs...) + } + + if math.IsNaN(af) && math.IsNaN(bf) { + return true + } + + if math.IsNaN(af) { + return Fail(t, "Expected must not be NaN", msgAndArgs...) + } + + if math.IsNaN(bf) { + return Fail(t, fmt.Sprintf("Expected %v with delta %v, but was NaN", expected, delta), msgAndArgs...) + } + + dt := af - bf + if dt < -delta || dt > delta { + return Fail(t, fmt.Sprintf("Max difference between %v and %v allowed is %v, but difference was %v", expected, actual, delta, dt), msgAndArgs...) + } + + return true +} + +// InDeltaSlice is the same as InDelta, except it compares two slices. +func InDeltaSlice(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if expected == nil || actual == nil || + reflect.TypeOf(actual).Kind() != reflect.Slice || + reflect.TypeOf(expected).Kind() != reflect.Slice { + return Fail(t, "Parameters must be slice", msgAndArgs...) + } + + actualSlice := reflect.ValueOf(actual) + expectedSlice := reflect.ValueOf(expected) + + for i := 0; i < actualSlice.Len(); i++ { + result := InDelta(t, actualSlice.Index(i).Interface(), expectedSlice.Index(i).Interface(), delta, msgAndArgs...) + if !result { + return result + } + } + + return true +} + +// InDeltaMapValues is the same as InDelta, but it compares all values between two maps. Both maps must have exactly the same keys. +func InDeltaMapValues(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if expected == nil || actual == nil || + reflect.TypeOf(actual).Kind() != reflect.Map || + reflect.TypeOf(expected).Kind() != reflect.Map { + return Fail(t, "Arguments must be maps", msgAndArgs...) + } + + expectedMap := reflect.ValueOf(expected) + actualMap := reflect.ValueOf(actual) + + if expectedMap.Len() != actualMap.Len() { + return Fail(t, "Arguments must have the same number of keys", msgAndArgs...) + } + + for _, k := range expectedMap.MapKeys() { + ev := expectedMap.MapIndex(k) + av := actualMap.MapIndex(k) + + if !ev.IsValid() { + return Fail(t, fmt.Sprintf("missing key %q in expected map", k), msgAndArgs...) + } + + if !av.IsValid() { + return Fail(t, fmt.Sprintf("missing key %q in actual map", k), msgAndArgs...) + } + + if !InDelta( + t, + ev.Interface(), + av.Interface(), + delta, + msgAndArgs..., + ) { + return false + } + } + + return true +} + +func calcRelativeError(expected, actual interface{}) (float64, error) { + af, aok := toFloat(expected) + bf, bok := toFloat(actual) + if !aok || !bok { + return 0, fmt.Errorf("Parameters must be numerical") + } + if math.IsNaN(af) && math.IsNaN(bf) { + return 0, nil + } + if math.IsNaN(af) { + return 0, errors.New("expected value must not be NaN") + } + if af == 0 { + return 0, fmt.Errorf("expected value must have a value other than zero to calculate the relative error") + } + if math.IsNaN(bf) { + return 0, errors.New("actual value must not be NaN") + } + + return math.Abs(af-bf) / math.Abs(af), nil +} + +// InEpsilon asserts that expected and actual have a relative error less than epsilon +func InEpsilon(t TestingT, expected, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if math.IsNaN(epsilon) { + return Fail(t, "epsilon must not be NaN", msgAndArgs...) + } + actualEpsilon, err := calcRelativeError(expected, actual) + if err != nil { + return Fail(t, err.Error(), msgAndArgs...) + } + if math.IsNaN(actualEpsilon) { + return Fail(t, "relative error is NaN", msgAndArgs...) + } + if actualEpsilon > epsilon { + return Fail(t, fmt.Sprintf("Relative error is too high: %#v (expected)\n"+ + " < %#v (actual)", epsilon, actualEpsilon), msgAndArgs...) + } + + return true +} + +// InEpsilonSlice is the same as InEpsilon, except it compares each value from two slices. +func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + if expected == nil || actual == nil { + return Fail(t, "Parameters must be slice", msgAndArgs...) + } + + expectedSlice := reflect.ValueOf(expected) + actualSlice := reflect.ValueOf(actual) + + if expectedSlice.Type().Kind() != reflect.Slice { + return Fail(t, "Expected value must be slice", msgAndArgs...) + } + + expectedLen := expectedSlice.Len() + if !IsType(t, expected, actual) || !Len(t, actual, expectedLen) { + return false + } + + for i := 0; i < expectedLen; i++ { + if !InEpsilon(t, expectedSlice.Index(i).Interface(), actualSlice.Index(i).Interface(), epsilon, "at index %d", i) { + return false + } + } + + return true +} + +/* + Errors +*/ + +// NoError asserts that a function returned no error (i.e. `nil`). +// +// actualObj, err := SomeFunction() +// if assert.NoError(t, err) { +// assert.Equal(t, expectedObj, actualObj) +// } +func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool { + if err != nil { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Fail(t, fmt.Sprintf("Received unexpected error:\n%+v", err), msgAndArgs...) + } + + return true +} + +// Error asserts that a function returned an error (i.e. not `nil`). +// +// actualObj, err := SomeFunction() +// if assert.Error(t, err) { +// assert.Equal(t, expectedError, err) +// } +func Error(t TestingT, err error, msgAndArgs ...interface{}) bool { + if err == nil { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Fail(t, "An error is expected but got nil.", msgAndArgs...) + } + + return true +} + +// EqualError asserts that a function returned an error (i.e. not `nil`) +// and that it is equal to the provided error. +// +// actualObj, err := SomeFunction() +// assert.EqualError(t, err, expectedErrorString) +func EqualError(t TestingT, theError error, errString string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if !Error(t, theError, msgAndArgs...) { + return false + } + expected := errString + actual := theError.Error() + // don't need to use deep equals here, we know they are both strings + if expected != actual { + return Fail(t, fmt.Sprintf("Error message not equal:\n"+ + "expected: %q\n"+ + "actual : %q", expected, actual), msgAndArgs...) + } + return true +} + +// ErrorContains asserts that a function returned an error (i.e. not `nil`) +// and that the error contains the specified substring. +// +// actualObj, err := SomeFunction() +// assert.ErrorContains(t, err, expectedErrorSubString) +func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if !Error(t, theError, msgAndArgs...) { + return false + } + + actual := theError.Error() + if !strings.Contains(actual, contains) { + return Fail(t, fmt.Sprintf("Error %#v does not contain %#v", actual, contains), msgAndArgs...) + } + + return true +} + +// matchRegexp return true if a specified regexp matches a string. +func matchRegexp(rx interface{}, str interface{}) bool { + var r *regexp.Regexp + if rr, ok := rx.(*regexp.Regexp); ok { + r = rr + } else { + r = regexp.MustCompile(fmt.Sprint(rx)) + } + + switch v := str.(type) { + case []byte: + return r.Match(v) + case string: + return r.MatchString(v) + default: + return r.MatchString(fmt.Sprint(v)) + } + +} + +// Regexp asserts that a specified regexp matches a string. +// +// assert.Regexp(t, regexp.MustCompile("start"), "it's starting") +// assert.Regexp(t, "start...$", "it's not starting") +func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + match := matchRegexp(rx, str) + + if !match { + Fail(t, fmt.Sprintf("Expect \"%v\" to match \"%v\"", str, rx), msgAndArgs...) + } + + return match +} + +// NotRegexp asserts that a specified regexp does not match a string. +// +// assert.NotRegexp(t, regexp.MustCompile("starts"), "it's starting") +// assert.NotRegexp(t, "^start", "it's not starting") +func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + match := matchRegexp(rx, str) + + if match { + Fail(t, fmt.Sprintf("Expect \"%v\" to NOT match \"%v\"", str, rx), msgAndArgs...) + } + + return !match + +} + +// Zero asserts that i is the zero value for its type. +func Zero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if i != nil && !reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) { + return Fail(t, fmt.Sprintf("Should be zero, but was %v", i), msgAndArgs...) + } + return true +} + +// NotZero asserts that i is not the zero value for its type. +func NotZero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) { + return Fail(t, fmt.Sprintf("Should not be zero, but was %v", i), msgAndArgs...) + } + return true +} + +// FileExists checks whether a file exists in the given path. It also fails if +// the path points to a directory or there is an error when trying to check the file. +func FileExists(t TestingT, path string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + info, err := os.Lstat(path) + if err != nil { + if os.IsNotExist(err) { + return Fail(t, fmt.Sprintf("unable to find file %q", path), msgAndArgs...) + } + return Fail(t, fmt.Sprintf("error when running os.Lstat(%q): %s", path, err), msgAndArgs...) + } + if info.IsDir() { + return Fail(t, fmt.Sprintf("%q is a directory", path), msgAndArgs...) + } + return true +} + +// NoFileExists checks whether a file does not exist in a given path. It fails +// if the path points to an existing _file_ only. +func NoFileExists(t TestingT, path string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + info, err := os.Lstat(path) + if err != nil { + return true + } + if info.IsDir() { + return true + } + return Fail(t, fmt.Sprintf("file %q exists", path), msgAndArgs...) +} + +// DirExists checks whether a directory exists in the given path. It also fails +// if the path is a file rather a directory or there is an error checking whether it exists. +func DirExists(t TestingT, path string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + info, err := os.Lstat(path) + if err != nil { + if os.IsNotExist(err) { + return Fail(t, fmt.Sprintf("unable to find file %q", path), msgAndArgs...) + } + return Fail(t, fmt.Sprintf("error when running os.Lstat(%q): %s", path, err), msgAndArgs...) + } + if !info.IsDir() { + return Fail(t, fmt.Sprintf("%q is a file", path), msgAndArgs...) + } + return true +} + +// NoDirExists checks whether a directory does not exist in the given path. +// It fails if the path points to an existing _directory_ only. +func NoDirExists(t TestingT, path string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + info, err := os.Lstat(path) + if err != nil { + if os.IsNotExist(err) { + return true + } + return true + } + if !info.IsDir() { + return true + } + return Fail(t, fmt.Sprintf("directory %q exists", path), msgAndArgs...) +} + +// JSONEq asserts that two JSON strings are equivalent. +// +// assert.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) +func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + var expectedJSONAsInterface, actualJSONAsInterface interface{} + + if err := json.Unmarshal([]byte(expected), &expectedJSONAsInterface); err != nil { + return Fail(t, fmt.Sprintf("Expected value ('%s') is not valid json.\nJSON parsing error: '%s'", expected, err.Error()), msgAndArgs...) + } + + if err := json.Unmarshal([]byte(actual), &actualJSONAsInterface); err != nil { + return Fail(t, fmt.Sprintf("Input ('%s') needs to be valid json.\nJSON parsing error: '%s'", actual, err.Error()), msgAndArgs...) + } + + return Equal(t, expectedJSONAsInterface, actualJSONAsInterface, msgAndArgs...) +} + +// YAMLEq asserts that two YAML strings are equivalent. +func YAMLEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + var expectedYAMLAsInterface, actualYAMLAsInterface interface{} + + if err := yaml.Unmarshal([]byte(expected), &expectedYAMLAsInterface); err != nil { + return Fail(t, fmt.Sprintf("Expected value ('%s') is not valid yaml.\nYAML parsing error: '%s'", expected, err.Error()), msgAndArgs...) + } + + if err := yaml.Unmarshal([]byte(actual), &actualYAMLAsInterface); err != nil { + return Fail(t, fmt.Sprintf("Input ('%s') needs to be valid yaml.\nYAML error: '%s'", actual, err.Error()), msgAndArgs...) + } + + return Equal(t, expectedYAMLAsInterface, actualYAMLAsInterface, msgAndArgs...) +} + +func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) { + t := reflect.TypeOf(v) + k := t.Kind() + + if k == reflect.Ptr { + t = t.Elem() + k = t.Kind() + } + return t, k +} + +// diff returns a diff of both values as long as both are of the same type and +// are a struct, map, slice, array or string. Otherwise it returns an empty string. +func diff(expected interface{}, actual interface{}) string { + if expected == nil || actual == nil { + return "" + } + + et, ek := typeAndKind(expected) + at, _ := typeAndKind(actual) + + if et != at { + return "" + } + + if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array && ek != reflect.String { + return "" + } + + var e, a string + + switch et { + case reflect.TypeOf(""): + e = reflect.ValueOf(expected).String() + a = reflect.ValueOf(actual).String() + case reflect.TypeOf(time.Time{}): + e = spewConfigStringerEnabled.Sdump(expected) + a = spewConfigStringerEnabled.Sdump(actual) + default: + e = spewConfig.Sdump(expected) + a = spewConfig.Sdump(actual) + } + + diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{ + A: difflib.SplitLines(e), + B: difflib.SplitLines(a), + FromFile: "Expected", + FromDate: "", + ToFile: "Actual", + ToDate: "", + Context: 1, + }) + + return "\n\nDiff:\n" + diff +} + +func isFunction(arg interface{}) bool { + if arg == nil { + return false + } + return reflect.TypeOf(arg).Kind() == reflect.Func +} + +var spewConfig = spew.ConfigState{ + Indent: " ", + DisablePointerAddresses: true, + DisableCapacities: true, + SortKeys: true, + DisableMethods: true, + MaxDepth: 10, +} + +var spewConfigStringerEnabled = spew.ConfigState{ + Indent: " ", + DisablePointerAddresses: true, + DisableCapacities: true, + SortKeys: true, + MaxDepth: 10, +} + +type tHelper = interface { + Helper() +} + +// Eventually asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. +// +// assert.Eventually(t, func() bool { return true; }, time.Second, 10*time.Millisecond) +func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + ch := make(chan bool, 1) + + timer := time.NewTimer(waitFor) + defer timer.Stop() + + ticker := time.NewTicker(tick) + defer ticker.Stop() + + for tick := ticker.C; ; { + select { + case <-timer.C: + return Fail(t, "Condition never satisfied", msgAndArgs...) + case <-tick: + tick = nil + go func() { ch <- condition() }() + case v := <-ch: + if v { + return true + } + tick = ticker.C + } + } +} + +// CollectT implements the TestingT interface and collects all errors. +type CollectT struct { + // A slice of errors. Non-nil slice denotes a failure. + // If it's non-nil but len(c.errors) == 0, this is also a failure + // obtained by direct c.FailNow() call. + errors []error +} + +// Errorf collects the error. +func (c *CollectT) Errorf(format string, args ...interface{}) { + c.errors = append(c.errors, fmt.Errorf(format, args...)) +} + +// FailNow stops execution by calling runtime.Goexit. +func (c *CollectT) FailNow() { + c.fail() + runtime.Goexit() +} + +// Deprecated: That was a method for internal usage that should not have been published. Now just panics. +func (*CollectT) Reset() { + panic("Reset() is deprecated") +} + +// Deprecated: That was a method for internal usage that should not have been published. Now just panics. +func (*CollectT) Copy(TestingT) { + panic("Copy() is deprecated") +} + +func (c *CollectT) fail() { + if !c.failed() { + c.errors = []error{} // Make it non-nil to mark a failure. + } +} + +func (c *CollectT) failed() bool { + return c.errors != nil +} + +// EventuallyWithT asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. In contrast to Eventually, +// it supplies a CollectT to the condition function, so that the condition +// function can use the CollectT to call other assertions. +// The condition is considered "met" if no errors are raised in a tick. +// The supplied CollectT collects all errors from one tick (if there are any). +// If the condition is not met before waitFor, the collected errors of +// the last tick are copied to t. +// +// externalValue := false +// go func() { +// time.Sleep(8*time.Second) +// externalValue = true +// }() +// assert.EventuallyWithT(t, func(c *assert.CollectT) { +// // add assertions as needed; any assertion failure will fail the current tick +// assert.True(c, externalValue, "expected 'externalValue' to be true") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") +func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + var lastFinishedTickErrs []error + ch := make(chan *CollectT, 1) + + timer := time.NewTimer(waitFor) + defer timer.Stop() + + ticker := time.NewTicker(tick) + defer ticker.Stop() + + for tick := ticker.C; ; { + select { + case <-timer.C: + for _, err := range lastFinishedTickErrs { + t.Errorf("%v", err) + } + return Fail(t, "Condition never satisfied", msgAndArgs...) + case <-tick: + tick = nil + go func() { + collect := new(CollectT) + defer func() { + ch <- collect + }() + condition(collect) + }() + case collect := <-ch: + if !collect.failed() { + return true + } + // Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached. + lastFinishedTickErrs = collect.errors + tick = ticker.C + } + } +} + +// Never asserts that the given condition doesn't satisfy in waitFor time, +// periodically checking the target function each tick. +// +// assert.Never(t, func() bool { return false; }, time.Second, 10*time.Millisecond) +func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + ch := make(chan bool, 1) + + timer := time.NewTimer(waitFor) + defer timer.Stop() + + ticker := time.NewTicker(tick) + defer ticker.Stop() + + for tick := ticker.C; ; { + select { + case <-timer.C: + return true + case <-tick: + tick = nil + go func() { ch <- condition() }() + case v := <-ch: + if v { + return Fail(t, "Condition satisfied", msgAndArgs...) + } + tick = ticker.C + } + } +} + +// ErrorIs asserts that at least one of the errors in err's chain matches target. +// This is a wrapper for errors.Is. +func ErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if errors.Is(err, target) { + return true + } + + var expectedText string + if target != nil { + expectedText = target.Error() + } + + chain := buildErrorChainString(err) + + return Fail(t, fmt.Sprintf("Target error should be in err chain:\n"+ + "expected: %q\n"+ + "in chain: %s", expectedText, chain, + ), msgAndArgs...) +} + +// NotErrorIs asserts that none of the errors in err's chain matches target. +// This is a wrapper for errors.Is. +func NotErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if !errors.Is(err, target) { + return true + } + + var expectedText string + if target != nil { + expectedText = target.Error() + } + + chain := buildErrorChainString(err) + + return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+ + "found: %q\n"+ + "in chain: %s", expectedText, chain, + ), msgAndArgs...) +} + +// ErrorAs asserts that at least one of the errors in err's chain matches target, and if so, sets target to that error value. +// This is a wrapper for errors.As. +func ErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if errors.As(err, target) { + return true + } + + chain := buildErrorChainString(err) + + return Fail(t, fmt.Sprintf("Should be in error chain:\n"+ + "expected: %q\n"+ + "in chain: %s", target, chain, + ), msgAndArgs...) +} + +// NotErrorAs asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func NotErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if !errors.As(err, target) { + return true + } + + chain := buildErrorChainString(err) + + return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+ + "found: %q\n"+ + "in chain: %s", target, chain, + ), msgAndArgs...) +} + +func buildErrorChainString(err error) string { + if err == nil { + return "" + } + + e := errors.Unwrap(err) + chain := fmt.Sprintf("%q", err.Error()) + for e != nil { + chain += fmt.Sprintf("\n\t%q", e.Error()) + e = errors.Unwrap(e) + } + return chain +} diff --git a/vendor/github.com/stretchr/testify/assert/doc.go b/vendor/github.com/stretchr/testify/assert/doc.go new file mode 100644 index 0000000..4953981 --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/doc.go @@ -0,0 +1,46 @@ +// Package assert provides a set of comprehensive testing tools for use with the normal Go testing system. +// +// # Example Usage +// +// The following is a complete example using assert in a standard test function: +// +// import ( +// "testing" +// "github.com/stretchr/testify/assert" +// ) +// +// func TestSomething(t *testing.T) { +// +// var a string = "Hello" +// var b string = "Hello" +// +// assert.Equal(t, a, b, "The two words should be the same.") +// +// } +// +// if you assert many times, use the format below: +// +// import ( +// "testing" +// "github.com/stretchr/testify/assert" +// ) +// +// func TestSomething(t *testing.T) { +// assert := assert.New(t) +// +// var a string = "Hello" +// var b string = "Hello" +// +// assert.Equal(a, b, "The two words should be the same.") +// } +// +// # Assertions +// +// Assertions allow you to easily write test code, and are global funcs in the `assert` package. +// All assertion functions take, as the first argument, the `*testing.T` object provided by the +// testing framework. This allows the assertion funcs to write the failings and other details to +// the correct place. +// +// Every assertion function also takes an optional string message as the final argument, +// allowing custom error messages to be appended to the message the assertion method outputs. +package assert diff --git a/vendor/github.com/stretchr/testify/assert/errors.go b/vendor/github.com/stretchr/testify/assert/errors.go new file mode 100644 index 0000000..ac9dc9d --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/errors.go @@ -0,0 +1,10 @@ +package assert + +import ( + "errors" +) + +// AnError is an error instance useful for testing. If the code does not care +// about error specifics, and only needs to return the error for example, this +// error should be used to make the test code more readable. +var AnError = errors.New("assert.AnError general error for testing") diff --git a/vendor/github.com/stretchr/testify/assert/forward_assertions.go b/vendor/github.com/stretchr/testify/assert/forward_assertions.go new file mode 100644 index 0000000..df189d2 --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/forward_assertions.go @@ -0,0 +1,16 @@ +package assert + +// Assertions provides assertion methods around the +// TestingT interface. +type Assertions struct { + t TestingT +} + +// New makes a new Assertions object for the specified TestingT. +func New(t TestingT) *Assertions { + return &Assertions{ + t: t, + } +} + +//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_forward.go.tmpl -include-format-funcs" diff --git a/vendor/github.com/stretchr/testify/assert/http_assertions.go b/vendor/github.com/stretchr/testify/assert/http_assertions.go new file mode 100644 index 0000000..861ed4b --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/http_assertions.go @@ -0,0 +1,165 @@ +package assert + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" +) + +// httpCode is a helper that returns HTTP code of the response. It returns -1 and +// an error if building a new request fails. +func httpCode(handler http.HandlerFunc, method, url string, values url.Values) (int, error) { + w := httptest.NewRecorder() + req, err := http.NewRequest(method, url, http.NoBody) + if err != nil { + return -1, err + } + req.URL.RawQuery = values.Encode() + handler(w, req) + return w.Code, nil +} + +// HTTPSuccess asserts that a specified handler returns a success status code. +// +// assert.HTTPSuccess(t, myHandler, "POST", "http://www.google.com", nil) +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + code, err := httpCode(handler, method, url, values) + if err != nil { + Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...) + } + + isSuccessCode := code >= http.StatusOK && code <= http.StatusPartialContent + if !isSuccessCode { + Fail(t, fmt.Sprintf("Expected HTTP success status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...) + } + + return isSuccessCode +} + +// HTTPRedirect asserts that a specified handler returns a redirect status code. +// +// assert.HTTPRedirect(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + code, err := httpCode(handler, method, url, values) + if err != nil { + Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...) + } + + isRedirectCode := code >= http.StatusMultipleChoices && code <= http.StatusTemporaryRedirect + if !isRedirectCode { + Fail(t, fmt.Sprintf("Expected HTTP redirect status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...) + } + + return isRedirectCode +} + +// HTTPError asserts that a specified handler returns an error status code. +// +// assert.HTTPError(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + code, err := httpCode(handler, method, url, values) + if err != nil { + Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...) + } + + isErrorCode := code >= http.StatusBadRequest + if !isErrorCode { + Fail(t, fmt.Sprintf("Expected HTTP error status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...) + } + + return isErrorCode +} + +// HTTPStatusCode asserts that a specified handler returns a specified status code. +// +// assert.HTTPStatusCode(t, myHandler, "GET", "/notImplemented", nil, 501) +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, statuscode int, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + code, err := httpCode(handler, method, url, values) + if err != nil { + Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...) + } + + successful := code == statuscode + if !successful { + Fail(t, fmt.Sprintf("Expected HTTP status code %d for %q but received %d", statuscode, url+"?"+values.Encode(), code), msgAndArgs...) + } + + return successful +} + +// HTTPBody is a helper that returns HTTP body of the response. It returns +// empty string if building a new request fails. +func HTTPBody(handler http.HandlerFunc, method, url string, values url.Values) string { + w := httptest.NewRecorder() + if len(values) > 0 { + url += "?" + values.Encode() + } + req, err := http.NewRequest(method, url, http.NoBody) + if err != nil { + return "" + } + handler(w, req) + return w.Body.String() +} + +// HTTPBodyContains asserts that a specified handler returns a +// body that contains a string. +// +// assert.HTTPBodyContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + body := HTTPBody(handler, method, url, values) + + contains := strings.Contains(body, fmt.Sprint(str)) + if !contains { + Fail(t, fmt.Sprintf("Expected response body for \"%s\" to contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...) + } + + return contains +} + +// HTTPBodyNotContains asserts that a specified handler returns a +// body that does not contain a string. +// +// assert.HTTPBodyNotContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + body := HTTPBody(handler, method, url, values) + + contains := strings.Contains(body, fmt.Sprint(str)) + if contains { + Fail(t, fmt.Sprintf("Expected response body for \"%s\" to NOT contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...) + } + + return !contains +} diff --git a/vendor/github.com/stretchr/testify/assert/yaml/yaml_custom.go b/vendor/github.com/stretchr/testify/assert/yaml/yaml_custom.go new file mode 100644 index 0000000..baa0cc7 --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/yaml/yaml_custom.go @@ -0,0 +1,25 @@ +//go:build testify_yaml_custom && !testify_yaml_fail && !testify_yaml_default +// +build testify_yaml_custom,!testify_yaml_fail,!testify_yaml_default + +// Package yaml is an implementation of YAML functions that calls a pluggable implementation. +// +// This implementation is selected with the testify_yaml_custom build tag. +// +// go test -tags testify_yaml_custom +// +// This implementation can be used at build time to replace the default implementation +// to avoid linking with [gopkg.in/yaml.v3]. +// +// In your test package: +// +// import assertYaml "github.com/stretchr/testify/assert/yaml" +// +// func init() { +// assertYaml.Unmarshal = func (in []byte, out interface{}) error { +// // ... +// return nil +// } +// } +package yaml + +var Unmarshal func(in []byte, out interface{}) error diff --git a/vendor/github.com/stretchr/testify/assert/yaml/yaml_default.go b/vendor/github.com/stretchr/testify/assert/yaml/yaml_default.go new file mode 100644 index 0000000..b83c6cf --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/yaml/yaml_default.go @@ -0,0 +1,37 @@ +//go:build !testify_yaml_fail && !testify_yaml_custom +// +build !testify_yaml_fail,!testify_yaml_custom + +// Package yaml is just an indirection to handle YAML deserialization. +// +// This package is just an indirection that allows the builder to override the +// indirection with an alternative implementation of this package that uses +// another implementation of YAML deserialization. This allows to not either not +// use YAML deserialization at all, or to use another implementation than +// [gopkg.in/yaml.v3] (for example for license compatibility reasons, see [PR #1120]). +// +// Alternative implementations are selected using build tags: +// +// - testify_yaml_fail: [Unmarshal] always fails with an error +// - testify_yaml_custom: [Unmarshal] is a variable. Caller must initialize it +// before calling any of [github.com/stretchr/testify/assert.YAMLEq] or +// [github.com/stretchr/testify/assert.YAMLEqf]. +// +// Usage: +// +// go test -tags testify_yaml_fail +// +// You can check with "go list" which implementation is linked: +// +// go list -f '{{.Imports}}' github.com/stretchr/testify/assert/yaml +// go list -tags testify_yaml_fail -f '{{.Imports}}' github.com/stretchr/testify/assert/yaml +// go list -tags testify_yaml_custom -f '{{.Imports}}' github.com/stretchr/testify/assert/yaml +// +// [PR #1120]: https://github.com/stretchr/testify/pull/1120 +package yaml + +import goyaml "gopkg.in/yaml.v3" + +// Unmarshal is just a wrapper of [gopkg.in/yaml.v3.Unmarshal]. +func Unmarshal(in []byte, out interface{}) error { + return goyaml.Unmarshal(in, out) +} diff --git a/vendor/github.com/stretchr/testify/assert/yaml/yaml_fail.go b/vendor/github.com/stretchr/testify/assert/yaml/yaml_fail.go new file mode 100644 index 0000000..e78f7df --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/yaml/yaml_fail.go @@ -0,0 +1,18 @@ +//go:build testify_yaml_fail && !testify_yaml_custom && !testify_yaml_default +// +build testify_yaml_fail,!testify_yaml_custom,!testify_yaml_default + +// Package yaml is an implementation of YAML functions that always fail. +// +// This implementation can be used at build time to replace the default implementation +// to avoid linking with [gopkg.in/yaml.v3]: +// +// go test -tags testify_yaml_fail +package yaml + +import "errors" + +var errNotImplemented = errors.New("YAML functions are not available (see https://pkg.go.dev/github.com/stretchr/testify/assert/yaml)") + +func Unmarshal([]byte, interface{}) error { + return errNotImplemented +} diff --git a/vendor/github.com/stretchr/testify/mock/doc.go b/vendor/github.com/stretchr/testify/mock/doc.go new file mode 100644 index 0000000..d6b3c84 --- /dev/null +++ b/vendor/github.com/stretchr/testify/mock/doc.go @@ -0,0 +1,44 @@ +// Package mock provides a system by which it is possible to mock your objects +// and verify calls are happening as expected. +// +// # Example Usage +// +// The mock package provides an object, Mock, that tracks activity on another object. It is usually +// embedded into a test object as shown below: +// +// type MyTestObject struct { +// // add a Mock object instance +// mock.Mock +// +// // other fields go here as normal +// } +// +// When implementing the methods of an interface, you wire your functions up +// to call the Mock.Called(args...) method, and return the appropriate values. +// +// For example, to mock a method that saves the name and age of a person and returns +// the year of their birth or an error, you might write this: +// +// func (o *MyTestObject) SavePersonDetails(firstname, lastname string, age int) (int, error) { +// args := o.Called(firstname, lastname, age) +// return args.Int(0), args.Error(1) +// } +// +// The Int, Error and Bool methods are examples of strongly typed getters that take the argument +// index position. Given this argument list: +// +// (12, true, "Something") +// +// You could read them out strongly typed like this: +// +// args.Int(0) +// args.Bool(1) +// args.String(2) +// +// For objects of your own type, use the generic Arguments.Get(index) method and make a type assertion: +// +// return args.Get(0).(*MyObject), args.Get(1).(*AnotherObjectOfMine) +// +// This may cause a panic if the object you are getting is nil (the type assertion will fail), in those +// cases you should check for nil first. +package mock diff --git a/vendor/github.com/stretchr/testify/mock/mock.go b/vendor/github.com/stretchr/testify/mock/mock.go new file mode 100644 index 0000000..eb5682d --- /dev/null +++ b/vendor/github.com/stretchr/testify/mock/mock.go @@ -0,0 +1,1288 @@ +package mock + +import ( + "errors" + "fmt" + "path" + "reflect" + "regexp" + "runtime" + "strings" + "sync" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/pmezard/go-difflib/difflib" + "github.com/stretchr/objx" + + "github.com/stretchr/testify/assert" +) + +// regex for GCCGO functions +var gccgoRE = regexp.MustCompile(`\.pN\d+_`) + +// TestingT is an interface wrapper around *testing.T +type TestingT interface { + Logf(format string, args ...interface{}) + Errorf(format string, args ...interface{}) + FailNow() +} + +/* + Call +*/ + +// Call represents a method call and is used for setting expectations, +// as well as recording activity. +type Call struct { + Parent *Mock + + // The name of the method that was or will be called. + Method string + + // Holds the arguments of the method. + Arguments Arguments + + // Holds the arguments that should be returned when + // this method is called. + ReturnArguments Arguments + + // Holds the caller info for the On() call + callerInfo []string + + // The number of times to return the return arguments when setting + // expectations. 0 means to always return the value. + Repeatability int + + // Amount of times this call has been called + totalCalls int + + // Call to this method can be optional + optional bool + + // Holds a channel that will be used to block the Return until it either + // receives a message or is closed. nil means it returns immediately. + WaitFor <-chan time.Time + + waitTime time.Duration + + // Holds a handler used to manipulate arguments content that are passed by + // reference. It's useful when mocking methods such as unmarshalers or + // decoders. + RunFn func(Arguments) + + // PanicMsg holds msg to be used to mock panic on the function call + // if the PanicMsg is set to a non nil string the function call will panic + // irrespective of other settings + PanicMsg *string + + // Calls which must be satisfied before this call can be + requires []*Call +} + +func newCall(parent *Mock, methodName string, callerInfo []string, methodArguments Arguments, returnArguments Arguments) *Call { + return &Call{ + Parent: parent, + Method: methodName, + Arguments: methodArguments, + ReturnArguments: returnArguments, + callerInfo: callerInfo, + Repeatability: 0, + WaitFor: nil, + RunFn: nil, + PanicMsg: nil, + } +} + +func (c *Call) lock() { + c.Parent.mutex.Lock() +} + +func (c *Call) unlock() { + c.Parent.mutex.Unlock() +} + +// Return specifies the return arguments for the expectation. +// +// Mock.On("DoSomething").Return(errors.New("failed")) +func (c *Call) Return(returnArguments ...interface{}) *Call { + c.lock() + defer c.unlock() + + c.ReturnArguments = returnArguments + + return c +} + +// Panic specifies if the function call should fail and the panic message +// +// Mock.On("DoSomething").Panic("test panic") +func (c *Call) Panic(msg string) *Call { + c.lock() + defer c.unlock() + + c.PanicMsg = &msg + + return c +} + +// Once indicates that the mock should only return the value once. +// +// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once() +func (c *Call) Once() *Call { + return c.Times(1) +} + +// Twice indicates that the mock should only return the value twice. +// +// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice() +func (c *Call) Twice() *Call { + return c.Times(2) +} + +// Times indicates that the mock should only return the indicated number +// of times. +// +// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5) +func (c *Call) Times(i int) *Call { + c.lock() + defer c.unlock() + c.Repeatability = i + return c +} + +// WaitUntil sets the channel that will block the mock's return until its closed +// or a message is received. +// +// Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second)) +func (c *Call) WaitUntil(w <-chan time.Time) *Call { + c.lock() + defer c.unlock() + c.WaitFor = w + return c +} + +// After sets how long to block until the call returns +// +// Mock.On("MyMethod", arg1, arg2).After(time.Second) +func (c *Call) After(d time.Duration) *Call { + c.lock() + defer c.unlock() + c.waitTime = d + return c +} + +// Run sets a handler to be called before returning. It can be used when +// mocking a method (such as an unmarshaler) that takes a pointer to a struct and +// sets properties in such struct +// +// Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}")).Return().Run(func(args Arguments) { +// arg := args.Get(0).(*map[string]interface{}) +// arg["foo"] = "bar" +// }) +func (c *Call) Run(fn func(args Arguments)) *Call { + c.lock() + defer c.unlock() + c.RunFn = fn + return c +} + +// Maybe allows the method call to be optional. Not calling an optional method +// will not cause an error while asserting expectations +func (c *Call) Maybe() *Call { + c.lock() + defer c.unlock() + c.optional = true + return c +} + +// On chains a new expectation description onto the mocked interface. This +// allows syntax like. +// +// Mock. +// On("MyMethod", 1).Return(nil). +// On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error")) +// +//go:noinline +func (c *Call) On(methodName string, arguments ...interface{}) *Call { + return c.Parent.On(methodName, arguments...) +} + +// Unset removes a mock handler from being called. +// +// test.On("func", mock.Anything).Unset() +func (c *Call) Unset() *Call { + var unlockOnce sync.Once + + for _, arg := range c.Arguments { + if v := reflect.ValueOf(arg); v.Kind() == reflect.Func { + panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg)) + } + } + + c.lock() + defer unlockOnce.Do(c.unlock) + + foundMatchingCall := false + + // in-place filter slice for calls to be removed - iterate from 0'th to last skipping unnecessary ones + var index int // write index + for _, call := range c.Parent.ExpectedCalls { + if call.Method == c.Method { + _, diffCount := call.Arguments.Diff(c.Arguments) + if diffCount == 0 { + foundMatchingCall = true + // Remove from ExpectedCalls - just skip it + continue + } + } + c.Parent.ExpectedCalls[index] = call + index++ + } + // trim slice up to last copied index + c.Parent.ExpectedCalls = c.Parent.ExpectedCalls[:index] + + if !foundMatchingCall { + unlockOnce.Do(c.unlock) + c.Parent.fail("\n\nmock: Could not find expected call\n-----------------------------\n\n%s\n\n", + callString(c.Method, c.Arguments, true), + ) + } + + return c +} + +// NotBefore indicates that the mock should only be called after the referenced +// calls have been called as expected. The referenced calls may be from the +// same mock instance and/or other mock instances. +// +// Mock.On("Do").Return(nil).NotBefore( +// Mock.On("Init").Return(nil) +// ) +func (c *Call) NotBefore(calls ...*Call) *Call { + c.lock() + defer c.unlock() + + for _, call := range calls { + if call.Parent == nil { + panic("not before calls must be created with Mock.On()") + } + } + + c.requires = append(c.requires, calls...) + return c +} + +// InOrder defines the order in which the calls should be made +// +// For example: +// +// InOrder( +// Mock.On("init").Return(nil), +// Mock.On("Do").Return(nil), +// ) +func InOrder(calls ...*Call) { + for i := 1; i < len(calls); i++ { + calls[i].NotBefore(calls[i-1]) + } +} + +// Mock is the workhorse used to track activity on another object. +// For an example of its usage, refer to the "Example Usage" section at the top +// of this document. +type Mock struct { + // Represents the calls that are expected of + // an object. + ExpectedCalls []*Call + + // Holds the calls that were made to this mocked object. + Calls []Call + + // test is An optional variable that holds the test struct, to be used when an + // invalid mock call was made. + test TestingT + + // TestData holds any data that might be useful for testing. Testify ignores + // this data completely allowing you to do whatever you like with it. + testData objx.Map + + mutex sync.Mutex +} + +// String provides a %v format string for Mock. +// Note: this is used implicitly by Arguments.Diff if a Mock is passed. +// It exists because go's default %v formatting traverses the struct +// without acquiring the mutex, which is detected by go test -race. +func (m *Mock) String() string { + return fmt.Sprintf("%[1]T<%[1]p>", m) +} + +// TestData holds any data that might be useful for testing. Testify ignores +// this data completely allowing you to do whatever you like with it. +func (m *Mock) TestData() objx.Map { + if m.testData == nil { + m.testData = make(objx.Map) + } + + return m.testData +} + +/* + Setting expectations +*/ + +// Test sets the test struct variable of the mock object +func (m *Mock) Test(t TestingT) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.test = t +} + +// fail fails the current test with the given formatted format and args. +// In case that a test was defined, it uses the test APIs for failing a test, +// otherwise it uses panic. +func (m *Mock) fail(format string, args ...interface{}) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.test == nil { + panic(fmt.Sprintf(format, args...)) + } + m.test.Errorf(format, args...) + m.test.FailNow() +} + +// On starts a description of an expectation of the specified method +// being called. +// +// Mock.On("MyMethod", arg1, arg2) +func (m *Mock) On(methodName string, arguments ...interface{}) *Call { + for _, arg := range arguments { + if v := reflect.ValueOf(arg); v.Kind() == reflect.Func { + panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg)) + } + } + + m.mutex.Lock() + defer m.mutex.Unlock() + + c := newCall(m, methodName, assert.CallerInfo(), arguments, make([]interface{}, 0)) + m.ExpectedCalls = append(m.ExpectedCalls, c) + return c +} + +// /* +// Recording and responding to activity +// */ + +func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) { + var expectedCall *Call + + for i, call := range m.ExpectedCalls { + if call.Method == method { + _, diffCount := call.Arguments.Diff(arguments) + if diffCount == 0 { + expectedCall = call + if call.Repeatability > -1 { + return i, call + } + } + } + } + + return -1, expectedCall +} + +type matchCandidate struct { + call *Call + mismatch string + diffCount int +} + +func (c matchCandidate) isBetterMatchThan(other matchCandidate) bool { + if c.call == nil { + return false + } + if other.call == nil { + return true + } + + if c.diffCount > other.diffCount { + return false + } + if c.diffCount < other.diffCount { + return true + } + + if c.call.Repeatability > 0 && other.call.Repeatability <= 0 { + return true + } + return false +} + +func (m *Mock) findClosestCall(method string, arguments ...interface{}) (*Call, string) { + var bestMatch matchCandidate + + for _, call := range m.expectedCalls() { + if call.Method == method { + + errInfo, tempDiffCount := call.Arguments.Diff(arguments) + tempCandidate := matchCandidate{ + call: call, + mismatch: errInfo, + diffCount: tempDiffCount, + } + if tempCandidate.isBetterMatchThan(bestMatch) { + bestMatch = tempCandidate + } + } + } + + return bestMatch.call, bestMatch.mismatch +} + +func callString(method string, arguments Arguments, includeArgumentValues bool) string { + var argValsString string + if includeArgumentValues { + var argVals []string + for argIndex, arg := range arguments { + if _, ok := arg.(*FunctionalOptionsArgument); ok { + argVals = append(argVals, fmt.Sprintf("%d: %s", argIndex, arg)) + continue + } + argVals = append(argVals, fmt.Sprintf("%d: %#v", argIndex, arg)) + } + argValsString = fmt.Sprintf("\n\t\t%s", strings.Join(argVals, "\n\t\t")) + } + + return fmt.Sprintf("%s(%s)%s", method, arguments.String(), argValsString) +} + +// Called tells the mock object that a method has been called, and gets an array +// of arguments to return. Panics if the call is unexpected (i.e. not preceded by +// appropriate .On .Return() calls) +// If Call.WaitFor is set, blocks until the channel is closed or receives a message. +func (m *Mock) Called(arguments ...interface{}) Arguments { + // get the calling function's name + pc, _, _, ok := runtime.Caller(1) + if !ok { + panic("Couldn't get the caller information") + } + functionPath := runtime.FuncForPC(pc).Name() + // Next four lines are required to use GCCGO function naming conventions. + // For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock + // uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree + // With GCCGO we need to remove interface information starting from pN

. + if gccgoRE.MatchString(functionPath) { + functionPath = gccgoRE.Split(functionPath, -1)[0] + } + parts := strings.Split(functionPath, ".") + functionName := parts[len(parts)-1] + return m.MethodCalled(functionName, arguments...) +} + +// MethodCalled tells the mock object that the given method has been called, and gets +// an array of arguments to return. Panics if the call is unexpected (i.e. not preceded +// by appropriate .On .Return() calls) +// If Call.WaitFor is set, blocks until the channel is closed or receives a message. +func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Arguments { + m.mutex.Lock() + // TODO: could combine expected and closes in single loop + found, call := m.findExpectedCall(methodName, arguments...) + + if found < 0 { + // expected call found, but it has already been called with repeatable times + if call != nil { + m.mutex.Unlock() + m.fail("\nassert: mock: The method has been called over %d times.\n\tEither do one more Mock.On(\"%s\").Return(...), or remove extra call.\n\tThis call was unexpected:\n\t\t%s\n\tat: %s", call.totalCalls, methodName, callString(methodName, arguments, true), assert.CallerInfo()) + } + // we have to fail here - because we don't know what to do + // as the return arguments. This is because: + // + // a) this is a totally unexpected call to this method, + // b) the arguments are not what was expected, or + // c) the developer has forgotten to add an accompanying On...Return pair. + closestCall, mismatch := m.findClosestCall(methodName, arguments...) + m.mutex.Unlock() + + if closestCall != nil { + m.fail("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n\n%s\nDiff: %s\nat: %s\n", + callString(methodName, arguments, true), + callString(methodName, closestCall.Arguments, true), + diffArguments(closestCall.Arguments, arguments), + strings.TrimSpace(mismatch), + assert.CallerInfo(), + ) + } else { + m.fail("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", methodName, methodName, callString(methodName, arguments, true), assert.CallerInfo()) + } + } + + for _, requirement := range call.requires { + if satisfied, _ := requirement.Parent.checkExpectation(requirement); !satisfied { + m.mutex.Unlock() + m.fail("mock: Unexpected Method Call\n-----------------------------\n\n%s\n\nMust not be called before%s:\n\n%s", + callString(call.Method, call.Arguments, true), + func() (s string) { + if requirement.totalCalls > 0 { + s = " another call of" + } + if call.Parent != requirement.Parent { + s += " method from another mock instance" + } + return + }(), + callString(requirement.Method, requirement.Arguments, true), + ) + } + } + + if call.Repeatability == 1 { + call.Repeatability = -1 + } else if call.Repeatability > 1 { + call.Repeatability-- + } + call.totalCalls++ + + // add the call + m.Calls = append(m.Calls, *newCall(m, methodName, assert.CallerInfo(), arguments, call.ReturnArguments)) + m.mutex.Unlock() + + // block if specified + if call.WaitFor != nil { + <-call.WaitFor + } else { + time.Sleep(call.waitTime) + } + + m.mutex.Lock() + panicMsg := call.PanicMsg + m.mutex.Unlock() + if panicMsg != nil { + panic(*panicMsg) + } + + m.mutex.Lock() + runFn := call.RunFn + m.mutex.Unlock() + + if runFn != nil { + runFn(arguments) + } + + m.mutex.Lock() + returnArgs := call.ReturnArguments + m.mutex.Unlock() + + return returnArgs +} + +/* + Assertions +*/ + +type assertExpectationiser interface { + AssertExpectations(TestingT) bool +} + +// AssertExpectationsForObjects asserts that everything specified with On and Return +// of the specified objects was in fact called as expected. +// +// Calls may have occurred in any order. +func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + for _, obj := range testObjects { + if m, ok := obj.(*Mock); ok { + t.Logf("Deprecated mock.AssertExpectationsForObjects(myMock.Mock) use mock.AssertExpectationsForObjects(myMock)") + obj = m + } + m := obj.(assertExpectationiser) + if !m.AssertExpectations(t) { + t.Logf("Expectations didn't match for Mock: %+v", reflect.TypeOf(m)) + return false + } + } + return true +} + +// AssertExpectations asserts that everything specified with On and Return was +// in fact called as expected. Calls may have occurred in any order. +func (m *Mock) AssertExpectations(t TestingT) bool { + if s, ok := t.(interface{ Skipped() bool }); ok && s.Skipped() { + return true + } + if h, ok := t.(tHelper); ok { + h.Helper() + } + + m.mutex.Lock() + defer m.mutex.Unlock() + var failedExpectations int + + // iterate through each expectation + expectedCalls := m.expectedCalls() + for _, expectedCall := range expectedCalls { + satisfied, reason := m.checkExpectation(expectedCall) + if !satisfied { + failedExpectations++ + t.Logf(reason) + } + } + + if failedExpectations != 0 { + t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(expectedCalls)-failedExpectations, len(expectedCalls), failedExpectations, assert.CallerInfo()) + } + + return failedExpectations == 0 +} + +func (m *Mock) checkExpectation(call *Call) (bool, string) { + if !call.optional && !m.methodWasCalled(call.Method, call.Arguments) && call.totalCalls == 0 { + return false, fmt.Sprintf("FAIL:\t%s(%s)\n\t\tat: %s", call.Method, call.Arguments.String(), call.callerInfo) + } + if call.Repeatability > 0 { + return false, fmt.Sprintf("FAIL:\t%s(%s)\n\t\tat: %s", call.Method, call.Arguments.String(), call.callerInfo) + } + return true, fmt.Sprintf("PASS:\t%s(%s)", call.Method, call.Arguments.String()) +} + +// AssertNumberOfCalls asserts that the method was called expectedCalls times. +func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls int) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + m.mutex.Lock() + defer m.mutex.Unlock() + var actualCalls int + for _, call := range m.calls() { + if call.Method == methodName { + actualCalls++ + } + } + return assert.Equal(t, expectedCalls, actualCalls, fmt.Sprintf("Expected number of calls (%d) does not match the actual number of calls (%d).", expectedCalls, actualCalls)) +} + +// AssertCalled asserts that the method was called. +// It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method. +func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + m.mutex.Lock() + defer m.mutex.Unlock() + if !m.methodWasCalled(methodName, arguments) { + var calledWithArgs []string + for _, call := range m.calls() { + calledWithArgs = append(calledWithArgs, fmt.Sprintf("%v", call.Arguments)) + } + if len(calledWithArgs) == 0 { + return assert.Fail(t, "Should have called with given arguments", + fmt.Sprintf("Expected %q to have been called with:\n%v\nbut no actual calls happened", methodName, arguments)) + } + return assert.Fail(t, "Should have called with given arguments", + fmt.Sprintf("Expected %q to have been called with:\n%v\nbut actual calls were:\n %v", methodName, arguments, strings.Join(calledWithArgs, "\n"))) + } + return true +} + +// AssertNotCalled asserts that the method was not called. +// It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method. +func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + m.mutex.Lock() + defer m.mutex.Unlock() + if m.methodWasCalled(methodName, arguments) { + return assert.Fail(t, "Should not have called with given arguments", + fmt.Sprintf("Expected %q to not have been called with:\n%v\nbut actually it was.", methodName, arguments)) + } + return true +} + +// IsMethodCallable checking that the method can be called +// If the method was called more than `Repeatability` return false +func (m *Mock) IsMethodCallable(t TestingT, methodName string, arguments ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + m.mutex.Lock() + defer m.mutex.Unlock() + + for _, v := range m.ExpectedCalls { + if v.Method != methodName { + continue + } + if len(arguments) != len(v.Arguments) { + continue + } + if v.Repeatability < v.totalCalls { + continue + } + if isArgsEqual(v.Arguments, arguments) { + return true + } + } + return false +} + +// isArgsEqual compares arguments +func isArgsEqual(expected Arguments, args []interface{}) bool { + if len(expected) != len(args) { + return false + } + for i, v := range args { + if !reflect.DeepEqual(expected[i], v) { + return false + } + } + return true +} + +func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool { + for _, call := range m.calls() { + if call.Method == methodName { + + _, differences := Arguments(expected).Diff(call.Arguments) + + if differences == 0 { + // found the expected call + return true + } + + } + } + // we didn't find the expected call + return false +} + +func (m *Mock) expectedCalls() []*Call { + return append([]*Call{}, m.ExpectedCalls...) +} + +func (m *Mock) calls() []Call { + return append([]Call{}, m.Calls...) +} + +/* + Arguments +*/ + +// Arguments holds an array of method arguments or return values. +type Arguments []interface{} + +const ( + // Anything is used in Diff and Assert when the argument being tested + // shouldn't be taken into consideration. + Anything = "mock.Anything" +) + +// AnythingOfTypeArgument contains the type of an argument +// for use when type checking. Used in [Arguments.Diff] and [Arguments.Assert]. +// +// Deprecated: this is an implementation detail that must not be used. Use the [AnythingOfType] constructor instead, example: +// +// m.On("Do", mock.AnythingOfType("string")) +// +// All explicit type declarations can be replaced with interface{} as is expected by [Mock.On], example: +// +// func anyString interface{} { +// return mock.AnythingOfType("string") +// } +type AnythingOfTypeArgument = anythingOfTypeArgument + +// anythingOfTypeArgument is a string that contains the type of an argument +// for use when type checking. Used in Diff and Assert. +type anythingOfTypeArgument string + +// AnythingOfType returns a special value containing the +// name of the type to check for. The type name will be matched against the type name returned by [reflect.Type.String]. +// +// Used in Diff and Assert. +// +// For example: +// +// args.Assert(t, AnythingOfType("string"), AnythingOfType("int")) +func AnythingOfType(t string) AnythingOfTypeArgument { + return anythingOfTypeArgument(t) +} + +// IsTypeArgument is a struct that contains the type of an argument +// for use when type checking. This is an alternative to [AnythingOfType]. +// Used in [Arguments.Diff] and [Arguments.Assert]. +type IsTypeArgument struct { + t reflect.Type +} + +// IsType returns an IsTypeArgument object containing the type to check for. +// You can provide a zero-value of the type to check. This is an +// alternative to [AnythingOfType]. Used in [Arguments.Diff] and [Arguments.Assert]. +// +// For example: +// +// args.Assert(t, IsType(""), IsType(0)) +func IsType(t interface{}) *IsTypeArgument { + return &IsTypeArgument{t: reflect.TypeOf(t)} +} + +// FunctionalOptionsArgument contains a list of functional options arguments +// expected for use when matching a list of arguments. +type FunctionalOptionsArgument struct { + values []interface{} +} + +// String returns the string representation of FunctionalOptionsArgument +func (f *FunctionalOptionsArgument) String() string { + var name string + if len(f.values) > 0 { + name = "[]" + reflect.TypeOf(f.values[0]).String() + } + + return strings.Replace(fmt.Sprintf("%#v", f.values), "[]interface {}", name, 1) +} + +// FunctionalOptions returns an [FunctionalOptionsArgument] object containing +// the expected functional-options to check for. +// +// For example: +// +// args.Assert(t, FunctionalOptions(foo.Opt1("strValue"), foo.Opt2(613))) +func FunctionalOptions(values ...interface{}) *FunctionalOptionsArgument { + return &FunctionalOptionsArgument{ + values: values, + } +} + +// argumentMatcher performs custom argument matching, returning whether or +// not the argument is matched by the expectation fixture function. +type argumentMatcher struct { + // fn is a function which accepts one argument, and returns a bool. + fn reflect.Value +} + +func (f argumentMatcher) Matches(argument interface{}) bool { + expectType := f.fn.Type().In(0) + expectTypeNilSupported := false + switch expectType.Kind() { + case reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Slice, reflect.Ptr: + expectTypeNilSupported = true + } + + argType := reflect.TypeOf(argument) + var arg reflect.Value + if argType == nil { + arg = reflect.New(expectType).Elem() + } else { + arg = reflect.ValueOf(argument) + } + + if argType == nil && !expectTypeNilSupported { + panic(errors.New("attempting to call matcher with nil for non-nil expected type")) + } + if argType == nil || argType.AssignableTo(expectType) { + result := f.fn.Call([]reflect.Value{arg}) + return result[0].Bool() + } + return false +} + +func (f argumentMatcher) String() string { + return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).String()) +} + +// MatchedBy can be used to match a mock call based on only certain properties +// from a complex struct or some calculation. It takes a function that will be +// evaluated with the called argument and will return true when there's a match +// and false otherwise. +// +// Example: +// +// m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" })) +// +// fn must be a function accepting a single argument (of the expected type) +// which returns a bool. If fn doesn't match the required signature, +// MatchedBy() panics. +func MatchedBy(fn interface{}) argumentMatcher { + fnType := reflect.TypeOf(fn) + + if fnType.Kind() != reflect.Func { + panic(fmt.Sprintf("assert: arguments: %s is not a func", fn)) + } + if fnType.NumIn() != 1 { + panic(fmt.Sprintf("assert: arguments: %s does not take exactly one argument", fn)) + } + if fnType.NumOut() != 1 || fnType.Out(0).Kind() != reflect.Bool { + panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn)) + } + + return argumentMatcher{fn: reflect.ValueOf(fn)} +} + +// Get Returns the argument at the specified index. +func (args Arguments) Get(index int) interface{} { + if index+1 > len(args) { + panic(fmt.Sprintf("assert: arguments: Cannot call Get(%d) because there are %d argument(s).", index, len(args))) + } + return args[index] +} + +// Is gets whether the objects match the arguments specified. +func (args Arguments) Is(objects ...interface{}) bool { + for i, obj := range args { + if obj != objects[i] { + return false + } + } + return true +} + +// Diff gets a string describing the differences between the arguments +// and the specified objects. +// +// Returns the diff string and number of differences found. +func (args Arguments) Diff(objects []interface{}) (string, int) { + // TODO: could return string as error and nil for No difference + + output := "\n" + var differences int + + maxArgCount := len(args) + if len(objects) > maxArgCount { + maxArgCount = len(objects) + } + + for i := 0; i < maxArgCount; i++ { + var actual, expected interface{} + var actualFmt, expectedFmt string + + if len(objects) <= i { + actual = "(Missing)" + actualFmt = "(Missing)" + } else { + actual = objects[i] + actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual) + } + + if len(args) <= i { + expected = "(Missing)" + expectedFmt = "(Missing)" + } else { + expected = args[i] + expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected) + } + + if matcher, ok := expected.(argumentMatcher); ok { + var matches bool + func() { + defer func() { + if r := recover(); r != nil { + actualFmt = fmt.Sprintf("panic in argument matcher: %v", r) + } + }() + matches = matcher.Matches(actual) + }() + if matches { + output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher) + } else { + differences++ + output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher) + } + } else { + switch expected := expected.(type) { + case anythingOfTypeArgument: + // type checking + if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) { + // not match + differences++ + output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt) + } + case *IsTypeArgument: + actualT := reflect.TypeOf(actual) + if actualT != expected.t { + differences++ + output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected.t.Name(), actualT.Name(), actualFmt) + } + case *FunctionalOptionsArgument: + var name string + if len(expected.values) > 0 { + name = "[]" + reflect.TypeOf(expected.values[0]).String() + } + + const tName = "[]interface{}" + if name != reflect.TypeOf(actual).String() && len(expected.values) != 0 { + differences++ + output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, tName, reflect.TypeOf(actual).Name(), actualFmt) + } else { + if ef, af := assertOpts(expected.values, actual); ef == "" && af == "" { + // match + output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, tName, tName) + } else { + // not match + differences++ + output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, af, ef) + } + } + + default: + if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) { + // match + output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt) + } else { + // not match + differences++ + output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, actualFmt, expectedFmt) + } + } + } + + } + + if differences == 0 { + return "No differences.", differences + } + + return output, differences +} + +// Assert compares the arguments with the specified objects and fails if +// they do not exactly match. +func (args Arguments) Assert(t TestingT, objects ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + // get the differences + diff, diffCount := args.Diff(objects) + + if diffCount == 0 { + return true + } + + // there are differences... report them... + t.Logf(diff) + t.Errorf("%sArguments do not match.", assert.CallerInfo()) + + return false +} + +// String gets the argument at the specified index. Panics if there is no argument, or +// if the argument is of the wrong type. +// +// If no index is provided, String() returns a complete string representation +// of the arguments. +func (args Arguments) String(indexOrNil ...int) string { + if len(indexOrNil) == 0 { + // normal String() method - return a string representation of the args + var argsStr []string + for _, arg := range args { + argsStr = append(argsStr, fmt.Sprintf("%T", arg)) // handles nil nicely + } + return strings.Join(argsStr, ",") + } else if len(indexOrNil) == 1 { + // Index has been specified - get the argument at that index + index := indexOrNil[0] + var s string + var ok bool + if s, ok = args.Get(index).(string); !ok { + panic(fmt.Sprintf("assert: arguments: String(%d) failed because object wasn't correct type: %s", index, args.Get(index))) + } + return s + } + + panic(fmt.Sprintf("assert: arguments: Wrong number of arguments passed to String. Must be 0 or 1, not %d", len(indexOrNil))) +} + +// Int gets the argument at the specified index. Panics if there is no argument, or +// if the argument is of the wrong type. +func (args Arguments) Int(index int) int { + var s int + var ok bool + if s, ok = args.Get(index).(int); !ok { + panic(fmt.Sprintf("assert: arguments: Int(%d) failed because object wasn't correct type: %v", index, args.Get(index))) + } + return s +} + +// Error gets the argument at the specified index. Panics if there is no argument, or +// if the argument is of the wrong type. +func (args Arguments) Error(index int) error { + obj := args.Get(index) + var s error + var ok bool + if obj == nil { + return nil + } + if s, ok = obj.(error); !ok { + panic(fmt.Sprintf("assert: arguments: Error(%d) failed because object wasn't correct type: %v", index, obj)) + } + return s +} + +// Bool gets the argument at the specified index. Panics if there is no argument, or +// if the argument is of the wrong type. +func (args Arguments) Bool(index int) bool { + var s bool + var ok bool + if s, ok = args.Get(index).(bool); !ok { + panic(fmt.Sprintf("assert: arguments: Bool(%d) failed because object wasn't correct type: %v", index, args.Get(index))) + } + return s +} + +func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) { + t := reflect.TypeOf(v) + k := t.Kind() + + if k == reflect.Ptr { + t = t.Elem() + k = t.Kind() + } + return t, k +} + +func diffArguments(expected Arguments, actual Arguments) string { + if len(expected) != len(actual) { + return fmt.Sprintf("Provided %v arguments, mocked for %v arguments", len(expected), len(actual)) + } + + for x := range expected { + if diffString := diff(expected[x], actual[x]); diffString != "" { + return fmt.Sprintf("Difference found in argument %v:\n\n%s", x, diffString) + } + } + + return "" +} + +// diff returns a diff of both values as long as both are of the same type and +// are a struct, map, slice or array. Otherwise it returns an empty string. +func diff(expected interface{}, actual interface{}) string { + if expected == nil || actual == nil { + return "" + } + + et, ek := typeAndKind(expected) + at, _ := typeAndKind(actual) + + if et != at { + return "" + } + + if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array { + return "" + } + + e := spewConfig.Sdump(expected) + a := spewConfig.Sdump(actual) + + diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{ + A: difflib.SplitLines(e), + B: difflib.SplitLines(a), + FromFile: "Expected", + FromDate: "", + ToFile: "Actual", + ToDate: "", + Context: 1, + }) + + return diff +} + +var spewConfig = spew.ConfigState{ + Indent: " ", + DisablePointerAddresses: true, + DisableCapacities: true, + SortKeys: true, +} + +type tHelper interface { + Helper() +} + +func assertOpts(expected, actual interface{}) (expectedFmt, actualFmt string) { + expectedOpts := reflect.ValueOf(expected) + actualOpts := reflect.ValueOf(actual) + + var expectedFuncs []*runtime.Func + var expectedNames []string + for i := 0; i < expectedOpts.Len(); i++ { + f := runtimeFunc(expectedOpts.Index(i).Interface()) + expectedFuncs = append(expectedFuncs, f) + expectedNames = append(expectedNames, funcName(f)) + } + var actualFuncs []*runtime.Func + var actualNames []string + for i := 0; i < actualOpts.Len(); i++ { + f := runtimeFunc(actualOpts.Index(i).Interface()) + actualFuncs = append(actualFuncs, f) + actualNames = append(actualNames, funcName(f)) + } + + if expectedOpts.Len() != actualOpts.Len() { + expectedFmt = fmt.Sprintf("%v", expectedNames) + actualFmt = fmt.Sprintf("%v", actualNames) + return + } + + for i := 0; i < expectedOpts.Len(); i++ { + if !isFuncSame(expectedFuncs[i], actualFuncs[i]) { + expectedFmt = expectedNames[i] + actualFmt = actualNames[i] + return + } + + expectedOpt := expectedOpts.Index(i).Interface() + actualOpt := actualOpts.Index(i).Interface() + + ot := reflect.TypeOf(expectedOpt) + var expectedValues []reflect.Value + var actualValues []reflect.Value + if ot.NumIn() == 0 { + return + } + + for i := 0; i < ot.NumIn(); i++ { + vt := ot.In(i).Elem() + expectedValues = append(expectedValues, reflect.New(vt)) + actualValues = append(actualValues, reflect.New(vt)) + } + + reflect.ValueOf(expectedOpt).Call(expectedValues) + reflect.ValueOf(actualOpt).Call(actualValues) + + for i := 0; i < ot.NumIn(); i++ { + if expectedArg, actualArg := expectedValues[i].Interface(), actualValues[i].Interface(); !assert.ObjectsAreEqual(expectedArg, actualArg) { + expectedFmt = fmt.Sprintf("%s(%T) -> %#v", expectedNames[i], expectedArg, expectedArg) + actualFmt = fmt.Sprintf("%s(%T) -> %#v", expectedNames[i], actualArg, actualArg) + return + } + } + } + + return "", "" +} + +func runtimeFunc(opt interface{}) *runtime.Func { + return runtime.FuncForPC(reflect.ValueOf(opt).Pointer()) +} + +func funcName(f *runtime.Func) string { + name := f.Name() + trimmed := strings.TrimSuffix(path.Base(name), path.Ext(name)) + splitted := strings.Split(trimmed, ".") + + if len(splitted) == 0 { + return trimmed + } + + return splitted[len(splitted)-1] +} + +func isFuncSame(f1, f2 *runtime.Func) bool { + f1File, f1Loc := f1.FileLine(f1.Entry()) + f2File, f2Loc := f2.FileLine(f2.Entry()) + + return f1File == f2File && f1Loc == f2Loc +} diff --git a/vendor/golang.org/x/text/language/parse.go b/vendor/golang.org/x/text/language/parse.go index 4d57222..053336e 100644 --- a/vendor/golang.org/x/text/language/parse.go +++ b/vendor/golang.org/x/text/language/parse.go @@ -59,7 +59,7 @@ func (c CanonType) Parse(s string) (t Tag, err error) { if changed { tt.RemakeString() } - return makeTag(tt), err + return makeTag(tt), nil } // Compose creates a Tag from individual parts, which may be of type Tag, Base, diff --git a/vendor/modules.txt b/vendor/modules.txt index b00d7b9..38d611c 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1,3 +1,6 @@ +# github.com/davecgh/go-spew v1.1.1 +## explicit +github.com/davecgh/go-spew/spew # github.com/fatih/color v1.18.0 ## explicit; go 1.17 github.com/fatih/color @@ -10,18 +13,29 @@ github.com/mattn/go-colorable # github.com/mattn/go-isatty v0.0.20 ## explicit; go 1.15 github.com/mattn/go-isatty +# github.com/pmezard/go-difflib v1.0.0 +## explicit +github.com/pmezard/go-difflib/difflib # github.com/spf13/cobra v1.9.1 ## explicit; go 1.15 github.com/spf13/cobra # github.com/spf13/pflag v1.0.6 ## explicit; go 1.12 github.com/spf13/pflag +# github.com/stretchr/objx v0.5.2 +## explicit; go 1.20 +github.com/stretchr/objx +# github.com/stretchr/testify v1.10.0 +## explicit; go 1.17 +github.com/stretchr/testify/assert +github.com/stretchr/testify/assert/yaml +github.com/stretchr/testify/mock # golang.org/x/sys v0.30.0 ## explicit; go 1.18 golang.org/x/sys/unix golang.org/x/sys/windows -# golang.org/x/text v0.22.0 -## explicit; go 1.18 +# golang.org/x/text v0.23.0 +## explicit; go 1.23.0 golang.org/x/text/cases golang.org/x/text/internal golang.org/x/text/internal/language