diff --git a/middleware/concurrent/dequeuer.go b/middleware/concurrent/dequeuer.go index f04f590..bdd31a9 100644 --- a/middleware/concurrent/dequeuer.go +++ b/middleware/concurrent/dequeuer.go @@ -1,9 +1,12 @@ package concurrent import ( + "fmt" + "github.com/go-redis/redis/v7" "github.com/google/uuid" "github.com/taylorchu/work" + "github.com/taylorchu/work/redislock" ) // DequeuerOptions defines how many jobs in the same queue can be running at the same time. @@ -17,64 +20,29 @@ type DequeuerOptions struct { // Dequeuer limits running job count from a queue. func Dequeuer(copt *DequeuerOptions) work.DequeueMiddleware { - lockScript := redis.NewScript(` - local ns = ARGV[1] - local queue_id = ARGV[2] - local at = tonumber(ARGV[3]) - local invis_sec = tonumber(ARGV[4]) - local worker_id = ARGV[5] - local max = tonumber(ARGV[6]) - local lock_key = table.concat({ns, "lock", queue_id}, ":") - - -- refresh expiry - redis.call("expire", lock_key, invis_sec) - - -- remove stale entries - redis.call("zremrangebyscore", lock_key, "-inf", at) - - if redis.call("zcard", lock_key) < max then - return redis.call("zadd", lock_key, "nx", at + invis_sec, worker_id) - end - return 0 - `) - unlockScript := redis.NewScript(` - local ns = ARGV[1] - local queue_id = ARGV[2] - local worker_id = ARGV[3] - local lock_key = table.concat({ns, "lock", queue_id}, ":") - - return redis.call("zrem", lock_key, worker_id) - `) return func(f work.DequeueFunc) work.DequeueFunc { workerID := copt.workerID if workerID == "" { workerID = uuid.New().String() } return func(opt *work.DequeueOptions) (*work.Job, error) { - err := opt.Validate() - if err != nil { - return nil, err + lock := &redislock.Lock{ + Client: copt.Client, + Key: fmt.Sprintf("%s:lock:%s", opt.Namespace, opt.QueueID), + ID: workerID, + At: opt.At, + ExpireInSec: opt.InvisibleSec, + MaxAcquirers: copt.Max, } - acquired, err := lockScript.Run(copt.Client, nil, - opt.Namespace, - opt.QueueID, - opt.At.Unix(), - opt.InvisibleSec, - workerID, - copt.Max, - ).Int64() + acquired, err := lock.Acquire() if err != nil { return nil, err } - if acquired == 0 { + if !acquired { return nil, work.ErrEmptyQueue } if !copt.disableUnlock { - defer unlockScript.Run(copt.Client, nil, - opt.Namespace, - opt.QueueID, - workerID, - ) + defer lock.Release() } return f(opt) } diff --git a/redislock/lock.go b/redislock/lock.go new file mode 100644 index 0000000..bd83d95 --- /dev/null +++ b/redislock/lock.go @@ -0,0 +1,73 @@ +package redislock + +import ( + "time" + + "github.com/go-redis/redis/v7" +) + +// Lock supports expiring lock with multiple acquirers. +type Lock struct { + Client redis.UniversalClient + + Key string + ID string + + At time.Time + ExpireInSec int64 + MaxAcquirers int64 +} + +// Acquire creates the lock if possible. +// If it is acquired, true is returned. +// Call Release to unlock. +func (l *Lock) Acquire() (bool, error) { + lockScript := redis.NewScript(` + local lock_key = ARGV[1] + local lock_id = ARGV[2] + local at = tonumber(ARGV[3]) + local expire_in_sec = tonumber(ARGV[4]) + local max_acquirers = tonumber(ARGV[5]) + + -- refresh expiry + redis.call("expire", lock_key, expire_in_sec) + + -- remove stale entries + redis.call("zremrangebyscore", lock_key, "-inf", at) + + if redis.call("zcard", lock_key) < max_acquirers then + return redis.call("zadd", lock_key, "nx", at + expire_in_sec, lock_id) + end + return 0 + `) + + acquired, err := lockScript.Run(l.Client, nil, + l.Key, + l.ID, + l.At.Unix(), + l.ExpireInSec, + l.MaxAcquirers, + ).Int64() + if err != nil { + return false, err + } + return acquired > 0, nil +} + +// Release clears the lock. +func (l *Lock) Release() error { + unlockScript := redis.NewScript(` + local lock_key = ARGV[1] + local lock_id = ARGV[2] + + return redis.call("zrem", lock_key, lock_id) + `) + err := unlockScript.Run(l.Client, nil, + l.Key, + l.ID, + ).Err() + if err != nil { + return err + } + return nil +}