diff --git a/cmd/templ/generatecmd/proxy/proxy.go b/cmd/templ/generatecmd/proxy/proxy.go index 9b942be18..a129888c9 100644 --- a/cmd/templ/generatecmd/proxy/proxy.go +++ b/cmd/templ/generatecmd/proxy/proxy.go @@ -16,6 +16,7 @@ import ( "strings" "time" + "github.com/PuerkitoBio/goquery" "github.com/a-h/templ/cmd/templ/generatecmd/sse" "github.com/andybalholm/brotli" @@ -36,7 +37,16 @@ type Handler struct { } func insertScriptTagIntoBody(body string) (updated string) { - return strings.Replace(body, "", scriptTag+"", -1) + doc, err := goquery.NewDocumentFromReader(strings.NewReader(body)) + if err != nil { + return strings.Replace(body, "", scriptTag+"", -1) + } + doc.Find("body").AppendHtml(scriptTag) + r, err := doc.Html() + if err != nil { + return strings.Replace(body, "", scriptTag+"", -1) + } + return r } type passthroughWriteCloser struct { diff --git a/cmd/templ/generatecmd/proxy/proxy_test.go b/cmd/templ/generatecmd/proxy/proxy_test.go index b33478fca..d2bc91517 100644 --- a/cmd/templ/generatecmd/proxy/proxy_test.go +++ b/cmd/templ/generatecmd/proxy/proxy_test.go @@ -161,6 +161,49 @@ func TestProxy(t *testing.T) { t.Errorf("unexpected response body (-got +want):\n%s", diff) } }) + t.Run("plain: body tags get the script inserted ignoring js with body tags", func(t *testing.T) { + // Arrange + r := &http.Response{ + Body: io.NopCloser(strings.NewReader(``)), + Header: make(http.Header), + Request: &http.Request{ + URL: &url.URL{ + Scheme: "http", + Host: "example.com", + }, + }, + } + r.Header.Set("Content-Type", "text/html, charset=utf-8") + r.Header.Set("Content-Length", "26") + + expectedString := insertScriptTagIntoBody(``) + if !strings.Contains(expectedString, scriptTag) { + t.Fatalf("expected the script tag to be inserted, but it wasn't: %q", expectedString) + } + if !strings.Contains(expectedString, `console.log("")`) { + t.Fatalf("expected the script tag to be inserted, but mangled the html: %q", expectedString) + } + + // Act + log := slog.New(slog.NewJSONHandler(io.Discard, nil)) + h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"}) + err := h.modifyResponse(r) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Assert + if r.Header.Get("Content-Length") != fmt.Sprintf("%d", len(expectedString)) { + t.Errorf("expected content length to be %d, got %v", len(expectedString), r.Header.Get("Content-Length")) + } + actualBody, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("unexpected error reading response: %v", err) + } + if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" { + t.Errorf("unexpected response body (-got +want):\n%s", diff) + } + }) t.Run("gzip: non-html content is not modified", func(t *testing.T) { // Arrange r := &http.Response{