diff --git a/esi/choose.go b/esi/choose.go index 572fdb3..6b1005f 100644 --- a/esi/choose.go +++ b/esi/choose.go @@ -44,8 +44,7 @@ func (c *chooseTag) Process(b []byte, req *http.Request) ([]byte, int) { for _, v := range tagIdxs { if validateTest(v[1], req) { res = Parse(v[2], req) - - break + return res, c.length } } diff --git a/esi/include.go b/esi/include.go index 2893e39..309598d 100644 --- a/esi/include.go +++ b/esi/include.go @@ -15,6 +15,18 @@ var ( altAttribute = regexp.MustCompile(`alt="?(.+?)"?( |/>)`) ) +// safe to pass to any origin +var headersSafe = []string{ + "Accept", + "Accept-Language", +} + +// safe to pass only to same-origin (same scheme, same host, same port) +var headersUnsafe = []string{ + "Cookie", + "Authorization", +} + type includeTag struct { *baseTag src string @@ -42,6 +54,15 @@ func sanitizeURL(u string, reqUrl *url.URL) string { return reqUrl.ResolveReference(parsed).String() } +func addHeaders(headers []string, req *http.Request, rq *http.Request) { + for _, h := range headers { + v := req.Header.Get(h) + if v != "" { + rq.Header.Add(h, v) + } + } +} + // Input (e.g. include src="https://domain.com/esi-include" alt="https://domain.com/alt-esi-include" />) // With or without the alt // With or without a space separator before the closing @@ -59,11 +80,20 @@ func (i *includeTag) Process(b []byte, req *http.Request) ([]byte, int) { } rq, _ := http.NewRequest(http.MethodGet, sanitizeURL(i.src, req.URL), nil) + addHeaders(headersSafe, req, rq) + if rq.URL.Scheme == req.URL.Scheme && rq.URL.Host == req.URL.Host { + addHeaders(headersUnsafe, req, rq) + } + client := &http.Client{} response, err := client.Do(rq) - if err != nil || response.StatusCode >= 400 { + if (err != nil || response.StatusCode >= 400) && i.alt != "" { rq, _ = http.NewRequest(http.MethodGet, sanitizeURL(i.alt, req.URL), nil) + addHeaders(headersSafe, req, rq) + if rq.URL.Scheme == req.URL.Scheme && rq.URL.Host == req.URL.Host { + addHeaders(headersUnsafe, req, rq) + } response, err = client.Do(rq) if err != nil || response.StatusCode >= 400 { @@ -71,6 +101,10 @@ func (i *includeTag) Process(b []byte, req *http.Request) ([]byte, int) { } } + if response == nil { + return nil, i.length + } + defer response.Body.Close() x, _ := io.ReadAll(response.Body) b = Parse(x, req) diff --git a/fixtures/full.html b/fixtures/full.html index 1773ee8..7a040c6 100644 --- a/fixtures/full.html +++ b/fixtures/full.html @@ -11,5 +11,18 @@ --> + + + + + + + + +
+ +
+ + diff --git a/middleware/caddy/esi.go b/middleware/caddy/esi.go index f0831a8..c643d8f 100644 --- a/middleware/caddy/esi.go +++ b/middleware/caddy/esi.go @@ -46,18 +46,14 @@ func (e *ESI) ServeHTTP(rw http.ResponseWriter, r *http.Request, next caddyhttp. defer bufPool.Put(buf) cw := writer.NewWriter(buf, rw, r) go func(w *writer.Writer) { - w.Header().Del("Content-Length") - if w.Rq.ProtoMajor == 1 { - w.Header().Set("Content-Encoding", "chunked") - } var i = 0 for { - if len(cw.AsyncBuf) <= i { + if len(w.AsyncBuf) <= i { continue } - rs := <-cw.AsyncBuf[i] + rs := <-w.AsyncBuf[i] if rs == nil { - cw.Done <- true + w.Done <- true break } _, _ = rw.Write(rs) @@ -65,6 +61,10 @@ func (e *ESI) ServeHTTP(rw http.ResponseWriter, r *http.Request, next caddyhttp. } }(cw) next.ServeHTTP(cw, r) + cw.Header().Del("Content-Length") + if cw.Rq.ProtoMajor == 1 { + cw.Header().Set("Content-Encoding", "chunked") + } cw.AsyncBuf = append(cw.AsyncBuf, make(chan []byte)) go func(w *writer.Writer, iteration int) { w.AsyncBuf[iteration] <- nil diff --git a/middleware/caddy/esi_test.go b/middleware/caddy/esi_test.go index 95c9589..8a0ffb5 100644 --- a/middleware/caddy/esi_test.go +++ b/middleware/caddy/esi_test.go @@ -19,6 +19,11 @@ const expectedOutput = `

CHAINED 2

ALTERNATE ESI INCLUDE

+ +
+ +
+ `