Skip to content

Commit

Permalink
dns: fix constant 30s backoff for re-resolution (#7262) (#7311)
Browse files Browse the repository at this point in the history
  • Loading branch information
purnesh42H authored Jun 7, 2024
1 parent 6d23620 commit 9b970fd
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 5 deletions.
10 changes: 6 additions & 4 deletions internal/resolver/dns/dns_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ var (
func init() {
resolver.Register(NewBuilder())
internal.TimeAfterFunc = time.After
internal.TimeNowFunc = time.Now
internal.TimeUntilFunc = time.Until
internal.NewNetResolver = newNetResolver
internal.AddressDialer = addressDialer
}
Expand Down Expand Up @@ -209,12 +211,12 @@ func (d *dnsResolver) watcher() {
err = d.cc.UpdateState(*state)
}

var waitTime time.Duration
var nextResolutionTime time.Time
if err == nil {
// Success resolving, wait for the next ResolveNow. However, also wait 30
// seconds at the very least to prevent constantly re-resolving.
backoffIndex = 1
waitTime = MinResolutionInterval
nextResolutionTime = internal.TimeNowFunc().Add(MinResolutionInterval)
select {
case <-d.ctx.Done():
return
Expand All @@ -223,13 +225,13 @@ func (d *dnsResolver) watcher() {
} else {
// Poll on an error found in DNS Resolver or an error received from
// ClientConn.
waitTime = backoff.DefaultExponential.Backoff(backoffIndex)
nextResolutionTime = internal.TimeNowFunc().Add(backoff.DefaultExponential.Backoff(backoffIndex))
backoffIndex++
}
select {
case <-d.ctx.Done():
return
case <-internal.TimeAfterFunc(waitTime):
case <-internal.TimeAfterFunc(internal.TimeUntilFunc(nextResolutionTime)):
}
}
}
Expand Down
78 changes: 78 additions & 0 deletions internal/resolver/dns/dns_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,27 @@ func overrideTimeAfterFuncWithChannel(t *testing.T) (durChan chan time.Duration,
return durChan, timeChan
}

// Override the current time used by the DNS resolver.
func overrideTimeNowFunc(t *testing.T, now time.Time) {
origTimeNowFunc := dnsinternal.TimeNowFunc
dnsinternal.TimeNowFunc = func() time.Time { return now }
t.Cleanup(func() { dnsinternal.TimeNowFunc = origTimeNowFunc })
}

// Override the remaining wait time to allow re-resolution by DNS resolver.
// Use the timeChan to read the time until resolver needs to wait for
// and return 0 wait time.
func overrideTimeUntilFuncWithChannel(t *testing.T) (timeChan chan time.Time) {
timeCh := make(chan time.Time, 1)
origTimeUntil := dnsinternal.TimeUntilFunc
dnsinternal.TimeUntilFunc = func(t time.Time) time.Duration {
timeCh <- t
return 0
}
t.Cleanup(func() { dnsinternal.TimeUntilFunc = origTimeUntil })
return timeCh
}

func enableSRVLookups(t *testing.T) {
origEnableSRVLookups := dns.EnableSRVLookups
dns.EnableSRVLookups = true
Expand Down Expand Up @@ -1290,3 +1311,60 @@ func (s) TestMinResolutionInterval(t *testing.T) {
r.ResolveNow(resolver.ResolveNowOptions{})
}
}

// TestMinResolutionInterval_NoExtraDelay verifies that there is no extra delay
// between two resolution requests apart from [MinResolutionInterval].
func (s) TestMinResolutionInterval_NoExtraDelay(t *testing.T) {
tr := &testNetResolver{
hostLookupTable: map[string][]string{
"foo.bar.com": {"1.2.3.4", "5.6.7.8"},
},
txtLookupTable: map[string][]string{
"_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood),
},
}
overrideNetResolver(t, tr)
// Override time.Now() to return a zero value for time. This will allow us
// to verify that the call to time.Until is made with the exact
// [MinResolutionInterval] that we expect.
overrideTimeNowFunc(t, time.Time{})
// Override time.Until() to read the time passed to it
// and return immediately without any delay
timeCh := overrideTimeUntilFuncWithChannel(t)

r, stateCh, errorCh := buildResolverWithTestClientConn(t, "foo.bar.com")

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

// Ensure that the first resolution happens.
select {
case <-ctx.Done():
t.Fatal("Timeout when waiting for DNS resolver")
case err := <-errorCh:
t.Fatalf("Unexpected error from resolver, %v", err)
case <-stateCh:
}

// Request re-resolution and verify that the resolver waits for
// [MinResolutionInterval].
r.ResolveNow(resolver.ResolveNowOptions{})
select {
case <-ctx.Done():
t.Fatal("Timeout when waiting for DNS resolver")
case gotTime := <-timeCh:
wantTime := time.Time{}.Add(dns.MinResolutionInterval)
if !gotTime.Equal(wantTime) {
t.Fatalf("DNS resolver waits for %v time before re-resolution, want %v", gotTime, wantTime)
}
}

// Ensure that the re-resolution request actually happens.
select {
case <-ctx.Done():
t.Fatal("Timeout when waiting for an error from the resolver")
case err := <-errorCh:
t.Fatalf("Unexpected error from resolver, %v", err)
case <-stateCh:
}
}
13 changes: 12 additions & 1 deletion internal/resolver/dns/internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,22 @@ var (
// The following vars are overridden from tests.
var (
// TimeAfterFunc is used by the DNS resolver to wait for the given duration
// to elapse. In non-test code, this is implemented by time.After. In test
// to elapse. In non-test code, this is implemented by time.After. In test
// code, this can be used to control the amount of time the resolver is
// blocked waiting for the duration to elapse.
TimeAfterFunc func(time.Duration) <-chan time.Time

// TimeNowFunc is used by the DNS resolver to get the current time.
// In non-test code, this is implemented by time.Now. In test code,
// this can be used to control the current time for the resolver.
TimeNowFunc func() time.Time

// TimeUntilFunc is used by the DNS resolver to calculate the remaining
// wait time for re-resolution. In non-test code, this is implemented by
// time.Until. In test code, this can be used to control the remaining
// time for resolver to wait for re-resolution.
TimeUntilFunc func(time.Time) time.Duration

// NewNetResolver returns the net.Resolver instance for the given target.
NewNetResolver func(string) (NetResolver, error)

Expand Down

0 comments on commit 9b970fd

Please sign in to comment.