diff --git a/redis/pool.go b/redis/pool.go index d7bb71e0..8e22f66b 100644 --- a/redis/pool.go +++ b/redis/pool.go @@ -128,16 +128,24 @@ type Pool struct { // DialContext is an application supplied function for creating and configuring a // connection with the given context. // - // The connection returned from Dial must not be in a special state + // The connection returned from DialContext must not be in a special state // (subscribed to pubsub channel, transaction started, ...). DialContext func(ctx context.Context) (Conn, error) // TestOnBorrow is an optional application supplied function for checking // the health of an idle connection before the connection is used again by - // the application. Argument t is the time that the connection was returned + // the application. Argument lastUsed is the time when the connection was returned + // to the pool. If the function returns an error, then the connection is + // closed. + TestOnBorrow func(c Conn, lastUsed time.Time) error + + // TestOnBorrowContext is an optional application supplied function + // for checking the health of an idle connection with the given context + // before the connection is used again by the application. + // Argument lastUsed is the time when the connection was returned // to the pool. If the function returns an error, then the connection is // closed. - TestOnBorrow func(c Conn, t time.Time) error + TestOnBorrowContext func(ctx context.Context, c Conn, lastUsed time.Time) error // Maximum number of idle connections in the pool. MaxIdle int @@ -228,6 +236,7 @@ func (p *Pool) GetContext(ctx context.Context) (Conn, error) { p.idle.popFront() p.mu.Unlock() if (p.TestOnBorrow == nil || p.TestOnBorrow(pc.c, pc.t) == nil) && + (p.TestOnBorrowContext == nil || p.TestOnBorrowContext(ctx, pc.c, pc.t) == nil) && (p.MaxConnLifetime == 0 || nowFunc().Sub(pc.created) < p.MaxConnLifetime) { return &activeConn{p: p, pc: pc}, nil } diff --git a/redis/pool_test.go b/redis/pool_test.go index f420b124..2748d977 100644 --- a/redis/pool_test.go +++ b/redis/pool_test.go @@ -48,6 +48,25 @@ func (c *poolTestConn) Close() error { func (c *poolTestConn) Err() error { return c.err } func (c *poolTestConn) Do(commandName string, args ...interface{}) (interface{}, error) { + return c.do(c.Conn.Do, commandName, args...) +} + +func (c *poolTestConn) DoContext(ctx context.Context, commandName string, args ...interface{}) (interface{}, error) { + cwc, ok := c.Conn.(redis.ConnWithContext) + if !ok { + return nil, errors.New("redis: connection does not support ConnWithContext") + } + return c.do( + func(c string, a ...interface{}) (interface{}, error) { + return cwc.DoContext(ctx, c, a...) + }, + commandName, args) +} + +func (c *poolTestConn) do( + fn func(commandName string, args ...interface{}) (interface{}, error), + commandName string, args ...interface{}, +) (interface{}, error) { if commandName == "ERR" { c.err = args[0].(error) commandName = "PING" @@ -55,7 +74,7 @@ func (c *poolTestConn) Do(commandName string, args ...interface{}) (interface{}, if commandName != "" { c.d.commands = append(c.d.commands, commandName) } - return c.Conn.Do(commandName, args...) + return fn(commandName, args...) } func (c *poolTestConn) Send(commandName string, args ...interface{}) error { @@ -63,6 +82,14 @@ func (c *poolTestConn) Send(commandName string, args ...interface{}) error { return c.Conn.Send(commandName, args...) } +func (c *poolTestConn) ReceiveContext(ctx context.Context) (reply interface{}, err error) { + cwc, ok := c.Conn.(redis.ConnWithContext) + if !ok { + return nil, errors.New("redis: connection does not support ConnWithContext") + } + return cwc.ReceiveContext(ctx) +} + type poolDialer struct { mu sync.Mutex t *testing.T @@ -73,6 +100,10 @@ type poolDialer struct { } func (d *poolDialer) dial() (redis.Conn, error) { + return d.dialContext(context.Background()) +} + +func (d *poolDialer) dialContext(ctx context.Context) (redis.Conn, error) { d.mu.Lock() d.dialed += 1 dialErr := d.dialErr @@ -80,7 +111,7 @@ func (d *poolDialer) dial() (redis.Conn, error) { if dialErr != nil { return nil, d.dialErr } - c, err := redis.DialDefaultServer() + c, err := redis.DialDefaultServerContext(ctx) if err != nil { return nil, err } @@ -90,15 +121,14 @@ func (d *poolDialer) dial() (redis.Conn, error) { return &poolTestConn{d: d, Conn: c}, nil } -func (d *poolDialer) dialContext(ctx context.Context) (redis.Conn, error) { - return d.dial() -} - func (d *poolDialer) check(message string, p *redis.Pool, dialed, open, inuse int) { + d.t.Helper() d.checkAll(message, p, dialed, open, inuse, 0, 0) } func (d *poolDialer) checkAll(message string, p *redis.Pool, dialed, open, inuse int, waitCountMax int64, waitDurationMax time.Duration) { + d.t.Helper() + d.mu.Lock() defer d.mu.Unlock() @@ -368,21 +398,142 @@ func TestPoolConcurrenSendReceive(t *testing.T) { } func TestPoolBorrowCheck(t *testing.T) { - d := poolDialer{t: t} - p := &redis.Pool{ - MaxIdle: 2, - Dial: d.dial, - TestOnBorrow: func(redis.Conn, time.Time) error { return redis.Error("BLAH") }, + pingN := func(ctx context.Context, p *redis.Pool, n int) { + for i := 0; i < n; i++ { + func() { + c, err := p.GetContext(ctx) + require.NoError(t, err) + defer func() { + require.NoError(t, c.Close()) + }() + _, err = redis.DoContext(c, ctx, "PING") + require.NoError(t, err) + }() + } } - defer p.Close() - for i := 0; i < 10; i++ { - c := p.Get() - _, err := c.Do("PING") - require.NoError(t, err) - c.Close() + checkLastUsedTimes := func(lastUsedTimes []time.Time, startTime, endTime time.Time, wantLen int) { + require.Len(t, lastUsedTimes, wantLen) + for i, lastUsed := range lastUsedTimes { + if i == 0 { + require.True(t, lastUsed.After(startTime)) + } else { + require.True(t, lastUsed.After(lastUsedTimes[i-1])) + } + require.True(t, lastUsed.Before(endTime)) + } } - d.check("1", p, 10, 1, 0) + + t.Run("TestOnBorrow-error", func(t *testing.T) { + d := poolDialer{t: t} + p := &redis.Pool{ + MaxIdle: 2, + DialContext: d.dialContext, + TestOnBorrow: func(redis.Conn, time.Time) error { return redis.Error("BLAH") }, + } + defer p.Close() + pingN(context.Background(), p, 10) + d.check("1", p, 10, 1, 0) + }) + + t.Run("TestOnBorrow-nil-error", func(t *testing.T) { + d := poolDialer{t: t} + var borrowErrs []error + var lastUsedTimes []time.Time + p := &redis.Pool{ + MaxIdle: 2, + DialContext: d.dialContext, + TestOnBorrow: func(c redis.Conn, lastUsed time.Time) error { + lastUsedTimes = append(lastUsedTimes, lastUsed) + _, err := c.Do("PING") + if err != nil { + borrowErrs = append(borrowErrs, err) + } + return err + }, + } + defer p.Close() + + startTime := time.Now() + pingN(context.Background(), p, 10) + endTime := time.Now() + + require.Empty(t, borrowErrs) + checkLastUsedTimes(lastUsedTimes, startTime, endTime, 9) + d.check("1", p, 1, 1, 0) + }) + + t.Run("TestOnBorrowContext-error", func(t *testing.T) { + d := poolDialer{t: t} + p := &redis.Pool{ + MaxIdle: 2, + DialContext: d.dialContext, + TestOnBorrowContext: func(context.Context, redis.Conn, time.Time) error { return redis.Error("BLAH") }, + } + defer p.Close() + pingN(context.Background(), p, 10) + d.check("1", p, 10, 1, 0) + }) + + t.Run("TestOnBorrowContext-nil-error", func(t *testing.T) { + d := poolDialer{t: t} + var borrowErrs []error + var lastUsedTimes []time.Time + p := &redis.Pool{ + MaxIdle: 2, + DialContext: d.dialContext, + TestOnBorrowContext: func(ctx context.Context, c redis.Conn, lastUsed time.Time) error { + lastUsedTimes = append(lastUsedTimes, lastUsed) + _, err := redis.DoContext(c, ctx, "PING") + if err != nil { + borrowErrs = append(borrowErrs, err) + } + return err + }, + } + defer p.Close() + + startTime := time.Now() + pingN(context.Background(), p, 10) + endTime := time.Now() + + require.Empty(t, borrowErrs) + checkLastUsedTimes(lastUsedTimes, startTime, endTime, 9) + d.check("1", p, 1, 1, 0) + }) + + t.Run("TestOnBorrowContext-context.Canceled", func(t *testing.T) { + d := poolDialer{t: t} + var borrowErrs []error + p := &redis.Pool{ + MaxIdle: 2, + DialContext: d.dialContext, + TestOnBorrowContext: func(ctx context.Context, c redis.Conn, _ time.Time) error { + _, err := redis.DoContext(c, ctx, "PING") + if err != nil { + borrowErrs = append(borrowErrs, err) + } + return err + }, + } + defer p.Close() + + ctx, ctxCancel := context.WithCancel(context.Background()) + defer ctxCancel() + + pingN(ctx, p, 2) + d.check("1", p, 1, 1, 0) + require.Empty(t, borrowErrs) + + ctxCancel() + + _, err := p.GetContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + d.check("1", p, 2, 0, 0) + require.Len(t, borrowErrs, 1) + require.ErrorIs(t, borrowErrs[0], context.Canceled) + }) } func TestPoolMaxActive(t *testing.T) { @@ -757,7 +908,7 @@ func TestLocking_TestOnBorrowFails_PoolDoesntCrash(t *testing.T) { MaxIdle: count, MaxActive: count, Dial: d.dial, - TestOnBorrow: func(c redis.Conn, t time.Time) error { + TestOnBorrow: func(redis.Conn, time.Time) error { return errors.New("No way back into the real world.") }, } diff --git a/redis/test_test.go b/redis/test_test.go index 11c34734..f7598683 100644 --- a/redis/test_test.go +++ b/redis/test_test.go @@ -16,6 +16,7 @@ package redis import ( "bufio" + "context" "errors" "flag" "fmt" @@ -197,15 +198,21 @@ func DefaultServerAddr() (string, error) { // DialDefaultServer starts the test server if not already started and dials a // connection to the server. func DialDefaultServer(options ...DialOption) (Conn, error) { + return DialDefaultServerContext(context.Background(), options...) +} + +// DialDefaultServerContext starts the test server if not already started and +// dials a connection to the server with the given context. +func DialDefaultServerContext(ctx context.Context, options ...DialOption) (Conn, error) { addr, err := DefaultServerAddr() if err != nil { return nil, err } - c, err := Dial("tcp", addr, append([]DialOption{DialReadTimeout(1 * time.Second), DialWriteTimeout(1 * time.Second)}, options...)...) + c, err := DialContext(ctx, "tcp", addr, append([]DialOption{DialReadTimeout(1 * time.Second), DialWriteTimeout(1 * time.Second)}, options...)...) if err != nil { return nil, err } - if _, err = c.Do("FLUSHDB"); err != nil { + if _, err = DoContext(c, ctx, "FLUSHDB"); err != nil { return nil, err } return c, nil