diff --git a/cmd/templ/lspcmd/lsp_test.go b/cmd/templ/lspcmd/lsp_test.go index 1587346ed..2f3ef995d 100644 --- a/cmd/templ/lspcmd/lsp_test.go +++ b/cmd/templ/lspcmd/lsp_test.go @@ -2,6 +2,7 @@ package lspcmd import ( "context" + "encoding/json" "fmt" "io" "os" @@ -14,6 +15,7 @@ import ( "github.com/a-h/templ/cmd/templ/generatecmd/modcheck" "github.com/a-h/templ/cmd/templ/lspcmd/lspdiff" "github.com/a-h/templ/cmd/templ/testproject" + "github.com/google/go-cmp/cmp" "go.lsp.dev/jsonrpc2" "go.lsp.dev/uri" "go.uber.org/zap" @@ -603,6 +605,139 @@ func TestCodeAction(t *testing.T) { } } +func TestDocumentSymbol(t *testing.T) { + if testing.Short() { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + log, _ := zap.NewProduction() + + ctx, appDir, _, server, teardown, err := Setup(ctx, log) + if err != nil { + t.Fatalf("failed to setup test: %v", err) + } + defer teardown(t) + defer cancel() + + tests := []struct { + uri string + expect []any + }{ + { + uri: "file://" + appDir + "/templates.templ", + expect: []any{ + protocol.SymbolInformation{ + Name: "Page", + Kind: protocol.SymbolKindFunction, + Location: protocol.Location{ + Range: protocol.Range{ + Start: protocol.Position{Line: 11, Character: 0}, + End: protocol.Position{Line: 50, Character: 1}, + }, + }, + }, + protocol.SymbolInformation{ + Name: "nihao", + Kind: protocol.SymbolKindVariable, + Location: protocol.Location{ + Range: protocol.Range{ + Start: protocol.Position{Line: 18, Character: 4}, + End: protocol.Position{Line: 18, Character: 16}, + }, + }, + }, + protocol.SymbolInformation{ + Name: "Struct", + Kind: protocol.SymbolKindStruct, + Location: protocol.Location{ + Range: protocol.Range{ + Start: protocol.Position{Line: 20, Character: 5}, + End: protocol.Position{Line: 22, Character: 1}, + }, + }, + }, + protocol.SymbolInformation{ + Name: "s", + Kind: protocol.SymbolKindVariable, + Location: protocol.Location{ + Range: protocol.Range{ + Start: protocol.Position{Line: 24, Character: 4}, + End: protocol.Position{Line: 24, Character: 16}, + }, + }, + }, + }, + }, + { + uri: "file://" + appDir + "/remoteparent.templ", + expect: []any{ + protocol.SymbolInformation{ + Name: "RemoteInclusionTest", + Kind: protocol.SymbolKindFunction, + Location: protocol.Location{ + Range: protocol.Range{ + Start: protocol.Position{Line: 9, Character: 0}, + End: protocol.Position{Line: 35, Character: 1}, + }, + }, + }, + protocol.SymbolInformation{ + Name: "Remote2", + Kind: protocol.SymbolKindFunction, + Location: protocol.Location{ + Range: protocol.Range{ + Start: protocol.Position{Line: 37, Character: 0}, + End: protocol.Position{Line: 63, Character: 1}, + }, + }, + }, + }, + }, + } + + for i, test := range tests { + t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { + actual, err := server.DocumentSymbol(ctx, &protocol.DocumentSymbolParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: uri.URI(test.uri), + }, + }) + if err != nil { + t.Errorf("failed to get document symbol: %v", err) + } + + // set expected URI + for i := range test.expect { + switch v := test.expect[i].(type) { + case protocol.SymbolInformation: + v.Location.URI = uri.URI(test.uri) + test.expect[i] = v + } + } + + expectdSlice, err := sliceToAnySlice(test.expect) + if err != nil { + t.Errorf("failed to convert expect to any slice: %v", err) + } + diff := cmp.Diff(expectdSlice, actual) + if diff != "" { + t.Errorf("unexpected document symbol: %v", diff) + } + }) + } +} + +func sliceToAnySlice(in []any) ([]any, error) { + b, err := json.Marshal(in) + if err != nil { + return nil, err + } + out := make([]any, 0, len(in)) + err = json.Unmarshal(b, &out) + return out, err +} + func runeIndexToUTF8ByteIndex(s string, runeIndex int) (lspChar uint32, err error) { for i, r := range []rune(s) { if i == runeIndex { diff --git a/cmd/templ/lspcmd/proxy/server.go b/cmd/templ/lspcmd/proxy/server.go index 02fc83571..6599e17ce 100644 --- a/cmd/templ/lspcmd/proxy/server.go +++ b/cmd/templ/lspcmd/proxy/server.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "encoding/json" "fmt" "os" "path/filepath" @@ -825,9 +826,46 @@ func (p *Server) DocumentLinkResolve(ctx context.Context, params *lsp.DocumentLi func (p *Server) DocumentSymbol(ctx context.Context, params *lsp.DocumentSymbolParams) (result []interface{} /* []SymbolInformation | []DocumentSymbol */, err error) { p.Log.Info("client -> server: DocumentSymbol") defer p.Log.Info("client -> server: DocumentSymbol end") - // TODO: Rewrite the request and response, but for now, ignore it. - // return p.Target.DocumentSymbol(ctx params) - return + isTemplFile, goURI := convertTemplToGoURI(params.TextDocument.URI) + if !isTemplFile { + return p.Target.DocumentSymbol(ctx, params) + } + templURI := params.TextDocument.URI + params.TextDocument.URI = goURI + symbols, err := p.Target.DocumentSymbol(ctx, params) + if err != nil { + return nil, err + } + + // recursively convert the ranges of the symbols and their children + var convertRange func(s *lsp.DocumentSymbol) + convertRange = func(s *lsp.DocumentSymbol) { + s.Range = p.convertGoRangeToTemplRange(templURI, s.Range) + s.SelectionRange = p.convertGoRangeToTemplRange(templURI, s.SelectionRange) + for i := 0; i < len(s.Children); i++ { + convertRange(&s.Children[i]) + } + } + + for _, s := range symbols { + if m, ok := s.(map[string]interface{}); ok { + s, err = mapToSymbol(m) + if err != nil { + return nil, err + } + } + switch s := s.(type) { + case lsp.DocumentSymbol: + convertRange(&s) + result = append(result, s) + case lsp.SymbolInformation: + s.Location.URI = templURI + s.Location.Range = p.convertGoRangeToTemplRange(templURI, s.Location.Range) + result = append(result, s) + } + } + + return result, err } func (p *Server) ExecuteCommand(ctx context.Context, params *lsp.ExecuteCommandParams) (result interface{}, err error) { @@ -1216,3 +1254,24 @@ func (p *Server) Request(ctx context.Context, method string, params interface{}) defer p.Log.Info("client -> server: Request end") return p.Target.Request(ctx, method, params) } + +func mapToSymbol(m map[string]interface{}) (interface{}, error) { + b, err := json.Marshal(m) + if err != nil { + return nil, err + } + + if _, ok := m["selectionRange"]; ok { + var s lsp.DocumentSymbol + if err := json.Unmarshal(b, &s); err != nil { + return nil, err + } + return s, nil + } + + var s lsp.SymbolInformation + if err := json.Unmarshal(b, &s); err != nil { + return nil, err + } + return s, nil +}