Skip to content

Commit

Permalink
Update shell completion to respect flag groups
Browse files Browse the repository at this point in the history
Signed-off-by: Marc Khouzam <marc.khouzam@montreal.ca>
  • Loading branch information
marckhouzam committed Apr 9, 2022
1 parent ea529ed commit 36eb005
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 10 deletions.
3 changes: 3 additions & 0 deletions completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi
var completions []string
var directive ShellCompDirective

// Enforce flag groups before doing flag completions
finalCmd.enforceFlagGroupsForCompletion()

// Note that we want to perform flagname completion even if finalCmd.DisableFlagParsing==true;
// doing this allows for completion of persistent flag names even for commands that disable flag parsing.
//
Expand Down
168 changes: 168 additions & 0 deletions completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2691,3 +2691,171 @@ func TestFixedCompletions(t *testing.T) {
t.Errorf("expected: %q, got: %q", expected, output)
}
}

func TestCompletionForGroupedFlags(t *testing.T) {
rootCmd := &Command{
Use: "root",
Run: emptyRun,
}
childCmd := &Command{
Use: "child",
ValidArgsFunction: func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
return []string{"subArg"}, ShellCompDirectiveNoFileComp
},
Run: emptyRun,
}
rootCmd.AddCommand(childCmd)

rootCmd.PersistentFlags().Int("group1-1", -1, "group1-1")
rootCmd.PersistentFlags().String("group1-2", "", "group1-2")

childCmd.Flags().Bool("group1-3", false, "group1-3")
childCmd.Flags().Bool("flag2", false, "flag2")

// Add flags to a group
childCmd.MarkFlagsRequiredTogether("group1-1", "group1-2", "group1-3")

// Test that flags in a group are not suggested without the - prefix
output, err := executeCommand(rootCmd, ShellCompNoDescRequestCmd, "child", "")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

