From c028e387e8fc47be55c18f4df4a10b29ab85a527 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Fri, 20 May 2022 19:12:47 -0500 Subject: [PATCH] Increase lock scope in ModuleCache.refresh --- aesara/link/c/cmodule.py | 458 ++++++++++++++++++----------------- tests/link/c/test_cmodule.py | 47 ++++ 2 files changed, 279 insertions(+), 226 deletions(-) diff --git a/aesara/link/c/cmodule.py b/aesara/link/c/cmodule.py index c6024d3a3f..ecae473f9f 100644 --- a/aesara/link/c/cmodule.py +++ b/aesara/link/c/cmodule.py @@ -785,263 +785,269 @@ def rmtree_empty(*args, **kwargs): except OSError: # This can happen if the dir don't exist. subdirs = [] - files, root = None, None # To make sure the "del" below works - for subdirs_elem in subdirs: - # Never clean/remove lock_dir - if subdirs_elem == "lock_dir": - continue - root = os.path.join(self.dirname, subdirs_elem) - key_pkl = os.path.join(root, "key.pkl") - if key_pkl in self.loaded_key_pkl: - continue - if not os.path.isdir(root): - continue - files = os.listdir(root) - if not files: - rmtree_empty(root, ignore_nocleanup=True, msg="empty dir") - continue - if "delete.me" in files: - rmtree(root, ignore_nocleanup=True, msg="delete.me found in dir") - continue - elif "key.pkl" in files: - try: - entry = module_name_from_dir(root, files=files) - except ValueError: # there is a key but no dll! - if not root.startswith("/tmp"): - # Under /tmp, file are removed periodically by the - # os. So it is normal that this happens from time - # to time. - _logger.warning( - "ModuleCache.refresh() Found key " - f"without dll in cache, deleting it. {key_pkl}", - ) - rmtree( - root, - ignore_nocleanup=True, - msg="missing module file", - level=logging.INFO, - ) - continue - if (time_now - last_access_time(entry)) < age_thresh_use: - _logger.debug(f"refresh adding {key_pkl}") - def unpickle_failure(): - _logger.info( - f"ModuleCache.refresh() Failed to unpickle cache file {key_pkl}", - ) + files, root = None, None # To make sure the "del" below works + # Collections used by external (and potentially asynchronous) + # compilation processes are modified in the following loop, so we need + # to lock on the compilation directory so that those processes don't + # work with stale/invalid data + with lock_ctx(): + for subdirs_elem in subdirs: + # Never clean/remove lock_dir + if subdirs_elem == "lock_dir": + continue + root = os.path.join(self.dirname, subdirs_elem) + key_pkl = os.path.join(root, "key.pkl") + if key_pkl in self.loaded_key_pkl: + continue + if not os.path.isdir(root): + continue + files = os.listdir(root) + if not files: + rmtree_empty(root, ignore_nocleanup=True, msg="empty dir") + continue + if "delete.me" in files: + rmtree(root, ignore_nocleanup=True, msg="delete.me found in dir") + continue + elif "key.pkl" in files: try: - with open(key_pkl, "rb") as f: - key_data = pickle.load(f) - except EOFError: - # Happened once... not sure why (would be worth - # investigating if it ever happens again). - unpickle_failure() + entry = module_name_from_dir(root, files=files) + except ValueError: # there is a key but no dll! + if not root.startswith("/tmp"): + # Under /tmp, file are removed periodically by the + # os. So it is normal that this happens from time + # to time. + _logger.warning( + "ModuleCache.refresh() Found key " + f"without dll in cache, deleting it. {key_pkl}", + ) rmtree( root, ignore_nocleanup=True, - msg="broken cache directory [EOF]", - level=logging.WARNING, + msg="missing module file", + level=logging.INFO, ) continue - except Exception: - unpickle_failure() - if delete_if_problem: + if (time_now - last_access_time(entry)) < age_thresh_use: + _logger.debug(f"refresh adding {key_pkl}") + + def unpickle_failure(): + _logger.info( + f"ModuleCache.refresh() Failed to unpickle cache file {key_pkl}", + ) + + try: + with open(key_pkl, "rb") as f: + key_data = pickle.load(f) + except EOFError: + # Happened once... not sure why (would be worth + # investigating if it ever happens again). + unpickle_failure() rmtree( root, ignore_nocleanup=True, - msg="broken cache directory", - level=logging.INFO, + msg="broken cache directory [EOF]", + level=logging.WARNING, ) - else: - # This exception is often triggered by keys - # that contain references to classes that have - # not yet been imported (e.g. when running two - # different Aesara-based scripts). They are not - # necessarily broken, but we cannot load them - # now. They will be loaded later if needed. - pass - continue - - if not isinstance(key_data, KeyData): - # This is some old cache data, that does not fit - # the new cache format. It would be possible to - # update it, but it is not entirely safe since we - # do not know the config options that were used. - # As a result, we delete it instead (which is also - # simpler to implement). - rmtree( - root, - ignore_nocleanup=True, - msg=( - "invalid cache entry format -- this " - "should not happen unless your cache " - "was really old" - ), - level=logging.WARN, - ) - continue + continue + except Exception: + unpickle_failure() + if delete_if_problem: + rmtree( + root, + ignore_nocleanup=True, + msg="broken cache directory", + level=logging.INFO, + ) + else: + # This exception is often triggered by keys + # that contain references to classes that have + # not yet been imported (e.g. when running two + # different Aesara-based scripts). They are not + # necessarily broken, but we cannot load them + # now. They will be loaded later if needed. + pass + continue - # Check the path to the module stored in the KeyData - # object matches the path to `entry`. There may be - # a mismatch e.g. due to symlinks, or some directory - # being renamed since last time cache was created. - kd_entry = key_data.get_entry() - if kd_entry != entry: - if is_same_entry(entry, kd_entry): - # Update KeyData object. Note that we also need - # to update the key_pkl field, because it is - # likely to be incorrect if the entry itself - # was wrong. - key_data.entry = entry - key_data.key_pkl = key_pkl - else: - # This is suspicious. Better get rid of it. + if not isinstance(key_data, KeyData): + # This is some old cache data, that does not fit + # the new cache format. It would be possible to + # update it, but it is not entirely safe since we + # do not know the config options that were used. + # As a result, we delete it instead (which is also + # simpler to implement). rmtree( root, ignore_nocleanup=True, - msg="module file path mismatch", - level=logging.INFO, + msg=( + "invalid cache entry format -- this " + "should not happen unless your cache " + "was really old" + ), + level=logging.WARN, ) continue - # Find unversioned keys from other processes. - # TODO: check if this can happen at all - to_del = [key for key in key_data.keys if not key[0]] - if to_del: - _logger.warning( - "ModuleCache.refresh() Found unversioned " - f"key in cache, removing it. {key_pkl}", - ) - # Since the version is in the module hash, all - # keys should be unversioned. - if len(to_del) != len(key_data.keys): - _logger.warning( - "Found a mix of unversioned and " - "versioned keys for the same " - f"module {key_pkl}", - ) - rmtree( - root, - ignore_nocleanup=True, - msg="unversioned key(s) in cache", - level=logging.INFO, - ) - continue - - mod_hash = key_data.module_hash - if mod_hash in self.module_hash_to_key_data: - # This may happen when two processes running - # simultaneously compiled the same module, one - # after the other. We delete one once it is old - # enough (to be confident there is no other process - # using it), or if `delete_if_problem` is True. - # Note that it is important to walk through - # directories in alphabetical order so as to make - # sure all new processes only use the first one. - if cleanup: - age = time.time() - last_access_time(entry) - if delete_if_problem or age > self.age_thresh_del: + # Check the path to the module stored in the KeyData + # object matches the path to `entry`. There may be + # a mismatch e.g. due to symlinks, or some directory + # being renamed since last time cache was created. + kd_entry = key_data.get_entry() + if kd_entry != entry: + if is_same_entry(entry, kd_entry): + # Update KeyData object. Note that we also need + # to update the key_pkl field, because it is + # likely to be incorrect if the entry itself + # was wrong. + key_data.entry = entry + key_data.key_pkl = key_pkl + else: + # This is suspicious. Better get rid of it. rmtree( root, ignore_nocleanup=True, - msg="duplicated module", - level=logging.DEBUG, - ) - else: - _logger.debug( - "Found duplicated module not " - "old enough yet to be deleted " - f"(age: {age}): {entry}", + msg="module file path mismatch", + level=logging.INFO, ) - continue + continue - # Remember the map from a module's hash to the KeyData - # object associated with it. - self.module_hash_to_key_data[mod_hash] = key_data - - for key in key_data.keys: - if key not in self.entry_from_key: - self.entry_from_key[key] = entry - # Assert that we have not already got this - # entry somehow. - assert entry not in self.module_from_name - # Store safe part of versioned keys. - if key[0]: - self.similar_keys.setdefault( - get_safe_part(key), [] - ).append(key) - else: - dir1 = os.path.dirname(self.entry_from_key[key]) - dir2 = os.path.dirname(entry) + # Find unversioned keys from other processes. + # TODO: check if this can happen at all + to_del = [key for key in key_data.keys if not key[0]] + if to_del: _logger.warning( - "The same cache key is associated to " - f"different modules ({dir1} and {dir2}). This " - "is not supposed to happen! You may " - "need to manually delete your cache " - "directory to fix this.", + "ModuleCache.refresh() Found unversioned " + f"key in cache, removing it. {key_pkl}", ) - # Clean up the name space to prevent bug. - if key_data.keys: - del key - self.loaded_key_pkl.add(key_pkl) - else: - too_old_to_use.append(entry) + # Since the version is in the module hash, all + # keys should be unversioned. + if len(to_del) != len(key_data.keys): + _logger.warning( + "Found a mix of unversioned and " + "versioned keys for the same " + f"module {key_pkl}", + ) + rmtree( + root, + ignore_nocleanup=True, + msg="unversioned key(s) in cache", + level=logging.INFO, + ) + continue - # If the compilation failed, no key.pkl is in that - # directory, but a mod.* should be there. - # We do nothing here. + mod_hash = key_data.module_hash + if mod_hash in self.module_hash_to_key_data: + # This may happen when two processes running + # simultaneously compiled the same module, one + # after the other. We delete one once it is old + # enough (to be confident there is no other process + # using it), or if `delete_if_problem` is True. + # Note that it is important to walk through + # directories in alphabetical order so as to make + # sure all new processes only use the first one. + if cleanup: + age = time.time() - last_access_time(entry) + if delete_if_problem or age > self.age_thresh_del: + rmtree( + root, + ignore_nocleanup=True, + msg="duplicated module", + level=logging.DEBUG, + ) + else: + _logger.debug( + "Found duplicated module not " + "old enough yet to be deleted " + f"(age: {age}): {entry}", + ) + continue - # Clean up the name space to prevent bug. - del root, files, subdirs + # Remember the map from a module's hash to the KeyData + # object associated with it. + self.module_hash_to_key_data[mod_hash] = key_data + + for key in key_data.keys: + if key not in self.entry_from_key: + self.entry_from_key[key] = entry + # Assert that we have not already got this + # entry somehow. + assert entry not in self.module_from_name + # Store safe part of versioned keys. + if key[0]: + self.similar_keys.setdefault( + get_safe_part(key), [] + ).append(key) + else: + dir1 = os.path.dirname(self.entry_from_key[key]) + dir2 = os.path.dirname(entry) + _logger.warning( + "The same cache key is associated to " + f"different modules ({dir1} and {dir2}). This " + "is not supposed to happen! You may " + "need to manually delete your cache " + "directory to fix this.", + ) + # Clean up the name space to prevent bug. + if key_data.keys: + del key + self.loaded_key_pkl.add(key_pkl) + else: + too_old_to_use.append(entry) - # Remove entries that are not in the filesystem. - items_copy = list(self.module_hash_to_key_data.items()) - for module_hash, key_data in items_copy: - entry = key_data.get_entry() - try: - # Test to see that the file is [present and] readable. - open(entry).close() - gone = False - except OSError: - gone = True - - if gone: - # Assert that we did not have one of the deleted files - # loaded up and in use. - # If so, it should not have been deleted. This should be - # considered a failure of the OTHER process, that deleted - # it. - if entry in self.module_from_name: - _logger.warning( - "A module that was loaded by this " - "ModuleCache can no longer be read from file " - f"{entry}... this could lead to problems.", - ) - del self.module_from_name[entry] - - _logger.info(f"deleting ModuleCache entry {entry}") - key_data.delete_keys_from(self.entry_from_key) - del self.module_hash_to_key_data[module_hash] - if key_data.keys and list(key_data.keys)[0][0]: - # this is a versioned entry, so should have been on - # disk. Something weird happened to cause this, so we - # are responding by printing a warning, removing - # evidence that we ever saw this mystery key. - pkl_file_to_remove = key_data.key_pkl - if not key_data.key_pkl.startswith("/tmp"): - # Under /tmp, file are removed periodically by the - # os. So it is normal that this happen from time to - # time. + # If the compilation failed, no key.pkl is in that + # directory, but a mod.* should be there. + # We do nothing here. + + # Clean up the name space to prevent bug. + del root, files, subdirs + + # Remove entries that are not in the filesystem. + items_copy = list(self.module_hash_to_key_data.items()) + for module_hash, key_data in items_copy: + entry = key_data.get_entry() + try: + # Test to see that the file is [present and] readable. + open(entry).close() + gone = False + except OSError: + gone = True + + if gone: + # Assert that we did not have one of the deleted files + # loaded up and in use. + # If so, it should not have been deleted. This should be + # considered a failure of the OTHER process, that deleted + # it. + if entry in self.module_from_name: _logger.warning( - f"Removing key file {pkl_file_to_remove} because the " - "corresponding module is gone from the " - "file system." + "A module that was loaded by this " + "ModuleCache can no longer be read from file " + f"{entry}... this could lead to problems.", ) - self.loaded_key_pkl.remove(pkl_file_to_remove) + del self.module_from_name[entry] + + _logger.info(f"deleting ModuleCache entry {entry}") + key_data.delete_keys_from(self.entry_from_key) + del self.module_hash_to_key_data[module_hash] + if key_data.keys and list(key_data.keys)[0][0]: + # this is a versioned entry, so should have been on + # disk. Something weird happened to cause this, so we + # are responding by printing a warning, removing + # evidence that we ever saw this mystery key. + pkl_file_to_remove = key_data.key_pkl + if not key_data.key_pkl.startswith("/tmp"): + # Under /tmp, file are removed periodically by the + # os. So it is normal that this happen from time to + # time. + _logger.warning( + f"Removing key file {pkl_file_to_remove} because the " + "corresponding module is gone from the " + "file system." + ) + self.loaded_key_pkl.remove(pkl_file_to_remove) - if to_delete or to_delete_empty: - with lock_ctx(): + if to_delete or to_delete_empty: for a, kw in to_delete: _rmtree(*a, **kw) for a, kw in to_delete_empty: @@ -1049,7 +1055,7 @@ def unpickle_failure(): if not files: _rmtree(*a, **kw) - _logger.debug(f"Time needed to refresh cache: {time.time() - start_time}") + _logger.debug(f"Time needed to refresh cache: {time.time() - start_time}") return too_old_to_use diff --git a/tests/link/c/test_cmodule.py b/tests/link/c/test_cmodule.py index 027046d123..b1cf944c7a 100644 --- a/tests/link/c/test_cmodule.py +++ b/tests/link/c/test_cmodule.py @@ -5,6 +5,7 @@ deterministic based on the input type and the op. """ import logging +import multiprocessing import os import tempfile from unittest.mock import patch @@ -12,6 +13,8 @@ import numpy as np import pytest +import aesara +import aesara.tensor as at from aesara.compile.function import function from aesara.compile.ops import DeepCopyOp from aesara.configdefaults import config @@ -139,3 +142,47 @@ def test_linking_patch(listdir_mock, platform): "-lmkl_core", "-lmkl_rt", ] + + +def test_cache_race_condition(): + + with tempfile.TemporaryDirectory() as dir_name: + + @config.change_flags(on_opt_error="raise", on_shape_error="raise") + def f_build(factor): + # Some of the caching issues arise during constant folding within the + # optimization passes, so we need these config changes to prevent the + # exceptions from being caught + a = at.vector() + f = aesara.function([a], factor * a) + return f(np.array([1], dtype=config.floatX)) + + ctx = multiprocessing.get_context() + compiledir_prop = aesara.config._config_var_dict["compiledir"] + + # The module cache must (initially) be `None` for all processes so that + # `ModuleCache.refresh` is called + with patch.object(compiledir_prop, "val", dir_name, create=True), patch.object( + aesara.link.c.cmodule, "_module_cache", None + ): + + assert aesara.config.compiledir == dir_name + + num_procs = 30 + rng = np.random.default_rng(209) + + for i in range(10): + # A random, constant input to prevent caching between runs + factor = rng.random() + procs = [ + ctx.Process(target=f_build, args=(factor,)) + for i in range(num_procs) + ] + for proc in procs: + proc.start() + for proc in procs: + proc.join() + + assert not any( + exit_code != 0 for exit_code in [proc.exitcode for proc in procs] + )