diff --git a/esi/choose.go b/esi/choose.go index 572fdb3..6b1005f 100644 --- a/esi/choose.go +++ b/esi/choose.go @@ -44,8 +44,7 @@ func (c *chooseTag) Process(b []byte, req *http.Request) ([]byte, int) { for _, v := range tagIdxs { if validateTest(v[1], req) { res = Parse(v[2], req) - - break + return res, c.length } } diff --git a/esi/include.go b/esi/include.go index 2893e39..309598d 100644 --- a/esi/include.go +++ b/esi/include.go @@ -15,6 +15,18 @@ var ( altAttribute = regexp.MustCompile(`alt="?(.+?)"?( |/>)`) ) +// safe to pass to any origin +var headersSafe = []string{ + "Accept", + "Accept-Language", +} + +// safe to pass only to same-origin (same scheme, same host, same port) +var headersUnsafe = []string{ + "Cookie", + "Authorization", +} + type includeTag struct { *baseTag src string @@ -42,6 +54,15 @@ func sanitizeURL(u string, reqUrl *url.URL) string { return reqUrl.ResolveReference(parsed).String() } +func addHeaders(headers []string, req *http.Request, rq *http.Request) { + for _, h := range headers { + v := req.Header.Get(h) + if v != "" { + rq.Header.Add(h, v) + } + } +} + // Input (e.g. include src="https://domain.com/esi-include" alt="https://domain.com/alt-esi-include" />) // With or without the alt // With or without a space separator before the closing @@ -59,11 +80,20 @@ func (i *includeTag) Process(b []byte, req *http.Request) ([]byte, int) { } rq, _ := http.NewRequest(http.MethodGet, sanitizeURL(i.src, req.URL), nil) + addHeaders(headersSafe, req, rq) + if rq.URL.Scheme == req.URL.Scheme && rq.URL.Host == req.URL.Host { + addHeaders(headersUnsafe, req, rq) + } + client := &http.Client{} response, err := client.Do(rq) - if err != nil || response.StatusCode >= 400 { + if (err != nil || response.StatusCode >= 400) && i.alt != "" { rq, _ = http.NewRequest(http.MethodGet, sanitizeURL(i.alt, req.URL), nil) + addHeaders(headersSafe, req, rq) + if rq.URL.Scheme == req.URL.Scheme && rq.URL.Host == req.URL.Host { + addHeaders(headersUnsafe, req, rq) + } response, err = client.Do(rq) if err != nil || response.StatusCode >= 400 { @@ -71,6 +101,10 @@ func (i *includeTag) Process(b []byte, req *http.Request) ([]byte, int) { } } + if response == nil { + return nil, i.length + } + defer response.Body.Close() x, _ := io.ReadAll(response.Body) b = Parse(x, req) diff --git a/fixtures/full.html b/fixtures/full.html index 1773ee8..7a040c6 100644 --- a/fixtures/full.html +++ b/fixtures/full.html @@ -11,5 +11,18 @@ --> + + + + + + + + +