diff --git a/cmd/templ/lspcmd/lsp_test.go b/cmd/templ/lspcmd/lsp_test.go index 6e9f914f1..2331e49df 100644 --- a/cmd/templ/lspcmd/lsp_test.go +++ b/cmd/templ/lspcmd/lsp_test.go @@ -319,6 +319,142 @@ func TestHover(t *testing.T) { } } +func TestReferences(t *testing.T) { + if testing.Short() { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + log, _ := zap.NewProduction() + + ctx, appDir, _, server, teardown, err := Setup(ctx, log) + if err != nil { + t.Fatalf("failed to setup test: %v", err) + return + } + defer teardown(t) + defer cancel() + + templFile, err := os.ReadFile(appDir + "/templates.templ") + if err != nil { + t.Fatalf("failed to read file %q: %v", appDir+"/templates.templ", err) + return + + } + err = server.DidOpen(ctx, &protocol.DidOpenTextDocumentParams{ + TextDocument: protocol.TextDocumentItem{ + URI: uri.URI("file://" + appDir + "/templates.templ"), + LanguageID: "templ", + Version: 1, + Text: string(templFile), + }, + }) + if err != nil { + t.Errorf("failed to register open file: %v", err) + return + } + log.Info("Calling References") + + tests := []struct { + line int + character int + assert func(t *testing.T, l []protocol.Location) (msg string, ok bool) + }{ + { + // this is the definition of the templ function in the templates.templ file. + line: 5, + character: 9, + assert: func(t *testing.T, actual []protocol.Location) (msg string, ok bool) { + expectedReference := []protocol.Location{ + { + // This is the useage of the templ function in the main.go file. + URI: uri.URI("file://" + appDir + "/main.go"), + Range: protocol.Range{ + Start: protocol.Position{ + Line: uint32(24), + Character: uint32(7), + }, + End: protocol.Position{ + Line: uint32(24), + Character: uint32(11), + }, + }, + }, + } + if diff := lspdiff.References(expectedReference, actual); diff != "" { + return fmt.Sprintf("Expected: %+v\nActual: %+v", expectedReference, actual), false + } + return "", true + }, + }, + { + // this is the definition of the struct in the templates.templ file. + line: 21, + character: 9, + assert: func(t *testing.T, actual []protocol.Location) (msg string, ok bool) { + expectedReference := []protocol.Location{ + { + // This is the useage of the struct in the templates.templ file. + URI: uri.URI("file://" + appDir + "/templates.templ"), + Range: protocol.Range{ + Start: protocol.Position{ + Line: uint32(24), + Character: uint32(8), + }, + End: protocol.Position{ + Line: uint32(24), + Character: uint32(14), + }, + }, + }, + } + if diff := lspdiff.References(expectedReference, actual); diff != "" { + return fmt.Sprintf("Expected: %+v\nActual: %+v", expectedReference, actual), false + } + return "", true + }, + }, + } + + for i, test := range tests { + t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { + // Give CI/CD pipeline executors some time because they're often quite slow. + var ok bool + var msg string + for i := 0; i < 3; i++ { + if err != nil { + t.Error(err) + return + } + actual, err := server.References(ctx, &protocol.ReferenceParams{ + TextDocumentPositionParams: protocol.TextDocumentPositionParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: uri.URI("file://" + appDir + "/templates.templ"), + }, + // Positions are zero indexed. + Position: protocol.Position{ + Line: uint32(test.line - 1), + Character: uint32(test.character - 1), + }, + }, + }) + if err != nil { + t.Errorf("failed to get references: %v", err) + return + } + msg, ok = test.assert(t, actual) + if !ok { + break + } + time.Sleep(time.Millisecond * 500) + } + if !ok { + t.Error(msg) + } + }) + } +} + func TestCodeAction(t *testing.T) { if testing.Short() { return diff --git a/cmd/templ/lspcmd/lspdiff/lspdiff.go b/cmd/templ/lspcmd/lspdiff/lspdiff.go index af67f8635..653ee9671 100644 --- a/cmd/templ/lspcmd/lspdiff/lspdiff.go +++ b/cmd/templ/lspcmd/lspdiff/lspdiff.go @@ -25,6 +25,10 @@ func CompletionList(expected, actual *protocol.CompletionList) string { ) } +func References(expected, actual []protocol.Location) string { + return cmp.Diff(expected, actual) +} + func CompletionListContainsText(cl *protocol.CompletionList, text string) bool { if cl == nil { return false diff --git a/cmd/templ/lspcmd/proxy/server.go b/cmd/templ/lspcmd/proxy/server.go index 8ac3597b2..fe6689454 100644 --- a/cmd/templ/lspcmd/proxy/server.go +++ b/cmd/templ/lspcmd/proxy/server.go @@ -939,7 +939,6 @@ func (p *Server) RangeFormatting(ctx context.Context, params *lsp.DocumentRangeF func (p *Server) References(ctx context.Context, params *lsp.ReferenceParams) (result []lsp.Location, err error) { p.Log.Info("client -> server: References") defer p.Log.Info("client -> server: References end") - templURI := params.TextDocument.URI // Rewrite the request. var ok bool ok, params.TextDocument.URI, params.Position = p.updatePosition(params.TextDocument.URI, params.Position) @@ -954,8 +953,12 @@ func (p *Server) References(ctx context.Context, params *lsp.ReferenceParams) (r // Rewrite the response. for i := 0; i < len(result); i++ { r := result[i] - r.URI = templURI - r.Range = p.convertGoRangeToTemplRange(templURI, r.Range) + isTemplURI, templURI := convertTemplGoToTemplURI(r.URI) + if isTemplURI { + p.Log.Info(fmt.Sprintf("references-%d - range conversion for %s", i, r.URI)) + r.URI, r.Range = templURI, p.convertGoRangeToTemplRange(templURI, r.Range) + } + p.Log.Info(fmt.Sprintf("references-%d: %+v", i, r)) result[i] = r } return result, err