Skip to content

Commit

Permalink
flag: support mutually exclusive flags
Browse files Browse the repository at this point in the history
  • Loading branch information
rsteube committed Dec 18, 2022
1 parent 02239e5 commit c4810d4
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 18 deletions.
2 changes: 1 addition & 1 deletion complete.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func lookupFlag(cmd *cobra.Command, arg string) (flag *pflag.Flag) {
if strings.HasPrefix(arg, "--") {
flag = cmd.Flags().Lookup(nameOrShorthand)
} else if strings.HasPrefix(arg, "-") && len(nameOrShorthand) > 0 {
if pflagfork.FlagSet(cmd.Flags()).IsPosix() {
if (pflagfork.FlagSet{FlagSet: cmd.Flags()}).IsPosix() {
flag = cmd.Flags().ShorthandLookup(string(nameOrShorthand[len(nameOrShorthand)-1]))
} else {
flag = cmd.Flags().ShorthandLookup(nameOrShorthand)
Expand Down
4 changes: 4 additions & 0 deletions example/cmd/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ func init() {
flagCmd.Flags().CountP("count", "c", "count flag")
flagCmd.Flags().StringP("optarg", "o", "", "optional argument")

flagCmd.Flags().Bool("exclusive1", false, "mutually exclusive flag")
flagCmd.Flags().Bool("exclusive2", false, "mutually exclusive flag")

flagCmd.Flag("optarg").NoOptDefVal = " "
flagCmd.MarkFlagsMutuallyExclusive("exclusive1", "exclusive2")

carapace.Gen(flagCmd).FlagCompletion(carapace.ActionMap{
"optarg": carapace.ActionValues("optarg1", "optarg2", "optarg3"),
Expand Down
14 changes: 10 additions & 4 deletions internal/pflagfork/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pflagfork

import (
"reflect"
"strings"

"github.com/spf13/pflag"
)
Expand All @@ -15,17 +16,22 @@ const (
NameAsShorthand // non-posix style where the name is also added as shorthand (single `-` prefix)
)

type flag struct {
type Flag struct {
*pflag.Flag
}

func (f flag) Style() style {
func (f Flag) Style() style {
if field := reflect.ValueOf(f.Flag).Elem().FieldByName("Style"); field.IsValid() && field.Kind() == reflect.Int {
return style(field.Int())
}
return Default
}

func Flag(f *pflag.Flag) *flag {
return &flag{Flag: f}
func (f Flag) IsRepeatable() bool {
if strings.Contains(f.Value.Type(), "Slice") ||
strings.Contains(f.Value.Type(), "Array") ||
f.Value.Type() == "count" {
return true
}
return false
}
25 changes: 21 additions & 4 deletions internal/pflagfork/flagset.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package pflagfork

import (
"reflect"
"strings"

"github.com/spf13/pflag"
)

type flagSet struct {
type FlagSet struct {
*pflag.FlagSet
}

func (f flagSet) IsPosix() bool {
func (f FlagSet) IsPosix() bool {
if method := reflect.ValueOf(f.FlagSet).MethodByName("IsPosix"); method.IsValid() {
if values := method.Call([]reflect.Value{}); len(values) == 1 && values[0].Kind() == reflect.Bool {
return values[0].Bool()
Expand All @@ -19,6 +20,22 @@ func (f flagSet) IsPosix() bool {
return true
}

func FlagSet(f *pflag.FlagSet) *flagSet {
return &flagSet{FlagSet: f}
func (f FlagSet) IsMutuallyExclusive(flag *pflag.Flag) bool {
if groups, ok := flag.Annotations["cobra_annotation_mutually_exclusive"]; ok {
for _, group := range groups {
for _, name := range strings.Split(group, " ") {
if other := f.Lookup(name); other != nil && other.Changed {
return true
}
}
}
}
return false
}

func (f *FlagSet) VisitAll(fn func(*Flag)) {
f.FlagSet.VisitAll(func(flag *pflag.Flag) {
fn(&Flag{flag})
})

}
2 changes: 1 addition & 1 deletion internal/shell/export/snippet.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type flag struct {

func convertFlag(f *pflag.Flag) flag {
longhand := ""
if pflagfork.Flag(f).Style() != pflagfork.ShorthandOnly {
if (pflagfork.Flag{Flag: f}).Style() != pflagfork.ShorthandOnly {
longhand = f.Name
}

Expand Down
17 changes: 9 additions & 8 deletions internalActions.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/rsteube/carapace/internal/pflagfork"
"github.com/rsteube/carapace/pkg/style"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)

func actionPath(fileSuffixes []string, dirOnly bool) Action {
Expand Down Expand Up @@ -72,22 +71,24 @@ func actionPath(fileSuffixes []string, dirOnly bool) Action {

func actionFlags(cmd *cobra.Command) Action {
return ActionCallback(func(c Context) Action {
flagSet := pflagfork.FlagSet{FlagSet: cmd.Flags()}
re := regexp.MustCompile("^-(?P<shorthand>[^-=]+)")
isShorthandSeries := re.MatchString(c.CallbackValue) && pflagfork.FlagSet(cmd.Flags()).IsPosix()
isShorthandSeries := re.MatchString(c.CallbackValue) && flagSet.IsPosix()

vals := make([]string, 0)
cmd.Flags().VisitAll(func(f *pflag.Flag) {
flagSet.VisitAll(func(f *pflagfork.Flag) {
if f.Deprecated != "" {
return // skip deprecated flags
}

if f.Changed &&
!strings.Contains(f.Value.Type(), "Slice") &&
!strings.Contains(f.Value.Type(), "Array") &&
f.Value.Type() != "count" {
if f.Changed && !f.IsRepeatable() {
return // don't repeat flag
}

if flagSet.IsMutuallyExclusive(f.Flag) {
return // skip flag of group already set
}

if isShorthandSeries {
if f.Shorthand != "" && f.ShorthandDeprecated == "" {
for _, shorthand := range c.CallbackValue[1:] {
Expand All @@ -98,7 +99,7 @@ func actionFlags(cmd *cobra.Command) Action {
vals = append(vals, f.Shorthand, f.Usage)
}
} else {
if flagstyle := pflagfork.Flag(f).Style(); flagstyle != pflagfork.ShorthandOnly {
if flagstyle := f.Style(); flagstyle != pflagfork.ShorthandOnly {
if flagstyle == pflagfork.NameAsShorthand {
vals = append(vals, "-"+f.Name, f.Usage)
} else {
Expand Down

0 comments on commit c4810d4

Please sign in to comment.