-
Notifications
You must be signed in to change notification settings - Fork 669
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: s3/transfermanager (v2): round-robin DNS and multi-NIC (#2975)
- Loading branch information
1 parent
4af4827
commit 21af651
Showing
3 changed files
with
387 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
package transfermanager | ||
|
||
import ( | ||
"sync" | ||
"time" | ||
|
||
"github.com/aws/smithy-go/container/private/cache" | ||
"github.com/aws/smithy-go/container/private/cache/lru" | ||
) | ||
|
||
// dnsCache implements an LRU cache of DNS query results by host. | ||
// | ||
// Cache retrievals will automatically rotate between IP addresses for | ||
// multi-value query results. | ||
type dnsCache struct { | ||
mu sync.Mutex | ||
addrs cache.Cache | ||
} | ||
|
||
// newDNSCache returns an initialized dnsCache with given capacity. | ||
func newDNSCache(cap int) *dnsCache { | ||
return &dnsCache{ | ||
addrs: lru.New(cap), | ||
} | ||
} | ||
|
||
// GetAddr returns the next IP address for the given host if present in the | ||
// cache. | ||
func (c *dnsCache) GetAddr(host string) (string, bool) { | ||
c.mu.Lock() | ||
defer c.mu.Unlock() | ||
|
||
v, ok := c.addrs.Get(host) | ||
if !ok { | ||
return "", false | ||
} | ||
|
||
record := v.(*dnsCacheEntry) | ||
if timeNow().After(record.expires) { | ||
return "", false | ||
} | ||
|
||
addr := record.addrs[record.index] | ||
record.index = (record.index + 1) % len(record.addrs) | ||
return addr, true | ||
} | ||
|
||
// PutAddrs stores a DNS query result in the cache, overwriting any present | ||
// entry for the host if it exists. | ||
func (c *dnsCache) PutAddrs(host string, addrs []string, expires time.Time) { | ||
c.mu.Lock() | ||
defer c.mu.Unlock() | ||
|
||
c.addrs.Put(host, &dnsCacheEntry{addrs, expires, 0}) | ||
} | ||
|
||
type dnsCacheEntry struct { | ||
addrs []string | ||
expires time.Time | ||
index int | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
package transfermanager | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net" | ||
"net/http" | ||
"sync" | ||
"time" | ||
|
||
"github.com/aws/aws-sdk-go-v2/internal/sync/singleflight" | ||
) | ||
|
||
var timeNow = time.Now | ||
|
||
// WithRoundRobinDNS configures an http.Transport to spread HTTP connections | ||
// across multiple IP addresses for a given host. | ||
// | ||
// This is recommended by the [S3 performance guide] in high-concurrency | ||
// application environments. | ||
// | ||
// WithRoundRobinDNS wraps the underlying DialContext hook on http.Transport. | ||
// Future modifications to this hook MUST preserve said wrapping in order for | ||
// round-robin DNS to operate. | ||
// | ||
// [S3 performance guide]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance-design-patterns.html | ||
func WithRoundRobinDNS(opts ...func(*RoundRobinDNSOptions)) func(*http.Transport) { | ||
options := &RoundRobinDNSOptions{ | ||
TTL: 30 * time.Second, | ||
MaxHosts: 100, | ||
} | ||
for _, opt := range opts { | ||
opt(options) | ||
} | ||
|
||
return func(t *http.Transport) { | ||
rr := &rrDNS{ | ||
cache: newDNSCache(options.MaxHosts), | ||
expiry: options.TTL, | ||
resolver: &net.Resolver{}, | ||
dialContext: t.DialContext, | ||
} | ||
t.DialContext = rr.DialContext | ||
} | ||
} | ||
|
||
// RoundRobinDNSOptions configures use of round-robin DNS. | ||
type RoundRobinDNSOptions struct { | ||
// The length of time for which the results of a DNS query are valid. | ||
TTL time.Duration | ||
|
||
// A limit to the number of DNS query results, cached by hostname, which are | ||
// stored. Round-robin DNS uses an LRU cache. | ||
MaxHosts int | ||
} | ||
|
||
type resolver interface { | ||
LookupHost(context.Context, string) ([]string, error) | ||
} | ||
|
||
type rrDNS struct { | ||
sf singleflight.Group | ||
cache *dnsCache | ||
|
||
expiry time.Duration | ||
resolver resolver | ||
|
||
dialContext func(ctx context.Context, network, addr string) (net.Conn, error) | ||
} | ||
|
||
// DialContext implements the DialContext hook used by http.Transport, | ||
// pre-caching IP addresses for a given host and distributing them evenly | ||
// across new connections. | ||
func (r *rrDNS) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { | ||
host, port, err := net.SplitHostPort(addr) | ||
if err != nil { | ||
return nil, fmt.Errorf("rrdns split host/port: %w", err) | ||
} | ||
|
||
ipaddr, err := r.getAddr(ctx, host) | ||
if err != nil { | ||
return nil, fmt.Errorf("rrdns lookup host: %w", err) | ||
} | ||
|
||
return r.dialContext(ctx, network, net.JoinHostPort(ipaddr, port)) | ||
} | ||
|
||
func (r *rrDNS) getAddr(ctx context.Context, host string) (string, error) { | ||
addr, ok := r.cache.GetAddr(host) | ||
if ok { | ||
return addr, nil | ||
} | ||
return r.lookupHost(ctx, host) | ||
} | ||
|
||
func (r *rrDNS) lookupHost(ctx context.Context, host string) (string, error) { | ||
ch := r.sf.DoChan(host, func() (interface{}, error) { | ||
addrs, err := r.resolver.LookupHost(ctx, host) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
expires := timeNow().Add(r.expiry) | ||
r.cache.PutAddrs(host, addrs, expires) | ||
return nil, nil | ||
}) | ||
|
||
select { | ||
case result := <-ch: | ||
if result.Err != nil { | ||
return "", result.Err | ||
} | ||
|
||
addr, _ := r.cache.GetAddr(host) | ||
return addr, nil | ||
case <-ctx.Done(): | ||
return "", ctx.Err() | ||
} | ||
} | ||
|
||
// WithRotoDialer configures an http.Transport to cycle through multiple local | ||
// network addresses when creating new HTTP connections. | ||
// | ||
// WithRotoDialer REPLACES the root DialContext hook on the underlying | ||
// Transport, thereby destroying any previously-applied wrappings around it. If | ||
// the caller needs to apply additional decorations to the DialContext hook, | ||
// they must do so after applying WithRotoDialer. | ||
func WithRotoDialer(addrs []net.Addr) func(*http.Transport) { | ||
return func(t *http.Transport) { | ||
var dialers []*net.Dialer | ||
for _, addr := range addrs { | ||
dialers = append(dialers, &net.Dialer{ | ||
LocalAddr: addr, | ||
}) | ||
} | ||
|
||
t.DialContext = (&rotoDialer{ | ||
dialers: dialers, | ||
}).DialContext | ||
} | ||
} | ||
|
||
type rotoDialer struct { | ||
mu sync.Mutex | ||
dialers []*net.Dialer | ||
index int | ||
} | ||
|
||
func (r *rotoDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { | ||
return r.next().DialContext(ctx, network, addr) | ||
} | ||
|
||
func (r *rotoDialer) next() *net.Dialer { | ||
r.mu.Lock() | ||
defer r.mu.Unlock() | ||
|
||
d := r.dialers[r.index] | ||
r.index = (r.index + 1) % len(r.dialers) | ||
return d | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
package transfermanager | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"net" | ||
"testing" | ||
"time" | ||
) | ||
|
||
// these tests also cover the cache impl (cycling+expiry+evict) | ||
|
||
type mockNow struct { | ||
now time.Time | ||
} | ||
|
||
func (m *mockNow) Now() time.Time { | ||
return m.now | ||
} | ||
|
||
func (m *mockNow) Add(d time.Duration) { | ||
m.now = m.now.Add(d) | ||
} | ||
|
||
func useMockNow(m *mockNow) func() { | ||
timeNow = m.Now | ||
return func() { | ||
timeNow = time.Now | ||
} | ||
} | ||
|
||
var errDialContextOK = errors.New("dial context ok") | ||
|
||
type mockResolver struct { | ||
addrs map[string][]string | ||
err error | ||
} | ||
|
||
func (m *mockResolver) LookupHost(ctx context.Context, host string) ([]string, error) { | ||
return m.addrs[host], m.err | ||
} | ||
|
||
type mockDialContext struct { | ||
calledWith string | ||
} | ||
|
||
func (m *mockDialContext) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { | ||
m.calledWith = addr | ||
return nil, errDialContextOK | ||
} | ||
|
||
func TestRoundRobinDNS_CycleIPs(t *testing.T) { | ||
restore := useMockNow(&mockNow{}) | ||
defer restore() | ||
|
||
addrs := []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"} | ||
r := &mockResolver{ | ||
addrs: map[string][]string{ | ||
"s3.us-east-1.amazonaws.com": addrs, | ||
}, | ||
} | ||
dc := &mockDialContext{} | ||
|
||
rr := &rrDNS{ | ||
cache: newDNSCache(1), | ||
resolver: r, | ||
dialContext: dc.DialContext, | ||
} | ||
|
||
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[0]) | ||
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[1]) | ||
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[2]) | ||
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[0]) | ||
} | ||
|
||
func TestRoundRobinDNS_MultiIP(t *testing.T) { | ||
restore := useMockNow(&mockNow{}) | ||
defer restore() | ||
|
||
r := &mockResolver{ | ||
addrs: map[string][]string{ | ||
"host1.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"}, | ||
"host2.com": []string{"1.0.0.1", "1.0.0.2", "1.0.0.3"}, | ||
}, | ||
} | ||
dc := &mockDialContext{} | ||
|
||
rr := &rrDNS{ | ||
cache: newDNSCache(2), | ||
resolver: r, | ||
dialContext: dc.DialContext, | ||
} | ||
|
||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0]) | ||
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][0]) | ||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][1]) | ||
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][1]) | ||
} | ||
|
||
func TestRoundRobinDNS_MaxHosts(t *testing.T) { | ||
restore := useMockNow(&mockNow{}) | ||
defer restore() | ||
|
||
r := &mockResolver{ | ||
addrs: map[string][]string{ | ||
"host1.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"}, | ||
"host2.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"}, | ||
}, | ||
} | ||
dc := &mockDialContext{} | ||
|
||
rr := &rrDNS{ | ||
cache: newDNSCache(1), | ||
resolver: r, | ||
dialContext: dc.DialContext, | ||
} | ||
|
||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0]) | ||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][1]) | ||
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][0]) // evicts host1 | ||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0]) // evicts host2 | ||
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][0]) | ||
} | ||
|
||
func TestRoundRobinDNS_Expires(t *testing.T) { | ||
now := &mockNow{time.Unix(0, 0)} | ||
restore := useMockNow(now) | ||
defer restore() | ||
|
||
r := &mockResolver{ | ||
addrs: map[string][]string{ | ||
"host1.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"}, | ||
}, | ||
} | ||
dc := &mockDialContext{} | ||
|
||
rr := &rrDNS{ | ||
cache: newDNSCache(2), | ||
expiry: 30, | ||
resolver: r, | ||
dialContext: dc.DialContext, | ||
} | ||
|
||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0]) | ||
now.Add(16) // hasn't expired | ||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][1]) | ||
now.Add(16) // expired, starts over | ||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0]) | ||
} | ||
|
||
func expectDialContext(t *testing.T, rr *rrDNS, dc *mockDialContext, host, expect string) { | ||
const port = "443" | ||
|
||
t.Helper() | ||
_, err := rr.DialContext(context.Background(), "", net.JoinHostPort(host, port)) | ||
if err != errDialContextOK { | ||
t.Errorf("expect sentinel err, got %v", err) | ||
} | ||
actual, _, err := net.SplitHostPort(dc.calledWith) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
if expect != actual { | ||
t.Errorf("expect addr %s, got %s", expect, actual) | ||
} | ||
} |