Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for user:pass in URLs #614

Merged
merged 1 commit into from
Aug 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion allocation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestAllocationClient(t *testing.T) {
go s.Serve(ln)

c := &Client{}
url := "http://" + ln.Addr().String()
url := "http://test:test@" + ln.Addr().String() + "/foo?bar=baz"

n := testing.AllocsPerRun(100, func() {
req := AcquireRequest()
Expand All @@ -68,3 +68,17 @@ func TestAllocationClient(t *testing.T) {
t.Fatalf("expected 0 allocations, got %f", n)
}
}

func TestAllocationURI(t *testing.T) {
uri := []byte("http://username:password@example.com/some/path?foo=bar#test")

n := testing.AllocsPerRun(100, func() {
u := AcquireURI()
u.Parse(nil, uri)
ReleaseURI(u)
})

if n != 0 {
t.Fatalf("expected 0 allocations, got %f", n)
}
}
38 changes: 38 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,44 @@ import (
"github.com/valyala/fasthttp/fasthttputil"
)

func TestClientURLAuth(t *testing.T) {
cases := map[string]string{
"user:pass@": "dXNlcjpwYXNz",
"foo:@": "Zm9vOg==",
":@": "",
"@": "",
"": "",
}

ch := make(chan string, 1)
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ch <- string(ctx.Request.Header.Peek(HeaderAuthorization))
},
}
go s.Serve(ln)
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
for up, expected := range cases {
req := AcquireRequest()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://" + up + "example.com")
if err := c.Do(req, nil); err != nil {
t.Fatal(err)
}

val := <-ch

if val != expected {
t.Fatalf("wrong %s header: %s expected %s", HeaderAuthorization, val, expected)
}
}
}

func TestClientNilResp(t *testing.T) {
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Expand Down
19 changes: 19 additions & 0 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package fasthttp
import (
"bufio"
"bytes"
"encoding/base64"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -1148,6 +1149,24 @@ func (req *Request) Write(w *bufio.Writer) error {
}
req.Header.SetHostBytes(host)
req.Header.SetRequestURIBytes(uri.RequestURI())

if len(uri.username) > 0 {
// RequestHeader.SetBytesKV only uses RequestHeader.bufKV.key
// So we are free to use RequestHeader.bufKV.value as a scratch pad for
// the base64 encoding.
nl := len(uri.username) + len(uri.password) + 1
tl := nl + base64.StdEncoding.EncodedLen(nl)
if tl > cap(req.Header.bufKV.value) {
req.Header.bufKV.value = make([]byte, 0, tl)
}
buf := req.Header.bufKV.value[:0]
buf = append(buf, uri.username...)
buf = append(buf, strColon...)
buf = append(buf, uri.password...)
buf = buf[:tl]
base64.StdEncoding.Encode(buf[nl:], buf[:nl])
req.Header.SetBytesKV(strAuthorization, buf[nl:])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this is all a bit dirty to use the scratch buffer of the req.Header. But the alternative would be to have another sync.Pool for []byte just for this.

Copy link
Contributor

@dgrr dgrr Jul 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would change these lines:

buf = buf[:tl]
base64.StdEncoding.Encode(buf[nl:], buf[:nl])

into:

buf = base64.StdEncoding.AppendEncode(buf, buf[:nl])

but it wouldn't help avoiding with the allocation.

}
}

if req.bodyStream != nil {
Expand Down
3 changes: 3 additions & 0 deletions strings.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ var (
strHTTP = []byte("http")
strHTTPS = []byte("https")
strHTTP11 = []byte("HTTP/1.1")
strColon = []byte(":")
strColonSlashSlash = []byte("://")
strColonSpace = []byte(": ")
strGMT = []byte("GMT")
strAt = []byte("@")

strResponseContinue = []byte("HTTP/1.1 100 Continue\r\n\r\n")

Expand Down Expand Up @@ -52,6 +54,7 @@ var (
strAcceptRanges = []byte(HeaderAcceptRanges)
strRange = []byte(HeaderRange)
strContentRange = []byte(HeaderContentRange)
strAuthorization = []byte(HeaderAuthorization)

strCookieExpires = []byte("expires")
strCookieDomain = []byte("domain")
Expand Down
51 changes: 51 additions & 0 deletions uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ type URI struct {
fullURI []byte
requestURI []byte

username []byte
password []byte

h *RequestHeader
}

Expand All @@ -63,6 +66,8 @@ func (u *URI) CopyTo(dst *URI) {
dst.queryString = append(dst.queryString[:0], u.queryString...)
dst.hash = append(dst.hash[:0], u.hash...)
dst.host = append(dst.host[:0], u.host...)
dst.username = append(dst.username[:0], u.username...)
dst.password = append(dst.password[:0], u.password...)

u.queryArgs.CopyTo(&dst.queryArgs)
dst.parsedQueryArgs = u.parsedQueryArgs
Expand All @@ -89,6 +94,36 @@ func (u *URI) SetHashBytes(hash []byte) {
u.hash = append(u.hash[:0], hash...)
}

// Username returns URI username
func (u *URI) Username() []byte {
return u.username
}

// SetUsername sets URI username.
func (u *URI) SetUsername(username string) {
u.username = append(u.username[:0], username...)
}

// SetUsernameBytes sets URI username.
func (u *URI) SetUsernameBytes(username []byte) {
u.username = append(u.username[:0], username...)
}

// Password returns URI password
func (u *URI) Password() []byte {
return u.password
}

// SetPassword sets URI password.
func (u *URI) SetPassword(password string) {
u.password = append(u.password[:0], password...)
}

// SetPasswordBytes sets URI password.
func (u *URI) SetPasswordBytes(password []byte) {
u.password = append(u.password[:0], password...)
}

// QueryString returns URI query string,
// i.e. baz=123 of http://aaa.com/foo/bar?baz=123#qwe .
//
Expand Down Expand Up @@ -174,6 +209,8 @@ func (u *URI) Reset() {
u.path = u.path[:0]
u.queryString = u.queryString[:0]
u.hash = u.hash[:0]
u.username = u.username[:0]
u.password = u.password[:0]

u.host = u.host[:0]
u.queryArgs.Reset()
Expand Down Expand Up @@ -236,6 +273,20 @@ func (u *URI) parse(host, uri []byte, h *RequestHeader) {
scheme, host, uri := splitHostURI(host, uri)
u.scheme = append(u.scheme, scheme...)
lowercaseBytes(u.scheme)

if n := bytes.Index(host, strAt); n >= 0 {
auth := host[:n]
host = host[n+1:]

if n := bytes.Index(auth, strColon); n >= 0 {
u.username = auth[:n]
u.password = auth[n+1:]
} else {
u.username = auth
u.password = auth[:0] // Make sure it's not nil
}
}

u.host = append(u.host, host...)
lowercaseBytes(u.host)

Expand Down