Skip to content

Commit

Permalink
DNS resolving with timeout.
Browse files Browse the repository at this point in the history
  • Loading branch information
and1truong committed Jan 11, 2024
1 parent 6ce73bf commit 5b77a44
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 110 deletions.
8 changes: 8 additions & 0 deletions dialoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ type dialOptions struct {
defaultServiceConfig *ServiceConfig // defaultServiceConfig is parsed from defaultServiceConfigRawJSON.
defaultServiceConfigRawJSON *string
resolvers []resolver.Builder
resolveTimeout time.Duration
idleTimeout time.Duration
recvBufferPool SharedBufferPool
}
Expand Down Expand Up @@ -694,6 +695,13 @@ func WithIdleTimeout(d time.Duration) DialOption {
})
}

// WithResolveTimeout returns a DialOption that configures a DNS resolving timeout.
func WithResolveTimeout(d time.Duration) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.resolveTimeout = d
})

Check warning on line 702 in dialoptions.go

View check run for this annotation

Codecov / codecov/patch

dialoptions.go#L699-L702

Added lines #L699 - L702 were not covered by tests
}

// WithRecvBufferPool returns a DialOption that configures the ClientConn
// to use the provided shared buffer pool for parsing incoming messages. Depending
// on the application's workload, this could result in reduced memory allocation.
Expand Down
58 changes: 36 additions & 22 deletions internal/resolver/dns/dns_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
"strings"
"sync"
"time"

grpclbstate "google.golang.org/grpc/balancer/grpclb/state"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/backoff"
Expand Down Expand Up @@ -77,14 +77,14 @@ var newNetResolver = func(authority string) (internal.NetResolver, error) {
if authority == "" {
return net.DefaultResolver, nil
}

host, port, err := parseTarget(authority, defaultDNSSvrPort)
if err != nil {
return nil, err
}

authorityWithPort := net.JoinHostPort(host, port)

return &net.Resolver{
PreferGo: true,
Dial: internal.AddressDialer(authorityWithPort),
Expand All @@ -100,36 +100,43 @@ type dnsBuilder struct{}

// Build creates and starts a DNS resolver that watches the name resolution of
// the target.
func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (
resolver.Resolver, error,
) {
host, port, err := parseTarget(target.Endpoint(), defaultPort)
if err != nil {
return nil, err
}

// IP address.
if ipAddr, ok := formatIP(host); ok {
addr := []resolver.Address{{Addr: ipAddr + ":" + port}}
cc.UpdateState(resolver.State{Addresses: addr})
return deadResolver{}, nil
}

// DNS address (non-IP).
ctx, cancel := context.WithCancel(context.Background())
d := &dnsResolver{
host: host,
port: port,
timeout: opts.Timeout,
ctx: ctx,
cancel: cancel,
cc: cc,
rn: make(chan struct{}, 1),
disableServiceConfig: opts.DisableServiceConfig,
}


if d.timeout == 0 {
d.timeout = 1 * time.Minute
}

d.resolver, err = internal.NewNetResolver(target.URL.Host)
if err != nil {
return nil, err
}

d.wg.Add(1)
go d.watcher()
return d, nil
Expand All @@ -152,6 +159,7 @@ type dnsResolver struct {
host string
port string
resolver internal.NetResolver
timeout time.Duration
ctx context.Context
cancel context.CancelFunc
cc resolver.ClientConn
Expand Down Expand Up @@ -195,7 +203,7 @@ func (d *dnsResolver) watcher() {
} else {
err = d.cc.UpdateState(*state)
}

var waitTime time.Duration
if err == nil {
// Success resolving, wait for the next ResolveNow. However, also wait 30
Expand All @@ -221,18 +229,18 @@ func (d *dnsResolver) watcher() {
}
}

func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) {
func (d *dnsResolver) lookupSRV(ctx context.Context) ([]resolver.Address, error) {
if !EnableSRVLookups {
return nil, nil
}
var newAddrs []resolver.Address
_, srvs, err := d.resolver.LookupSRV(d.ctx, "grpclb", "tcp", d.host)
_, srvs, err := d.resolver.LookupSRV(ctx, "grpclb", "tcp", d.host)
if err != nil {
err = handleDNSError(err, "SRV") // may become nil
return nil, err
}
for _, s := range srvs {
lbAddrs, err := d.resolver.LookupHost(d.ctx, s.Target)
lbAddrs, err := d.resolver.LookupHost(ctx, s.Target)
if err != nil {
err = handleDNSError(err, "A") // may become nil
if err == nil {
Expand Down Expand Up @@ -269,8 +277,8 @@ func handleDNSError(err error, lookupType string) error {
return err
}

func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult {
ss, err := d.resolver.LookupTXT(d.ctx, txtPrefix+d.host)
func (d *dnsResolver) lookupTXT(ctx context.Context) *serviceconfig.ParseResult {
ss, err := d.resolver.LookupTXT(ctx, txtPrefix+d.host)
if err != nil {
if envconfig.TXTErrIgnore {
return nil
Expand All @@ -284,7 +292,7 @@ func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult {
for _, s := range ss {
res += s
}

// TXT record must have "grpc_config=" attribute in order to be used as
// service config.
if !strings.HasPrefix(res, txtAttribute) {
Expand All @@ -297,8 +305,8 @@ func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult {
return d.cc.ParseServiceConfig(sc)
}

func (d *dnsResolver) lookupHost() ([]resolver.Address, error) {
addrs, err := d.resolver.LookupHost(d.ctx, d.host)
func (d *dnsResolver) lookupHost(ctx context.Context) ([]resolver.Address, error) {
addrs, err := d.resolver.LookupHost(ctx, d.host)
if err != nil {
err = handleDNSError(err, "A")
return nil, err
Expand All @@ -316,18 +324,24 @@ func (d *dnsResolver) lookupHost() ([]resolver.Address, error) {
}

func (d *dnsResolver) lookup() (*resolver.State, error) {
srv, srvErr := d.lookupSRV()
addrs, hostErr := d.lookupHost()
ctxSRV, cancelSRV := context.WithTimeout(d.ctx, d.timeout)
defer cancelSRV()
srv, srvErr := d.lookupSRV(ctxSRV)
ctxHost, cancelHost := context.WithTimeout(d.ctx, d.timeout)
defer cancelHost()
addrs, hostErr := d.lookupHost(ctxHost)
if hostErr != nil && (srvErr != nil || len(srv) == 0) {
return nil, hostErr
}

state := resolver.State{Addresses: addrs}
if len(srv) > 0 {
state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv})
}
if !d.disableServiceConfig {
state.ServiceConfig = d.lookupTXT()
ctxTXT, cancelTXT := context.WithTimeout(d.ctx, d.timeout)
defer cancelTXT()
state.ServiceConfig = d.lookupTXT(ctxTXT)
}
return &state, nil
}
Expand Down
Loading

0 comments on commit 5b77a44

Please sign in to comment.