From 4b2d950f601cd242cf1f54836d97d18971570f5e Mon Sep 17 00:00:00 2001 From: Jonathon Vandezande Date: Tue, 4 Oct 2022 08:41:50 -0400 Subject: [PATCH] Exits working directory if wrapped function throws an exception. (#175) * Exits working directory if wrapped function throws an exception * Adds jevandezande to README --- README.md | 1 + autode/utils.py | 37 +++++++++++++++++++++---------------- tests/test_utils.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index f5a817bf6..e28e06ec8 100644 --- a/README.md +++ b/README.md @@ -92,3 +92,4 @@ If **autodE** is used in a publication please consider citing the [paper](https: - Kjell Jorner ([@kjelljorner](https://github.com/kjelljorner)) - Thibault Lestang ([@tlestang](https://github.com/tlestang)) - Domen Pregeljc ([@dpregeljc](https://github.com/dpregeljc)) +- Jonathon Vandezande ([@jevandezande](https://github.com/jevandezande)) diff --git a/autode/utils.py b/autode/utils.py index 42b41514d..99521acdb 100644 --- a/autode/utils.py +++ b/autode/utils.py @@ -143,13 +143,15 @@ def wrapped_function(*args, **kwargs): os.mkdir(dir_path) os.chdir(dir_path) - result = func(*args, **kwargs) - os.chdir(here) + try: + result = func(*args, **kwargs) + finally: + os.chdir(here) - if len(os.listdir(dir_path)) == 0: - logger.warning(f'Worked in {dir_path} but made no files ' - f'- deleting') - os.rmdir(dir_path) + if len(os.listdir(dir_path)) == 0: + logger.warning(f'Worked in {dir_path} but made no files ' + f'- deleting') + os.rmdir(dir_path) return result @@ -204,20 +206,23 @@ def wrapped_function(*args, **kwargs): # Move directories and execute os.chdir(tmpdir_path) - logger.info('Function ...running') - result = func(*args, **kwargs) - logger.info(' ...done') + try: + logger.info('Function ...running') + result = func(*args, **kwargs) + logger.info(' ...done') + + for filename in os.listdir(tmpdir_path): - for filename in os.listdir(tmpdir_path): + if any([filename.endswith(ext) for ext in kept_file_exts]): + logger.info(f'Copying back {filename}') + shutil.copy(filename, here) - if any([filename.endswith(ext) for ext in kept_file_exts]): - logger.info(f'Copying back {filename}') - shutil.copy(filename, here) + finally: + os.chdir(here) - os.chdir(here) + logger.info('Removing temporary directory') + shutil.rmtree(tmpdir_path) - logger.info('Removing temporary directory') - shutil.rmtree(tmpdir_path) return result return wrapped_function diff --git a/tests/test_utils.py b/tests/test_utils.py index 6745f090c..aaa027970 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -29,6 +29,20 @@ def make_test_files(): os.rmdir('test') +def test_reset_dir_on_error(): + @utils.work_in("tmp_path") + def raise_error(): + assert 0 + + here = os.getcwd() + try: + raise_error() + except AssertionError: + pass + + assert here == os.getcwd() + + def test_monitored_external(): echo = ['echo', 'test'] @@ -77,6 +91,20 @@ def test(): os.remove('test.txt') +def test_reset_tmp_dir_on_error(): + @utils.work_in_tmp_dir() + def raise_error(): + assert 0 + + here = os.getcwd() + try: + raise_error() + except AssertionError: + pass + + assert here == os.getcwd() + + @work_in_tmp_dir(filenames_to_copy=[], kept_file_exts=[]) def test_calc_output():