From 958ed361941dca1fcb353fae6d8ef12e0d794eff Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Sun, 29 Dec 2019 11:34:55 +0100 Subject: [PATCH] Allow no response to be send when a connection is hijacked (#712) * Allow no response to be send when a connection is hijacked At the moment there is always a HTTP response before the connection gets hijacked. This second option to Hijack() prevents this response from being send. Fixes: https://github.com/valyala/fasthttp/issues/698 * Add HijackSetNoResponse method instead --- server.go | 66 +++++++++++++++++++++++++++++++------------------- server_test.go | 45 ++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 25 deletions(-) diff --git a/server.go b/server.go index 1eb5f50535..5703578b15 100644 --- a/server.go +++ b/server.go @@ -508,7 +508,8 @@ type RequestCtx struct { timeoutCh chan struct{} timeoutTimer *time.Timer - hijackHandler HijackHandler + hijackHandler HijackHandler + hijackNoResponse bool } // HijackHandler must process the hijacked connection c. @@ -535,6 +536,7 @@ type HijackHandler func(c net.Conn) // * Unexpected error during response writing to the connection. // // The server stops processing requests from hijacked connections. +// // Server limits such as Concurrency, ReadTimeout, WriteTimeout, etc. // aren't applied to hijacked connections. // @@ -550,6 +552,15 @@ func (ctx *RequestCtx) Hijack(handler HijackHandler) { ctx.hijackHandler = handler } +// HijackSetNoResponse changes the behavior of hijacking a request. +// If HijackSetNoResponse is called with false fasthttp will send a response +// to the client before calling the HijackHandler (default). If HijackSetNoResponse +// is called with true no response is send back before calling the +// HijackHandler supplied in the Hijack function. +func (ctx *RequestCtx) HijackSetNoResponse(noResponse bool) { + ctx.hijackNoResponse = noResponse +} + // Hijacked returns true after Hijack is called. func (ctx *RequestCtx) Hijacked() bool { return ctx.hijackHandler != nil @@ -1869,9 +1880,10 @@ func (s *Server) serveConn(c net.Conn) error { br *bufio.Reader bw *bufio.Writer - err error - timeoutResponse *Response - hijackHandler HijackHandler + err error + timeoutResponse *Response + hijackHandler HijackHandler + hijackNoResponse bool connectionClose bool isHTTP11 bool @@ -2044,6 +2056,8 @@ func (s *Server) serveConn(c net.Conn) error { hijackHandler = ctx.hijackHandler ctx.hijackHandler = nil + hijackNoResponse = ctx.hijackNoResponse + ctx.hijackNoResponse = false ctx.userValues.Reset() @@ -2071,30 +2085,32 @@ func (s *Server) serveConn(c net.Conn) error { ctx.Response.Header.SetServerBytes(serverName) } - if bw == nil { - bw = acquireWriter(ctx) - } - if err = writeResponse(ctx, bw); err != nil { - break - } + if !hijackNoResponse { + if bw == nil { + bw = acquireWriter(ctx) + } + if err = writeResponse(ctx, bw); err != nil { + break + } - // Only flush the writer if we don't have another request in the pipeline. - // This is a big of an ugly optimization for https://www.techempower.com/benchmarks/ - // This benchmark will send 16 pipelined requests. It is faster to pack as many responses - // in a TCP packet and send it back at once than waiting for a flush every request. - // In real world circumstances this behaviour could be argued as being wrong. - if br == nil || br.Buffered() == 0 || connectionClose { - err = bw.Flush() - if err != nil { + // Only flush the writer if we don't have another request in the pipeline. + // This is a big of an ugly optimization for https://www.techempower.com/benchmarks/ + // This benchmark will send 16 pipelined requests. It is faster to pack as many responses + // in a TCP packet and send it back at once than waiting for a flush every request. + // In real world circumstances this behaviour could be argued as being wrong. + if br == nil || br.Buffered() == 0 || connectionClose { + err = bw.Flush() + if err != nil { + break + } + } + if connectionClose { break } - } - if connectionClose { - break - } - if s.ReduceMemoryUsage { - releaseWriter(s, bw) - bw = nil + if s.ReduceMemoryUsage && hijackHandler == nil { + releaseWriter(s, bw) + bw = nil + } } if hijackHandler != nil { diff --git a/server_test.go b/server_test.go index 0a4337a361..9821b948c3 100644 --- a/server_test.go +++ b/server_test.go @@ -2098,6 +2098,51 @@ func TestRequestCtxHijack(t *testing.T) { } } +func TestRequestCtxHijackNoResponse(t *testing.T) { + t.Parallel() + + hijackDone := make(chan error) + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.Hijack(func(c net.Conn) { + _, err := c.Write([]byte("test")) + hijackDone <- err + }) + ctx.HijackSetNoResponse(true) + }, + } + + rw := &readWriter{} + rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 0\r\n\r\n") + + ch := make(chan error) + go func() { + ch <- s.ServeConn(rw) + }() + + select { + case err := <-ch: + if err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout") + } + + select { + case err := <-hijackDone: + if err != nil { + t.Fatalf("Unexpected error from hijack: %s", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout") + } + + if got := rw.w.String(); got != "test" { + t.Errorf(`expected "test", got %q`, got) + } +} + func TestRequestCtxInit(t *testing.T) { var ctx RequestCtx var logger testLogger