Skip to content

Commit

Permalink
Fix flag completion
Browse files Browse the repository at this point in the history
The flag completion functions should not be stored in the root cmd.
There is no requirement that the root cmd should be the same when
`RegisterFlagCompletionFunc` was called. Storing the flags there does
not work when you add the the flags to your cmd struct before you add the
cmd to the parent/root cmd. The flags can no longer be found in the rigth
place when the completion command is called and thus the flag completion
does not work.

Also spf13#1423 claims that this would be thread safe but we still have a map
which will fail when accessed concurrently. To truly fix this issue use a
RWMutex.

Fixes spf13#1437
Fixes spf13#1320

Signed-off-by: Paul Holzinger <pholzing@redhat.com>
  • Loading branch information
Luap99 committed Jul 2, 2021
1 parent 07861c8 commit 46777a7
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 11 deletions.
4 changes: 3 additions & 1 deletion bash_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,9 @@ func writeLocalNonPersistentFlag(buf io.StringWriter, flag *pflag.Flag) {

// Setup annotations for go completions for registered flags
func prepareCustomAnnotationsForFlags(cmd *Command) {
for flag := range cmd.Root().flagCompletionFunctions {
flagCompletionMutex.RLock()
defer flagCompletionMutex.RUnlock()
for flag := range flagCompletionFunctions {
// Make sure the completion script calls the __*_go_custom_completion function for
// every registered flag. We need to do this here (and not when the flag was registered
// for completion) so that we can know the root command name for the prefix
Expand Down
3 changes: 0 additions & 3 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,6 @@ type Command struct {
// that we can use on every pflag set and children commands
globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName

//flagCompletionFunctions is map of flag completion functions.
flagCompletionFunctions map[*flag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective)

// usageFunc is usage func defined by user.
usageFunc func(*Command) error
// usageTemplate is usage template defined by user.
Expand Down
21 changes: 14 additions & 7 deletions completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"os"
"strings"
"sync"

"github.com/spf13/pflag"
)
Expand All @@ -17,6 +18,12 @@ const (
ShellCompNoDescRequestCmd = "__completeNoDesc"
)

// Global map of flag completion functions. Make sure to use flagCompletionMutex before you try to read and write from it.
var flagCompletionFunctions = map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective){}

// lock for reading and writing from flagCompletionFunctions
var flagCompletionMutex = &sync.RWMutex{}

// ShellCompDirective is a bit map representing the different behaviors the shell
// can be instructed to have once completions have been provided.
type ShellCompDirective int
Expand Down Expand Up @@ -100,15 +107,13 @@ func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Comman
if flag == nil {
return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' does not exist", flagName)
}
flagCompletionMutex.Lock()
defer flagCompletionMutex.Unlock()

root := c.Root()
if _, exists := root.flagCompletionFunctions[flag]; exists {
if _, exists := flagCompletionFunctions[flag]; exists {
return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' already registered", flagName)
}
if root.flagCompletionFunctions == nil {
root.flagCompletionFunctions = map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective){}
}
root.flagCompletionFunctions[flag] = f
flagCompletionFunctions[flag] = f
return nil
}

Expand Down Expand Up @@ -402,7 +407,9 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi
// Find the completion function for the flag or command
var completionFn func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective)
if flag != nil && flagCompletion {
completionFn = c.Root().flagCompletionFunctions[flag]
flagCompletionMutex.RLock()
completionFn = flagCompletionFunctions[flag]
flagCompletionMutex.RUnlock()
} else {
completionFn = finalCmd.ValidArgsFunction
}
Expand Down
50 changes: 50 additions & 0 deletions completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1969,6 +1969,56 @@ func TestFlagCompletionWithNotInterspersedArgs(t *testing.T) {
if output != expected {
t.Errorf("expected: %q, got: %q", expected, output)
}

// Test that no flag completion works on a subcmd
output, err = executeCommand(rootCmd, ShellCompRequestCmd, "child", "--string", "")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

expected = strings.Join([]string{
"myval",
":0",
"Completion ended with directive: ShellCompDirectiveDefault", ""}, "\n")

if output != expected {
t.Errorf("expected: %q, got: %q", expected, output)
}
}

func TestFlagCompletionWorksRootCommandAddedAfterFlags(t *testing.T) {
rootCmd := &Command{Use: "root", Run: emptyRun}
childCmd := &Command{
Use: "child",
Run: emptyRun,
ValidArgsFunction: func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
return []string{"--validarg", "test"}, ShellCompDirectiveDefault
},
}
childCmd.Flags().Bool("bool", false, "test bool flag")
childCmd.Flags().String("string", "", "test string flag")
_ = childCmd.RegisterFlagCompletionFunc("string", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
return []string{"myval"}, ShellCompDirectiveDefault
})

// Important: This is a test for https://github.com/spf13/cobra/issues/1437
// Only add the subcommand after RegisterFlagCompletionFunc was called, do not change this order!
rootCmd.AddCommand(childCmd)

// Test that flag completion works for the subcmd
output, err := executeCommand(rootCmd, ShellCompRequestCmd, "child", "--string", "")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

expected := strings.Join([]string{
"myval",
":0",
"Completion ended with directive: ShellCompDirectiveDefault", ""}, "\n")

if output != expected {
t.Errorf("expected: %q, got: %q", expected, output)
}
}

func TestFlagCompletionInGoWithDesc(t *testing.T) {
Expand Down

0 comments on commit 46777a7

Please sign in to comment.