Skip to content

Commit

Permalink
Fix xfh handling #59
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Ludwig committed Nov 9, 2020
1 parent 771b30c commit 0a5801c
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 21 deletions.
4 changes: 4 additions & 0 deletions config/runtime/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import (

type Port string

func (p Port) String() string {
return string(p)
}

type Server map[Port]*ServerMux

// ServerMux represents the ServerMux struct.
Expand Down
34 changes: 30 additions & 4 deletions server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net"
"net/http"
"strings"
"time"

"github.com/rs/xid"
Expand All @@ -27,6 +28,7 @@ type HTTPServer struct {
log logrus.FieldLogger
mux *Mux
name string
port string
shutdownCh chan struct{}
srv *http.Server
uidFn func() string
Expand Down Expand Up @@ -81,12 +83,13 @@ func New(cmdCtx context.Context, log logrus.FieldLogger, conf *runtime.HTTPConfi
log: log,
mux: mux,
name: name,
port: p.String(),
shutdownCh: shutdownCh,
uidFn: uidFn,
}

srv := &http.Server{
Addr: ":" + string(p),
Addr: ":" + p.String(),
Handler: httpSrv,
IdleTimeout: conf.Timings.IdleTimeout,
ReadHeaderTimeout: conf.Timings.ReadHeaderTimeout,
Expand Down Expand Up @@ -166,9 +169,32 @@ func (s *HTTPServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
ctx = context.WithValue(ctx, request.ServerName, s.name)
*req = *req.WithContext(ctx)

if s.config.UseXFH {
req.Host = req.Header.Get("X-Forwarded-Host")
}
req.Host = s.getHost(req)

h := s.mux.FindHandler(req)
s.accessLog.ServeHTTP(NewHeaderWriter(rw), req, h)
}

// getHost configures the host from the incoming request host based on
// the xfh setting and listener port to be prepared for the http multiplexer.
func (s *HTTPServer) getHost(req *http.Request) string {
host := req.Host
if s.config.UseXFH {
host = req.Header.Get("X-Forwarded-Host")
}

if !strings.Contains(host, ":") {
return s.cleanHostAppendPort(host)
}

h, _, err := net.SplitHostPort(host)
if err != nil {
return s.cleanHostAppendPort(host)
}

return s.cleanHostAppendPort(h)
}

func (s *HTTPServer) cleanHostAppendPort(host string) string {
return strings.TrimSuffix(host, ".") + ":" + s.port
}
107 changes: 90 additions & 17 deletions server/http_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,30 @@ func newCouper(file string, helper *test.Helper) (func(), *logrustest.Hook) {
return cancelFn, hook
}

func newClient() *http.Client {
dialer := &net.Dialer{}
return &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
_, port, _ := net.SplitHostPort(addr)
if port != "" {
return dialer.DialContext(ctx, "tcp4", "127.0.0.1:"+port)
}
return dialer.DialContext(ctx, "tcp4", "127.0.0.1")
},
},
}
}

func cleanup(shutdown func(), t *testing.T) {
shutdown()

err := os.Chdir(testWorkingDir)
if err != nil {
t.Fatal(err)
}
}

func TestHTTPServer_ServeHTTP(t *testing.T) {
type testRequest struct {
method, url string
Expand All @@ -111,18 +135,7 @@ func TestHTTPServer_ServeHTTP(t *testing.T) {
requests []requestCase
}

dialer := &net.Dialer{}
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
_, port, _ := net.SplitHostPort(addr)
if port != "" {
return dialer.DialContext(ctx, "tcp4", "127.0.0.1:"+port)
}
return dialer.DialContext(ctx, "tcp4", "127.0.0.1")
},
},
}
client := newClient()

for i, testcase := range []testCase{
{"spa/01_couper.hcl", []requestCase{
Expand Down Expand Up @@ -277,10 +290,70 @@ func TestHTTPServer_ServeHTTP(t *testing.T) {
}
})
}
shutdown()
err := os.Chdir(testWorkingDir)
if err != nil {
t.Fatal(err)
}

cleanup(shutdown, t)
}
}

func TestHTTPServer_HostHeader(t *testing.T) {
client := newClient()

confPath := path.Join("testdata/integration", "files/02_couper.hcl")
shutdown, logHook := newCouper(confPath, test.New(t))

t.Run("Test", func(subT *testing.T) {
helper := test.New(subT)
logHook.Reset()

req, err := http.NewRequest(http.MethodGet, "http://example.com:9898/b", nil)
helper.Must(err)

req.Host = "example.com."
res, err := client.Do(req)
helper.Must(err)

resBytes, err := ioutil.ReadAll(res.Body)
helper.Must(err)

_ = res.Body.Close()

if `<html lang="en">index B</html>` != string(resBytes) {
t.Errorf("%s", resBytes)
}
})

cleanup(shutdown, t)
}

func TestHTTPServer_XFHHeader(t *testing.T) {
client := newClient()

os.Setenv("COUPER_XFH", "true")
confPath := path.Join("testdata/integration", "files/02_couper.hcl")
shutdown, logHook := newCouper(confPath, test.New(t))
os.Setenv("COUPER_XFH", "")

t.Run("Test", func(subT *testing.T) {
helper := test.New(subT)
logHook.Reset()

req, err := http.NewRequest(http.MethodGet, "http://example.com:9898/b", nil)
helper.Must(err)

req.Host = "example.com"
req.Header.Set("X-Forwarded-Host", "example.com.")
res, err := client.Do(req)
helper.Must(err)

resBytes, err := ioutil.ReadAll(res.Body)
helper.Must(err)

_ = res.Body.Close()

if `<html lang="en">index B</html>` != string(resBytes) {
t.Errorf("%s", resBytes)
}
})

cleanup(shutdown, t)
}

0 comments on commit 0a5801c

Please sign in to comment.