diff --git a/internal/resolver/dns/dns_resolver.go b/internal/resolver/dns/dns_resolver.go index 50afcadc0fd5..4552db16b028 100644 --- a/internal/resolver/dns/dns_resolver.go +++ b/internal/resolver/dns/dns_resolver.go @@ -63,6 +63,8 @@ var ( func init() { resolver.Register(NewBuilder()) internal.TimeAfterFunc = time.After + internal.TimeNowFunc = time.Now + internal.TimeUntilFunc = time.Until internal.NewNetResolver = newNetResolver internal.AddressDialer = addressDialer } @@ -209,12 +211,12 @@ func (d *dnsResolver) watcher() { err = d.cc.UpdateState(*state) } - var waitTime time.Duration + var nextResolutionTime time.Time if err == nil { // Success resolving, wait for the next ResolveNow. However, also wait 30 // seconds at the very least to prevent constantly re-resolving. backoffIndex = 1 - waitTime = MinResolutionInterval + nextResolutionTime = internal.TimeNowFunc().Add(MinResolutionInterval) select { case <-d.ctx.Done(): return @@ -223,13 +225,13 @@ func (d *dnsResolver) watcher() { } else { // Poll on an error found in DNS Resolver or an error received from // ClientConn. - waitTime = backoff.DefaultExponential.Backoff(backoffIndex) + nextResolutionTime = internal.TimeNowFunc().Add(backoff.DefaultExponential.Backoff(backoffIndex)) backoffIndex++ } select { case <-d.ctx.Done(): return - case <-internal.TimeAfterFunc(waitTime): + case <-internal.TimeAfterFunc(internal.TimeUntilFunc(nextResolutionTime)): } } } diff --git a/internal/resolver/dns/dns_resolver_test.go b/internal/resolver/dns/dns_resolver_test.go index 95fd4b5eeeb5..57780f4d68dc 100644 --- a/internal/resolver/dns/dns_resolver_test.go +++ b/internal/resolver/dns/dns_resolver_test.go @@ -102,6 +102,27 @@ func overrideTimeAfterFuncWithChannel(t *testing.T) (durChan chan time.Duration, return durChan, timeChan } +// Override the current time used by the DNS resolver. +func overrideTimeNowFunc(t *testing.T, now time.Time) { + origTimeNowFunc := dnsinternal.TimeNowFunc + dnsinternal.TimeNowFunc = func() time.Time { return now } + t.Cleanup(func() { dnsinternal.TimeNowFunc = origTimeNowFunc }) +} + +// Override the remaining wait time to allow re-resolution by DNS resolver. +// Use the timeChan to read the time until resolver needs to wait for +// and return 0 wait time. +func overrideTimeUntilFuncWithChannel(t *testing.T) (timeChan chan time.Time) { + timeCh := make(chan time.Time, 1) + origTimeUntil := dnsinternal.TimeUntilFunc + dnsinternal.TimeUntilFunc = func(t time.Time) time.Duration { + timeCh <- t + return 0 + } + t.Cleanup(func() { dnsinternal.TimeUntilFunc = origTimeUntil }) + return timeCh +} + func enableSRVLookups(t *testing.T) { origEnableSRVLookups := dns.EnableSRVLookups dns.EnableSRVLookups = true @@ -1290,3 +1311,60 @@ func (s) TestMinResolutionInterval(t *testing.T) { r.ResolveNow(resolver.ResolveNowOptions{}) } } + +// TestMinResolutionInterval_NoExtraDelay verifies that there is no extra delay +// between two resolution requests apart from [MinResolutionInterval]. +func (s) TestMinResolutionInterval_NoExtraDelay(t *testing.T) { + tr := &testNetResolver{ + hostLookupTable: map[string][]string{ + "foo.bar.com": {"1.2.3.4", "5.6.7.8"}, + }, + txtLookupTable: map[string][]string{ + "_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood), + }, + } + overrideNetResolver(t, tr) + // Override time.Now() to return a zero value for time. This will allow us + // to verify that the call to time.Until is made with the exact + // [MinResolutionInterval] that we expect. + overrideTimeNowFunc(t, time.Time{}) + // Override time.Until() to read the time passed to it + // and return immediately without any delay + timeCh := overrideTimeUntilFuncWithChannel(t) + + r, stateCh, errorCh := buildResolverWithTestClientConn(t, "foo.bar.com") + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Ensure that the first resolution happens. + select { + case <-ctx.Done(): + t.Fatal("Timeout when waiting for DNS resolver") + case err := <-errorCh: + t.Fatalf("Unexpected error from resolver, %v", err) + case <-stateCh: + } + + // Request re-resolution and verify that the resolver waits for + // [MinResolutionInterval]. + r.ResolveNow(resolver.ResolveNowOptions{}) + select { + case <-ctx.Done(): + t.Fatal("Timeout when waiting for DNS resolver") + case gotTime := <-timeCh: + wantTime := time.Time{}.Add(dns.MinResolutionInterval) + if !gotTime.Equal(wantTime) { + t.Fatalf("DNS resolver waits for %v time before re-resolution, want %v", gotTime, wantTime) + } + } + + // Ensure that the re-resolution request actually happens. + select { + case <-ctx.Done(): + t.Fatal("Timeout when waiting for an error from the resolver") + case err := <-errorCh: + t.Fatalf("Unexpected error from resolver, %v", err) + case <-stateCh: + } +} diff --git a/internal/resolver/dns/internal/internal.go b/internal/resolver/dns/internal/internal.go index a7ecaf8d5223..c0eae4f5f83f 100644 --- a/internal/resolver/dns/internal/internal.go +++ b/internal/resolver/dns/internal/internal.go @@ -51,11 +51,22 @@ var ( // The following vars are overridden from tests. var ( // TimeAfterFunc is used by the DNS resolver to wait for the given duration - // to elapse. In non-test code, this is implemented by time.After. In test + // to elapse. In non-test code, this is implemented by time.After. In test // code, this can be used to control the amount of time the resolver is // blocked waiting for the duration to elapse. TimeAfterFunc func(time.Duration) <-chan time.Time + // TimeNowFunc is used by the DNS resolver to get the current time. + // In non-test code, this is implemented by time.Now. In test code, + // this can be used to control the current time for the resolver. + TimeNowFunc func() time.Time + + // TimeUntilFunc is used by the DNS resolver to calculate the remaining + // wait time for re-resolution. In non-test code, this is implemented by + // time.Until. In test code, this can be used to control the remaining + // time for resolver to wait for re-resolution. + TimeUntilFunc func(time.Time) time.Duration + // NewNetResolver returns the net.Resolver instance for the given target. NewNetResolver func(string) (NetResolver, error)