diff --git a/zk/lock.go b/zk/lock.go index 3c35a427..8b1692b9 100644 --- a/zk/lock.go +++ b/zk/lock.go @@ -5,6 +5,7 @@ import ( "fmt" "strconv" "strings" + "time" ) var ( @@ -12,6 +13,8 @@ var ( ErrDeadlock = errors.New("zk: trying to acquire a lock twice") // ErrNotLocked is returned by Unlock when trying to release a lock that has not first be acquired. ErrNotLocked = errors.New("zk: not locked") + // ErrTimeout is returned by Lock when trying to lock and timeout is reached before lock is acquired. + ErrTimeout = errors.New("zk: acquire timeout") ) // Lock is a mutual exclusion lock. @@ -20,6 +23,7 @@ type Lock struct { path string acl []ACL lockPath string + locked bool seq int } @@ -39,6 +43,29 @@ func parseSeq(path string) (int, error) { return strconv.Atoi(parts[len(parts)-1]) } +// Lock attempts to acquire the lock. It will wait up to its timeout duration +// to return until the lock is acquired or an error occurs. +// If this instance already has the lock then ErrDeadlock is returned. If timeout +// reached return ErrTimeout is returned. +func (l *Lock) TryLock(timeout time.Duration) error { + var err error + var done = make(chan struct{}, 1) + go func() { + err = l.Lock() + close(done) + }() + select { + case <-time.After(timeout): + l.locked = false + if err := l.c.Delete(l.lockPath, -1); err != nil { + return err + } + return ErrTimeout + case <-done: + return err + } +} + // Lock attempts to acquire the lock. It will wait to return until the lock // is acquired or an error occurs. If this instance already has the lock // then ErrDeadlock is returned. @@ -87,6 +114,9 @@ func (l *Lock) Lock() error { return err } + l.seq = seq + l.lockPath = path + for { children, _, err := l.c.Children(l.path) if err != nil { @@ -130,15 +160,14 @@ func (l *Lock) Lock() error { } } - l.seq = seq - l.lockPath = path + l.locked = true return nil } // Unlock releases an acquired lock. If the lock is not currently acquired by // this Lock instance than ErrNotLocked is returned. func (l *Lock) Unlock() error { - if l.lockPath == "" { + if !l.locked || l.lockPath == "" { return ErrNotLocked } if err := l.c.Delete(l.lockPath, -1); err != nil { @@ -146,5 +175,6 @@ func (l *Lock) Unlock() error { } l.lockPath = "" l.seq = 0 + l.locked = false return nil } diff --git a/zk/lock_test.go b/zk/lock_test.go index 8a3478a3..11f28290 100644 --- a/zk/lock_test.go +++ b/zk/lock_test.go @@ -5,6 +5,36 @@ import ( "time" ) +func TestTryLock(t *testing.T) { + ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "}) + if err != nil { + t.Fatal(err) + } + defer ts.Stop() + zk, _, err := ts.ConnectAll() + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + defer zk.Close() + acls := WorldACL(PermAll) + l := NewLock(zk, "/test", acls) + if err := l.TryLock(time.Second); err != nil { + t.Fatal(err) + } + if err := l.Unlock(); err != nil { + t.Fatal(err) + } + // + if err := l.Lock(); err != nil { + t.Fatal(err) + } + defer l.Unlock() + // should return timeout err since lock is not released + if err := l.TryLock(time.Second); err == nil { + t.Fatal(err) + } +} + func TestLock(t *testing.T) { ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "}) if err != nil {