diff --git a/gopls/internal/lsp/source/completion/postfix_snippets.go b/gopls/internal/lsp/source/completion/postfix_snippets.go index 1661709e5dc..0490b386161 100644 --- a/gopls/internal/lsp/source/completion/postfix_snippets.go +++ b/gopls/internal/lsp/source/completion/postfix_snippets.go @@ -68,6 +68,10 @@ type postfixTmplArgs struct { // Type is the type of "foo.bar" in "foo.bar.print!". Type types.Type + // FuncResult are results of the enclosed function + FuncResults []*types.Var + + sel *ast.SelectorExpr scope *types.Scope snip snippet.Builder importIfNeeded func(pkgPath string, scope *types.Scope) (name string, edits []protocol.TextEdit, err error) @@ -75,6 +79,7 @@ type postfixTmplArgs struct { qf types.Qualifier varNames map[string]bool placeholders bool + currentTabStop int } var postfixTmpls = []postfixTmpl{{ @@ -250,26 +255,119 @@ if {{.X}} != nil { body: `{{if (eq .Kind "slice" "map" "array" "chan") -}} len({{.X}}) {{- end}}`, +}, { + label: "iferr", + details: "check error and return", + body: `{{if and .StmtOK (eq (.TypeName .Type) "error") -}} +{{- $errName := (or (and .IsIdent .X) "err") -}} +if {{if not .IsIdent}}err := {{.X}}; {{end}}{{$errName}} != nil { + return {{$a := .}}{{range $i, $v := .FuncResults}} + {{- if $i}}, {{end -}} + {{- if eq ($a.TypeName $v.Type) "error" -}} + {{$a.Placeholder $errName}} + {{- else -}} + {{$a.Zero $v.Type}} + {{- end -}} + {{end}} +} +{{end}}`, +}, { + label: "iferr", + details: "check error and return", + body: `{{if and .StmtOK (eq .Kind "tuple") (len .Tuple) (eq (.TypeName .TupleLast.Type) "error") -}} +{{- $a := . -}} +if {{range $i, $v := .Tuple}}{{if $i}}, {{end}}{{if and (eq ($a.TypeName $v.Type) "error") (eq (inc $i) (len $a.Tuple))}}err{{else}}_{{end}}{{end}} := {{.X -}} +; err != nil { + return {{range $i, $v := .FuncResults}} + {{- if $i}}, {{end -}} + {{- if eq ($a.TypeName $v.Type) "error" -}} + {{$a.Placeholder "err"}} + {{- else -}} + {{$a.Zero $v.Type}} + {{- end -}} + {{end}} +} +{{end}}`, +}, { + // variferr snippets use nested placeholders, as described in + // https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#snippet_syntax, + // so that users can wrap the returned error without modifying the error + // variable name. + label: "variferr", + details: "assign variables and check error", + body: `{{if and .StmtOK (eq .Kind "tuple") (len .Tuple) (eq (.TypeName .TupleLast.Type) "error") -}} +{{- $a := . -}} +{{- $errName := "err" -}} +{{- range $i, $v := .Tuple -}} + {{- if $i}}, {{end -}} + {{- if and (eq ($a.TypeName $v.Type) "error") (eq (inc $i) (len $a.Tuple)) -}} + {{$errName | $a.SpecifiedPlaceholder (len $a.Tuple)}} + {{- else -}} + {{$a.VarName $v.Type $v.Name | $a.Placeholder}} + {{- end -}} +{{- end}} := {{.X}} +if {{$errName | $a.SpecifiedPlaceholder (len $a.Tuple)}} != nil { + return {{range $i, $v := .FuncResults}} + {{- if $i}}, {{end -}} + {{- if eq ($a.TypeName $v.Type) "error" -}} + {{$errName | $a.SpecifiedPlaceholder (len $a.Tuple) | + $a.SpecifiedPlaceholder (inc (len $a.Tuple))}} + {{- else -}} + {{$a.Zero $v.Type}} + {{- end -}} + {{end}} +} +{{end}}`, +}, { + label: "variferr", + details: "assign variables and check error", + body: `{{if and .StmtOK (eq (.TypeName .Type) "error") -}} +{{- $a := . -}} +{{- $errName := .VarName nil "err" -}} +{{$errName | $a.SpecifiedPlaceholder 1}} := {{.X}} +if {{$errName | $a.SpecifiedPlaceholder 1}} != nil { + return {{range $i, $v := .FuncResults}} + {{- if $i}}, {{end -}} + {{- if eq ($a.TypeName $v.Type) "error" -}} + {{$errName | $a.SpecifiedPlaceholder 1 | $a.SpecifiedPlaceholder 2}} + {{- else -}} + {{$a.Zero $v.Type}} + {{- end -}} + {{end}} +} +{{end}}`, }} // Cursor indicates where the client's cursor should end up after the // snippet is done. func (a *postfixTmplArgs) Cursor() string { - a.snip.WriteFinalTabstop() - return "" + return "$0" } -// Placeholder indicate a tab stops with the placeholder string, the order +// Placeholder indicate a tab stop with the placeholder string, the order // of tab stops is the same as the order of invocation -func (a *postfixTmplArgs) Placeholder(s string) string { - if a.placeholders { - a.snip.WritePlaceholder(func(b *snippet.Builder) { - b.WriteText(s) - }) - } else { - a.snip.WritePlaceholder(nil) +func (a *postfixTmplArgs) Placeholder(placeholder string) string { + if !a.placeholders { + placeholder = "" + } + return fmt.Sprintf("${%d:%s}", a.nextTabStop(), placeholder) +} + +// nextTabStop returns the next tab stop index for a new placeholder. +func (a *postfixTmplArgs) nextTabStop() int { + // Tab stops start from 1, so increment before returning. + a.currentTabStop++ + return a.currentTabStop +} + +// SpecifiedPlaceholder indicate a specified tab stop with the placeholder string. +// Sometimes the same tab stop appears in multiple places and their numbers +// need to be specified. e.g. variferr +func (a *postfixTmplArgs) SpecifiedPlaceholder(tabStop int, placeholder string) string { + if !a.placeholders { + placeholder = "" } - return "" + return fmt.Sprintf("${%d:%s}", tabStop, placeholder) } // Import makes sure the package corresponding to path is imported, @@ -309,7 +407,7 @@ func (a *postfixTmplArgs) KeyType() types.Type { return a.Type.Underlying().(*types.Map).Key() } -// Tuple returns the tuple result vars if X is a call expression. +// Tuple returns the tuple result vars if the type of X is tuple. func (a *postfixTmplArgs) Tuple() []*types.Var { tuple, _ := a.Type.(*types.Tuple) if tuple == nil { @@ -323,6 +421,18 @@ func (a *postfixTmplArgs) Tuple() []*types.Var { return typs } +// TupleLast returns the last tuple result vars if the type of X is tuple. +func (a *postfixTmplArgs) TupleLast() *types.Var { + tuple, _ := a.Type.(*types.Tuple) + if tuple == nil { + return nil + } + if tuple.Len() == 0 { + return nil + } + return tuple.At(tuple.Len() - 1) +} + // TypeName returns the textual representation of type t. func (a *postfixTmplArgs) TypeName(t types.Type) (string, error) { if t == nil || t == types.Typ[types.Invalid] { @@ -331,6 +441,16 @@ func (a *postfixTmplArgs) TypeName(t types.Type) (string, error) { return types.TypeString(t, a.qf), nil } +// Zero return the zero value representation of type t +func (a *postfixTmplArgs) Zero(t types.Type) string { + return formatZeroValue(t, a.qf) +} + +func (a *postfixTmplArgs) IsIdent() bool { + _, ok := a.sel.X.(*ast.Ident) + return ok +} + // VarName returns a suitable variable name for the type t. If t // implements the error interface, "err" is used. If t is not a named // type then nonNamedDefault is used. Otherwise a name is made by @@ -417,6 +537,17 @@ func (c *completer) addPostfixSnippetCandidates(ctx context.Context, sel *ast.Se } } + var funcResults []*types.Var + if c.enclosingFunc != nil { + results := c.enclosingFunc.sig.Results() + if results != nil { + funcResults = make([]*types.Var, results.Len()) + for i := 0; i < results.Len(); i++ { + funcResults[i] = results.At(i) + } + } + } + scope := c.pkg.GetTypes().Scope().Innermost(c.pos) if scope == nil { return @@ -455,6 +586,8 @@ func (c *completer) addPostfixSnippetCandidates(ctx context.Context, sel *ast.Se StmtOK: stmtOK, Obj: exprObj(c.pkg.GetTypesInfo(), sel.X), Type: selType, + FuncResults: funcResults, + sel: sel, qf: c.qf, importIfNeeded: c.importIfNeeded, scope: scope, @@ -497,7 +630,9 @@ func initPostfixRules() { var idx int for _, rule := range postfixTmpls { var err error - rule.tmpl, err = template.New("postfix_snippet").Parse(rule.body) + rule.tmpl, err = template.New("postfix_snippet").Funcs(template.FuncMap{ + "inc": inc, + }).Parse(rule.body) if err != nil { log.Panicf("error parsing postfix snippet template: %v", err) } @@ -508,6 +643,10 @@ func initPostfixRules() { }) } +func inc(i int) int { + return i + 1 +} + // importIfNeeded returns the package identifier and any necessary // edits to import package pkgPath. func (c *completer) importIfNeeded(pkgPath string, scope *types.Scope) (string, []protocol.TextEdit, error) { diff --git a/gopls/internal/test/integration/completion/postfix_snippet_test.go b/gopls/internal/test/integration/completion/postfix_snippet_test.go index 0677280c5ec..31ea2e02b3e 100644 --- a/gopls/internal/test/integration/completion/postfix_snippet_test.go +++ b/gopls/internal/test/integration/completion/postfix_snippet_test.go @@ -306,6 +306,7 @@ func _() { ${1:}, ${2:} := foo() } `, + allowMultipleItem: true, }, { name: "var_single_value", @@ -318,6 +319,7 @@ func _() { foo().var } `, + allowMultipleItem: true, after: ` package foo diff --git a/gopls/internal/test/marker/testdata/completion/postfix.txt b/gopls/internal/test/marker/testdata/completion/postfix.txt index 63661ee9228..cab097465d7 100644 --- a/gopls/internal/test/marker/testdata/completion/postfix.txt +++ b/gopls/internal/test/marker/testdata/completion/postfix.txt @@ -13,6 +13,10 @@ go 1.18 -- postfix.go -- package snippets +import ( + "strconv" +) + func _() { var foo []int foo.append //@rank(" //", postfixAppend) @@ -96,3 +100,32 @@ func _() { foo.fo //@snippet(" //", postfixForChannel, "for ${1:} := range foo {\n\t$0\n}") foo.rang //@snippet(" //", postfixRangeChannel, "for ${1:} := range foo {\n\t$0\n}") } + +type T struct { + Name string +} + +func _() (string, T, map[string]string, error) { + /* iferr! */ //@item(postfixIfErr, "iferr!", "check error and return", "snippet") + /* variferr! */ //@item(postfixVarIfErr, "variferr!", "assign variables and check error", "snippet") + /* var! */ //@item(postfixVars, "var!", "assign to variables", "snippet") + + strconv.Atoi("32"). //@complete(" //", postfixIfErr, postfixPrint, postfixVars, postfixVarIfErr) + + var err error + err.iferr //@snippet(" //", postfixIfErr, "if err != nil {\n\treturn \"\", T{}, nil, ${1:}\n}\n") + + strconv.Atoi("32").iferr //@snippet(" //", postfixIfErr, "if _, err := strconv.Atoi(\"32\"); err != nil {\n\treturn \"\", T{}, nil, ${1:}\n}\n") + + strconv.Atoi("32").variferr //@snippet(" //", postfixVarIfErr, "${1:}, ${2:} := strconv.Atoi(\"32\")\nif ${2:} != nil {\n\treturn \"\", T{}, nil, ${3:}\n}\n") + + // test function return multiple errors + var foo func() (error, error) + foo().iferr //@snippet(" //", postfixIfErr, "if _, err := foo(); err != nil {\n\treturn \"\", T{}, nil, ${1:}\n}\n") + foo().variferr //@snippet(" //", postfixVarIfErr, "${1:}, ${2:} := foo()\nif ${2:} != nil {\n\treturn \"\", T{}, nil, ${3:}\n}\n") + + // test function just return error + var bar func() error + bar().iferr //@snippet(" //", postfixIfErr, "if err := bar(); err != nil {\n\treturn \"\", T{}, nil, ${1:}\n}\n") + bar().variferr //@snippet(" //", postfixVarIfErr, "${1:} := bar()\nif ${1:} != nil {\n\treturn \"\", T{}, nil, ${2:}\n}\n") +} diff --git a/gopls/internal/test/marker/testdata/completion/postfix_placeholder.txt b/gopls/internal/test/marker/testdata/completion/postfix_placeholder.txt index 44dfbc96df1..7569f130466 100644 --- a/gopls/internal/test/marker/testdata/completion/postfix_placeholder.txt +++ b/gopls/internal/test/marker/testdata/completion/postfix_placeholder.txt @@ -16,6 +16,10 @@ go 1.18 -- postfix.go -- package snippets +import ( + "strconv" +) + func _() { /* for! */ //@item(postfixFor, "for!", "range over slice by index", "snippet") /* forr! */ //@item(postfixForr, "forr!", "range over slice by index and value", "snippet") @@ -51,3 +55,29 @@ func _() { foo.fo //@snippet(" //", postfixForChannel, "for ${1:e} := range foo {\n\t$0\n}") foo.rang //@snippet(" //", postfixRangeChannel, "for ${1:e} := range foo {\n\t$0\n}") } + +type T struct { + Name string +} + +func _() (string, T, map[string]string, error) { + /* iferr! */ //@item(postfixIfErr, "iferr!", "check error and return", "snippet") + /* variferr! */ //@item(postfixVarIfErr, "variferr!", "assign variables and check error", "snippet") + /* var! */ //@item(postfixVars, "var!", "assign to variables", "snippet") + + + var err error + err.iferr //@snippet(" //", postfixIfErr, "if err != nil {\n\treturn \"\", T{}, nil, ${1:err}\n}\n") + strconv.Atoi("32").iferr //@snippet(" //", postfixIfErr, "if _, err := strconv.Atoi(\"32\"); err != nil {\n\treturn \"\", T{}, nil, ${1:err}\n}\n") + strconv.Atoi("32").variferr //@snippet(" //", postfixVarIfErr, "${1:i}, ${2:err} := strconv.Atoi(\"32\")\nif ${2:err} != nil {\n\treturn \"\", T{}, nil, ${3:${2:err}}\n}\n") + + // test function return multiple errors + var foo func() (error, error) + foo().iferr //@snippet(" //", postfixIfErr, "if _, err := foo(); err != nil {\n\treturn \"\", T{}, nil, ${1:err}\n}\n") + foo().variferr //@snippet(" //", postfixVarIfErr, "${1:err2}, ${2:err} := foo()\nif ${2:err} != nil {\n\treturn \"\", T{}, nil, ${3:${2:err}}\n}\n") + + // test function just return error + var bar func() error + bar().iferr //@snippet(" //", postfixIfErr, "if err := bar(); err != nil {\n\treturn \"\", T{}, nil, ${1:err}\n}\n") + bar().variferr //@snippet(" //", postfixVarIfErr, "${1:err2} := bar()\nif ${1:err2} != nil {\n\treturn \"\", T{}, nil, ${2:${1:err2}}\n}\n") +}