expected := strings.Join([]string{
"subArg",
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")

if output != expected {
t.Errorf("expected: %q, got: %q", expected, output)
}

// Test that flags in a group are suggested with the - prefix
output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "child", "-")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

expected = strings.Join([]string{
"--group1-1",
"--group1-2",
"--flag2",
"--group1-3",
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")

if output != expected {
t.Errorf("expected: %q, got: %q", expected, output)
}

// Test that when a flag in a group is present, the other flags in the group are suggested
// even without the - prefix
output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "child", "--group1-2", "value", "")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

expected = strings.Join([]string{
"--group1-1",
"--group1-3",
"subArg",
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")

if output != expected {
t.Errorf("expected: %q, got: %q", expected, output)
}

// Test that when all flags in a group are present, flags are not suggested without the - prefix
output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "child",
"--group1-1", "8",
"--group1-2", "value2",
"--group1-3",
"")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

expected = strings.Join([]string{
"subArg",
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")

if output != expected {
t.Errorf("expected: %q, got: %q", expected, output)
}
}

func TestCompletionForMutuallyExclusiveFlags(t *testing.T) {
rootCmd := &Command{
Use: "root",
Run: emptyRun,
}
childCmd := &Command{
Use: "child",
ValidArgsFunction: func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
return []string{"subArg"}, ShellCompDirectiveNoFileComp
},
Run: emptyRun,
}
rootCmd.AddCommand(childCmd)

rootCmd.PersistentFlags().IntSlice("group1-1", []int{1}, "group1-1")
rootCmd.PersistentFlags().String("group1-2", "", "group1-2")

childCmd.Flags().Bool("group1-3", false, "group1-3")
childCmd.Flags().Bool("flag2", false, "flag2")

// Add flags to a group
childCmd.MarkFlagsMutuallyExclusive("group1-1", "group1-2", "group1-3")

// Test that flags in a mutually exclusive group are not suggested without the - prefix
output, err := executeCommand(rootCmd, ShellCompNoDescRequestCmd, "child", "")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

expected := strings.Join([]string{
"subArg",
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")

if output != expected {
t.Errorf("expected: %q, got: %q", expected, output)
}

// Test that flags in a mutually exclusive group are suggested with the - prefix
output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "child", "-")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

expected = strings.Join([]string{
"--group1-1",
"--group1-2",
"--flag2",
"--group1-3",
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")

if output != expected {
t.Errorf("expected: %q, got: %q", expected, output)
}

// Test that when a flag in a mutually exclusive group is present, the other flags in the group are
// not suggested even with the - prefix
output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "child", "--group1-1", "8", "-")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

expected = strings.Join([]string{
"--group1-1", // Should be repeated since it is a slice
"--flag2",
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")

if output != expected {
t.Errorf("expected: %q, got: %q", expected, output)
}
}
56 changes: 52 additions & 4 deletions flag_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func processFlagForGroupAnnotation(pflag *flag.Flag, annotation string, groupSta

func validateRequiredFlagGroups(data map[string]map[string]bool) error {
keys := sortedKeys(data)
for _, flagList := range keys{
for _, flagList := range keys {
flagnameAndStatus := data[flagList]

unset := []string{}
Expand All @@ -123,7 +123,7 @@ func validateRequiredFlagGroups(data map[string]map[string]bool) error {

func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
keys := sortedKeys(data)
for _, flagList := range keys{
for _, flagList := range keys {
flagnameAndStatus := data[flagList]
var set []string
for flagname, isSet := range flagnameAndStatus {
Expand All @@ -142,7 +142,7 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
return nil
}

func sortedKeys(m map[string]map[string]bool) ([]string) {
func sortedKeys(m map[string]map[string]bool) []string {
keys := make([]string, len(m))
i := 0
for k := range m {
Expand All @@ -151,4 +151,52 @@ func sortedKeys(m map[string]map[string]bool) ([]string) {
}
sort.Strings(keys)
return keys
}
}

// enforceFlagGroupsForCompletion will do the following:
// - when a flag in a group is present, other flags in the group will be marked required
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
// This allows the standard completion logic to behave appropriately for flag groups
func (c *Command) enforceFlagGroupsForCompletion() {
if c.DisableFlagParsing {
return
}

groupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
c.Flags().VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(pflag, requiredAsGroup, groupStatus)
processFlagForGroupAnnotation(pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
})

// If a flag that is part of a group is present, we make all the other flags
// of that group required so that the shell completion suggests them automatically
for flagList, flagnameAndStatus := range groupStatus {
for _, isSet := range flagnameAndStatus {
if isSet {
// One of the flags of the group is set, mark the other ones as required
for _, fName := range strings.Split(flagList, " ") {
c.MarkFlagRequired(fName)
}
}
}
}

// If a flag that is mutually exclusive to others is present, we hide the other
// flags of that group so the shell completion does not suggest them
for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
for flagName, isSet := range flagnameAndStatus {
if isSet {
// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
// Don't mark the flag that is already set as hidden because it may be an
// array or slice flag and therefore must continue being suggested
for _, fName := range strings.Split(flagList, " ") {
if fName != flagName {
flag := c.Flags().Lookup(fName)
flag.Hidden = true
}
}
}
}
}
}
12 changes: 6 additions & 6 deletions flag_groups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,16 @@ func TestValidateFlagGroups(t *testing.T) {
args: []string{"testcmd", "--a=foo", "--c=foo", "--d=foo"},
expectErr: `if any flags in the group [a b c] are set none of the others can be; [a c] were all set`,
}, {
desc: "Validation of required groups occurs on groups in sorted order",
desc: "Validation of required groups occurs on groups in sorted order",
flagGroupsRequired: []string{"a d", "a b", "a c"},
args: []string{"testcmd", "--a=foo"},
expectErr: `if any flags in the group [a b] are set they must all be set; missing [b]`,
},{
args: []string{"testcmd", "--a=foo"},
expectErr: `if any flags in the group [a b] are set they must all be set; missing [b]`,
}, {
desc: "Validation of exclusive groups occurs on groups in sorted order",
flagGroupsExclusive: []string{"a d", "a b", "a c"},
args: []string{"testcmd", "--a=foo", "--b=foo", "--c=foo"},
expectErr: `if any flags in the group [a b] are set none of the others can be; [a b] were all set`,
},{
}, {
desc: "Persistent flags utilize both features and can fail required groups",
flagGroupsRequired: []string{"a e", "e f"},
flagGroupsExclusive: []string{"f g"},
Expand All @@ -88,7 +88,7 @@ func TestValidateFlagGroups(t *testing.T) {
desc: "Persistent flags utilize both features and can fail mutually exclusive groups",
flagGroupsRequired: []string{"a e", "e f"},
flagGroupsExclusive: []string{"f g"},
args: []string{"testcmd", "--a=foo", "--e=foo","--f=foo", "--g=foo"},
args: []string{"testcmd", "--a=foo", "--e=foo", "--f=foo", "--g=foo"},
expectErr: `if any flags in the group [f g] are set none of the others can be; [f g] were all set`,
}, {
desc: "Persistent flags utilize both features and can pass",
Expand Down

0 comments on commit 36eb005

Please sign in to comment.