Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions codeowners.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package main

import (
"context"

"github.com/google/go-github/v71/github"
)

var defaultCodeOwnerPaths = []string{
"CODEOWNERS",
".github/CODEOWNERS",
"docs/CODEOWNERS",
}

// FetchCodeOwners attempts to retrieve the CODEOWNERS file for the repository.
// If the file cannot be found or is inaccessible due to permissions, the function
// returns nil and logs the reason so policy evaluation can proceed with an empty value.
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function comment at line 17 states that it "logs the reason" when the file cannot be found or is inaccessible, but the function doesn't actually log anything when returning nil (lines 29, 38). Consider either updating the comment to reflect the actual behavior, or add logging when permission errors occur or when no CODEOWNERS file is found.

Suggested change
// returns nil and logs the reason so policy evaluation can proceed with an empty value.
// returns nil so policy evaluation can proceed with an empty value.

Copilot uses AI. Check for mistakes.
func (l *GithubReposPlugin) FetchCodeOwners(ctx context.Context, repo *github.Repository) (*github.RepositoryContent, error) {
owner := repo.GetOwner().GetLogin()
name := repo.GetName()

for _, path := range defaultCodeOwnerPaths {
file, _, resp, err := l.githubClient.Repositories.GetContents(ctx, owner, name, path, nil)
if err != nil {
if resp != nil && resp.StatusCode == 404 {
continue
}
if isPermissionError(err) {
return nil, nil
}
return nil, err
}
if file != nil {
return file, nil
}
}

return nil, nil
}
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ require (
github.com/hashicorp/go-hclog v1.6.3
github.com/hashicorp/go-plugin v1.7.0
github.com/mitchellh/mapstructure v1.5.0
github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7
golang.org/x/oauth2 v0.34.0
)

