diff --git a/walk/git.go b/walk/git.go index 73c299d3..20321857 100644 --- a/walk/git.go +++ b/walk/git.go @@ -9,6 +9,7 @@ import ( "os/exec" "path/filepath" "strconv" + "strings" "github.com/charmbracelet/log" "github.com/numtide/treefmt/v2/git" @@ -38,7 +39,7 @@ func (g *GitReader) Read(ctx context.Context, files []*File) (n int, err error) r, w := io.Pipe() // create a command which will execute from the specified sub path within root - cmd := exec.Command("git", "ls-files", "--cached", "--others", "--exclude-standard") + cmd := exec.Command("git", "ls-files", "--cached", "--others", "--exclude-standard", "--stage") cmd.Dir = filepath.Join(g.root, g.path) cmd.Stdout = w @@ -51,19 +52,39 @@ func (g *GitReader) Read(ctx context.Context, files []*File) (n int, err error) g.scanner = bufio.NewScanner(r) } - nextLine := func() (string, error) { - line := g.scanner.Text() + nextFile := func() (string, error) { + for line := g.scanner.Text(); len(line) > 0; line = g.scanner.Text() { + lineSplit := strings.Split(line, "\t") - if len(line) == 0 || line[0] != '"' { - return line, nil - } + var stage, file string + // Untracked files just show as ``, while tracked files show as ` ` + if len(lineSplit) == 1 { + stage, file = "", lineSplit[0] + } else { + stage, file = lineSplit[0], lineSplit[1] + } + + // 160000 is the mode for submodules, skip them because they are separate projects that may have their own + // formatting rules + if strings.HasPrefix(stage, "160000") { + g.scanner.Scan() + + continue + } + + if file[0] != '"' { + return file, nil + } + + unquoted, err := strconv.Unquote(file) + if err != nil { + return "", fmt.Errorf("failed to unquote file %s: %w", file, err) + } - unquoted, err := strconv.Unquote(line) - if err != nil { - return "", fmt.Errorf("failed to unquote line %s: %w", line, err) + return unquoted, nil } - return unquoted, nil + return "", io.EOF } LOOP: @@ -82,7 +103,7 @@ LOOP: default: // read the next file if g.scanner.Scan() { - entry, err := nextLine() + entry, err := nextFile() if err != nil { return n, err } diff --git a/walk/git_test.go b/walk/git_test.go index 959bf150..c9eda4ca 100644 --- a/walk/git_test.go +++ b/walk/git_test.go @@ -19,8 +19,21 @@ func TestGitReader(t *testing.T) { tempDir := test.TempExamples(t) + // configure git username and email + cmd := exec.Command("git", "config", "--global", "user.name", "testing") + cmd.Dir = tempDir + as.NoError(cmd.Run(), "failed to set git username") + cmd = exec.Command("git", "config", "--global", "user.email", "testing@example.com") + cmd.Dir = tempDir + as.NoError(cmd.Run(), "failed to set git email") + // https://github.blog/open-source/git/git-security-vulnerabilities-announced/#cve-2022-39253 + // We only use submodules we trust + cmd = exec.Command("git", "config", "--global", "protocol.file.allow", "always") + cmd.Dir = tempDir + as.NoError(cmd.Run(), "failed to allow file protocol") + // init a git repo - cmd := exec.Command("git", "init") + cmd = exec.Command("git", "init") cmd.Dir = tempDir as.NoError(cmd.Run(), "failed to init git repository") @@ -37,6 +50,27 @@ func TestGitReader(t *testing.T) { as.Equal(33, n) as.ErrorIs(err, io.EOF) + // add a git submodule + tempSubmoduleDir := test.TempExamples(t) + cmd = exec.Command("git", "init") + cmd.Dir = tempSubmoduleDir + as.NoError(cmd.Run(), "failed to init git submodule repository") + + // add everything to the submodule's git index + cmd = exec.Command("git", "add", ".") + cmd.Dir = tempSubmoduleDir + as.NoError(cmd.Run(), "failed to add everything to the submodule index") + + // commit the submodule + cmd = exec.Command("git", "commit", "-m", "submodule") + cmd.Dir = tempSubmoduleDir + as.NoError(cmd.Run(), "failed to commit the submodule") + + // add the submodule to the main git repository + cmd = exec.Command("git", "submodule", "add", tempSubmoduleDir) + cmd.Dir = tempDir + as.NoError(cmd.Run(), "failed to add the submodule to the main repository") + // add everything to the git index cmd = exec.Command("git", "add", ".") cmd.Dir = tempDir @@ -63,8 +97,8 @@ func TestGitReader(t *testing.T) { } } - as.Equal(33, count) - as.Equal(33, statz.Value(stats.Traversed)) + as.Equal(34, count) + as.Equal(34, statz.Value(stats.Traversed)) as.Equal(0, statz.Value(stats.Matched)) as.Equal(0, statz.Value(stats.Formatted)) as.Equal(0, statz.Value(stats.Changed))