Skip to content

Commit

Permalink
Merge pull request #171 from tobymao/toby/fix_solo_cron
Browse files Browse the repository at this point in the history
fix: concurrency 1 needs to queue cron jobs
  • Loading branch information
tobymao authored Oct 5, 2024
2 parents b0f7a61 + 99d23c5 commit f38422b
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 57 deletions.
117 changes: 61 additions & 56 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

CHANNEL = "saq:{}"
ENQUEUE = "saq:enqueue"
DEQUEUE = "saq:dequeue"
JOBS_TABLE = "saq_jobs"
STATS_TABLE = "saq_stats"

Expand Down Expand Up @@ -126,14 +127,11 @@ async def init_db(self) -> None:
)

async def connect(self) -> None:
if self._dequeue_conn:
# If connection exists, connect() was already called
if self.pool._opened:
return

await self.pool.open()
await self.pool.resize(min_size=self.min_size, max_size=self.max_size)
# Reserve a connection for dequeue and advisory locks
self._dequeue_conn = await self.pool.getconn()
await self.init_db()

def serialize(self, job: Job) -> bytes | str:
Expand Down Expand Up @@ -531,8 +529,9 @@ async def dequeue(self, timeout: float = 0) -> Job | None:
)
else:
async with self._listen_lock:
async for _ in self._listener.listen(ENQUEUE, timeout=timeout):
await self._dequeue()
async for payload in self._listener.listen(ENQUEUE, DEQUEUE, timeout=timeout):
if payload["key"] == ENQUEUE:
await self._dequeue()

if not self._job_queue.empty():
job = self._job_queue.get_nowait()
Expand All @@ -547,6 +546,53 @@ async def dequeue(self, timeout: float = 0) -> Job | None:

return job

async def _dequeue(self) -> None:
if self._dequeue_lock.locked():
return

async with self._dequeue_lock:
async with self._get_dequeue_conn() as conn, conn.cursor() as cursor, conn.transaction():
if not self._waiting:
return
await cursor.execute(
SQL(
dedent(
"""
WITH locked_job AS (
SELECT key, lock_key
FROM {jobs_table}
WHERE status = 'queued'
AND queue = %(queue)s
AND %(now)s >= scheduled
ORDER BY scheduled
LIMIT %(limit)s
FOR UPDATE SKIP LOCKED
)
UPDATE {jobs_table} SET status = 'active'
FROM locked_job
WHERE {jobs_table}.key = locked_job.key
AND pg_try_advisory_lock({job_lock_keyspace}, locked_job.lock_key)
RETURNING job
"""
)
).format(
jobs_table=self.jobs_table,
job_lock_keyspace=self.job_lock_keyspace,
),
{
"queue": self.name,
"now": math.ceil(seconds(now())),
"limit": self._waiting,
},
)
results = await cursor.fetchall()

for result in results:
self._job_queue.put_nowait(self.deserialize(result[0]))

if results:
await self._notify(DEQUEUE)

async def _enqueue(self, job: Job) -> Job | None:
async with self.pool.connection() as conn, conn.cursor() as cursor:
await cursor.execute(
Expand Down Expand Up @@ -676,49 +722,6 @@ async def _finish(
await self.notify(job, conn)
await self._release_job(key)

async def _dequeue(self) -> None:
if self._dequeue_lock.locked():
return

async with self._dequeue_lock:
async with self._get_dequeue_conn() as conn, conn.cursor() as cursor, conn.transaction():
if not self._waiting:
return
await cursor.execute(
SQL(
dedent(
"""
WITH locked_job AS (
SELECT key, lock_key
FROM {jobs_table}
WHERE status = 'queued'
AND queue = %(queue)s
AND %(now)s >= scheduled
ORDER BY scheduled
LIMIT %(limit)s
FOR UPDATE SKIP LOCKED
)
UPDATE {jobs_table} SET status = 'active'
FROM locked_job
WHERE {jobs_table}.key = locked_job.key
AND pg_try_advisory_lock({job_lock_keyspace}, locked_job.lock_key)
RETURNING job
"""
)
).format(
jobs_table=self.jobs_table,
job_lock_keyspace=self.job_lock_keyspace,
),
{
"queue": self.name,
"now": math.ceil(seconds(now())),
"limit": self._waiting,
},
)
results = await cursor.fetchall()
for result in results:
self._job_queue.put_nowait(self.deserialize(result[0]))

async def _notify(
self, key: str, data: t.Any | None = None, connection: AsyncConnection | None = None
) -> None:
Expand All @@ -736,14 +739,16 @@ async def _notify(

@asynccontextmanager
async def _get_dequeue_conn(self) -> t.AsyncGenerator:
assert self._dequeue_conn
async with self._connection_lock:
try:
# Pool normally performs this check when getting a connection.
await self.pool.check_connection(self._dequeue_conn)
except OperationalError:
# The connection is bad so return it to the pool and get a new one.
await self.pool.putconn(self._dequeue_conn)
if self._dequeue_conn:
try:
# Pool normally performs this check when getting a connection.
await self.pool.check_connection(self._dequeue_conn)
except OperationalError:
# The connection is bad so return it to the pool and get a new one.
await self.pool.putconn(self._dequeue_conn)
self._dequeue_conn = await self.pool.getconn()
else:
self._dequeue_conn = await self.pool.getconn()
yield self._dequeue_conn

Expand Down
5 changes: 4 additions & 1 deletion tests/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,10 @@ async def test_finish_ttl_negative(self, mock_time: MagicMock) -> None:

async def test_bad_connection(self) -> None:
job = await self.enqueue("test")
original_connection = self.queue._dequeue_conn

async with self.queue._get_dequeue_conn() as original_connection:
pass

await original_connection.close()
# Test dequeue still works
self.assertEqual((await self.dequeue()), job)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,22 @@ async def handler(_ctx: Context) -> None:
await asyncio.sleep(6)
self.assertEqual(state["counter"], 0)

async def test_cron_solo_worker(self) -> None:
state = {"counter": 0}

async def handler(_ctx: Context) -> None:
state["counter"] += 1

self.worker = Worker(
self.queue,
functions=[],
cron_jobs=[CronJob(handler, cron="* * * * * */1")],
concurrency=1,
)
asyncio.create_task(self.worker.start())
await asyncio.sleep(2)
self.assertGreater(state["counter"], 0)


class TestWorkerRedisQueue(TestWorker):
async def asyncSetUp(self) -> None:
Expand Down

0 comments on commit f38422b

Please sign in to comment.