Skip to content

Commit

Permalink
Implement SourceUnitEnumChunker for GitHub (#3298)
Browse files Browse the repository at this point in the history
* Implement SourceUnitEnumChunker for GitHub

This change refactors the internal scan method to introduce a scanRepo
method to perform the actual scan.

* Export unit fields so the values are captured in the report

* Add comment for scanRepo

* Break out ensureRepoInfoCache into a method

* Update comments and check errors

* Ensure that the repoInfoCache contains the repo during ChunkUnit

* Add integration test for ChunkUnit

* Move s.scanOptions initialization to Init()
  • Loading branch information
mcastorina authored Sep 23, 2024
1 parent 764db68 commit 2f3a410
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 113 deletions.
259 changes: 149 additions & 110 deletions pkg/sources/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,20 @@ var _ sources.SourceUnit = (*RepoUnit)(nil)
var _ sources.SourceUnit = (*GistUnit)(nil)

type RepoUnit struct {
name string
url string
Name string `json:"name"`
URL string `json:"url"`
}

func (r RepoUnit) SourceUnitID() (string, sources.SourceUnitKind) { return r.url, "repo" }
func (r RepoUnit) Display() string { return r.name }
func (r RepoUnit) SourceUnitID() (string, sources.SourceUnitKind) { return r.URL, "repo" }
func (r RepoUnit) Display() string { return r.Name }

type GistUnit struct {
name string
url string
Name string `json:"name"`
URL string `json:"url"`
}

func (g GistUnit) SourceUnitID() (string, sources.SourceUnitKind) { return g.url, "gist" }
func (g GistUnit) Display() string { return g.name }
func (g GistUnit) SourceUnitID() (string, sources.SourceUnitKind) { return g.URL, "gist" }
func (g GistUnit) Display() string { return g.Name }

// --------------------------------------------------------------------------------

Expand All @@ -118,6 +118,7 @@ func (s *Source) setScanOptions(base, head string) {
// Ensure the Source satisfies the interfaces at compile time
var _ sources.Source = (*Source)(nil)
var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
var _ sources.SourceUnitEnumChunker = (*Source)(nil)

var endsWithGithub = regexp.MustCompile(`github\.com/?$`)

Expand Down Expand Up @@ -212,6 +213,11 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so
s.jobPool = &errgroup.Group{}
s.jobPool.SetLimit(concurrency)

// Setup scan options if it wasn't provided.
if s.scanOptions == nil {
s.scanOptions = &git.ScanOptions{}
}

var conn sourcespb.GitHub
err = anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{})
if err != nil {
Expand Down Expand Up @@ -338,19 +344,19 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, tar
return nil
},
}
err := s.enumerate(ctx, noopReporter)
err := s.Enumerate(ctx, noopReporter)
if err != nil {
return fmt.Errorf("error enumerating: %w", err)
}

return s.scan(ctx, chunksReporter)
}

// enumerate enumerates the GitHub source based on authentication method and
// Enumerate enumerates the GitHub source based on authentication method and
// user configuration. It populates s.filteredRepoCache, s.repoInfoCache,
// s.memberCache, s.totalRepoSize, s.orgsCache, and s.repos. Additionally,
// repositories and gists are reported to the provided UnitReporter.
func (s *Source) enumerate(ctx context.Context, reporter sources.UnitReporter) error {
func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) error {
seenUnits := make(map[sources.SourceUnit]struct{})
// Wrapper reporter to deduplicate and filter found units.
dedupeReporter := sources.VisitorReporter{
Expand All @@ -370,9 +376,19 @@ func (s *Source) enumerate(ctx context.Context, reporter sources.UnitReporter) e
VisitErr: reporter.UnitErr,
}
// Report any values that were already configured.
// This compensates for differences in enumeration logic between `--org` and `--repo`.
// See: https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788
for _, name := range s.filteredRepoCache.Keys() {
url, _ := s.filteredRepoCache.Get(name)
_ = dedupeReporter.UnitOk(ctx, RepoUnit{name: name, url: url})
url, err := s.ensureRepoInfoCache(ctx, url)
if err != nil {
if err := dedupeReporter.UnitErr(ctx, err); err != nil {
return err
}
}
if err := dedupeReporter.UnitOk(ctx, RepoUnit{Name: name, URL: url}); err != nil {
return err
}
}

// I'm not wild about switching on the connector type here (as opposed to dispatching to the connector itself) but
Expand All @@ -395,55 +411,14 @@ func (s *Source) enumerate(ctx context.Context, reporter sources.UnitReporter) e
}
s.repos = make([]string, 0, s.filteredRepoCache.Count())

