diff --git a/README.md b/README.md index 9ac651e..136856c 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,28 @@ sudo mv yankrun /usr/local/bin/ --- +
+SSH Key Configuration + +For SSH repos (`git@github.com:...`), yankrun auto-detects your SSH key in this order: + +1. `~/.ssh/id_ed25519` +2. `~/.ssh/id_ecdsa` +3. `~/.ssh/id_rsa` + +If your key has a different name or location, use the `--ssh-key` flag: + +```sh +yankrun clone --repo git@github.com:org/repo.git \ + --outputDir ./out --ssh-key ~/.ssh/my_custom_key +``` + +For HTTPS repos, no SSH key is needed. + +
+ +--- + ## Quick Start ### 1. Create a values file diff --git a/actions/clone.go b/actions/clone.go index 54bcf07..e2ea9e1 100644 --- a/actions/clone.go +++ b/actions/clone.go @@ -44,6 +44,10 @@ func (a *CloneAction) Execute(c *cli.Context) error { dryRun := c.Bool("dryRun") ignoreFlags := c.StringSlice("ignore") + if sshKey := c.String("ssh-key"); sshKey != "" { + a.cloner.SetSSHKeyPath(sshKey) + } + // Load defaults from config when flags not provided cfg, _ := services.Load() if cfg == nil { diff --git a/actions/generate.go b/actions/generate.go index 35da896..89ff719 100644 --- a/actions/generate.go +++ b/actions/generate.go @@ -44,6 +44,10 @@ func (a *GenerateAction) Execute(c *cli.Context) error { ignoreFlags := c.StringSlice("ignore") noCache := c.Bool("noCache") + if sshKey := c.String("ssh-key"); sshKey != "" { + a.cloner.SetSSHKeyPath(sshKey) + } + // Validate flag combination if onlyTemplates && !processTemplates { return fmt.Errorf("--onlyTemplates requires --processTemplates to be set") diff --git a/flags.go b/flags.go index 946029f..d1476bf 100644 --- a/flags.go +++ b/flags.go @@ -90,3 +90,9 @@ var noCacheFlag = cli.BoolFlag{ Name: "noCache, nc", Usage: "Bypass cache and fetch fresh data from remote", } + +var sshKeyFlag = cli.StringFlag{ + Name: "ssh-key", + Value: "", + Usage: "Path to SSH private key (auto-detects id_ed25519, id_ecdsa, id_rsa if not set)", +} diff --git a/main.go b/main.go index 680b90d..e12e6ee 100644 --- a/main.go +++ b/main.go @@ -53,13 +53,13 @@ func main() { Name: "clone", Aliases: []string{"r"}, Usage: "Clone a repo with template file replacements", - Flags: []cli.Flag{repoFlag, inputFlag, outputDirFlag, verboseFlag, fileSizeLimitFlag, startDelimFlag, endDelimFlag, interactiveFlag, branchFlag, processTemplatesFlag, onlyTemplatesFlag, dryRunFlag, ignoreFlag}, + Flags: []cli.Flag{repoFlag, inputFlag, outputDirFlag, verboseFlag, fileSizeLimitFlag, startDelimFlag, endDelimFlag, interactiveFlag, branchFlag, processTemplatesFlag, onlyTemplatesFlag, dryRunFlag, ignoreFlag, sshKeyFlag}, Action: cloneAction.Execute, }, { Name: "generate", Usage: "Interactively choose a template repo/branch and clone it as a new repo (removes .git)", - Flags: []cli.Flag{inputFlag, outputDirFlag, verboseFlag, fileSizeLimitFlag, startDelimFlag, endDelimFlag, interactiveFlag, templateNameFlag, branchFlag, processTemplatesFlag, onlyTemplatesFlag, dryRunFlag, ignoreFlag, noCacheFlag}, + Flags: []cli.Flag{inputFlag, outputDirFlag, verboseFlag, fileSizeLimitFlag, startDelimFlag, endDelimFlag, interactiveFlag, templateNameFlag, branchFlag, processTemplatesFlag, onlyTemplatesFlag, dryRunFlag, ignoreFlag, noCacheFlag, sshKeyFlag}, Action: generateAction.Execute, }, { diff --git a/services/cloner.go b/services/cloner.go index d2d0c92..718b34d 100644 --- a/services/cloner.go +++ b/services/cloner.go @@ -16,12 +16,14 @@ import ( type Cloner interface { CloneRepository(repoURL, outputDir string) error - CloneRepositoryBranch(repoURL, branch, outputDir string) error - ListRemoteBranches(repoURL string) ([]string, error) + CloneRepositoryBranch(repoURL, branch, outputDir string) error + ListRemoteBranches(repoURL string) ([]string, error) + SetSSHKeyPath(path string) } type GitCloner struct { FileSystem FileSystem + SSHKeyPath string // optional override; auto-detected if empty } func (gc *GitCloner) CloneRepository(repoURL, outputDir string) error { @@ -128,14 +130,36 @@ func HeadSHA(repoPath string) (string, error) { return ref.Hash().String(), nil } +func (gc *GitCloner) SetSSHKeyPath(path string) { + gc.SSHKeyPath = path +} + func (gc *GitCloner) isSSH(repoURL string) bool { return strings.HasPrefix(repoURL, "git@") || strings.HasPrefix(repoURL, "ssh://") } func (gc *GitCloner) getSSHKeyPath() (string, error) { + if gc.SSHKeyPath != "" { + if _, err := os.Stat(gc.SSHKeyPath); err != nil { + return "", fmt.Errorf("specified SSH key not found: %s", gc.SSHKeyPath) + } + return gc.SSHKeyPath, nil + } + u, err := user.Current() if err != nil { return "", err } - return filepath.Join(u.HomeDir, ".ssh", "id_rsa"), nil + sshDir := filepath.Join(u.HomeDir, ".ssh") + + // Try common key types in preference order + candidates := []string{"id_ed25519", "id_ecdsa", "id_rsa"} + for _, name := range candidates { + p := filepath.Join(sshDir, name) + if _, err := os.Stat(p); err == nil { + return p, nil + } + } + + return "", fmt.Errorf("no SSH key found in %s (tried: %s)", sshDir, strings.Join(candidates, ", ")) } diff --git a/services/cloner_test.go b/services/cloner_test.go new file mode 100644 index 0000000..0643086 --- /dev/null +++ b/services/cloner_test.go @@ -0,0 +1,143 @@ +package services + +import ( + "os" + "path/filepath" + "testing" +) + +func TestGetSSHKeyPath_AutoDetect(t *testing.T) { + // Create a temp dir to simulate ~/.ssh + tmpHome := t.TempDir() + sshDir := filepath.Join(tmpHome, ".ssh") + if err := os.MkdirAll(sshDir, 0700); err != nil { + t.Fatal(err) + } + + gc := &GitCloner{} + + // Override HOME so getSSHKeyPath resolves to our temp dir + // We can't easily override user.Current(), so test the detection logic directly + // by testing with SSHKeyPath set (override path) and the fallback candidates + + t.Run("explicit path exists", func(t *testing.T) { + keyFile := filepath.Join(sshDir, "my_custom_key") + if err := os.WriteFile(keyFile, []byte("fake-key"), 0600); err != nil { + t.Fatal(err) + } + gc.SSHKeyPath = keyFile + got, err := gc.getSSHKeyPath() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got != keyFile { + t.Errorf("expected %s, got %s", keyFile, got) + } + }) + + t.Run("explicit path does not exist", func(t *testing.T) { + gc.SSHKeyPath = filepath.Join(sshDir, "nonexistent_key") + _, err := gc.getSSHKeyPath() + if err == nil { + t.Fatal("expected error for missing key file") + } + }) + + t.Run("auto-detect prefers ed25519 over rsa", func(t *testing.T) { + gc.SSHKeyPath = "" // reset to auto-detect + + // Create both id_rsa and id_ed25519 + rsaFile := filepath.Join(sshDir, "id_rsa") + ed25519File := filepath.Join(sshDir, "id_ed25519") + os.WriteFile(rsaFile, []byte("fake-rsa"), 0600) + os.WriteFile(ed25519File, []byte("fake-ed25519"), 0600) + + // Call the auto-detect logic directly (can't override user.Current in unit test, + // so we test the candidate ordering via a helper) + candidates := []string{"id_ed25519", "id_ecdsa", "id_rsa"} + var found string + for _, name := range candidates { + p := filepath.Join(sshDir, name) + if _, err := os.Stat(p); err == nil { + found = p + break + } + } + if found != ed25519File { + t.Errorf("expected ed25519 to be preferred, got %s", found) + } + + os.Remove(rsaFile) + os.Remove(ed25519File) + }) + + t.Run("auto-detect falls back to rsa", func(t *testing.T) { + rsaFile := filepath.Join(sshDir, "id_rsa") + os.WriteFile(rsaFile, []byte("fake-rsa"), 0600) + + candidates := []string{"id_ed25519", "id_ecdsa", "id_rsa"} + var found string + for _, name := range candidates { + p := filepath.Join(sshDir, name) + if _, err := os.Stat(p); err == nil { + found = p + break + } + } + if found != rsaFile { + t.Errorf("expected rsa fallback, got %s", found) + } + + os.Remove(rsaFile) + }) + + t.Run("auto-detect no keys returns error", func(t *testing.T) { + // Empty ssh dir — no keys + emptySshDir := filepath.Join(t.TempDir(), ".ssh") + os.MkdirAll(emptySshDir, 0700) + + candidates := []string{"id_ed25519", "id_ecdsa", "id_rsa"} + found := false + for _, name := range candidates { + p := filepath.Join(emptySshDir, name) + if _, err := os.Stat(p); err == nil { + found = true + break + } + } + if found { + t.Error("expected no keys to be found") + } + }) +} + +func TestSetSSHKeyPath(t *testing.T) { + gc := &GitCloner{} + gc.SetSSHKeyPath("/custom/path/id_ed25519") + if gc.SSHKeyPath != "/custom/path/id_ed25519" { + t.Errorf("expected SSHKeyPath to be set, got %s", gc.SSHKeyPath) + } +} + +func TestIsSSH(t *testing.T) { + gc := &GitCloner{} + + tests := []struct { + url string + expect bool + }{ + {"git@github.com:org/repo.git", true}, + {"ssh://git@github.com/org/repo.git", true}, + {"https://github.com/org/repo.git", false}, + {"http://github.com/org/repo.git", false}, + } + + for _, tt := range tests { + t.Run(tt.url, func(t *testing.T) { + got := gc.isSSH(tt.url) + if got != tt.expect { + t.Errorf("isSSH(%q) = %v, want %v", tt.url, got, tt.expect) + } + }) + } +}