Skip to content

Commit

Permalink
feat: add brotli support to proxy, warn on unsupported encoding (#695)
Browse files Browse the repository at this point in the history
Co-authored-by: Joe Davidson <joe.davidson.21111@gmail.com>
  • Loading branch information
a-h and joerdav authored Apr 22, 2024
1 parent 41d5003 commit ef58c7a
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.668
0.2.668
2 changes: 1 addition & 1 deletion cmd/templ/generatecmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ func (cmd *Generate) StartProxy(ctx context.Context) (p *proxy.Handler, err erro
if cmd.Args.ProxyBind == "" {
cmd.Args.ProxyBind = "127.0.0.1"
}
p = proxy.New(cmd.Args.ProxyBind, cmd.Args.ProxyPort, target)
p = proxy.New(cmd.Log, cmd.Args.ProxyBind, cmd.Args.ProxyPort, target)
go func() {
cmd.Log.Info("Proxying", slog.String("from", p.URL), slog.String("to", p.Target.String()))
if err := http.ListenAndServe(fmt.Sprintf("%s:%d", cmd.Args.ProxyBind, cmd.Args.ProxyPort), p); err != nil {
Expand Down
116 changes: 73 additions & 43 deletions cmd/templ/generatecmd/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import (
"compress/gzip"
"fmt"
"io"
"log"
stdlog "log"
"log/slog"
"math"
"net/http"
"net/http/httputil"
Expand All @@ -16,6 +17,7 @@ import (
"time"

"github.com/a-h/templ/cmd/templ/generatecmd/sse"
"github.com/andybalholm/brotli"

_ "embed"
)
Expand All @@ -26,85 +28,113 @@ var script string
const scriptTag = `<script src="/_templ/reload/script.js"></script>`

type Handler struct {
log *slog.Logger
URL string
Target *url.URL
p *httputil.ReverseProxy
sse *sse.Handler
}

func updateGzipResponse(r *http.Response) error {
plainr, err := gzip.NewReader(r.Body)
func insertScriptTagIntoBody(body string) (updated string) {
return strings.Replace(body, "</body>", scriptTag+"</body>", -1)
}

type passthroughWriteCloser struct {
io.Writer
}

func (pwc passthroughWriteCloser) Close() error {
return nil
}

const unsupportedContentEncoding = "Unsupported content encoding, hot reload script not inserted."

func (h *Handler) modifyResponse(r *http.Response) error {
if r.Header.Get("templ-skip-modify") == "true" {
return nil
}
if contentType := r.Header.Get("Content-Type"); !strings.HasPrefix(contentType, "text/html") {
return nil
}

// Set up readers and writers.
newReader := func(in io.Reader) (out io.Reader, err error) {
return in, nil
}
newWriter := func(out io.Writer) io.WriteCloser {
return passthroughWriteCloser{out}
}
switch r.Header.Get("Content-Encoding") {
case "gzip":
newReader = func(in io.Reader) (out io.Reader, err error) {
return gzip.NewReader(in)
}
newWriter = func(out io.Writer) io.WriteCloser {
return gzip.NewWriter(out)
}
case "br":
newReader = func(in io.Reader) (out io.Reader, err error) {
return brotli.NewReader(in), nil
}
newWriter = func(out io.Writer) io.WriteCloser {
return brotli.NewWriter(out)
}
case "":
// No content encoding.
default:
h.log.Warn(unsupportedContentEncoding, slog.String("encoding", r.Header.Get("Content-Encoding")))
}

// Read the encoded body.
encr, err := newReader(r.Body)
if err != nil {
return err
}
defer plainr.Close()
body, err := io.ReadAll(plainr)
defer r.Body.Close()
body, err := io.ReadAll(encr)
if err != nil {
return err
}

// Update it.
updated := insertScriptTagIntoBody(string(body))

// Encode the response.
var buf bytes.Buffer
gzw := gzip.NewWriter(&buf)
defer gzw.Close()
_, err = gzw.Write([]byte(updated))
encw := newWriter(&buf)
_, err = encw.Write([]byte(updated))
if err != nil {
return err
}
err = gzw.Close()
err = encw.Close()
if err != nil {
return err
}

// Update the response.
r.Body = io.NopCloser(&buf)
r.ContentLength = int64(buf.Len())
r.Header.Set("Content-Length", strconv.Itoa(buf.Len()))
return nil
}

func updatePlainResponse(r *http.Response) error {
body, err := io.ReadAll(r.Body)
if err != nil {
return err
}
updated := insertScriptTagIntoBody(string(body))
r.Body = io.NopCloser(strings.NewReader(updated))
r.ContentLength = int64(len(updated))
r.Header.Set("Content-Length", strconv.Itoa(len(updated)))
return nil
}

func insertScriptTagIntoBody(body string) (updated string) {
return strings.Replace(body, "</body>", scriptTag+"</body>", -1)
}

func modifyResponse(r *http.Response) error {
if r.Header.Get("templ-skip-modify") == "true" {
return nil
}
if contentType := r.Header.Get("Content-Type"); !strings.HasPrefix(contentType, "text/html") {
return nil
}
modifier := updatePlainResponse
if r.Header.Get("Content-Encoding") == "gzip" {
modifier = updateGzipResponse
}
return modifier(r)
}

func New(bind string, port int, target *url.URL) *Handler {
func New(log *slog.Logger, bind string, port int, target *url.URL) (h *Handler) {
p := httputil.NewSingleHostReverseProxy(target)
p.ErrorLog = log.New(os.Stderr, "Proxy to target error: ", 0)
p.ErrorLog = stdlog.New(os.Stderr, "Proxy to target error: ", 0)
p.Transport = &roundTripper{
maxRetries: 10,
initialDelay: 100 * time.Millisecond,
backoffExponent: 1.5,
}
p.ModifyResponse = modifyResponse
return &Handler{
h = &Handler{
log: log,
URL: fmt.Sprintf("http://%s:%d", bind, port),
Target: target,
p: p,
sse: sse.New(),
}
p.ModifyResponse = h.modifyResponse
return h
}

func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Expand Down
138 changes: 132 additions & 6 deletions cmd/templ/generatecmd/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@ import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"sync"
"testing"
"time"

"github.com/andybalholm/brotli"
"github.com/google/go-cmp/cmp"
)

Expand Down Expand Up @@ -57,7 +60,9 @@ func TestProxy(t *testing.T) {
r.Header.Set("Content-Length", "16")

// Act
err := modifyResponse(r)
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -85,7 +90,9 @@ func TestProxy(t *testing.T) {
r.Header.Set("templ-skip-modify", "true")

// Act
err := modifyResponse(r)
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -117,7 +124,9 @@ func TestProxy(t *testing.T) {
}

// Act
err := modifyResponse(r)
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -147,7 +156,9 @@ func TestProxy(t *testing.T) {
r.Header.Set("Content-Length", "16")

// Act
err := modifyResponse(r)
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -195,7 +206,10 @@ func TestProxy(t *testing.T) {
r.Header.Set("Content-Length", fmt.Sprintf("%d", expectedLength))

// Act
if err = modifyResponse(r); err != nil {
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err = h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

Expand All @@ -216,7 +230,57 @@ func TestProxy(t *testing.T) {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("brotli: body tags get the script inserted", func(t *testing.T) {
// Arrange
body := `<html><body></body></html>`
var buf bytes.Buffer
brw := brotli.NewWriter(&buf)
_, err := brw.Write([]byte(body))
if err != nil {
t.Fatalf("unexpected error writing gzip: %v", err)
}
brw.Close()

expectedString := insertScriptTagIntoBody(body)

var expectedBytes bytes.Buffer
brw = brotli.NewWriter(&expectedBytes)
_, err = brw.Write([]byte(expectedString))
if err != nil {
t.Fatalf("unexpected error writing gzip: %v", err)
}
brw.Close()
expectedLength := len(expectedBytes.Bytes())

r := &http.Response{
Body: io.NopCloser(&buf),
Header: make(http.Header),
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Encoding", "br")
r.Header.Set("Content-Length", fmt.Sprintf("%d", expectedLength))

// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err = h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// Assert
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", expectedLength) {
t.Errorf("expected content length to be %d, got %v", expectedLength, r.Header.Get("Content-Length"))
}

actualBody, err := io.ReadAll(brotli.NewReader(r.Body))
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("notify-proxy: sending POST request to /_templ/reload/events should receive reload sse event", func(t *testing.T) {
// Arrange 1: create a test proxy server.
dummyHandler := func(w http.ResponseWriter, r *http.Request) {}
Expand All @@ -227,7 +291,8 @@ func TestProxy(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error parsing URL: %v", err)
}
handler := New("0.0.0.0", 0, u)
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
handler := New(log, "0.0.0.0", 0, u)
proxyServer := httptest.NewServer(handler)
defer proxyServer.Close()

Expand Down Expand Up @@ -305,4 +370,65 @@ func TestProxy(t *testing.T) {
t.Fatalf("timeout waiting for sse response")
}
})
t.Run("unsupported encodings result in a warning", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(bytes.NewReader([]byte("<p>Data</p>"))),
Header: make(http.Header),
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Encoding", "weird-encoding")

// Act
lh := newTestLogHandler()
log := slog.New(lh)
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// Assert
if len(lh.records) != 1 {
t.Fatalf("expected 1 log entry, but got %d", len(lh.records))
}
record := lh.records[0]
if record.Message != unsupportedContentEncoding {
t.Errorf("expected warning message %q, got %q", unsupportedContentEncoding, record.Message)
}
if record.Level != slog.LevelWarn {
t.Errorf("expected warning, got level %v", record.Level)
}
})
}

func newTestLogHandler() *testLogHandler {
return &testLogHandler{
m: new(sync.Mutex),
records: nil,
}
}

type testLogHandler struct {
m *sync.Mutex
records []slog.Record
}

func (h *testLogHandler) Enabled(context.Context, slog.Level) bool {
return true
}

func (h *testLogHandler) Handle(ctx context.Context, r slog.Record) error {
h.m.Lock()
defer h.m.Unlock()
h.records = append(h.records, r)
return nil
}

func (h *testLogHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return h
}

func (h *testLogHandler) WithGroup(name string) slog.Handler {
return h
}
Loading

0 comments on commit ef58c7a

Please sign in to comment.