RepoLoop:
// Double make sure that all enumerated repositories in the
// filteredRepoCache have an entry in the repoInfoCache.
for _, repo := range s.filteredRepoCache.Values() {
repoCtx := context.WithValue(ctx, "repo", repo)

// Ensure that |s.repoInfoCache| contains an entry for |repo|.
// This compensates for differences in enumeration logic between `--org` and `--repo`.
// See: https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788
if _, ok := s.repoInfoCache.get(repo); !ok {
repoCtx.Logger().V(2).Info("Caching repository info")
ctx := context.WithValue(ctx, "repo", repo)

_, urlParts, err := getRepoURLParts(repo)
if err != nil {
repoCtx.Logger().Error(err, "Failed to parse repository URL")
continue
}

if isGistUrl(urlParts) {
// Cache gist info.
for {
gistID := extractGistID(urlParts)
gist, _, err := s.connector.APIClient().Gists.Get(repoCtx, gistID)
// Normalize the URL to the Gist's pull URL.
// See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937
repo = gist.GetGitPullURL()
if s.handleRateLimit(repoCtx, err) {
continue
}
if err != nil {
repoCtx.Logger().Error(err, "Failed to fetch gist")
continue RepoLoop
}
s.cacheGistInfo(gist)
break
}
} else {
// Cache repository info.
for {
ghRepo, _, err := s.connector.APIClient().Repositories.Get(repoCtx, urlParts[1], urlParts[2])
if s.handleRateLimit(repoCtx, err) {
continue
}
if err != nil {
repoCtx.Logger().Error(err, "Failed to fetch repository")
continue RepoLoop
}
s.cacheRepoInfo(ghRepo)
break
}
}
repo, err := s.ensureRepoInfoCache(ctx, repo)
if err != nil {
ctx.Logger().Error(err, "error caching repo info")
}
s.repos = append(s.repos, repo)
}
Expand All @@ -454,6 +429,55 @@ RepoLoop:
return nil
}

// ensureRepoInfoCache checks that s.repoInfoCache has an entry for the
// provided repository URL. If not, it fetches and stores the metadata for the
// repository. In some cases, the gist URL needs to be normalized, which is
// returned by this function.
func (s *Source) ensureRepoInfoCache(ctx context.Context, repo string) (string, error) {
if _, ok := s.repoInfoCache.get(repo); ok {
return repo, nil
}
ctx.Logger().V(2).Info("Caching repository info")

_, urlParts, err := getRepoURLParts(repo)
if err != nil {
return repo, fmt.Errorf("failed to parse repository URL: %w", err)
}

if isGistUrl(urlParts) {
// Cache gist info.
for {
gistID := extractGistID(urlParts)
gist, _, err := s.connector.APIClient().Gists.Get(ctx, gistID)
// Normalize the URL to the Gist's pull URL.
// See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937
repo = gist.GetGitPullURL()
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
return repo, fmt.Errorf("failed to fetch gist")
}
s.cacheGistInfo(gist)
break
}
} else {
// Cache repository info.
for {
ghRepo, _, err := s.connector.APIClient().Repositories.Get(ctx, urlParts[1], urlParts[2])
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
return repo, fmt.Errorf("failed to fetch repository")
}
s.cacheRepoInfo(ghRepo)
break
}
}
return repo, nil
}

