diff --git a/allocation_test.go b/allocation_test.go index e57c3afd44..45f10e5731 100644 --- a/allocation_test.go +++ b/allocation_test.go @@ -49,7 +49,7 @@ func TestAllocationClient(t *testing.T) { go s.Serve(ln) c := &Client{} - url := "http://" + ln.Addr().String() + url := "http://test:test@" + ln.Addr().String() + "/foo?bar=baz" n := testing.AllocsPerRun(100, func() { req := AcquireRequest() @@ -68,3 +68,17 @@ func TestAllocationClient(t *testing.T) { t.Fatalf("expected 0 allocations, got %f", n) } } + +func TestAllocationURI(t *testing.T) { + uri := []byte("http://username:password@example.com/some/path?foo=bar#test") + + n := testing.AllocsPerRun(100, func() { + u := AcquireURI() + u.Parse(nil, uri) + ReleaseURI(u) + }) + + if n != 0 { + t.Fatalf("expected 0 allocations, got %f", n) + } +} diff --git a/client_test.go b/client_test.go index d602b7d093..8cd3679541 100644 --- a/client_test.go +++ b/client_test.go @@ -19,6 +19,44 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) +func TestClientURLAuth(t *testing.T) { + cases := map[string]string{ + "user:pass@": "dXNlcjpwYXNz", + "foo:@": "Zm9vOg==", + ":@": "", + "@": "", + "": "", + } + + ch := make(chan string, 1) + ln := fasthttputil.NewInmemoryListener() + s := &Server{ + Handler: func(ctx *RequestCtx) { + ch <- string(ctx.Request.Header.Peek(HeaderAuthorization)) + }, + } + go s.Serve(ln) + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + for up, expected := range cases { + req := AcquireRequest() + req.Header.SetMethod(MethodGet) + req.SetRequestURI("http://" + up + "example.com") + if err := c.Do(req, nil); err != nil { + t.Fatal(err) + } + + val := <-ch + + if val != expected { + t.Fatalf("wrong %s header: %s expected %s", HeaderAuthorization, val, expected) + } + } +} + func TestClientNilResp(t *testing.T) { ln := fasthttputil.NewInmemoryListener() s := &Server{ diff --git a/http.go b/http.go index 737673a2d1..99782a852c 100644 --- a/http.go +++ b/http.go @@ -3,6 +3,7 @@ package fasthttp import ( "bufio" "bytes" + "encoding/base64" "errors" "fmt" "io" @@ -1148,6 +1149,24 @@ func (req *Request) Write(w *bufio.Writer) error { } req.Header.SetHostBytes(host) req.Header.SetRequestURIBytes(uri.RequestURI()) + + if len(uri.username) > 0 { + // RequestHeader.SetBytesKV only uses RequestHeader.bufKV.key + // So we are free to use RequestHeader.bufKV.value as a scratch pad for + // the base64 encoding. + nl := len(uri.username) + len(uri.password) + 1 + tl := nl + base64.StdEncoding.EncodedLen(nl) + if tl > cap(req.Header.bufKV.value) { + req.Header.bufKV.value = make([]byte, 0, tl) + } + buf := req.Header.bufKV.value[:0] + buf = append(buf, uri.username...) + buf = append(buf, strColon...) + buf = append(buf, uri.password...) + buf = buf[:tl] + base64.StdEncoding.Encode(buf[nl:], buf[:nl]) + req.Header.SetBytesKV(strAuthorization, buf[nl:]) + } } if req.bodyStream != nil { diff --git a/strings.go b/strings.go index f654f958a0..343544ad18 100644 --- a/strings.go +++ b/strings.go @@ -16,9 +16,11 @@ var ( strHTTP = []byte("http") strHTTPS = []byte("https") strHTTP11 = []byte("HTTP/1.1") + strColon = []byte(":") strColonSlashSlash = []byte("://") strColonSpace = []byte(": ") strGMT = []byte("GMT") + strAt = []byte("@") strResponseContinue = []byte("HTTP/1.1 100 Continue\r\n\r\n") @@ -52,6 +54,7 @@ var ( strAcceptRanges = []byte(HeaderAcceptRanges) strRange = []byte(HeaderRange) strContentRange = []byte(HeaderContentRange) + strAuthorization = []byte(HeaderAuthorization) strCookieExpires = []byte("expires") strCookieDomain = []byte("domain") diff --git a/uri.go b/uri.go index d536f5934b..e9cd4b13a1 100644 --- a/uri.go +++ b/uri.go @@ -51,6 +51,9 @@ type URI struct { fullURI []byte requestURI []byte + username []byte + password []byte + h *RequestHeader } @@ -63,6 +66,8 @@ func (u *URI) CopyTo(dst *URI) { dst.queryString = append(dst.queryString[:0], u.queryString...) dst.hash = append(dst.hash[:0], u.hash...) dst.host = append(dst.host[:0], u.host...) + dst.username = append(dst.username[:0], u.username...) + dst.password = append(dst.password[:0], u.password...) u.queryArgs.CopyTo(&dst.queryArgs) dst.parsedQueryArgs = u.parsedQueryArgs @@ -89,6 +94,36 @@ func (u *URI) SetHashBytes(hash []byte) { u.hash = append(u.hash[:0], hash...) } +// Username returns URI username +func (u *URI) Username() []byte { + return u.username +} + +// SetUsername sets URI username. +func (u *URI) SetUsername(username string) { + u.username = append(u.username[:0], username...) +} + +// SetUsernameBytes sets URI username. +func (u *URI) SetUsernameBytes(username []byte) { + u.username = append(u.username[:0], username...) +} + +// Password returns URI password +func (u *URI) Password() []byte { + return u.password +} + +// SetPassword sets URI password. +func (u *URI) SetPassword(password string) { + u.password = append(u.password[:0], password...) +} + +// SetPasswordBytes sets URI password. +func (u *URI) SetPasswordBytes(password []byte) { + u.password = append(u.password[:0], password...) +} + // QueryString returns URI query string, // i.e. baz=123 of http://aaa.com/foo/bar?baz=123#qwe . // @@ -174,6 +209,8 @@ func (u *URI) Reset() { u.path = u.path[:0] u.queryString = u.queryString[:0] u.hash = u.hash[:0] + u.username = u.username[:0] + u.password = u.password[:0] u.host = u.host[:0] u.queryArgs.Reset() @@ -236,6 +273,20 @@ func (u *URI) parse(host, uri []byte, h *RequestHeader) { scheme, host, uri := splitHostURI(host, uri) u.scheme = append(u.scheme, scheme...) lowercaseBytes(u.scheme) + + if n := bytes.Index(host, strAt); n >= 0 { + auth := host[:n] + host = host[n+1:] + + if n := bytes.Index(auth, strColon); n >= 0 { + u.username = auth[:n] + u.password = auth[n+1:] + } else { + u.username = auth + u.password = auth[:0] // Make sure it's not nil + } + } + u.host = append(u.host, host...) lowercaseBytes(u.host)