Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add framework for dynamic tab completions #883

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions bash_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
const (
BashCompFilenameExt = "cobra_annotation_bash_completion_filename_extensions"
BashCompCustom = "cobra_annotation_bash_completion_custom"
BashCompDynamic = "cobra_annotation_bash_completion_dynamic"
BashCompOneRequiredFlag = "cobra_annotation_bash_completion_one_required_flag"
BashCompSubdirsInDir = "cobra_annotation_bash_completion_subdirs_in_dir"
)
Expand Down Expand Up @@ -282,6 +283,34 @@ __%[1]s_handle_word()
`, name))
}

func writeDynamicFlagCompletionFunction(buf *bytes.Buffer, dynamicFlagCompletionFunc string) {
buf.WriteString(fmt.Sprintf(`
%s()
{
export COBRA_FLAG_COMPLETION="$1"

local output
if ! output="$(mktemp)" ; then
return $?
fi

if ! error_message="$("${COMP_WORDS[@]}" > "$output")" ; then
local st="$?"
echo "$error_message"
return "$st"
fi

while read -r -d '' line ; do
COMPREPLY+=("$line")
done < "$output"

unset COBRA_FLAG_COMPLETION
rm "$output"
}

`, dynamicFlagCompletionFunc))
}

func writePostscript(buf *bytes.Buffer, name string) {
name = strings.Replace(name, ":", "__", -1)
buf.WriteString(fmt.Sprintf("__start_%s()\n", name))
Expand Down Expand Up @@ -354,6 +383,9 @@ func writeFlagHandler(buf *bytes.Buffer, name string, annotations map[string][]s
} else {
buf.WriteString(" flags_completion+=(:)\n")
}
case BashCompDynamic:
buf.WriteString(fmt.Sprintf(" flags_with_completion+=(%q)\n", name))
buf.WriteString(fmt.Sprintf(" flags_completion+=(%q)\n", value[0]))
case BashCompSubdirsInDir:
buf.WriteString(fmt.Sprintf(" flags_with_completion+=(%q)\n", name))

Expand Down Expand Up @@ -526,11 +558,19 @@ func gen(buf *bytes.Buffer, cmd *Command) {

// GenBashCompletion generates bash completion file and writes to the passed writer.
func (c *Command) GenBashCompletion(w io.Writer) error {
dynamicFlagCompletionFunc := "__" + c.Root().Name() + "_handle_dynamic_flag_completion"
c.Root().visitAllFlagsWithCompletions(func(f *pflag.Flag) {
f.Annotations[BashCompDynamic] = []string{dynamicFlagCompletionFunc + " " + f.Name}
})

buf := new(bytes.Buffer)
writePreamble(buf, c.Name())
if len(c.BashCompletionFunction) > 0 {
buf.WriteString(c.BashCompletionFunction + "\n")
}
if c.HasDynamicCompletions() {
writeDynamicFlagCompletionFunction(buf, dynamicFlagCompletionFunc)
}
gen(buf, c)
writePostscript(buf, c.Name())

Expand Down
133 changes: 133 additions & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"io"
"log"
"os"
"path/filepath"
"sort"
Expand Down Expand Up @@ -147,6 +148,10 @@ type Command struct {
// FParseErrWhitelist flag parse errors to be ignored
FParseErrWhitelist FParseErrWhitelist

// RunPreRunsDuringCompletion defines if the (Persistent)PreRun functions should be run before calling the
// completion functions
RunPreRunsDuringCompletion bool

ctx context.Context

// commands is the list of commands supported by this program.
Expand Down Expand Up @@ -206,6 +211,10 @@ type Command struct {
outWriter io.Writer
// errWriter is a writer defined by the user that replaces stderr
errWriter io.Writer

// dynamicFlagCompletions is a map of flag to a function that returns a list of values to suggest during tab
// completion for this flag
dynamicFlagCompletions map[*flag.Flag]DynamicFlagCompletion
}

// Context returns underlying command context. If command wasn't
Expand Down Expand Up @@ -750,6 +759,8 @@ func (c *Command) ArgsLenAtDash() int {
return c.Flags().ArgsLenAtDash()
}

const FlagCompletionEnvVar = "COBRA_FLAG_COMPLETION"

func (c *Command) execute(a []string) (err error) {
if c == nil {
return fmt.Errorf("Called Execute() on a nil Command")
Expand Down Expand Up @@ -865,6 +876,94 @@ func (c *Command) execute(a []string) (err error) {
return nil
}

func (c *Command) complete(flagName string, a []string) (err error) {
if c == nil {
return fmt.Errorf("Called Execute() on a nil Command")
}

// initialize help and version flag at the last point possible to allow for user
// overriding
c.InitDefaultHelpFlag()
c.InitDefaultVersionFlag()

var flagToComplete *flag.Flag
var currentCompletionValue string

oldFlags := c.Flags()
c.flags = nil
oldFlags.VisitAll(func(f *flag.Flag) {
if f.Name == flagName {
flagToComplete = f
} else {
c.Flags().AddFlag(f)
}
})
if flagToComplete == nil {
log.Panicln(flagName, "is not a known flag")
}

if flagToComplete.Shorthand != "" {
c.Flags().StringVarP(&currentCompletionValue, flagName, flagToComplete.Shorthand, "", "")
} else {
c.Flags().StringVar(&currentCompletionValue, flagName, "", "")
}

err = c.ParseFlags(a)
if err != nil {
return c.FlagErrorFunc()(c, err)
}

c.preRun()

currentCommand := c
completionFunc := currentCommand.dynamicFlagCompletions[flagToComplete]
for completionFunc == nil && currentCommand.HasParent() {
currentCommand = currentCommand.Parent()
completionFunc = currentCommand.dynamicFlagCompletions[flagToComplete]
}
if completionFunc == nil {
return fmt.Errorf("%s does not have completions enabled", flagName)
}

if c.RunPreRunsDuringCompletion {
argWoFlags := c.Flags().Args()
if c.DisableFlagParsing {
argWoFlags = a
}

for p := c; p != nil; p = p.Parent() {
if p.PersistentPreRunE != nil {
if err := p.PersistentPreRunE(c, argWoFlags); err != nil {
return err
}
break
} else if p.PersistentPreRun != nil {
p.PersistentPreRun(c, argWoFlags)
break
}
}
if c.PreRunE != nil {
if err := c.PreRunE(c, argWoFlags); err != nil {
return err
}
} else if c.PreRun != nil {
c.PreRun(c, argWoFlags)
}
}

values, err := completionFunc(currentCompletionValue)
if err != nil {
return err
}

for _, v := range values {
c.OutOrStdout()
fmt.Print(v + "\x00")
}

return nil
}

func (c *Command) preRun() {
for _, x := range initializers {
x()
Expand Down Expand Up @@ -936,6 +1035,16 @@ func (c *Command) ExecuteC() (cmd *Command, err error) {
cmd.commandCalledAs.name = cmd.Name()
}

flagName, flagCompletionEnabled := os.LookupEnv(FlagCompletionEnvVar)
if flagCompletionEnabled {
err = cmd.complete(flagName, flags)
if err != nil {
c.Println("Error:", err.Error())
}

return cmd, err
}

// We have to pass global context to children command
// if context is present on the parent command.
if cmd.ctx == nil {
Expand Down Expand Up @@ -1631,3 +1740,27 @@ func (c *Command) updateParentsPflags() {
c.parentsPflags.AddFlagSet(parent.PersistentFlags())
})
}

func (c *Command) HasDynamicCompletions() bool {
hasCompletions := false
c.visitAllFlagsWithCompletions(func(*flag.Flag) { hasCompletions = true })
return hasCompletions
}

// visitAllFlagsWithCompletions recursively visits all flags and persistent flags that have dynamic completions enabled.
// Initializes the flag's Annotations map if nil before calling fn
func (c Command) visitAllFlagsWithCompletions(fn func(*flag.Flag)) {
filterFunc := func(f *flag.Flag) {
if _, ok := c.dynamicFlagCompletions[f]; ok {
if f.Annotations == nil {
f.Annotations = make(map[string][]string)
}
fn(f)
}
}
c.Flags().VisitAll(filterFunc)
c.PersistentFlags().VisitAll(filterFunc)
for _, sc := range c.Commands() {
sc.visitAllFlagsWithCompletions(fn)
}
}
27 changes: 27 additions & 0 deletions shell_completions.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package cobra

import (
"fmt"

"github.com/spf13/pflag"
)

Expand Down Expand Up @@ -83,3 +85,28 @@ func MarkFlagDirname(flags *pflag.FlagSet, name string) error {
zshPattern := "-(/)"
return flags.SetAnnotation(name, zshCompDirname, []string{zshPattern})
}

type DynamicFlagCompletion func(currentValue string) (suggestedValues []string, err error)

// MarkDynamicFlagCompletion provides cobra a function to dynamically suggest values to the user during tab completion
// for this flag. All (Persistent)PreRun(E) functions will be run accordingly before the provided function is called if
// RunPreRunsDuringCompletion is set to true. All flags other than the flag currently being completed will be parsed
// according to their type. The flag being completed is parsed as a raw string with no format requirements
//
// Shell Completion compatibility matrix: bash, zsh
func (c *Command) MarkDynamicFlagCompletion(name string, completion DynamicFlagCompletion) error {
flag := c.Flag(name)
if flag == nil {
return fmt.Errorf("no such flag %s", name)
}
if flag.NoOptDefVal != "" {
return fmt.Errorf("%s takes no parameters", name)
}

if c.dynamicFlagCompletions == nil {
c.dynamicFlagCompletions = make(map[*pflag.Flag]DynamicFlagCompletion)
}
c.dynamicFlagCompletions[flag] = completion

return nil
}
56 changes: 51 additions & 5 deletions zsh_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,26 @@ const (
zshCompArgumentFilenameComp = "cobra_annotations_zsh_completion_argument_file_completion"
zshCompArgumentWordComp = "cobra_annotations_zsh_completion_argument_word_completion"
zshCompDirname = "cobra_annotations_zsh_dirname"
zshCompDynamicCompletion = "cobra_annotations_zsh_completion_dynamic_completion"
)

var (
zshCompFuncMap = template.FuncMap{
"genZshFuncName": zshCompGenFuncName,
"extractFlags": zshCompExtractFlag,
"genFlagEntryForZshArguments": zshCompGenFlagEntryForArguments,
"extractArgsCompletions": zshCompExtractArgumentCompletionHintsForRendering,
"genZshFuncName": zshCompGenFuncName,
"extractFlags": zshCompExtractFlag,
"genFlagEntryForZshArguments": zshCompGenFlagEntryForArguments,
"extractArgsCompletions": zshCompExtractArgumentCompletionHintsForRendering,
"genZshFlagDynamicCompletionFuncName": zshCompGenDynamicFlagCompletionFuncName,
}
zshCompletionText = `
{{/* should accept Command (that contains subcommands) as parameter */}}
{{define "argumentsC" -}}
{{ $cmdPath := genZshFuncName .}}
function {{$cmdPath}} {
local -a commands
{{/* If we are at the root, save a copy of the $words array as it contains the full command, including any empty
strings and other parameters */}}
{{ if (not .HasParent) and .HasDynamicCompletions }} full_command=("${(@)words}"){{- end}}

_arguments -C \{{- range extractFlags .}}
{{genFlagEntryForZshArguments .}} \{{- end}}
Expand Down Expand Up @@ -79,6 +84,36 @@ function {{genZshFuncName .}} {
{{define "Main" -}}
#compdef _{{.Name}} {{.Name}}

{{if .HasDynamicCompletions -}}
function {{genZshFlagDynamicCompletionFuncName .}} {
export COBRA_FLAG_COMPLETION="$1"

local output
if ! output="$(mktemp)" ; then
return $?
fi

if ! error_message="$("${(@)full_command}" 2>&1 > "$output")" ; then
local st="$?"
_message "Exception occurred during completion: $error_message"
return "$st"
fi

local -a args
while read -r -d '' line ; do
args+="$line"
done < "$output"

if [[ $#args -gt 0 ]] ; then
_values "$1" "${(@)args}"
else
_message "No matching completion for $descr: $opt_args"
fi

unset COBRA_FLAG_COMPLETION
rm "$output"
}{{- end}}

{{template "selectCmdTemplate" .}}
{{end}}
`
Expand Down Expand Up @@ -112,6 +147,11 @@ func (c *Command) GenZshCompletionFile(filename string) error {
// writer. The completion always run on the root command regardless of the
// command it was called from.
func (c *Command) GenZshCompletion(w io.Writer) error {
dynamicFlagCompletionFuncName := zshCompGenDynamicFlagCompletionFuncName(c.Root())
c.Root().visitAllFlagsWithCompletions(func(f *pflag.Flag) {
f.Annotations[zshCompDynamicCompletion] = []string{dynamicFlagCompletionFuncName + " " + f.Name}
})

tmpl, err := template.New("Main").Funcs(zshCompFuncMap).Parse(zshCompletionText)
if err != nil {
return fmt.Errorf("error creating zsh completion template: %v", err)
Expand Down Expand Up @@ -310,7 +350,7 @@ func zshCompGenFlagEntryExtras(f *pflag.Flag) string {
return ""
}

extras := ":" // allow options for flag (even without assistance)
extras := ":" + f.Name // allow options for flag (even without assistance)
for key, values := range f.Annotations {
switch key {
case zshCompDirname:
Expand All @@ -320,6 +360,8 @@ func zshCompGenFlagEntryExtras(f *pflag.Flag) string {
for _, pattern := range values {
extras = extras + fmt.Sprintf(` -g "%s"`, pattern)
}
case zshCompDynamicCompletion:
extras += fmt.Sprintf(":{%s}", values[0])
}
}

Expand All @@ -334,3 +376,7 @@ func zshCompFlagCouldBeSpecifiedMoreThenOnce(f *pflag.Flag) bool {
func zshCompQuoteFlagDescription(s string) string {
return strings.Replace(s, "'", `'\''`, -1)
}

func zshCompGenDynamicFlagCompletionFuncName(c *Command) string {
return "_" + c.Root().Name() + "-handle-dynamic-flag-completion"
}
Loading