diff --git a/args.go b/args.go index 86dc7b9bf7..47a97c6ac2 100644 --- a/args.go +++ b/args.go @@ -623,3 +623,13 @@ func decodeArgAppendNoPlus(dst, src []byte) []byte { } return dst } + +func peekAllArgBytesToDst(dst [][]byte, h []argsKV, k []byte) [][]byte { + for i, n := 0, len(h); i < n; i++ { + kv := &h[i] + if bytes.Equal(kv.key, k) { + dst = append(dst, kv.value) + } + } + return dst +} diff --git a/bytesconv_table_gen.go b/bytesconv_table_gen.go index e92136ec72..abf69d4ab3 100644 --- a/bytesconv_table_gen.go +++ b/bytesconv_table_gen.go @@ -7,6 +7,7 @@ import ( "bytes" "fmt" "log" + "os" ) const ( diff --git a/fasthttpadaptor/request_test.go b/fasthttpadaptor/request_test.go index 3b6ba54ca7..1f214c2d06 100644 --- a/fasthttpadaptor/request_test.go +++ b/fasthttpadaptor/request_test.go @@ -1,9 +1,10 @@ package fasthttpadaptor import ( - "github.com/valyala/fasthttp" "net/http" "testing" + + "github.com/valyala/fasthttp" ) func BenchmarkConvertRequest(b *testing.B) { diff --git a/header.go b/header.go index 4c3c0eafa8..83f0d404e9 100644 --- a/header.go +++ b/header.go @@ -42,6 +42,7 @@ type ResponseHeader struct { contentType []byte contentEncoding []byte server []byte + mulHeader [][]byte h []argsKV trailer []argsKV @@ -79,6 +80,7 @@ type RequestHeader struct { host []byte contentType []byte userAgent []byte + mulHeader [][]byte h []argsKV trailer []argsKV @@ -974,6 +976,7 @@ func (h *ResponseHeader) resetSkipNormalize() { h.h = h.h[:0] h.cookies = h.cookies[:0] h.trailer = h.trailer[:0] + h.mulHeader = h.mulHeader[:0] } // SetNoDefaultContentType allows you to control if a default Content-Type header will be set (false) or not (true). @@ -1002,6 +1005,7 @@ func (h *RequestHeader) resetSkipNormalize() { h.contentType = h.contentType[:0] h.userAgent = h.userAgent[:0] h.trailer = h.trailer[:0] + h.mulHeader = h.mulHeader[:0] h.h = h.h[:0] h.cookies = h.cookies[:0] @@ -1793,6 +1797,85 @@ func (h *RequestHeader) peek(key []byte) []byte { } } +// PeekAll returns all header value for the given key. +// +// The returned value is valid until the request is released, +// either though ReleaseRequest or your request handler returning. +// Do not store references to returned value. Make copies instead. +func (h *RequestHeader) PeekAll(key string) [][]byte { + k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) + return h.peekAll(k) +} + +func (h *RequestHeader) peekAll(key []byte) [][]byte { + h.mulHeader = h.mulHeader[:0] + switch string(key) { + case HeaderHost: + h.mulHeader = append(h.mulHeader, h.Host()) + case HeaderContentType: + h.mulHeader = append(h.mulHeader, h.ContentType()) + case HeaderUserAgent: + h.mulHeader = append(h.mulHeader, h.UserAgent()) + case HeaderConnection: + if h.ConnectionClose() { + h.mulHeader = append(h.mulHeader, strClose) + } else { + h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key) + } + case HeaderContentLength: + h.mulHeader = append(h.mulHeader, h.contentLengthBytes) + case HeaderCookie: + if h.cookiesCollected { + h.mulHeader = append(h.mulHeader, appendRequestCookieBytes(nil, h.cookies)) + return [][]byte{appendRequestCookieBytes(nil, h.cookies)} + } else { + h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key) + } + case HeaderTrailer: + h.mulHeader = append(h.mulHeader, appendArgsKeyBytes(nil, h.trailer, strCommaSpace)) + default: + h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key) + } + return h.mulHeader +} + +// PeekAll returns all header value for the given key. +// +// The returned value is valid until the request is released, +// either though ReleaseRequest or your request handler returning. +// Do not store references to returned value. Make copies instead. +func (h *ResponseHeader) PeekAll(key string) [][]byte { + k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) + return h.peekAll(k) +} + +func (h *ResponseHeader) peekAll(key []byte) [][]byte { + h.mulHeader = h.mulHeader[:0] + switch string(key) { + case HeaderContentType: + h.mulHeader = append(h.mulHeader, h.ContentType()) + case HeaderContentEncoding: + h.mulHeader = append(h.mulHeader, h.ContentEncoding()) + case HeaderServer: + h.mulHeader = append(h.mulHeader, h.Server()) + case HeaderConnection: + if h.ConnectionClose() { + h.mulHeader = append(h.mulHeader, strClose) + } else { + h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key) + } + case HeaderContentLength: + h.mulHeader = append(h.mulHeader, h.contentLengthBytes) + case HeaderSetCookie: + h.mulHeader = append(h.mulHeader, appendResponseCookieBytes(nil, h.cookies)) + case HeaderTrailer: + h.mulHeader = append(h.mulHeader, appendArgsKeyBytes(nil, h.trailer, strCommaSpace)) + default: + h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key) + } + return h.mulHeader +} + // Cookie returns cookie for the given key. func (h *RequestHeader) Cookie(key string) []byte { h.collectCookies() diff --git a/header_test.go b/header_test.go index 3ea7706007..643e8f3956 100644 --- a/header_test.go +++ b/header_test.go @@ -2861,3 +2861,65 @@ func verifyTrailer(t *testing.T, r *bufio.Reader, expectedTrailers map[string]st } verifyResponseTrailer(t, &resp.Header, expectedTrailers) } + +func TestRequestHeader_PeekAll(t *testing.T) { + t.Parallel() + h := &RequestHeader{} + h.Add(HeaderConnection, "keep-alive") + h.Add("Content-Type", "aaa") + h.Add(HeaderHost, "aaabbb") + h.Add("User-Agent", "asdfas") + h.Add("Content-Length", "1123") + h.Add("Cookie", "foobar=baz") + h.Add(HeaderTrailer, "foo, bar") + h.Add("aaa", "aaa") + h.Add("aaa", "bbb") + + expectRequestHeaderAll(t, h, HeaderConnection, [][]byte{s2b("keep-alive")}) + expectRequestHeaderAll(t, h, "Content-Type", [][]byte{s2b("aaa")}) + expectRequestHeaderAll(t, h, HeaderHost, [][]byte{s2b("aaabbb")}) + expectRequestHeaderAll(t, h, "User-Agent", [][]byte{s2b("asdfas")}) + expectRequestHeaderAll(t, h, "Content-Length", [][]byte{s2b("1123")}) + expectRequestHeaderAll(t, h, "Cookie", [][]byte{s2b("foobar=baz")}) + expectRequestHeaderAll(t, h, HeaderTrailer, [][]byte{s2b("Foo, Bar")}) + expectRequestHeaderAll(t, h, "aaa", [][]byte{s2b("aaa"), s2b("bbb")}) +} +func expectRequestHeaderAll(t *testing.T, h *RequestHeader, key string, expectedValue [][]byte) { + if len(h.PeekAll(key)) != len(expectedValue) { + t.Fatalf("Unexpected size for key %q: %d. Expected %d", key, len(h.PeekAll(key)), len(expectedValue)) + } + if !reflect.DeepEqual(h.PeekAll(key), expectedValue) { + t.Fatalf("Unexpected value for key %q: %q. Expected %q", key, h.PeekAll(key), expectedValue) + } +} + +func TestResponseHeader_PeekAll(t *testing.T) { + t.Parallel() + + h := &ResponseHeader{} + h.Add(HeaderContentType, "aaa/bbb") + h.Add(HeaderContentEncoding, "gzip") + h.Add(HeaderConnection, "close") + h.Add(HeaderContentLength, "1234") + h.Add(HeaderServer, "aaaa") + h.Add(HeaderSetCookie, "cccc") + h.Add("aaa", "aaa") + h.Add("aaa", "bbb") + + expectResponseHeaderAll(t, h, HeaderContentType, [][]byte{s2b("aaa/bbb")}) + expectResponseHeaderAll(t, h, HeaderContentEncoding, [][]byte{s2b("gzip")}) + expectResponseHeaderAll(t, h, HeaderConnection, [][]byte{s2b("close")}) + expectResponseHeaderAll(t, h, HeaderContentLength, [][]byte{s2b("1234")}) + expectResponseHeaderAll(t, h, HeaderServer, [][]byte{s2b("aaaa")}) + expectResponseHeaderAll(t, h, HeaderSetCookie, [][]byte{s2b("cccc")}) + expectResponseHeaderAll(t, h, "aaa", [][]byte{s2b("aaa"), s2b("bbb")}) +} + +func expectResponseHeaderAll(t *testing.T, h *ResponseHeader, key string, expectedValue [][]byte) { + if len(h.PeekAll(key)) != len(expectedValue) { + t.Fatalf("Unexpected size for key %q: %d. Expected %d", key, len(h.PeekAll(key)), len(expectedValue)) + } + if !reflect.DeepEqual(h.PeekAll(key), expectedValue) { + t.Fatalf("Unexpected value for key %q: %q. Expected %q", key, h.PeekAll(key), expectedValue) + } +}