diff --git a/webbpsf/tests/test_errorhandling.py b/webbpsf/tests/test_errorhandling.py index badd592d..3a3e8a74 100644 --- a/webbpsf/tests/test_errorhandling.py +++ b/webbpsf/tests/test_errorhandling.py @@ -4,6 +4,7 @@ import logging import os import os.path +from pathlib import Path import pytest @@ -48,7 +49,7 @@ def test_invalid_masks(): assert _exception_message_starts_with(excinfo, "Instrument NIRCam doesn't have a pupil mask called 'JUNK'.") -def test_get_webbpsf_data_path_invalid(monkeypatch): +def test_get_webbpsf_data_path_invalid(monkeypatch, tmp_path): real_env_webbpsf_path = os.getenv('WEBBPSF_PATH') real_conf_webbpsf_path = conf.WEBBPSF_PATH real_webbpsf_path = real_env_webbpsf_path or real_conf_webbpsf_path @@ -57,9 +58,19 @@ def test_get_webbpsf_data_path_invalid(monkeypatch): # config says to (and env var has been unset) monkeypatch.delenv('WEBBPSF_PATH') monkeypatch.setattr(conf, 'WEBBPSF_PATH', 'from_environment_variable') - with pytest.raises(EnvironmentError) as excinfo: + + # Patch the function that gets the home directory so we don't overwrite the + # what is on the system + def mockreturn(): + return Path(tmp_path) + + monkeypatch.setattr(Path, "home", mockreturn) + + with pytest.warns(UserWarning, match=r"Environment variable \$WEBBPSF_PATH is not set!\n.*") as excinfo: _ = utils.get_webbpsf_data_path() - assert 'Environment variable $WEBBPSF_PATH is not set!' in str(excinfo) + + # Check that the data was downloaded + assert any((tmp_path / "data" / "webbpsf-data").iterdir()) # Test that we can override the WEBBPSF_PATH setting here through # the config object even though the environment var is deleted diff --git a/webbpsf/utils.py b/webbpsf/utils.py index 29c6ddc7..529e68f7 100644 --- a/webbpsf/utils.py +++ b/webbpsf/utils.py @@ -180,6 +180,39 @@ def setup_logging(level='INFO', filename=None): """ +def auto_download_webbpsf_data(): + import os + import tarfile + from pathlib import Path + from tempfile import TemporaryDirectory + from urllib.request import urlretrieve + + # Create a default directory for the data files + default_path = Path.home() / "data" / "webbpsf-data" + default_path.mkdir(parents=True, exist_ok=True) + + os.environ["WEBBPSF_PATH"] = str(default_path) + + # Download the data files if the directory is empty + if not any(default_path.iterdir()): + warnings.warn("WebbPSF data files not found in default location, attempting to download them now...") + + with TemporaryDirectory() as tmpdir: + # Download the data files to a temporary directory + url = "https://stsci.box.com/shared/static/qxpiaxsjwo15ml6m4pkhtk36c9jgj70k.gz" + filename = Path(tmpdir) / "webbpsf-data-LATEST.tar.gz" + urlretrieve(url, filename) + + # Extract the tarball + with tarfile.open(filename, "r:gz") as tar: + tar.extractall(default_path.parent, filter="fully_trusted") + + if not any(default_path.iterdir()): + raise IOError(f"Failed to get and extract WebbPSF data files to {default_path}") + + return default_path + + def get_webbpsf_data_path(data_version_min=None, return_version=False): """Get the WebbPSF data path @@ -201,7 +234,12 @@ def get_webbpsf_data_path(data_version_min=None, return_version=False): if path_from_config == 'from_environment_variable': path = os.getenv('WEBBPSF_PATH') if path is None: - raise EnvironmentError(f'Environment variable $WEBBPSF_PATH is not set!\n{MISSING_WEBBPSF_DATA_MESSAGE}') + message = ( + 'Environment variable $WEBBPSF_PATH is not set!\n' + f'{MISSING_WEBBPSF_DATA_MESSAGE} searching default location..s' + ) + warnings.warn(message) + path = auto_download_webbpsf_data() else: path = path_from_config