diff --git a/README.md b/README.md index 98496c7..4be91df 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,6 @@ [![Test](https://github.com/bsm/redislock/actions/workflows/test.yml/badge.svg)](https://github.com/bsm/redislock/actions/workflows/test.yml) [![GoDoc](https://godoc.org/github.com/bsm/redislock?status.png)](http://godoc.org/github.com/bsm/redislock) -[![Go Report Card](https://goreportcard.com/badge/github.com/bsm/redislock)](https://goreportcard.com/report/github.com/bsm/redislock) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) Simplified distributed locking implementation using [Redis](http://redis.io/topics/distlock). diff --git a/README.md.tpl b/README.md.tpl index c1be07e..92a04fe 100644 --- a/README.md.tpl +++ b/README.md.tpl @@ -2,7 +2,6 @@ [![Test](https://github.com/bsm/redislock/actions/workflows/test.yml/badge.svg)](https://github.com/bsm/redislock/actions/workflows/test.yml) [![GoDoc](https://godoc.org/github.com/bsm/redislock?status.png)](http://godoc.org/github.com/bsm/redislock) -[![Go Report Card](https://goreportcard.com/badge/github.com/bsm/redislock)](https://goreportcard.com/report/github.com/bsm/redislock) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) Simplified distributed locking implementation using [Redis](http://redis.io/topics/distlock). diff --git a/go.mod b/go.mod index 918d505..c6e66f8 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,4 @@ module github.com/bsm/redislock go 1.13 -require ( - github.com/bsm/ginkgo v1.16.4 - github.com/bsm/gomega v1.16.0 - github.com/go-redis/redis/v8 v8.11.4 -) +require github.com/go-redis/redis/v8 v8.11.4 diff --git a/go.sum b/go.sum index 6428d8e..5e86c5e 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,3 @@ -github.com/bsm/ginkgo v1.16.4 h1:pkHpo2VJRvI0NGlxCYi8qovww76L7+g82MgM+UBvH4A= -github.com/bsm/ginkgo v1.16.4/go.mod h1:RabIZLzOCPghgHJKUqHZpqrQETA5AnF4aCSIYy5C1bk= -github.com/bsm/gomega v1.16.0 h1:LEoRGHyYl3MqAcXgczKX/C3bxlxjl3gjP37PGvPNplw= -github.com/bsm/gomega v1.16.0/go.mod h1:JifAceMQ4crZIWYUKrlGcmbN3bqHogVTADMD2ATsbwk= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/redislock.go b/redislock.go index b1b00b5..3279bbc 100644 --- a/redislock.go +++ b/redislock.go @@ -61,12 +61,16 @@ func (c *Client) Obtain(ctx context.Context, key string, ttl time.Duration, opt value := token + opt.getMetadata() retry := opt.getRetryStrategy() - deadlinectx, cancel := context.WithDeadline(ctx, time.Now().Add(ttl)) - defer cancel() + // make sure we don't retry forever + if _, ok := ctx.Deadline(); !ok { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, time.Now().Add(ttl)) + defer cancel() + } var timer *time.Timer for { - ok, err := c.obtain(deadlinectx, key, value, ttl) + ok, err := c.obtain(ctx, key, value, ttl) if err != nil { return nil, err } else if ok { @@ -86,7 +90,7 @@ func (c *Client) Obtain(ctx context.Context, key string, ttl time.Duration, opt } select { - case <-deadlinectx.Done(): + case <-ctx.Done(): return nil, ErrNotObtained case <-timer.C: } diff --git a/redislock_test.go b/redislock_test.go index 52c3c0c..9ca0244 100644 --- a/redislock_test.go +++ b/redislock_test.go @@ -2,234 +2,272 @@ package redislock_test import ( "context" - "fmt" + "errors" "math/rand" "sync" "sync/atomic" "testing" "time" - . "github.com/bsm/ginkgo" - . "github.com/bsm/gomega" . "github.com/bsm/redislock" "github.com/go-redis/redis/v8" ) const lockKey = "__bsm_redislock_unit_test__" -var _ = Describe("Client", func() { - var subject *Client - var ctx = context.Background() +var redisOpts = &redis.Options{ + Network: "tcp", + Addr: "127.0.0.1:6379", DB: 9, +} - BeforeEach(func() { - subject = New(redisClient) - }) +func TestClient(t *testing.T) { + ctx := context.Background() + rc := redis.NewClient(redisOpts) + defer teardown(t, rc) - AfterEach(func() { - Expect(redisClient.Del(ctx, lockKey).Err()).To(Succeed()) - }) + // init client + client := New(rc) - It("obtains once with TTL", func() { - lock1, err := subject.Obtain(ctx, lockKey, time.Hour, nil) - Expect(err).NotTo(HaveOccurred()) - Expect(lock1.Token()).To(HaveLen(22)) - Expect(lock1.TTL(ctx)).To(BeNumerically("~", time.Hour, time.Second)) - defer lock1.Release(ctx) - - _, err = subject.Obtain(ctx, lockKey, time.Hour, nil) - Expect(err).To(Equal(ErrNotObtained)) - Expect(lock1.Release(ctx)).To(Succeed()) - - lock2, err := subject.Obtain(ctx, lockKey, time.Minute, nil) - Expect(err).NotTo(HaveOccurred()) - Expect(lock2.TTL(ctx)).To(BeNumerically("~", time.Minute, time.Second)) - Expect(lock2.Release(ctx)).To(Succeed()) - }) + // obtain + lock, err := client.Obtain(ctx, lockKey, time.Hour, nil) + if err != nil { + t.Fatal(err) + } + defer lock.Release(ctx) - It("obtains through short-cut", func() { - lock, err := Obtain(ctx, redisClient, lockKey, time.Hour, nil) - Expect(err).NotTo(HaveOccurred()) - Expect(lock.Release(ctx)).To(Succeed()) - }) + if exp, got := 22, len(lock.Token()); exp != got { + t.Fatalf("expected %v, got %v", exp, got) + } - It("supports custom metadata", func() { - lock, err := Obtain(ctx, redisClient, lockKey, time.Hour, &Options{Metadata: "my-data"}) - Expect(err).NotTo(HaveOccurred()) - Expect(lock.Metadata()).To(Equal("my-data")) - Expect(lock.Release(ctx)).To(Succeed()) - }) + // check TTL + assertTTL(t, lock, time.Hour) - It("refreshes", func() { - lock, err := Obtain(ctx, redisClient, lockKey, time.Minute, nil) - Expect(err).NotTo(HaveOccurred()) - Expect(lock.TTL(ctx)).To(BeNumerically("~", time.Minute, time.Second)) - Expect(lock.Refresh(ctx, time.Hour, nil)).To(Succeed()) - Expect(lock.TTL(ctx)).To(BeNumerically("~", time.Hour, time.Second)) - Expect(lock.Release(ctx)).To(Succeed()) - }) + // try to obtain again + _, err = client.Obtain(ctx, lockKey, time.Hour, nil) + if exp, got := ErrNotObtained, err; !errors.Is(got, exp) { + t.Fatalf("expected %v, got %v", exp, got) + } - It("fails to release if expired", func() { - lock, err := Obtain(ctx, redisClient, lockKey, time.Millisecond, nil) - Expect(err).NotTo(HaveOccurred()) - time.Sleep(5 * time.Millisecond) - Expect(lock.Release(ctx)).To(MatchError(ErrLockNotHeld)) - }) + // manually unlock + if err := lock.Release(ctx); err != nil { + t.Fatal(err) + } - It("fails to release if ontained by someone else", func() { - lock, err := Obtain(ctx, redisClient, lockKey, time.Minute, nil) - Expect(err).NotTo(HaveOccurred()) + // lock again + lock, err = client.Obtain(ctx, lockKey, time.Hour, nil) + if err != nil { + t.Fatal(err) + } + defer lock.Release(ctx) +} - Expect(redisClient.Set(ctx, lockKey, "ABCD", 0).Err()).NotTo(HaveOccurred()) - Expect(lock.Release(ctx)).To(MatchError(ErrLockNotHeld)) - }) +func TestObtain(t *testing.T) { + ctx := context.Background() + rc := redis.NewClient(redisOpts) + defer teardown(t, rc) - It("fails to refresh if expired", func() { - lock, err := Obtain(ctx, redisClient, lockKey, time.Millisecond, nil) - Expect(err).NotTo(HaveOccurred()) - time.Sleep(5 * time.Millisecond) - Expect(lock.Refresh(ctx, time.Hour, nil)).To(MatchError(ErrNotObtained)) - }) + lock := quickObtain(t, rc, time.Hour) + if err := lock.Release(ctx); err != nil { + t.Fatal(err) + } +} - It("retries if enabled", func() { - // retry, succeed - Expect(redisClient.Set(ctx, lockKey, "ABCD", 0).Err()).NotTo(HaveOccurred()) - Expect(redisClient.PExpire(ctx, lockKey, 20*time.Millisecond).Err()).NotTo(HaveOccurred()) +func TestObtain_metadata(t *testing.T) { + ctx := context.Background() + rc := redis.NewClient(redisOpts) + defer teardown(t, rc) - lock, err := Obtain(ctx, redisClient, lockKey, time.Hour, &Options{ - RetryStrategy: LimitRetry(LinearBackoff(100*time.Millisecond), 3), - }) - Expect(err).NotTo(HaveOccurred()) - Expect(lock.Release(ctx)).To(Succeed()) + meta := "my-data" + lock, err := Obtain(ctx, rc, lockKey, time.Hour, &Options{Metadata: meta}) + if err != nil { + t.Fatal(err) + } + defer lock.Release(ctx) - // no retry, fail - Expect(redisClient.Set(ctx, lockKey, "ABCD", 0).Err()).NotTo(HaveOccurred()) - Expect(redisClient.PExpire(ctx, lockKey, 50*time.Millisecond).Err()).NotTo(HaveOccurred()) + if exp, got := meta, lock.Metadata(); exp != got { + t.Fatalf("expected %v, got %v", exp, got) + } +} - _, err = Obtain(ctx, redisClient, lockKey, time.Hour, nil) - Expect(err).To(MatchError(ErrNotObtained)) +func TestObtain_retry_success(t *testing.T) { + ctx := context.Background() + rc := redis.NewClient(redisOpts) + defer teardown(t, rc) - // retry 2x, give up & fail - Expect(redisClient.Set(ctx, lockKey, "ABCD", 0).Err()).NotTo(HaveOccurred()) - Expect(redisClient.PExpire(ctx, lockKey, 50*time.Millisecond).Err()).NotTo(HaveOccurred()) + // obtain for 20ms + lock1 := quickObtain(t, rc, 20*time.Millisecond) + defer lock1.Release(ctx) - _, err = Obtain(ctx, redisClient, lockKey, time.Hour, &Options{ - RetryStrategy: LimitRetry(LinearBackoff(time.Millisecond), 2), - }) - Expect(err).To(MatchError(ErrNotObtained)) + // lock again with linar retry - 3x for 20ms + lock2, err := Obtain(ctx, rc, lockKey, time.Hour, &Options{ + RetryStrategy: LimitRetry(LinearBackoff(20*time.Millisecond), 3), }) + if err != nil { + t.Fatal(err) + } + defer lock2.Release(ctx) +} - It("prevents multiple locks (fuzzing)", func() { - numLocks := int32(0) - wg := new(sync.WaitGroup) - for i := 0; i < 100; i++ { - wg.Add(1) +func TestObtain_retry_failure(t *testing.T) { + ctx := context.Background() + rc := redis.NewClient(redisOpts) + defer teardown(t, rc) - go func() { - defer GinkgoRecover() - defer wg.Done() + // obtain for 50ms + lock1 := quickObtain(t, rc, 50*time.Millisecond) + defer lock1.Release(ctx) + + // lock again with linar retry - 2x for 5ms + _, err := Obtain(ctx, rc, lockKey, time.Hour, &Options{ + RetryStrategy: LimitRetry(LinearBackoff(5*time.Millisecond), 2), + }) + if exp, got := ErrNotObtained, err; !errors.Is(got, exp) { + t.Fatalf("expected %v, got %v", exp, got) + } +} - wait := rand.Int63n(int64(10 * time.Millisecond)) - time.Sleep(time.Duration(wait)) +func TestObtain_concurrent(t *testing.T) { + ctx := context.Background() + rc := redis.NewClient(redisOpts) + defer teardown(t, rc) - _, err := subject.Obtain(ctx, lockKey, time.Minute, nil) - if err == ErrNotObtained { - return - } - Expect(err).NotTo(HaveOccurred()) + numLocks := int32(0) + numThreads := 100 + wg := new(sync.WaitGroup) + errs := make(chan error, numThreads) + for i := 0; i < numThreads; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + wait := rand.Int63n(int64(10 * time.Millisecond)) + time.Sleep(time.Duration(wait)) + + _, err := Obtain(ctx, rc, lockKey, time.Minute, nil) + if err == ErrNotObtained { + return + } else if err != nil { + errs <- err + } else { atomic.AddInt32(&numLocks, 1) - }() - } - wg.Wait() - Expect(numLocks).To(Equal(int32(1))) - }) -}) + } + }() + } + wg.Wait() -var _ = Describe("RetryStrategy", func() { - It("supports no-retry", func() { - subject := NoRetry() - Expect(subject.NextBackoff()).To(Equal(time.Duration(0))) - }) + close(errs) + for err := range errs { + t.Fatal(err) + } + if exp, got := 1, int(numLocks); exp != got { + t.Fatalf("expected %v, got %v", exp, got) + } +} - It("supports linear backoff", func() { - subject := LinearBackoff(time.Second) - Expect(subject.NextBackoff()).To(Equal(time.Second)) - Expect(subject.NextBackoff()).To(Equal(time.Second)) - Expect(subject).To(beThreadSafe{}) - }) +func TestLock_Refresh(t *testing.T) { + ctx := context.Background() + rc := redis.NewClient(redisOpts) + defer teardown(t, rc) - It("supports limits", func() { - subject := LimitRetry(LinearBackoff(time.Second), 2) - Expect(subject.NextBackoff()).To(Equal(time.Second)) - Expect(subject.NextBackoff()).To(Equal(time.Second)) - Expect(subject.NextBackoff()).To(Equal(time.Duration(0))) - Expect(subject).To(beThreadSafe{}) - }) + lock := quickObtain(t, rc, time.Hour) + defer lock.Release(ctx) - It("supports exponential backoff", func() { - subject := ExponentialBackoff(10*time.Millisecond, 300*time.Millisecond) - Expect(subject.NextBackoff()).To(Equal(10 * time.Millisecond)) - Expect(subject.NextBackoff()).To(Equal(10 * time.Millisecond)) - Expect(subject.NextBackoff()).To(Equal(16 * time.Millisecond)) - Expect(subject.NextBackoff()).To(Equal(32 * time.Millisecond)) - Expect(subject.NextBackoff()).To(Equal(64 * time.Millisecond)) - Expect(subject.NextBackoff()).To(Equal(128 * time.Millisecond)) - Expect(subject.NextBackoff()).To(Equal(256 * time.Millisecond)) - Expect(subject.NextBackoff()).To(Equal(300 * time.Millisecond)) - Expect(subject.NextBackoff()).To(Equal(300 * time.Millisecond)) - Expect(subject.NextBackoff()).To(Equal(300 * time.Millisecond)) - Expect(subject).To(beThreadSafe{}) - }) -}) + // check TTL + assertTTL(t, lock, time.Hour) -// -------------------------------------------------------------------- + // update TTL + if err := lock.Refresh(ctx, time.Minute, nil); err != nil { + t.Fatal(err) + } -func TestSuite(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "redislock") + // check TTL again + assertTTL(t, lock, time.Minute) } -var redisClient *redis.Client +func TestLock_Refresh_expired(t *testing.T) { + ctx := context.Background() + rc := redis.NewClient(redisOpts) + defer teardown(t, rc) -var _ = BeforeSuite(func() { - redisClient = redis.NewClient(&redis.Options{ - Network: "tcp", - Addr: "127.0.0.1:6379", DB: 9, - }) - Expect(redisClient.Ping(context.Background()).Err()).To(Succeed()) -}) + lock := quickObtain(t, rc, 5*time.Millisecond) + defer lock.Release(ctx) -var _ = AfterSuite(func() { - Expect(redisClient.Close()).To(Succeed()) -}) + // try releasing + time.Sleep(10 * time.Millisecond) + if exp, got := ErrNotObtained, lock.Refresh(ctx, time.Minute, nil); !errors.Is(got, exp) { + t.Fatalf("expected %v, got %v", exp, got) + } +} -type beThreadSafe struct{} +func TestLock_Release_expired(t *testing.T) { + ctx := context.Background() + rc := redis.NewClient(redisOpts) + defer teardown(t, rc) -func (beThreadSafe) Match(actual interface{}) (bool, error) { - strategy, ok := actual.(RetryStrategy) - if !ok { - return false, fmt.Errorf("beThreadSafe matcher expects a RetryStrategy") + lock := quickObtain(t, rc, 5*time.Millisecond) + defer lock.Release(ctx) + + // try releasing + time.Sleep(10 * time.Millisecond) + if exp, got := ErrLockNotHeld, lock.Release(ctx); !errors.Is(got, exp) { + t.Fatalf("expected %v, got %v", exp, got) } +} - wg := new(sync.WaitGroup) - for i := 0; i < 1000; i++ { - wg.Add(1) +func TestLock_Release_not_own(t *testing.T) { + ctx := context.Background() + rc := redis.NewClient(redisOpts) + defer teardown(t, rc) - go func() { - defer GinkgoRecover() - defer wg.Done() + lock := quickObtain(t, rc, time.Hour) + defer lock.Release(ctx) - strategy.NextBackoff() - }() + // manually transfer ownership + if err := rc.Set(ctx, lockKey, "ABCD", 0).Err(); err != nil { + t.Fatal(err) } - wg.Wait() - return true, nil + // try releasing + if exp, got := ErrLockNotHeld, lock.Release(ctx); !errors.Is(got, exp) { + t.Fatalf("expected %v, got %v", exp, got) + } } -func (beThreadSafe) FailureMessage(actual interface{}) (message string) { - return fmt.Sprintf("Expected\n\t%T\nto be thread-safe", actual) +func quickObtain(t *testing.T, rc *redis.Client, ttl time.Duration) *Lock { + t.Helper() + + lock, err := Obtain(context.Background(), rc, lockKey, ttl, nil) + if err != nil { + t.Fatal(err) + } + return lock } -func (beThreadSafe) NegatedFailureMessage(actual interface{}) (message string) { - return fmt.Sprintf("Expected\n\t%T\nnot to be thread-safe", actual) +func assertTTL(t *testing.T, lock *Lock, exp time.Duration) { + t.Helper() + + ttl, err := lock.TTL(context.Background()) + if err != nil { + t.Fatal(err) + } + + delta := ttl - exp + if delta < 0 { + delta = 1 - delta + } + if delta > time.Second { + t.Fatalf("expected ~%v, got %v", exp, ttl) + } +} + +func teardown(t *testing.T, rc *redis.Client) { + t.Helper() + + if err := rc.Del(context.Background(), lockKey).Err(); err != nil { + t.Fatal(err) + } + if err := rc.Close(); err != nil { + t.Fatal(err) + } } diff --git a/retry_strategy_test.go b/retry_strategy_test.go new file mode 100644 index 0000000..fc44fb4 --- /dev/null +++ b/retry_strategy_test.go @@ -0,0 +1,62 @@ +package redislock_test + +import ( + "testing" + "time" + + . "github.com/bsm/redislock" +) + +func TestNoRetry(t *testing.T) { + retry := NoRetry() + for i, exp := range []time.Duration{0, 0, 0} { + if got := retry.NextBackoff(); exp != got { + t.Fatalf("expected %d to be %v, got %v", i, exp, got) + } + } +} + +func TestLinearBackoff(t *testing.T) { + retry := LinearBackoff(time.Second) + for i, exp := range []time.Duration{ + time.Second, + time.Second, + time.Second, + } { + if got := retry.NextBackoff(); exp != got { + t.Fatalf("expected %d to be %v, got %v", i, exp, got) + } + } +} + +func TestExponentialBackoff(t *testing.T) { + retry := ExponentialBackoff(10*time.Millisecond, 300*time.Millisecond) + for i, exp := range []time.Duration{ + 10 * time.Millisecond, + 10 * time.Millisecond, + 16 * time.Millisecond, + 32 * time.Millisecond, + 64 * time.Millisecond, + 128 * time.Millisecond, + 256 * time.Millisecond, + 300 * time.Millisecond, + 300 * time.Millisecond, + } { + if got := retry.NextBackoff(); exp != got { + t.Fatalf("expected %d to be %v, got %v", i, exp, got) + } + } +} + +func TestLimitRetry(t *testing.T) { + retry := LimitRetry(LinearBackoff(time.Second), 2) + for i, exp := range []time.Duration{ + time.Second, + time.Second, + 0, + } { + if got := retry.NextBackoff(); exp != got { + t.Fatalf("expected %d to be %v, got %v", i, exp, got) + } + } +}