func (s *Source) enumerateBasicAuth(ctx context.Context, reporter sources.UnitReporter) error {
for _, org := range s.orgsCache.Keys() {
orgCtx := context.WithValue(ctx, "account", org)
Expand Down Expand Up @@ -603,18 +627,13 @@ func (s *Source) scan(ctx context.Context, reporter sources.ChunkReporter) error
reposToScan, progressIndexOffset := sources.FilterReposToResume(s.repos, s.GetProgress().EncodedResumeInfo)
s.repos = reposToScan

scanErrs := sources.NewScanErrors()
// Setup scan options if it wasn't provided.
if s.scanOptions == nil {
s.scanOptions = &git.ScanOptions{}
}

for i, repoURL := range s.repos {
i, repoURL := i, repoURL
s.jobPool.Go(func() error {
if common.IsDone(ctx) {
return nil
}
ctx := context.WithValue(ctx, "repo", repoURL)

// TODO: set progress complete is being called concurrently with i
s.setProgressCompleteWithRepo(i, progressIndexOffset, repoURL)
Expand All @@ -625,64 +644,72 @@ func (s *Source) scan(ctx context.Context, reporter sources.ChunkReporter) error
s.resumeInfoSlice = sources.RemoveRepoFromResumeInfo(s.resumeInfoSlice, repoURL)
}(s, repoURL)

if !strings.HasSuffix(repoURL, ".git") {
scanErrs.Add(fmt.Errorf("repo %s does not end in .git", repoURL))
if err := s.scanRepo(ctx, repoURL, reporter); err != nil {
ctx.Logger().Error(err, "error scanning repo")
return nil
}

// Scan the repository
repoInfo, ok := s.repoInfoCache.get(repoURL)
if !ok {
// This should never happen.
err := fmt.Errorf("no repoInfo for URL: %s", repoURL)
ctx.Logger().Error(err, "failed to scan repository")
return nil
}
repoCtx := context.WithValues(ctx, "repo", repoURL)
duration, err := s.cloneAndScanRepo(repoCtx, repoURL, repoInfo, reporter)
if err != nil {
scanErrs.Add(err)
return nil
}
atomic.AddUint64(&scannedCount, 1)
return nil
})
}

// Scan the wiki, if enabled, and the repo has one.
if s.conn.IncludeWikis && repoInfo.hasWiki && s.wikiIsReachable(ctx, repoURL) {
wikiURL := strings.TrimSuffix(repoURL, ".git") + ".wiki.git"
wikiCtx := context.WithValue(ctx, "repo", wikiURL)
_ = s.jobPool.Wait()
s.SetProgressComplete(len(s.repos), len(s.repos), "Completed GitHub scan", "")

_, err := s.cloneAndScanRepo(wikiCtx, wikiURL, repoInfo, reporter)
if err != nil {
// Ignore "Repository not found" errors.
// It's common for GitHub's API to say a repo has a wiki when it doesn't.
if !strings.Contains(err.Error(), "not found") {
scanErrs.Add(fmt.Errorf("error scanning wiki: %w", err))
}
return nil
}

// Don't return, it still might be possible to scan comments.
}
}
// scanRepo attempts to scan the provided URL and any associated wiki and
// comments if configured. An error is returned if we could not find necessary
// repository metadata or clone the repo, otherwise all errors are reported to
// the ChunkReporter.
func (s *Source) scanRepo(ctx context.Context, repoURL string, reporter sources.ChunkReporter) error {
if !strings.HasSuffix(repoURL, ".git") {
return fmt.Errorf("repo does not end in .git")
}
// Scan the repository
repoInfo, ok := s.repoInfoCache.get(repoURL)
if !ok {
// This should never happen.
return fmt.Errorf("no repoInfo for URL: %s", repoURL)
}
duration, err := s.cloneAndScanRepo(ctx, repoURL, repoInfo, reporter)
if err != nil {
return err
}

// Scan the wiki, if enabled, and the repo has one.
if s.conn.IncludeWikis && repoInfo.hasWiki && s.wikiIsReachable(ctx, repoURL) {
wikiURL := strings.TrimSuffix(repoURL, ".git") + ".wiki.git"
wikiCtx := context.WithValue(ctx, "repo", wikiURL)

// Scan comments, if enabled.
if s.includeGistComments || s.includeIssueComments || s.includePRComments {
if err = s.scanComments(repoCtx, repoURL, repoInfo, reporter); err != nil {
scanErrs.Add(fmt.Errorf("error scanning comments in repo %s: %w", repoURL, err))
return nil
_, err := s.cloneAndScanRepo(wikiCtx, wikiURL, repoInfo, reporter)
if err != nil {
// Ignore "Repository not found" errors.
// It's common for GitHub's API to say a repo has a wiki when it doesn't.
if !strings.Contains(err.Error(), "not found") {
if err := reporter.ChunkErr(ctx, fmt.Errorf("error scanning wiki: %w", err)); err != nil {
return err
}
}

repoCtx.Logger().V(2).Info(fmt.Sprintf("scanned %d/%d repos", scannedCount, len(s.repos)), "duration_seconds", duration)
githubReposScanned.WithLabelValues(s.name).Inc()
atomic.AddUint64(&scannedCount, 1)
return nil
})
// Don't return, it still might be possible to scan comments.
}
}

_ = s.jobPool.Wait()
if scanErrs.Count() > 0 {
ctx.Logger().Info("failed to scan some repositories", "error_count", scanErrs.Count(), "errors", scanErrs.String())
// Scan comments, if enabled.
if s.includeGistComments || s.includeIssueComments || s.includePRComments {
if err := s.scanComments(ctx, repoURL, repoInfo, reporter); err != nil {
err := fmt.Errorf("error scanning comments: %w", err)
if err := reporter.ChunkErr(ctx, err); err != nil {
return err
}
}
}
s.SetProgressComplete(len(s.repos), len(s.repos), "Completed GitHub scan", "")

ctx.Logger().V(2).Info("finished scanning repo", "duration_seconds", duration)
githubReposScanned.WithLabelValues(s.name).Inc()
return nil
}

Expand Down Expand Up @@ -815,7 +842,7 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string, reporter
for _, gist := range gists {
s.filteredRepoCache.Set(gist.GetID(), gist.GetGitPullURL())
s.cacheGistInfo(gist)
if err := reporter.UnitOk(ctx, GistUnit{name: gist.GetID(), url: gist.GetGitPullURL()}); err != nil {
if err := reporter.UnitOk(ctx, GistUnit{Name: gist.GetID(), URL: gist.GetGitPullURL()}); err != nil {
return err
}
}
Expand Down Expand Up @@ -1478,3 +1505,15 @@ func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget,
Verify: s.verify}
return handlers.HandleFile(ctx, readCloser, &chunkSkel, reporter)
}

func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error {
repoURL, _ := unit.SourceUnitID()
ctx = context.WithValue(ctx, "repo", repoURL)
// ChunkUnit is not guaranteed to be called from Enumerate, so we must
// check and fetch the repoInfoCache for this repo.
repoURL, err := s.ensureRepoInfoCache(ctx, repoURL)
if err != nil {
return err
}
return s.scanRepo(ctx, repoURL, reporter)
}
Loading

0 comments on commit 2f3a410

Please sign in to comment.