Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent repeated context expired errors #228

Merged
merged 2 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ type Dialer struct {
// RSA keypair is generated will be faster.
func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
cfg := &dialerConfig{
refreshTimeout: 30 * time.Second,
refreshTimeout: alloydb.RefreshTimeout,
dialFunc: proxy.Dial,
useragents: []string{userAgent},
}
Expand Down
79 changes: 54 additions & 25 deletions internal/alloydb/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,26 @@ import (

"cloud.google.com/go/alloydbconn/errtype"
"cloud.google.com/go/alloydbconn/internal/alloydbapi"
"golang.org/x/time/rate"
)

// the refresh buffer is the amount of time before a refresh's result expires
// that a new refresh operation begins.
const refreshBuffer = 4 * time.Minute
const (
// the refresh buffer is the amount of time before a refresh's result
// expires that a new refresh operation begins.
refreshBuffer = 4 * time.Minute

// refreshInterval is the amount of time between refresh attempts as
// enforced by the rate limiter.
refreshInterval = 30 * time.Second

// RefreshTimeout is the maximum amount of time to wait for a refresh
// cycle to complete. This value should be greater than the
// refreshInterval.
RefreshTimeout = 60 * time.Second

// refreshBurst is the initial burst allowed by the rate limiter.
refreshBurst = 2
)

var (
// Instance URI is in the format:
Expand Down Expand Up @@ -117,7 +132,12 @@ type Instance struct {

instanceURI
key *rsa.PrivateKey
r refresher
// refreshTimeout sets the maximum duration a refresh cycle can run
// for.
refreshTimeout time.Duration
// l controls the rate at which refresh cycles are run.
l *rate.Limiter
r refresher

resultGuard sync.RWMutex
// cur represents the current refreshOperation that will be used to
Expand Down Expand Up @@ -148,17 +168,13 @@ func NewInstance(
}
ctx, cancel := context.WithCancel(context.Background())
i := &Instance{
instanceURI: cn,
key: key,
r: newRefresher(
client,
refreshTimeout,
30*time.Second,
2,
dialerID,
),
ctx: ctx,
cancel: cancel,
instanceURI: cn,
key: key,
l: rate.NewLimiter(rate.Every(refreshInterval), refreshBurst),
r: newRefresher(client, dialerID),
refreshTimeout: refreshTimeout,
ctx: ctx,
cancel: cancel,
}
// For the initial refresh operation, set cur = next so that connection
// requests block until the first refresh is complete.
Expand Down Expand Up @@ -234,20 +250,33 @@ func refreshDuration(now, certExpiry time.Time) time.Duration {

// scheduleRefresh schedules a refresh operation to be triggered after a given
// duration. The returned refreshOperation can be used to either Cancel or Wait
// for the operations result.
// for the operation's result.
func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {
res := &refreshOperation{}
res.ready = make(chan struct{})
res.timer = time.AfterFunc(d, func() {
res.result, res.err = i.r.performRefresh(i.ctx, i.instanceURI, i.key)
close(res.ready)
r := &refreshOperation{}
r.ready = make(chan struct{})
r.timer = time.AfterFunc(d, func() {
ctx, cancel := context.WithTimeout(i.ctx, i.refreshTimeout)
defer cancel()

err := i.l.Wait(ctx)
if err != nil {
r.err = errtype.NewDialError(
"context was canceled or expired before refresh completed",
i.instanceURI.String(),
nil,
)
} else {
r.result, r.err = i.r.performRefresh(i.ctx, i.instanceURI, i.key)
}

close(r.ready)

// Once the refresh is complete, update "current" with working
// result and schedule a new refresh
i.resultGuard.Lock()
defer i.resultGuard.Unlock()
// if failed, scheduled the next refresh immediately
if res.err != nil {
if r.err != nil {
i.next = i.scheduleRefresh(0)
// If the latest result is bad, avoid replacing the
// used result while it's still valid and potentially
Expand All @@ -256,13 +285,13 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {
// valid are surpressed. We should try to surface
// errors in a more meaningful way.
if !i.cur.isValid() {
i.cur = res
i.cur = r
}
return
}
// Update the current results, and schedule the next refresh in
// the future
i.cur = res
i.cur = r
select {
case <-i.ctx.Done():
// instance has been closed, don't schedule anything
Expand All @@ -272,7 +301,7 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {
t := refreshDuration(time.Now(), i.cur.result.expiry)
i.next = i.scheduleRefresh(t)
})
return res
return r
}

// String returns the instance's URI.
Expand Down
3 changes: 2 additions & 1 deletion internal/alloydb/instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"crypto/rand"
"crypto/rsa"
"errors"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -210,7 +211,7 @@ func TestClose(t *testing.T) {
im.Close()

_, _, err = im.ConnectInfo(ctx)
if !errors.Is(err, context.Canceled) {
if !strings.Contains(err.Error(), "context was canceled or expired") {
t.Fatalf("failed to retrieve connect info: %v", err)
}
}
Expand Down
32 changes: 2 additions & 30 deletions internal/alloydb/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"cloud.google.com/go/alloydbconn/errtype"
"cloud.google.com/go/alloydbconn/internal/alloydbapi"
"cloud.google.com/go/alloydbconn/internal/trace"
"golang.org/x/time/rate"
)

type connectInfo struct {
Expand Down Expand Up @@ -196,16 +195,11 @@ func createTLSConfig(inst instanceURI, cc certChain, info connectInfo, k *rsa.Pr
// newRefresher creates a Refresher.
func newRefresher(
client *alloydbapi.Client,
timeout time.Duration,
interval time.Duration,
burst int,
dialerID string,
) refresher {
return refresher{
client: client,
timeout: timeout,
clientLimiter: rate.NewLimiter(rate.Every(interval), burst),
dialerID: dialerID,
client: client,
dialerID: dialerID,
}
}

Expand All @@ -215,14 +209,8 @@ type refresher struct {
// client provides access to the AlloyDB Admin API
client *alloydbapi.Client

// timeout is the maximum amount of time a refresh operation should be allowed to take.
timeout time.Duration

// dialerID is the unique ID of the associated dialer.
dialerID string

// clientLimiter limits the number of refreshes.
clientLimiter *rate.Limiter
}

type refreshResult struct {
Expand All @@ -247,22 +235,6 @@ func (r refresher) performRefresh(ctx context.Context, cn instanceURI, k *rsa.Pr
refreshEnd(err)
}()

ctx, cancel := context.WithTimeout(ctx, r.timeout)
defer cancel()
if ctx.Err() == context.Canceled {
return refreshResult{}, ctx.Err()
}

// avoid refreshing too often to try not to tax the AlloyDB Admin API quotas
err = r.clientLimiter.Wait(ctx)
if err != nil {
return refreshResult{}, errtype.NewDialError(
"refresh was throttled until context expired",
cn.String(),
nil,
)
}

type mdRes struct {
info connectInfo
err error
Expand Down
17 changes: 4 additions & 13 deletions internal/alloydb/refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ import (
"testing"
"time"

"cloud.google.com/go/alloydbconn/errtype"
"cloud.google.com/go/alloydbconn/internal/alloydbapi"
"cloud.google.com/go/alloydbconn/internal/mock"
"google.golang.org/api/option"
)

const testDialerID = "some-dialer-id"

func TestRefresh(t *testing.T) {
wantIP := "10.0.0.1"
wantExpiry := time.Now().Add(time.Hour).UTC().Round(time.Second)
Expand Down Expand Up @@ -57,7 +58,7 @@ func TestRefresh(t *testing.T) {
if err != nil {
t.Fatalf("admin API client error: %v", err)
}
r := newRefresher(cl, time.Hour, 30*time.Second, 2, "some-id")
r := newRefresher(cl, testDialerID)
res, err := r.performRefresh(context.Background(), cn, RSAKey)
if err != nil {
t.Fatalf("performRefresh unexpectedly failed with error: %v", err)
Expand Down Expand Up @@ -98,7 +99,7 @@ func TestRefreshFailsFast(t *testing.T) {
if err != nil {
t.Fatalf("admin API client error: %v", err)
}
r := newRefresher(cl, time.Hour, 30*time.Second, 1, "some-id")
r := newRefresher(cl, testDialerID)

_, err = r.performRefresh(context.Background(), cn, RSAKey)
if err != nil {
Expand All @@ -112,14 +113,4 @@ func TestRefreshFailsFast(t *testing.T) {
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled error, got = %v", err)
}

// force the rate limiter to throttle with a timed out context
ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
_, err = r.performRefresh(ctx, cn, RSAKey)

var wantErr *errtype.DialError
if !errors.As(err, &wantErr) {
t.Fatalf("when refresh is throttled, want = %T, got = %v", wantErr, err)
}
}
3 changes: 2 additions & 1 deletion options.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ func WithRSAKey(k *rsa.PrivateKey) Option {
}
}

// WithRefreshTimeout returns an Option that sets a timeout on refresh operations. Defaults to 30s.
// WithRefreshTimeout returns an Option that sets a timeout on refresh
// operations. Defaults to 60s.
func WithRefreshTimeout(t time.Duration) Option {
return func(d *dialerConfig) {
d.refreshTimeout = t
Expand Down