diff --git a/cmd/merge.go b/cmd/merge.go index 47bbd183..dbf3634b 100644 --- a/cmd/merge.go +++ b/cmd/merge.go @@ -19,6 +19,7 @@ func MergeCmd() *cobra.Command { } cmd.Flags().StringP("branch", "B", "multi-gitter-branch", "The name of the branch where changes are committed.") + cmd.Flags().StringSliceP("merge-type", "", []string{"merge", "squash", "rebase"}, "The type of merge that should be done (GitHub). Multiple types can be used as backup strategies if the first one is not allowed.") cmd.Flags().AddFlagSet(platformFlags()) cmd.Flags().AddFlagSet(logFlags("-")) diff --git a/cmd/root.go b/cmd/root.go index 7bb14fd4..cdb6c2b6 100755 --- a/cmd/root.go +++ b/cmd/root.go @@ -7,6 +7,7 @@ import ( "os" "time" + "github.com/lindell/multi-gitter/internal/domain" "github.com/lindell/multi-gitter/internal/github" "github.com/lindell/multi-gitter/internal/gitlab" "github.com/lindell/multi-gitter/internal/multigitter" @@ -138,6 +139,7 @@ func createGithubClient(flag *flag.FlagSet) (multigitter.VersionController, erro orgs, _ := flag.GetStringSlice("org") users, _ := flag.GetStringSlice("user") repos, _ := flag.GetStringSlice("repo") + mergeTypeStrs, _ := flag.GetStringSlice("merge-type") // Only used for the merge command if len(orgs) == 0 && len(users) == 0 && len(repos) == 0 { return nil, errors.New("no organization or user set") @@ -156,11 +158,20 @@ func createGithubClient(flag *flag.FlagSet) (multigitter.VersionController, erro } } + // Convert all defined merge types (if any) + mergeTypes := make([]domain.MergeType, len(mergeTypeStrs)) + for i, mt := range mergeTypeStrs { + mergeTypes[i], err = domain.ParseMergeType(mt) + if err != nil { + return nil, err + } + } + vc, err := github.New(token, ghBaseURL, github.RepositoryListing{ Organizations: orgs, Users: users, Repositories: repoRefs, - }) + }, mergeTypes) if err != nil { return nil, err } diff --git a/internal/domain/pr.go b/internal/domain/pr.go index 645b1bd4..d685c813 100755 --- a/internal/domain/pr.go +++ b/internal/domain/pr.go @@ -1,5 +1,10 @@ package domain +import ( + "fmt" + "strings" +) + // NewPullRequest is the data needed to create a new pull request type NewPullRequest struct { Title string @@ -46,3 +51,41 @@ type PullRequest interface { Status() PullRequestStatus String() string } + +// MergeType is the way a pull request is "merged" into the base branch +type MergeType int + +// All MergeTypes +const ( + MergeTypeUnknown MergeType = iota + MergeTypeMerge + MergeTypeRebase + MergeTypeSquash +) + +// ParseMergeType parses a merge type +func ParseMergeType(typ string) (MergeType, error) { + switch strings.ToLower(typ) { + case "merge": + return MergeTypeMerge, nil + case "rebase": + return MergeTypeRebase, nil + case "squash": + return MergeTypeSquash, nil + } + return MergeTypeUnknown, fmt.Errorf(`not a valid merge type: "%s"`, typ) +} + +// MergeTypeIntersection calculates the intersection of two merge type slices, +// The order of the first slice will be preserved +func MergeTypeIntersection(mergeTypes1, mergeTypes2 []MergeType) []MergeType { + res := []MergeType{} + for _, mt := range mergeTypes1 { + for _, mt2 := range mergeTypes2 { + if mt == mt2 { + res = append(res, mt) + } + } + } + return res +} diff --git a/internal/github/github.go b/internal/github/github.go index 89b6d242..3cabd597 100755 --- a/internal/github/github.go +++ b/internal/github/github.go @@ -2,6 +2,7 @@ package github import ( "context" + "errors" "fmt" "net/url" "sort" @@ -16,7 +17,7 @@ import ( ) // New create a new Github client -func New(token, baseURL string, repoListing RepositoryListing) (*Github, error) { +func New(token, baseURL string, repoListing RepositoryListing, mergeTypes []domain.MergeType) (*Github, error) { ctx := context.Background() ts := oauth2.StaticTokenSource( &oauth2.Token{AccessToken: token}, @@ -39,6 +40,7 @@ func New(token, baseURL string, repoListing RepositoryListing) (*Github, error) return &Github{ RepositoryListing: repoListing, + MergeTypes: mergeTypes, ghClient: client, }, nil } @@ -46,7 +48,8 @@ func New(token, baseURL string, repoListing RepositoryListing) (*Github, error) // Github contain github configuration type Github struct { RepositoryListing - ghClient *github.Client + MergeTypes []domain.MergeType + ghClient *github.Client } // RepositoryListing contains information about which repositories that should be fetched @@ -365,7 +368,21 @@ func (g Github) GetPullRequestStatuses(ctx context.Context, branchName string) ( func (g Github) MergePullRequest(ctx context.Context, pullReq domain.PullRequest) error { pr := pullReq.(pullRequest) - _, _, err := g.ghClient.PullRequests.Merge(ctx, pr.ownerName, pr.repoName, pr.number, "", nil) + // We need to fetch the repo again since no AllowXMerge is present in listings of repositories + repo, _, err := g.ghClient.Repositories.Get(ctx, pr.ownerName, pr.repoName) + if err != nil { + return err + } + + // Filter out all merge types to only the allowed ones, but keep the order of the ones left + mergeTypes := domain.MergeTypeIntersection(g.MergeTypes, repoMergeTypes(repo)) + if len(mergeTypes) == 0 { + return errors.New("none of the configured merge types was permitted") + } + + _, _, err = g.ghClient.PullRequests.Merge(ctx, pr.ownerName, pr.repoName, pr.number, "", &github.PullRequestOptions{ + MergeMethod: mergeTypeGhName[mergeTypes[0]], + }) if err != nil { return err } diff --git a/internal/github/util.go b/internal/github/util.go new file mode 100644 index 00000000..acae802b --- /dev/null +++ b/internal/github/util.go @@ -0,0 +1,28 @@ +package github + +import ( + "github.com/google/go-github/v32/github" + "github.com/lindell/multi-gitter/internal/domain" +) + +// maps merge types to what they are called in the github api +var mergeTypeGhName = map[domain.MergeType]string{ + domain.MergeTypeMerge: "merge", + domain.MergeTypeRebase: "rebase", + domain.MergeTypeSquash: "squash", +} + +// repoMergeTypes returns a list of all allowed merge types +func repoMergeTypes(repo *github.Repository) []domain.MergeType { + ret := []domain.MergeType{} + if repo.GetAllowMergeCommit() { + ret = append(ret, domain.MergeTypeMerge) + } + if repo.GetAllowRebaseMerge() { + ret = append(ret, domain.MergeTypeRebase) + } + if repo.GetAllowSquashMerge() { + ret = append(ret, domain.MergeTypeSquash) + } + return ret +}