Skip to content

Commit

Permalink
feat: s3/transfermanager (v2): round-robin DNS and multi-NIC (#2975)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucix-aws authored and wty-Bryant committed Jan 30, 2025
1 parent 4af4827 commit 21af651
Show file tree
Hide file tree
Showing 3 changed files with 387 additions and 0 deletions.
61 changes: 61 additions & 0 deletions feature/s3/transfermanager/dns_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package transfermanager

import (
"sync"
"time"

"github.com/aws/smithy-go/container/private/cache"
"github.com/aws/smithy-go/container/private/cache/lru"
)

// dnsCache implements an LRU cache of DNS query results by host.
//
// Cache retrievals will automatically rotate between IP addresses for
// multi-value query results.
type dnsCache struct {
mu sync.Mutex
addrs cache.Cache
}

// newDNSCache returns an initialized dnsCache with given capacity.
func newDNSCache(cap int) *dnsCache {
return &dnsCache{
addrs: lru.New(cap),
}
}

// GetAddr returns the next IP address for the given host if present in the
// cache.
func (c *dnsCache) GetAddr(host string) (string, bool) {
c.mu.Lock()
defer c.mu.Unlock()

v, ok := c.addrs.Get(host)
if !ok {
return "", false
}

record := v.(*dnsCacheEntry)
if timeNow().After(record.expires) {
return "", false
}

addr := record.addrs[record.index]
record.index = (record.index + 1) % len(record.addrs)
return addr, true
}

// PutAddrs stores a DNS query result in the cache, overwriting any present
// entry for the host if it exists.
func (c *dnsCache) PutAddrs(host string, addrs []string, expires time.Time) {
c.mu.Lock()
defer c.mu.Unlock()

c.addrs.Put(host, &dnsCacheEntry{addrs, expires, 0})
}

type dnsCacheEntry struct {
addrs []string
expires time.Time
index int
}
160 changes: 160 additions & 0 deletions feature/s3/transfermanager/rrdns.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package transfermanager

import (
"context"
"fmt"
"net"
"net/http"
"sync"
"time"

"github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
)

var timeNow = time.Now

// WithRoundRobinDNS configures an http.Transport to spread HTTP connections
// across multiple IP addresses for a given host.
//
// This is recommended by the [S3 performance guide] in high-concurrency
// application environments.
//
// WithRoundRobinDNS wraps the underlying DialContext hook on http.Transport.
// Future modifications to this hook MUST preserve said wrapping in order for
// round-robin DNS to operate.
//
// [S3 performance guide]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance-design-patterns.html
func WithRoundRobinDNS(opts ...func(*RoundRobinDNSOptions)) func(*http.Transport) {
options := &RoundRobinDNSOptions{
TTL: 30 * time.Second,
MaxHosts: 100,
}
for _, opt := range opts {
opt(options)
}

return func(t *http.Transport) {
rr := &rrDNS{
cache: newDNSCache(options.MaxHosts),
expiry: options.TTL,
resolver: &net.Resolver{},
dialContext: t.DialContext,
}
t.DialContext = rr.DialContext
}
}

// RoundRobinDNSOptions configures use of round-robin DNS.
type RoundRobinDNSOptions struct {
// The length of time for which the results of a DNS query are valid.
TTL time.Duration

// A limit to the number of DNS query results, cached by hostname, which are
// stored. Round-robin DNS uses an LRU cache.
MaxHosts int
}

type resolver interface {
LookupHost(context.Context, string) ([]string, error)
}

type rrDNS struct {
sf singleflight.Group
cache *dnsCache

expiry time.Duration
resolver resolver

dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
}

// DialContext implements the DialContext hook used by http.Transport,
// pre-caching IP addresses for a given host and distributing them evenly
// across new connections.
func (r *rrDNS) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("rrdns split host/port: %w", err)
}

ipaddr, err := r.getAddr(ctx, host)
if err != nil {
return nil, fmt.Errorf("rrdns lookup host: %w", err)
}

return r.dialContext(ctx, network, net.JoinHostPort(ipaddr, port))
}

func (r *rrDNS) getAddr(ctx context.Context, host string) (string, error) {
addr, ok := r.cache.GetAddr(host)
if ok {
return addr, nil
}
return r.lookupHost(ctx, host)
}

func (r *rrDNS) lookupHost(ctx context.Context, host string) (string, error) {
ch := r.sf.DoChan(host, func() (interface{}, error) {
addrs, err := r.resolver.LookupHost(ctx, host)
if err != nil {
return nil, err
}

expires := timeNow().Add(r.expiry)
r.cache.PutAddrs(host, addrs, expires)
return nil, nil
})

select {
case result := <-ch:
if result.Err != nil {
return "", result.Err
}

addr, _ := r.cache.GetAddr(host)
return addr, nil
case <-ctx.Done():
return "", ctx.Err()
}
}

