diff --git a/src/filelock/_api.py b/src/filelock/_api.py index 50f1c3d..958369a 100644 --- a/src/filelock/_api.py +++ b/src/filelock/_api.py @@ -6,7 +6,7 @@ import time import warnings from abc import ABC, abstractmethod -from threading import Lock +from threading import local from types import TracebackType from typing import Any @@ -36,7 +36,7 @@ def __exit__( self.lock.release() -class BaseFileLock(ABC, contextlib.ContextDecorator): +class BaseFileLock(ABC, contextlib.ContextDecorator, local): """Abstract base class for a file lock object.""" def __init__( @@ -67,9 +67,6 @@ def __init__( # The mode for the lock files self._mode: int = mode - # We use this lock primarily for the lock counter. - self._thread_lock: Lock = Lock() - # The lock counter is used for implementing the nested locking mechanism. Whenever the lock is acquired, the # counter is increased and the lock is only released, when this value is 0 again. self._lock_counter: int = 0 @@ -168,18 +165,16 @@ def acquire( poll_interval = poll_intervall # Increment the number right at the beginning. We can still undo it, if something fails. - with self._thread_lock: - self._lock_counter += 1 + self._lock_counter += 1 lock_id = id(self) lock_filename = self._lock_file start_time = time.perf_counter() try: while True: - with self._thread_lock: - if not self.is_locked: - _LOGGER.debug("Attempting to acquire lock %s on %s", lock_id, lock_filename) - self._acquire() + if not self.is_locked: + _LOGGER.debug("Attempting to acquire lock %s on %s", lock_id, lock_filename) + self._acquire() if self.is_locked: _LOGGER.debug("Lock %s acquired on %s", lock_id, lock_filename) break @@ -194,8 +189,7 @@ def acquire( _LOGGER.debug(msg, lock_id, lock_filename, poll_interval) time.sleep(poll_interval) except BaseException: # Something did go wrong, so decrement the counter. - with self._thread_lock: - self._lock_counter = max(0, self._lock_counter - 1) + self._lock_counter = max(0, self._lock_counter - 1) raise return AcquireReturnProxy(lock=self) @@ -206,17 +200,16 @@ def release(self, force: bool = False) -> None: :param force: If true, the lock counter is ignored and the lock is released in every case/ """ - with self._thread_lock: - if self.is_locked: - self._lock_counter -= 1 + if self.is_locked: + self._lock_counter -= 1 - if self._lock_counter == 0 or force: - lock_id, lock_filename = id(self), self._lock_file + if self._lock_counter == 0 or force: + lock_id, lock_filename = id(self), self._lock_file - _LOGGER.debug("Attempting to release lock %s on %s", lock_id, lock_filename) - self._release() - self._lock_counter = 0 - _LOGGER.debug("Lock %s released on %s", lock_id, lock_filename) + _LOGGER.debug("Attempting to release lock %s on %s", lock_id, lock_filename) + self._release() + self._lock_counter = 0 + _LOGGER.debug("Lock %s released on %s", lock_id, lock_filename) def __enter__(self) -> BaseFileLock: """ diff --git a/tests/test_filelock.py b/tests/test_filelock.py index 406a62d..f261b7c 100644 --- a/tests/test_filelock.py +++ b/tests/test_filelock.py @@ -5,6 +5,7 @@ import os import sys import threading +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from errno import ENOSYS from inspect import getframeinfo, stack @@ -12,6 +13,7 @@ from stat import S_IWGRP, S_IWOTH, S_IWUSR, filemode from types import TracebackType from typing import Callable, Iterator, Tuple, Type, Union +from uuid import uuid4 import pytest from _pytest.logging import LogCaptureFixture @@ -81,6 +83,10 @@ def tmp_path_ro(tmp_path: Path) -> Iterator[Path]: @pytest.mark.parametrize("lock_type", [FileLock, SoftFileLock]) @pytest.mark.skipif(sys.platform == "win32", reason="Windows does not have read only folders") +@pytest.mark.skipif( + sys.platform != "win32" and os.geteuid() == 0, # noqa: SC200 + reason="Cannot make a read only file (that the current user: root can't read)", +) def test_ro_folder(lock_type: type[BaseFileLock], tmp_path_ro: Path) -> None: lock = lock_type(str(tmp_path_ro / "a")) with pytest.raises(PermissionError, match="Permission denied"): @@ -96,6 +102,10 @@ def tmp_file_ro(tmp_path: Path) -> Iterator[Path]: @pytest.mark.parametrize("lock_type", [FileLock, SoftFileLock]) +@pytest.mark.skipif( + sys.platform != "win32" and os.geteuid() == 0, # noqa: SC200 + reason="Cannot make a read only file (that the current user: root can't read)", +) def test_ro_file(lock_type: type[BaseFileLock], tmp_file_ro: Path) -> None: lock = lock_type(str(tmp_file_ro)) with pytest.raises(PermissionError, match="Permission denied"): @@ -509,3 +519,60 @@ def test_soft_errors(tmp_path: Path, mocker: MockerFixture) -> None: mocker.patch("os.open", side_effect=OSError(ENOSYS, "mock error")) with pytest.raises(OSError, match="mock error"): SoftFileLock(tmp_path / "a.lock").acquire() + + +def _check_file_read_write(txt_file: Path) -> None: + for _ in range(3): + uuid = str(uuid4()) + txt_file.write_text(uuid) + assert txt_file.read_text() == uuid + + +@pytest.mark.parametrize("lock_type", [FileLock, SoftFileLock]) +def test_thrashing_with_thread_pool_passing_lock_to_threads(tmp_path: Path, lock_type: type[BaseFileLock]) -> None: + def mess_with_file(lock_: BaseFileLock) -> None: + with lock_: + _check_file_read_write(txt_file) + + lock_file, txt_file = tmp_path / "test.txt.lock", tmp_path / "test.txt" + lock = lock_type(lock_file) + results = [] + with ThreadPoolExecutor() as executor: + for _ in range(100): + results.append(executor.submit(mess_with_file, lock)) + + assert all(r.result() is None for r in results) + + +@pytest.mark.parametrize("lock_type", [FileLock, SoftFileLock]) +def test_thrashing_with_thread_pool_global_lock(tmp_path: Path, lock_type: type[BaseFileLock]) -> None: + def mess_with_file() -> None: + with lock: + _check_file_read_write(txt_file) + + lock_file, txt_file = tmp_path / "test.txt.lock", tmp_path / "test.txt" + lock = lock_type(lock_file) + results = [] + with ThreadPoolExecutor() as executor: + for _ in range(100): + results.append(executor.submit(mess_with_file)) + + assert all(r.result() is None for r in results) + + +@pytest.mark.parametrize("lock_type", [FileLock, SoftFileLock]) +def test_thrashing_with_thread_pool_lock_recreated_in_each_thread( + tmp_path: Path, + lock_type: type[BaseFileLock], +) -> None: + def mess_with_file() -> None: + with lock_type(lock_file): + _check_file_read_write(txt_file) + + lock_file, txt_file = tmp_path / "test.txt.lock", tmp_path / "test.txt" + results = [] + with ThreadPoolExecutor() as executor: + for _ in range(100): + results.append(executor.submit(mess_with_file)) + + assert all(r.result() is None for r in results)