require (
Expand Down Expand Up @@ -40,6 +42,7 @@ require (
github.com/prometheus/common v0.57.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect
github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/tchap/go-patricia/v2 v2.3.1 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
Expand Down
279 changes: 251 additions & 28 deletions go.sum

Large diffs are not rendered by default.

150 changes: 92 additions & 58 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ package main

import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"slices"
"strings"

Expand All @@ -16,6 +14,8 @@ import (
"github.com/hashicorp/go-hclog"
goplugin "github.com/hashicorp/go-plugin"
"github.com/mitchellh/mapstructure"
"github.com/shurcooL/githubv4"
"golang.org/x/oauth2"
)

type Validator interface {
Expand Down Expand Up @@ -51,18 +51,23 @@ type SaturatedRepository struct {
WorkflowRuns []*github.WorkflowRun `json:"workflow_runs"`
// ProtectedBranches is the list of protected branches in the repository
ProtectedBranches []string `json:"protected_branches"`
// BranchProtectionRules maps branch name -> full branch protection configuration
BranchProtectionRules map[string]*github.Protection `json:"branch_protection_rules"`
// RequiredStatusChecks maps branch name -> required status checks configuration
RequiredStatusChecks map[string]*github.RequiredStatusChecks `json:"required_status_checks"`
SBOM *github.SBOM `json:"sbom"`
LastRelease *github.RepositoryRelease `json:"last_release"`
OpenPullRequests []*github.PullRequest `json:"pull_requests"`
OpenPullRequests []*OpenPullRequest `json:"pull_requests"`
CodeOwners *github.RepositoryContent `json:"code_owners"`
OrgTeams []*OrgTeam `json:"org_teams"`
}

type GithubReposPlugin struct {
Logger hclog.Logger

config *PluginConfig
githubClient *github.Client
config *PluginConfig
githubClient *github.Client
graphqlClient *githubv4.Client
}

func (l *GithubReposPlugin) Configure(req *proto.ConfigureRequest) (*proto.ConfigureResponse, error) {
Expand All @@ -80,7 +85,11 @@ func (l *GithubReposPlugin) Configure(req *proto.ConfigureRequest) (*proto.Confi
}

l.config = config
l.githubClient = github.NewClient(nil).WithAuthToken(config.Token)
httpClient := oauth2.NewClient(context.Background(), oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: config.Token,
}))
l.githubClient = github.NewClient(httpClient)
l.graphqlClient = githubv4.NewClient(httpClient)

return &proto.ConfigureResponse{}, nil
}
Expand All @@ -90,6 +99,14 @@ func (l *GithubReposPlugin) Eval(req *proto.EvalRequest, apiHelper runner.ApiHel
repochan, errchan := l.FetchRepositories(ctx, req)
done := false

orgTeams, err := l.GatherOrgTeams(ctx)
if err != nil {
l.Logger.Error("Error gathering organization teams", "error", err)
return &proto.EvalResponse{
Status: proto.ExecutionStatus_FAILURE,
}, err
}

for !done {
select {
case err, ok := <-errchan:
Expand All @@ -106,8 +123,6 @@ func (l *GithubReposPlugin) Eval(req *proto.EvalRequest, apiHelper runner.ApiHel
done = true
continue
}
l.Logger.Debug("Processing repository:", "repo_name", repo.GetName())

workflows, err := l.GatherConfiguredWorkflows(ctx, repo)
if err != nil {
l.Logger.Error("Error gathering workflows", "error", err)
Expand All @@ -133,30 +148,38 @@ func (l *GithubReposPlugin) Eval(req *proto.EvalRequest, apiHelper runner.ApiHel
}, err
}
branchNames := make([]string, 0, len(branches))
branchProtectionRules := make(map[string]*github.Protection)
requiredChecks := make(map[string]*github.RequiredStatusChecks)
for _, b := range branches {
if b == nil || b.Name == nil {
continue
}
name := b.GetName()
l.Logger.Debug("Found protected branch", "branch", name)
branchNames = append(branchNames, name)
checks, err := l.GetRequiredStatusChecks(ctx, repo, name)
l.Logger.Debug("Fetched required status checks", "branch", name, "checks", checks)
protection, checks, err := l.GetBranchProtectionAndRequiredStatusCheck(ctx, repo, name)
if err != nil {
l.Logger.Trace("Branch required checks fetch failed", "repo", repo.GetFullName(), "branch", name, "error", err)
l.Logger.Trace("Branch protection fetch failed", "repo", repo.GetFullName(), "branch", name, "error", err)
continue
}
if protection != nil {
branchProtectionRules[name] = protection
}
if checks != nil {
requiredChecks[name] = checks
}
}
// Fallback to default branch if none collected
if len(requiredChecks) == 0 {
l.Logger.Debug("No protected branches with required status checks found, checking default branch", "repo", repo.GetFullName())
if def := repo.GetDefaultBranch(); def != "" {
if checks, err := l.GetRequiredStatusChecks(ctx, repo, def); err == nil && checks != nil {
requiredChecks[def] = checks
if protection, checks, err := l.GetBranchProtectionAndRequiredStatusCheck(ctx, repo, def); err == nil {
if protection != nil {
branchProtectionRules[def] = protection
}
if checks != nil {
requiredChecks[def] = checks
}
} else {
l.Logger.Trace("Default branch protection fetch failed", "repo", repo.GetFullName(), "branch", def, "error", err)
}
}
}
Expand All @@ -176,33 +199,43 @@ func (l *GithubReposPlugin) Eval(req *proto.EvalRequest, apiHelper runner.ApiHel
Status: proto.ExecutionStatus_FAILURE,
}, err
}
openPullRequests, err := l.GatherReviewsAndComments(ctx, repo, pullRequests)
if err != nil {
l.Logger.Error("error gathering pull request reviews/comments", "error", err)
return &proto.EvalResponse{
Status: proto.ExecutionStatus_FAILURE,
}, err
}
release, err := l.FecthLatestRelease(ctx, repo)
if err != nil {
l.Logger.Error("error gathering latest release", "error", err)
return &proto.EvalResponse{
Status: proto.ExecutionStatus_FAILURE,
}, err
}
codeOwners, err := l.FetchCodeOwners(ctx, repo)
if err != nil {
l.Logger.Error("error gathering CODEOWNERS", "error", err)
return &proto.EvalResponse{
Status: proto.ExecutionStatus_FAILURE,
}, err
}
data := &SaturatedRepository{
Settings: repo,
Workflows: workflows,
WorkflowRuns: workflowRuns,
ProtectedBranches: branchNames,
RequiredStatusChecks: requiredChecks,
LastRelease: release,
SBOM: sbom,
OpenPullRequests: pullRequests,
Settings: repo,
Workflows: workflows,
WorkflowRuns: workflowRuns,
ProtectedBranches: branchNames,
BranchProtectionRules: branchProtectionRules,
RequiredStatusChecks: requiredChecks,
LastRelease: release,
SBOM: sbom,
OpenPullRequests: openPullRequests,
CodeOwners: codeOwners,
OrgTeams: orgTeams,
}

// Uncomment to check the data that is being passed through from
// the client, as data formats are often slightly different than
// the raw API endpoints
jsonData, _ := json.MarshalIndent(data, "", " ")
err = os.WriteFile(fmt.Sprintf("./dist/%s.json", repo.GetName()), jsonData, 0o644)
if err != nil {
l.Logger.Error("failed to write file", "error", err)
}

evidences, err := l.EvaluatePolicies(ctx, data, req)
if err != nil {
l.Logger.Error("Error evaluating policies", "error", err)
Expand All @@ -218,7 +251,6 @@ func (l *GithubReposPlugin) Eval(req *proto.EvalRequest, apiHelper runner.ApiHel
}, err
}

l.Logger.Debug("Successfully processed repository:", "repo_name", repo.GetName())
}
}

Expand Down Expand Up @@ -256,7 +288,6 @@ func (l *GithubReposPlugin) FetchRepositories(ctx context.Context, req *proto.Ev
})
if err != nil {
errchan <- err
done = true
return
}

Expand All @@ -279,7 +310,7 @@ func (l *GithubReposPlugin) FetchRepositories(ctx context.Context, req *proto.Ev
repochan <- repo
}

if resp.NextPage == 0 {
if resp == nil || resp.NextPage == 0 {
done = true
} else {
paginationOpts.Page = resp.NextPage
Expand Down Expand Up @@ -331,7 +362,6 @@ func (l *GithubReposPlugin) GatherWorkflowRuns(ctx context.Context, repo *github
func (l *GithubReposPlugin) ListProtectedBranches(ctx context.Context, repo *github.Repository) ([]*github.Branch, error) {
owner := repo.GetOwner().GetLogin()
name := repo.GetName()

opts := &github.BranchListOptions{
Protected: github.Ptr(true),
ListOptions: github.ListOptions{PerPage: 100, Page: 1},
Expand All @@ -343,15 +373,15 @@ func (l *GithubReposPlugin) ListProtectedBranches(ctx context.Context, repo *git
return nil, err
}
out = append(out, branches...)
if resp.NextPage == 0 {
if resp == nil || resp.NextPage == 0 {
break
}
opts.ListOptions.Page = resp.NextPage
opts.Page = resp.NextPage
}
return out, nil
}

func (l *GithubReposPlugin) GetRequiredStatusChecks(ctx context.Context, repo *github.Repository, branch string) (*github.RequiredStatusChecks, error) {
func (l *GithubReposPlugin) GetBranchProtectionAndRequiredStatusCheck(ctx context.Context, repo *github.Repository, branch string) (*github.Protection, *github.RequiredStatusChecks, error) {
owner := repo.GetOwner().GetLogin()
name := repo.GetName()

Expand All @@ -365,30 +395,34 @@ func (l *GithubReposPlugin) GetRequiredStatusChecks(ctx context.Context, repo *g
checksSet := make(map[checkKey]struct{})

// 1) Legacy branch protection settings (if present).
var branchProtection *github.Protection
protection, _, err := l.githubClient.Repositories.GetBranchProtection(ctx, owner, name, branch)
if err == nil && protection != nil && protection.RequiredStatusChecks != nil {
strict = strict || protection.RequiredStatusChecks.Strict
// Normalize both Checks and Contexts into Checks entries to avoid dual population.
if protection.RequiredStatusChecks.Checks != nil {
for _, c := range *protection.RequiredStatusChecks.Checks {
if c == nil {
continue
}
key := checkKey{context: c.Context}
if c.AppID != nil {
key.hasAppID = true
key.appID = *c.AppID
if err == nil && protection != nil {
branchProtection = protection
if protection.RequiredStatusChecks != nil {
strict = strict || protection.RequiredStatusChecks.Strict
// Normalize both Checks and Contexts into Checks entries to avoid dual population.
if protection.RequiredStatusChecks.Checks != nil {
for _, c := range *protection.RequiredStatusChecks.Checks {
if c == nil {
continue
}
key := checkKey{context: c.Context}
if c.AppID != nil {
key.hasAppID = true
key.appID = *c.AppID
}
checksSet[key] = struct{}{}
}
checksSet[key] = struct{}{}
}
}
if protection.RequiredStatusChecks.Contexts != nil {
for _, ctxName := range *protection.RequiredStatusChecks.Contexts {
key := checkKey{context: ctxName}
checksSet[key] = struct{}{}
if protection.RequiredStatusChecks.Contexts != nil {
for _, ctxName := range *protection.RequiredStatusChecks.Contexts {
key := checkKey{context: ctxName}
checksSet[key] = struct{}{}
}
}
}
} else if err != nil {
} else if err != nil && !errors.Is(err, github.ErrBranchNotProtected) {
// Non-404s are significant; 404 just means no protection on this branch.
// We'll log at trace and continue to gather rules-based checks.
l.Logger.Trace("GetBranchProtection failed", "repo", repo.GetFullName(), "branch", branch, "error", err)
Expand Down Expand Up @@ -424,7 +458,7 @@ func (l *GithubReposPlugin) GetRequiredStatusChecks(ctx context.Context, repo *g
// If no checks found from either source, return nil to indicate absence.
if len(checksSet) == 0 {
if !strict {
return nil, nil
return branchProtection, nil, nil
}
// If strict is set without explicit checks (edge), still return an empty set with strict.
}
Expand All @@ -444,7 +478,7 @@ func (l *GithubReposPlugin) GetRequiredStatusChecks(ctx context.Context, repo *g
}
// Always prefer Checks representation to avoid populating both fields.
result.Checks = &outChecks
return result, nil
return branchProtection, result, nil
}

func (l *GithubReposPlugin) GatherSBOM(ctx context.Context, repo *github.Repository) (*github.SBOM, error) {
Expand Down
Loading