Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the lock a thread local variable #219

Merged
merged 8 commits into from
Apr 6, 2023
37 changes: 15 additions & 22 deletions src/filelock/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import warnings
from abc import ABC, abstractmethod
from threading import Lock
from threading import local
from types import TracebackType
from typing import Any

Expand Down Expand Up @@ -36,7 +36,7 @@ def __exit__(
self.lock.release()


class BaseFileLock(ABC, contextlib.ContextDecorator):
class BaseFileLock(ABC, contextlib.ContextDecorator, local):
"""Abstract base class for a file lock object."""

def __init__(
Expand Down Expand Up @@ -67,9 +67,6 @@ def __init__(
# The mode for the lock files
self._mode: int = mode

# We use this lock primarily for the lock counter.
self._thread_lock: Lock = Lock()

# The lock counter is used for implementing the nested locking mechanism. Whenever the lock is acquired, the
# counter is increased and the lock is only released, when this value is 0 again.
self._lock_counter: int = 0
Expand Down Expand Up @@ -168,18 +165,16 @@ def acquire(
poll_interval = poll_intervall

# Increment the number right at the beginning. We can still undo it, if something fails.
with self._thread_lock:
self._lock_counter += 1
self._lock_counter += 1

lock_id = id(self)
lock_filename = self._lock_file
start_time = time.perf_counter()
try:
while True:
with self._thread_lock:
if not self.is_locked:
_LOGGER.debug("Attempting to acquire lock %s on %s", lock_id, lock_filename)
self._acquire()
if not self.is_locked:
_LOGGER.debug("Attempting to acquire lock %s on %s", lock_id, lock_filename)
self._acquire()
if self.is_locked:
_LOGGER.debug("Lock %s acquired on %s", lock_id, lock_filename)
break
Expand All @@ -194,8 +189,7 @@ def acquire(
_LOGGER.debug(msg, lock_id, lock_filename, poll_interval)
time.sleep(poll_interval)
except BaseException: # Something did go wrong, so decrement the counter.
with self._thread_lock:
self._lock_counter = max(0, self._lock_counter - 1)
self._lock_counter = max(0, self._lock_counter - 1)
raise
return AcquireReturnProxy(lock=self)

Expand All @@ -206,17 +200,16 @@ def release(self, force: bool = False) -> None:

:param force: If true, the lock counter is ignored and the lock is released in every case/
"""
with self._thread_lock:
if self.is_locked:
self._lock_counter -= 1
if self.is_locked:
self._lock_counter -= 1

if self._lock_counter == 0 or force:
lock_id, lock_filename = id(self), self._lock_file
if self._lock_counter == 0 or force:
lock_id, lock_filename = id(self), self._lock_file

_LOGGER.debug("Attempting to release lock %s on %s", lock_id, lock_filename)
self._release()
self._lock_counter = 0
_LOGGER.debug("Lock %s released on %s", lock_id, lock_filename)
_LOGGER.debug("Attempting to release lock %s on %s", lock_id, lock_filename)
self._release()
self._lock_counter = 0
_LOGGER.debug("Lock %s released on %s", lock_id, lock_filename)

def __enter__(self) -> BaseFileLock:
"""
Expand Down
67 changes: 67 additions & 0 deletions tests/test_filelock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import os
import sys
import threading
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from errno import ENOSYS
from inspect import getframeinfo, stack
from pathlib import Path, PurePath
from stat import S_IWGRP, S_IWOTH, S_IWUSR, filemode
from types import TracebackType
from typing import Callable, Iterator, Tuple, Type, Union
from uuid import uuid4

import pytest
from _pytest.logging import LogCaptureFixture
Expand Down Expand Up @@ -81,6 +83,10 @@ def tmp_path_ro(tmp_path: Path) -> Iterator[Path]:

@pytest.mark.parametrize("lock_type", [FileLock, SoftFileLock])
@pytest.mark.skipif(sys.platform == "win32", reason="Windows does not have read only folders")
@pytest.mark.skipif(
sys.platform != "win32" and os.geteuid() == 0, # noqa: SC200
reason="Cannot make a read only file (that the current user: root can't read)",
)
def test_ro_folder(lock_type: type[BaseFileLock], tmp_path_ro: Path) -> None:
lock = lock_type(str(tmp_path_ro / "a"))
with pytest.raises(PermissionError, match="Permission denied"):
Expand All @@ -96,6 +102,10 @@ def tmp_file_ro(tmp_path: Path) -> Iterator[Path]:


@pytest.mark.parametrize("lock_type", [FileLock, SoftFileLock])
@pytest.mark.skipif(
sys.platform != "win32" and os.geteuid() == 0, # noqa: SC200
reason="Cannot make a read only file (that the current user: root can't read)",
)
def test_ro_file(lock_type: type[BaseFileLock], tmp_file_ro: Path) -> None:
lock = lock_type(str(tmp_file_ro))
with pytest.raises(PermissionError, match="Permission denied"):
Expand Down Expand Up @@ -509,3 +519,60 @@ def test_soft_errors(tmp_path: Path, mocker: MockerFixture) -> None:
mocker.patch("os.open", side_effect=OSError(ENOSYS, "mock error"))
with pytest.raises(OSError, match="mock error"):
SoftFileLock(tmp_path / "a.lock").acquire()


def _check_file_read_write(txt_file: Path) -> None:
for _ in range(3):
uuid = str(uuid4())
txt_file.write_text(uuid)
assert txt_file.read_text() == uuid


@pytest.mark.parametrize("lock_type", [FileLock, SoftFileLock])
def test_thrashing_with_thread_pool_passing_lock_to_threads(tmp_path: Path, lock_type: type[BaseFileLock]) -> None:
def mess_with_file(lock_: BaseFileLock) -> None:
with lock_:
_check_file_read_write(txt_file)

lock_file, txt_file = tmp_path / "test.txt.lock", tmp_path / "test.txt"
lock = lock_type(lock_file)
results = []
with ThreadPoolExecutor() as executor:
for _ in range(100):
results.append(executor.submit(mess_with_file, lock))

assert all(r.result() is None for r in results)


@pytest.mark.parametrize("lock_type", [FileLock, SoftFileLock])
def test_thrashing_with_thread_pool_global_lock(tmp_path: Path, lock_type: type[BaseFileLock]) -> None:
def mess_with_file() -> None:
with lock:
_check_file_read_write(txt_file)

lock_file, txt_file = tmp_path / "test.txt.lock", tmp_path / "test.txt"
lock = lock_type(lock_file)
results = []
with ThreadPoolExecutor() as executor:
for _ in range(100):
results.append(executor.submit(mess_with_file))

assert all(r.result() is None for r in results)


@pytest.mark.parametrize("lock_type", [FileLock, SoftFileLock])
def test_thrashing_with_thread_pool_lock_recreated_in_each_thread(
tmp_path: Path,
lock_type: type[BaseFileLock],
) -> None:
def mess_with_file() -> None:
with lock_type(lock_file):
_check_file_read_write(txt_file)

lock_file, txt_file = tmp_path / "test.txt.lock", tmp_path / "test.txt"
results = []
with ThreadPoolExecutor() as executor:
for _ in range(100):
results.append(executor.submit(mess_with_file))

assert all(r.result() is None for r in results)