// WithRotoDialer configures an http.Transport to cycle through multiple local
// network addresses when creating new HTTP connections.
//
// WithRotoDialer REPLACES the root DialContext hook on the underlying
// Transport, thereby destroying any previously-applied wrappings around it. If
// the caller needs to apply additional decorations to the DialContext hook,
// they must do so after applying WithRotoDialer.
func WithRotoDialer(addrs []net.Addr) func(*http.Transport) {
return func(t *http.Transport) {
var dialers []*net.Dialer
for _, addr := range addrs {
dialers = append(dialers, &net.Dialer{
LocalAddr: addr,
})
}

t.DialContext = (&rotoDialer{
dialers: dialers,
}).DialContext
}
}

type rotoDialer struct {
mu sync.Mutex
dialers []*net.Dialer
index int
}

func (r *rotoDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
return r.next().DialContext(ctx, network, addr)
}

func (r *rotoDialer) next() *net.Dialer {
r.mu.Lock()
defer r.mu.Unlock()

d := r.dialers[r.index]
r.index = (r.index + 1) % len(r.dialers)
return d
}
166 changes: 166 additions & 0 deletions feature/s3/transfermanager/rrdns_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
package transfermanager

import (
"context"
"errors"
"net"
"testing"
"time"
)

// these tests also cover the cache impl (cycling+expiry+evict)

type mockNow struct {
now time.Time
}

func (m *mockNow) Now() time.Time {
return m.now
}

func (m *mockNow) Add(d time.Duration) {
m.now = m.now.Add(d)
}

func useMockNow(m *mockNow) func() {
timeNow = m.Now
return func() {
timeNow = time.Now
}
}

var errDialContextOK = errors.New("dial context ok")

type mockResolver struct {
addrs map[string][]string
err error
}

func (m *mockResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
return m.addrs[host], m.err
}

type mockDialContext struct {
calledWith string
}

func (m *mockDialContext) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
m.calledWith = addr
return nil, errDialContextOK
}

func TestRoundRobinDNS_CycleIPs(t *testing.T) {
restore := useMockNow(&mockNow{})
defer restore()

addrs := []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"}
r := &mockResolver{
addrs: map[string][]string{
"s3.us-east-1.amazonaws.com": addrs,
},
}
dc := &mockDialContext{}

rr := &rrDNS{
cache: newDNSCache(1),
resolver: r,
dialContext: dc.DialContext,
}

expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[0])
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[1])
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[2])
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[0])
}

func TestRoundRobinDNS_MultiIP(t *testing.T) {
restore := useMockNow(&mockNow{})
defer restore()

r := &mockResolver{
addrs: map[string][]string{
"host1.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"},
"host2.com": []string{"1.0.0.1", "1.0.0.2", "1.0.0.3"},
},
}
dc := &mockDialContext{}

rr := &rrDNS{
cache: newDNSCache(2),
resolver: r,
dialContext: dc.DialContext,
}

expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0])
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][0])
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][1])
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][1])
}

func TestRoundRobinDNS_MaxHosts(t *testing.T) {
restore := useMockNow(&mockNow{})
defer restore()

r := &mockResolver{
addrs: map[string][]string{
"host1.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"},
"host2.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"},
},
}
dc := &mockDialContext{}

rr := &rrDNS{
cache: newDNSCache(1),
resolver: r,
dialContext: dc.DialContext,
}

expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0])
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][1])
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][0]) // evicts host1
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0]) // evicts host2
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][0])
}

func TestRoundRobinDNS_Expires(t *testing.T) {
now := &mockNow{time.Unix(0, 0)}
restore := useMockNow(now)
defer restore()

r := &mockResolver{
addrs: map[string][]string{
"host1.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"},
},
}
dc := &mockDialContext{}

rr := &rrDNS{
cache: newDNSCache(2),
expiry: 30,
resolver: r,
dialContext: dc.DialContext,
}

expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0])
now.Add(16) // hasn't expired
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][1])
now.Add(16) // expired, starts over
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0])
}

func expectDialContext(t *testing.T, rr *rrDNS, dc *mockDialContext, host, expect string) {
const port = "443"

t.Helper()
_, err := rr.DialContext(context.Background(), "", net.JoinHostPort(host, port))
if err != errDialContextOK {
t.Errorf("expect sentinel err, got %v", err)
}
actual, _, err := net.SplitHostPort(dc.calledWith)
if err != nil {
t.Fatal(err)
}
if expect != actual {
t.Errorf("expect addr %s, got %s", expect, actual)
}
}

0 comments on commit 21af651

Please sign in to comment.