diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 57a52b649b..733a73a409 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -166,6 +166,7 @@ (?x)^( aiida/cmdline/commands/.*| aiida/cmdline/params/.*| + aiida/cmdline/params/types/.*| utils/validate_consistency.py| )$ pass_filenames: false diff --git a/aiida/cmdline/commands/cmd_import.py b/aiida/cmdline/commands/cmd_import.py index 1a9d3e1f9d..26665689c7 100644 --- a/aiida/cmdline/commands/cmd_import.py +++ b/aiida/cmdline/commands/cmd_import.py @@ -17,7 +17,7 @@ from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.params import options -from aiida.cmdline.params.types import GroupParamType, ImportPath +from aiida.cmdline.params.types import GroupParamType, PathOrUrl from aiida.cmdline.utils import decorators, echo EXTRAS_MODE_EXISTING = ['keep_existing', 'update_existing', 'mirror', 'none', 'ask'] @@ -186,7 +186,7 @@ def _migrate_archive(ctx, temp_folder, file_to_import, archive, non_interactive, @verdi.command('import') -@click.argument('archives', nargs=-1, type=ImportPath(exists=True, readable=True)) +@click.argument('archives', nargs=-1, type=PathOrUrl(exists=True, readable=True)) @click.option( '-w', '--webpages', diff --git a/aiida/cmdline/params/options/__init__.py b/aiida/cmdline/params/options/__init__.py index 0b6fe9645b..9e2f46d616 100644 --- a/aiida/cmdline/params/options/__init__.py +++ b/aiida/cmdline/params/options/__init__.py @@ -497,8 +497,8 @@ def decorator(command): CONFIG_FILE = ConfigFileOption( '--config', - type=click.Path(exists=True, dir_okay=False), - help='Load option values from configuration file in yaml format.' + type=types.FileOrUrl(), + help='Load option values from configuration file in yaml format (local path or URL).' ) IDENTIFIER = OverridableOption( diff --git a/aiida/cmdline/params/options/config.py b/aiida/cmdline/params/options/config.py index 0b25e13c0e..2cfcbd79bc 100644 --- a/aiida/cmdline/params/options/config.py +++ b/aiida/cmdline/params/options/config.py @@ -18,10 +18,9 @@ from .overridable import OverridableOption -def yaml_config_file_provider(file_path, cmd_name): # pylint: disable=unused-argument - """Read yaml config file.""" - with open(file_path, 'r') as handle: - return yaml.safe_load(handle) +def yaml_config_file_provider(handle, cmd_name): # pylint: disable=unused-argument + """Read yaml config file from file handle.""" + return yaml.safe_load(handle) class ConfigFileOption(OverridableOption): diff --git a/aiida/cmdline/params/types/__init__.py b/aiida/cmdline/params/types/__init__.py index 3b44d31358..cedb380572 100644 --- a/aiida/cmdline/params/types/__init__.py +++ b/aiida/cmdline/params/types/__init__.py @@ -21,7 +21,7 @@ from .node import NodeParamType from .process import ProcessParamType from .strings import (NonEmptyStringParamType, EmailType, HostnameType, EntryPointType, LabelStringType) -from .path import AbsolutePathParamType, ImportPath +from .path import AbsolutePathParamType, PathOrUrl, FileOrUrl from .plugin import PluginParamType from .profile import ProfileParamType from .user import UserParamType @@ -32,5 +32,6 @@ 'LazyChoice', 'IdentifierParamType', 'CalculationParamType', 'CodeParamType', 'ComputerParamType', 'ConfigOptionParamType', 'DataParamType', 'GroupParamType', 'NodeParamType', 'MpirunCommandParamType', 'MultipleValueParamType', 'NonEmptyStringParamType', 'PluginParamType', 'AbsolutePathParamType', 'ShebangParamType', - 'UserParamType', 'TestModuleParamType', 'ProfileParamType', 'WorkflowParamType', 'ProcessParamType', 'ImportPath' + 'UserParamType', 'TestModuleParamType', 'ProfileParamType', 'WorkflowParamType', 'ProcessParamType', 'PathOrUrl', + 'FileOrUrl' ) diff --git a/aiida/cmdline/params/types/path.py b/aiida/cmdline/params/types/path.py index 20a96cd436..55b2b08166 100644 --- a/aiida/cmdline/params/types/path.py +++ b/aiida/cmdline/params/types/path.py @@ -12,12 +12,25 @@ # See https://stackoverflow.com/a/41217363/1069467 import urllib.request import urllib.error - +from socket import timeout import click URL_TIMEOUT_SECONDS = 10 +def _check_timeout_seconds(timeout_seconds): + """Raise if timeout is not within range [0;60]""" + try: + timeout_seconds = int(timeout_seconds) + except ValueError: + raise TypeError('timeout_seconds should be an integer but got: {}'.format(type(timeout_seconds))) + + if timeout_seconds < 0 or timeout_seconds > 60: + raise ValueError('timeout_seconds needs to be in the range [0;60].') + + return timeout_seconds + + class AbsolutePathParamType(click.Path): """ The ParamType for identifying absolute Paths (derived from click.Path). @@ -52,22 +65,23 @@ def __repr__(self): return 'ABSOLUTEPATHEMPTY' -class ImportPath(click.Path): - """AiiDA extension of Click's Path-type to include URLs - An ImportPath can either be a `click.Path`-type or a URL. +class PathOrUrl(click.Path): + """Extension of click's Path-type to include URLs. + + A PathOrUrl can either be a `click.Path`-type or a URL. - :param timeout_seconds: Timeout time in seconds that a URL response is expected. - :value timeout_seconds: Must be an int in the range [0;60], extrema included. - If an int outside the range [0;60] is given, the value will be set to the respective extremum value. - If any other type than int is given a TypeError will be raised. + :param int timeout_seconds: Maximum timeout accepted for URL response. + Must be an integer in the range [0;60]. """ # pylint: disable=protected-access + name = 'PathOrUrl' + def __init__(self, timeout_seconds=URL_TIMEOUT_SECONDS, **kwargs): super().__init__(**kwargs) - self.timeout_seconds = timeout_seconds + self.timeout_seconds = _check_timeout_seconds(timeout_seconds) def convert(self, value, param, ctx): """Overwrite `convert` @@ -80,38 +94,57 @@ def convert(self, value, param, ctx): # Check if URL return self.checks_url(value, param, ctx) - def checks_url(self, value, param, ctx): - """Do checks for possible URL path""" - from socket import timeout - - url = value - + def checks_url(self, url, param, ctx): + """Check whether URL is reachable within timeout.""" try: - urllib.request.urlopen(url, data=None, timeout=self.timeout_seconds) + urllib.request.urlopen(url, timeout=self.timeout_seconds) except (urllib.error.URLError, urllib.error.HTTPError, timeout): self.fail( '{0} "{1}" could not be reached within {2} s.\n' - 'It may be neither a valid {3} nor a valid URL.'.format( + 'Is it a valid {3} or URL?'.format( self.path_type, click._compat.filename_to_ui(url), self.timeout_seconds, self.name ), param, ctx ) return url - @property - def timeout_seconds(self): - return self._timeout_seconds - # pylint: disable=attribute-defined-outside-init - @timeout_seconds.setter - def timeout_seconds(self, value): - try: - self._timeout_seconds = int(value) - except ValueError: - raise TypeError('timeout_seconds should be an integer but got: {}'.format(type(value))) +class FileOrUrl(click.File): + """Extension of click's File-type to include URLs. + + Returns handle either to local file or to remote file fetched from URL. + + :param int timeout_seconds: Maximum timeout accepted for URL response. + Must be an integer in the range [0;60]. + """ + + name = 'FileOrUrl' + + # pylint: disable=protected-access + + def __init__(self, timeout_seconds=URL_TIMEOUT_SECONDS, **kwargs): + super().__init__(**kwargs) - if self._timeout_seconds < 0: - self._timeout_seconds = 0 + self.timeout_seconds = _check_timeout_seconds(timeout_seconds) - if self._timeout_seconds > 60: - self._timeout_seconds = 60 + def convert(self, value, param, ctx): + """Return file handle. + """ + try: + # Check if `click.File`-type + return super().convert(value, param, ctx) + except click.exceptions.BadParameter: + # Check if URL + handle = self.get_url(value, param, ctx) + return handle + + def get_url(self, url, param, ctx): + """Retrieve file from URL.""" + try: + return urllib.request.urlopen(url, timeout=self.timeout_seconds) + except (urllib.error.URLError, urllib.error.HTTPError, timeout): + self.fail( + '"{0}" could not be reached within {1} s.\n' + 'Is it a valid {2} or URL?.'.format(click._compat.filename_to_ui(url), self.timeout_seconds, self.name), + param, ctx + ) diff --git a/aiida/common/extendeddicts.py b/aiida/common/extendeddicts.py index b916cb873d..616d907f12 100644 --- a/aiida/common/extendeddicts.py +++ b/aiida/common/extendeddicts.py @@ -15,7 +15,7 @@ __all__ = ('AttributeDict', 'FixedFieldsAttributeDict', 'DefaultFieldsAttributeDict') -class AttributeDict(dict): +class AttributeDict(dict): # pylint: disable=too-many-instance-attributes """ This class internally stores values in a dictionary, but exposes the keys also as attributes, i.e. asking for attrdict.key diff --git a/docs/requirements_for_rtd.txt b/docs/requirements_for_rtd.txt index 38e84b0382..d8ec7acd5b 100644 --- a/docs/requirements_for_rtd.txt +++ b/docs/requirements_for_rtd.txt @@ -5,7 +5,7 @@ alembic~=1.2 ase~=3.18 circus~=0.16.1 click-completion~=0.5.1 -click-config-file~=0.5.0 +click-config-file~=0.6.0 click-spinner~=0.1.8 click~=7.0 coverage<5.0 diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 46dc269e12..24dea5b000 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -34,6 +34,7 @@ py:class click.types.Choice py:class click.types.IntParamType py:class click.types.StringParamType py:class click.types.Path +py:class click.types.File py:meth click.Option.get_default py:class concurrent.futures._base.TimeoutError diff --git a/docs/source/reference/command_line.rst b/docs/source/reference/command_line.rst index 8a8f3a00bc..197701cfe8 100644 --- a/docs/source/reference/command_line.rst +++ b/docs/source/reference/command_line.rst @@ -499,8 +499,8 @@ Below is a list with all available subcommands. superuser. --repository DIRECTORY Absolute path to the file repository. - --config FILE Load option values from configuration file - in yaml format. + --config FILEORURL Load option values from configuration file + in yaml format (local path or URL). --help Show this message and exit. @@ -622,8 +622,8 @@ Below is a list with all available subcommands. --db-password TEXT Password of the database user. [required] --repository DIRECTORY Absolute path to the file repository. - --config FILE Load option values from configuration file - in yaml format. + --config FILEORURL Load option values from configuration file + in yaml format (local path or URL). --help Show this message and exit. diff --git a/environment.yml b/environment.yml index e9e7c7bd65..e7dc46ea32 100644 --- a/environment.yml +++ b/environment.yml @@ -10,7 +10,7 @@ dependencies: - alembic~=1.2 - circus~=0.16.1 - click-completion~=0.5.1 -- click-config-file~=0.5.0 +- click-config-file~=0.6.0 - click-spinner~=0.1.8 - click~=7.0 - django~=2.2 diff --git a/requirements/requirements-py-3.5.txt b/requirements/requirements-py-3.5.txt index 90c0d71ccd..036891c4cd 100644 --- a/requirements/requirements-py-3.5.txt +++ b/requirements/requirements-py-3.5.txt @@ -15,7 +15,7 @@ chardet==3.0.4 circus==0.16.1 Click==7.0 click-completion==0.5.2 -click-config-file==0.5.0 +click-config-file==0.6.0 click-spinner==0.1.8 configobj==5.0.6 coverage==4.5.4 diff --git a/requirements/requirements-py-3.6.txt b/requirements/requirements-py-3.6.txt index c30c8a3ff5..c1c93a9c3b 100644 --- a/requirements/requirements-py-3.6.txt +++ b/requirements/requirements-py-3.6.txt @@ -15,7 +15,7 @@ chardet==3.0.4 circus==0.16.1 Click==7.0 click-completion==0.5.2 -click-config-file==0.5.0 +click-config-file==0.6.0 click-spinner==0.1.8 configobj==5.0.6 coverage==4.5.4 diff --git a/requirements/requirements-py-3.7.txt b/requirements/requirements-py-3.7.txt index a8d7014e97..afeb877e17 100644 --- a/requirements/requirements-py-3.7.txt +++ b/requirements/requirements-py-3.7.txt @@ -15,7 +15,7 @@ chardet==3.0.4 circus==0.16.1 Click==7.0 click-completion==0.5.2 -click-config-file==0.5.0 +click-config-file==0.6.0 click-spinner==0.1.8 configobj==5.0.6 coverage==4.5.4 diff --git a/requirements/requirements-py-3.8.txt b/requirements/requirements-py-3.8.txt index faa2ba1347..be63d559d0 100644 --- a/requirements/requirements-py-3.8.txt +++ b/requirements/requirements-py-3.8.txt @@ -15,7 +15,7 @@ chardet==3.0.4 circus==0.16.1 Click==7.0 click-completion==0.5.2 -click-config-file==0.5.0 +click-config-file==0.6.0 click-spinner==0.1.8 configobj==5.0.6 coverage==4.5.4 diff --git a/setup.json b/setup.json index 62360864a0..b39bffa052 100644 --- a/setup.json +++ b/setup.json @@ -25,7 +25,7 @@ "alembic~=1.2", "circus~=0.16.1", "click-completion~=0.5.1", - "click-config-file~=0.5.0", + "click-config-file~=0.6.0", "click-spinner~=0.1.8", "click~=7.0", "django~=2.2", diff --git a/tests/cmdline/commands/test_code.py b/tests/cmdline/commands/test_code.py index d61c3194ee..c09899fd2b 100644 --- a/tests/cmdline/commands/test_code.py +++ b/tests/cmdline/commands/test_code.py @@ -13,6 +13,7 @@ import subprocess as sp from textwrap import dedent +from unittest import mock from click.testing import CliRunner import pytest @@ -69,32 +70,46 @@ def test_noninteractive_upload(self): self.assertIsInstance(orm.Code.get_from_string('{}'.format(label)), orm.Code) def test_from_config(self): - """Test setting up a code from a config file""" - import tempfile + """Test setting up a code from a config file. - label = 'noninteractive_config' + Try loading from local file and from URL. + """ + import tempfile - with tempfile.NamedTemporaryFile('w') as handle: - handle.write( - dedent( - """ + config_file_template = dedent( + """ --- label: {label} input_plugin: arithmetic.add computer: {computer} remote_abs_path: /remote/abs/path """ - ).format(label=label, computer=self.computer.name) - ) + ) + + # local file + label = 'noninteractive_config' + with tempfile.NamedTemporaryFile('w') as handle: + handle.write(config_file_template.format(label=label, computer=self.computer.name)) handle.flush() result = self.cli_runner.invoke( setup_code, ['--non-interactive', '--config', os.path.realpath(handle.name)] ) - self.assertClickResultNoException(result) self.assertIsInstance(orm.Code.get_from_string('{}'.format(label)), orm.Code) + # url + label = 'noninteractive_config_url' + fake_url = 'https://my.url.com' + with mock.patch( + 'urllib.request.urlopen', + return_value=config_file_template.format(label=label, computer=self.computer.name) + ): + result = self.cli_runner.invoke(setup_code, ['--non-interactive', '--config', fake_url]) + + self.assertClickResultNoException(result) + self.assertIsInstance(orm.Code.get_from_string('{}'.format(label)), orm.Code) + class TestVerdiCodeCommands(AiidaTestCase): """Testing verdi code commands. diff --git a/tests/cmdline/commands/test_import.py b/tests/cmdline/commands/test_import.py index cad98c783a..cd1668aded 100644 --- a/tests/cmdline/commands/test_import.py +++ b/tests/cmdline/commands/test_import.py @@ -209,11 +209,11 @@ def test_import_url_and_local_archives(self): def test_import_url_timeout(self): """Test a timeout to valid URL is correctly errored""" - from aiida.cmdline.params.types import ImportPath + from aiida.cmdline.params.types import PathOrUrl timeout_url = 'http://www.google.com:81' - test_timeout_path = ImportPath(exists=True, readable=True, timeout_seconds=0) + test_timeout_path = PathOrUrl(exists=True, readable=True, timeout_seconds=0) with self.assertRaises(BadParameter) as cmd_exc: test_timeout_path(timeout_url) @@ -229,7 +229,7 @@ def test_raise_malformed_url(self): self.assertIsNotNone(result.exception, result.output) self.assertNotEqual(result.exit_code, 0, result.output) - error_message = 'It may be neither a valid path nor a valid URL.' + error_message = 'Is it a valid path or URL?' self.assertIn(error_message, result.output, result.exception) def test_non_interactive_and_migration(self): diff --git a/tests/cmdline/params/types/test_path.py b/tests/cmdline/params/types/test_path.py index 3844c59d93..f8b5e09d60 100644 --- a/tests/cmdline/params/types/test_path.py +++ b/tests/cmdline/params/types/test_path.py @@ -10,60 +10,37 @@ """Tests for Path types""" from aiida.backends.testbase import AiidaTestCase -from aiida.cmdline.params.types import ImportPath +from aiida.cmdline.params.types.path import PathOrUrl, _check_timeout_seconds -class TestImportPath(AiidaTestCase): - """Tests `ImportPath`""" +class TestPath(AiidaTestCase): + """Tests for `PathOrUrl` and `FileOrUrl`""" def test_default_timeout(self): """Test the default timeout_seconds value is correct""" from aiida.cmdline.params.types.path import URL_TIMEOUT_SECONDS - import_path = ImportPath() + import_path = PathOrUrl() self.assertEqual(import_path.timeout_seconds, URL_TIMEOUT_SECONDS) - def test_valid_timeout(self): - """Test a valid timeout_seconds value""" + def test_timeout_checks(self): + """Test that timeout check handles different values. + * valid + * none + * wrong type + * outside range + """ valid_values = [42, '42'] for value in valid_values: - import_path = ImportPath(timeout_seconds=value) - - self.assertEqual(import_path.timeout_seconds, int(value)) - - def test_none_timeout(self): - """Test a TypeError is raised when a None value is given for timeout_seconds""" - - with self.assertRaises(TypeError): - ImportPath(timeout_seconds=None) - - def test_wrong_type_timeout(self): - """Test a TypeError is raised when wrong type is given for timeout_seconds""" - - with self.assertRaises(TypeError): - ImportPath(timeout_seconds='test') - - def test_range_timeout(self): - """Test timeout_seconds defines extrema when out of range - Range of timeout_seconds is [0;60], extrema included. - """ - - range_timeout = [0, 60] - lower = range_timeout[0] - 5 - upper = range_timeout[1] + 5 - - lower_path = ImportPath(timeout_seconds=lower) - upper_path = ImportPath(timeout_seconds=upper) + self.assertEqual(_check_timeout_seconds(value), int(value)) - msg_lower = "timeout_seconds should have been corrected to the lower bound: '{}', but instead it is {}".format( - range_timeout[0], lower_path.timeout_seconds - ) - self.assertEqual(lower_path.timeout_seconds, range_timeout[0], msg_lower) + for invalid in [None, 'test']: + with self.assertRaises(TypeError): + _check_timeout_seconds(invalid) - msg_upper = "timeout_seconds should have been corrected to the upper bound: '{}', but instead it is {}".format( - range_timeout[1], upper_path.timeout_seconds - ) - self.assertEqual(upper_path.timeout_seconds, range_timeout[1], msg_upper) + for invalid in [-5, 65]: + with self.assertRaises(ValueError): + _check_timeout_seconds(invalid)