Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Race http fallback ping #1521

Merged
merged 1 commit into from
Dec 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions pkg/v1/remote/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,7 @@ const (
var DefaultTransport http.RoundTripper = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
// By default we wrap the transport in retries, so reduce the
// default dial timeout to 5s to avoid 5x 30s of connection
// timeouts when doing the "ping" on certain http registries.
Timeout: 5 * time.Second,
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
Expand Down
172 changes: 110 additions & 62 deletions pkg/v1/remote/transport/ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import (
"io"
"net/http"
"strings"
"time"

authchallenge "github.com/docker/distribution/registry/client/auth/challenge"
"github.com/google/go-containerregistry/pkg/logs"
"github.com/google/go-containerregistry/pkg/name"
)

Expand All @@ -34,6 +36,9 @@ const (
bearer challenge = "bearer"
)

// 300ms is the default fallback period for go's DNS dialer but we could make this configurable.
var fallbackDelay = 300 * time.Millisecond

type pingResp struct {
challenge challenge

Expand All @@ -49,82 +54,125 @@ func (c challenge) Canonical() challenge {
return challenge(strings.ToLower(string(c)))
}

func parseChallenge(suffix string) map[string]string {
kv := make(map[string]string)
for _, token := range strings.Split(suffix, ",") {
// Trim any whitespace around each token.
token = strings.Trim(token, " ")

// Break the token into a key/value pair
if parts := strings.SplitN(token, "=", 2); len(parts) == 2 {
// Unquote the value, if it is quoted.
kv[parts[0]] = strings.Trim(parts[1], `"`)
} else {
// If there was only one part, treat is as a key with an empty value
kv[token] = ""
}
}
return kv
}

func ping(ctx context.Context, reg name.Registry, t http.RoundTripper) (*pingResp, error) {
client := http.Client{Transport: t}

// This first attempts to use "https" for every request, falling back to http
// if the registry matches our localhost heuristic or if it is intentionally
// set to insecure via name.NewInsecureRegistry.
schemes := []string{"https"}
if reg.Scheme() == "http" {
schemes = append(schemes, "http")
}
if len(schemes) == 1 {
return pingSingle(ctx, reg, t, schemes[0])
}
return pingParallel(ctx, reg, t, schemes)
}

var errs []error
for _, scheme := range schemes {
url := fmt.Sprintf("%s://%s/v2/", scheme, reg.Name())
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
errs = append(errs, err)
// Potentially retry with http.
continue
}
defer func() {
// By draining the body, make sure to reuse the connection made by
// the ping for the following access to the registry
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}()

switch resp.StatusCode {
case http.StatusOK:
// If we get a 200, then no authentication is needed.
func pingSingle(ctx context.Context, reg name.Registry, t http.RoundTripper, scheme string) (*pingResp, error) {
client := http.Client{Transport: t}
url := fmt.Sprintf("%s://%s/v2/", scheme, reg.Name())
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
defer func() {
// By draining the body, make sure to reuse the connection made by
// the ping for the following access to the registry
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}()

switch resp.StatusCode {
case http.StatusOK:
// If we get a 200, then no authentication is needed.
return &pingResp{
challenge: anonymous,
scheme: scheme,
}, nil
case http.StatusUnauthorized:
if challenges := authchallenge.ResponseChallenges(resp); len(challenges) != 0 {
// If we hit more than one, let's try to find one that we know how to handle.
wac := pickFromMultipleChallenges(challenges)
return &pingResp{
challenge: anonymous,
scheme: scheme,
challenge: challenge(wac.Scheme).Canonical(),
parameters: wac.Parameters,
scheme: scheme,
}, nil
case http.StatusUnauthorized:
if challenges := authchallenge.ResponseChallenges(resp); len(challenges) != 0 {
// If we hit more than one, let's try to find one that we know how to handle.
wac := pickFromMultipleChallenges(challenges)
return &pingResp{
challenge: challenge(wac.Scheme).Canonical(),
parameters: wac.Parameters,
scheme: scheme,
}, nil
}
// Otherwise, just return the challenge without parameters.
return &pingResp{
challenge: challenge(resp.Header.Get("WWW-Authenticate")).Canonical(),
scheme: scheme,
}, nil
default:
return nil, CheckError(resp, http.StatusOK, http.StatusUnauthorized)
}
}

// Based on the golang happy eyeballs dialParallel impl in net/dial.go.
func pingParallel(ctx context.Context, reg name.Registry, t http.RoundTripper, schemes []string) (*pingResp, error) {
returned := make(chan struct{})
defer close(returned)

type pingResult struct {
*pingResp
error
primary bool
done bool
}

results := make(chan pingResult)

startRacer := func(ctx context.Context, scheme string) {
pr, err := pingSingle(ctx, reg, t, scheme)
select {
case results <- pingResult{pingResp: pr, error: err, primary: scheme == "https", done: true}:
case <-returned:
if pr != nil {
logs.Debug.Printf("%s lost race", scheme)
}
}
}

var primary, fallback pingResult

primaryCtx, primaryCancel := context.WithCancel(ctx)
defer primaryCancel()
go startRacer(primaryCtx, schemes[0])

fallbackTimer := time.NewTimer(fallbackDelay)
defer fallbackTimer.Stop()

for {
select {
case <-fallbackTimer.C:
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
defer fallbackCancel()
go startRacer(fallbackCtx, schemes[1])

case res := <-results:
if res.error == nil {
return res.pingResp, nil
}
if res.primary {
primary = res
} else {
fallback = res
}
if primary.done && fallback.done {
return nil, multierrs([]error{primary.error, fallback.error})
}
if res.primary && fallbackTimer.Stop() {
// Primary failed and we haven't started the fallback,
// reset time to start fallback immediately.
fallbackTimer.Reset(0)
}
// Otherwise, just return the challenge without parameters.
return &pingResp{
challenge: challenge(resp.Header.Get("WWW-Authenticate")).Canonical(),
scheme: scheme,
}, nil
default:
return nil, CheckError(resp, http.StatusOK, http.StatusUnauthorized)
}
}
return nil, multierrs(errs)
}

func pickFromMultipleChallenges(challenges []authchallenge.Challenge) authchallenge.Challenge {
Expand Down
59 changes: 12 additions & 47 deletions pkg/v1/remote/transport/ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,59 +20,17 @@ import (
"net/http/httptest"
"net/url"
"strings"
"sync/atomic"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-containerregistry/pkg/name"
)

var (
testRegistry, _ = name.NewRegistry("localhost:8080", name.StrictValidation)
)

func TestChallengeParsing(t *testing.T) {
tests := []struct {
input string
output map[string]string
}{{
input: `foo="bar"`,
output: map[string]string{
"foo": "bar",
},
}, {
input: `foo`,
output: map[string]string{
"foo": "",
},
}, {
input: `foo="bar",baz="blah"`,
output: map[string]string{
"foo": "bar",
"baz": "blah",
},
}, {
input: `baz="blah", foo="bar"`,
output: map[string]string{
"foo": "bar",
"baz": "blah",
},
}, {
input: `realm="https://gcr.io/v2/token", service="gcr.io", scope="repository:foo/bar:pull"`,
output: map[string]string{
"realm": "https://gcr.io/v2/token",
"service": "gcr.io",
"scope": "repository:foo/bar:pull",
},
}}

for _, test := range tests {
params := parseChallenge(test.input)
if diff := cmp.Diff(test.output, params); diff != "" {
t.Errorf("parseChallenge(%s); (-want +got) %s", test.input, diff)
}
}
}

func TestPingNoChallenge(t *testing.T) {
server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -218,7 +176,7 @@ func TestUnsupportedStatus(t *testing.T) {
func TestPingHttpFallback(t *testing.T) {
tests := []struct {
reg name.Registry
wantCount int
wantCount int64
err string
contains []string
}{{
Expand All @@ -234,10 +192,15 @@ func TestPingHttpFallback(t *testing.T) {
contains: []string{"https://us.gcr.io/v2/", "http://us.gcr.io/v2/"},
}}

gotCount := 0
gotCount := int64(0)
server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotCount++
atomic.AddInt64(&gotCount, 1)
if r.URL.Scheme != "http" {
// Sleep a little bit so we can exercise the
// happy eyeballs race.
time.Sleep(5 * time.Millisecond)
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
Expand All @@ -248,6 +211,8 @@ func TestPingHttpFallback(t *testing.T) {
},
}

fallbackDelay = 2 * time.Millisecond

for _, test := range tests {
// This is the last one, fatal error it.
if strings.Contains(test.reg.String(), "us.gcr.io") {
Expand Down