From 4a6c009bbf00a52ee66f454ddac5d6cb075a6fe8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20D?= Date: Tue, 24 Jul 2018 16:25:52 +0200 Subject: [PATCH 01/13] Add Data Revenue to the `blogged` list (#2472) --- README.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/README.rst b/README.rst index 3a76eb03a4..f3a02646ab 100644 --- a/README.rst +++ b/README.rst @@ -149,6 +149,7 @@ or held presentations about Luigi: * `Leipzig University Library `_ `(presentation, 2016) `__ / `(project) `__ * `Synetiq `_ `(presentation, 2017) `__ * `Glossier `_ `(blog, 2018) `__ +* `Data Revenue `_ `(blog, 2018) `_ Some more companies are using Luigi but haven't had a chance yet to write about it: From f9a99dce22e2887406c6d156d5d669660547d257 Mon Sep 17 00:00:00 2001 From: Dillon Stadther Date: Wed, 25 Jul 2018 08:26:02 -0400 Subject: [PATCH 02/13] Add codeowners file with default and specific example (#2465) --- .github/CODEOWNERS | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000..f83f67ab75 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,12 @@ +# The following patterns are used to auto-assign review requests +# to specific individuals. Order is important; the last matching +# pattern takes the most precedence. + +# These owners will be the default owners for everything in +# the repo. Unless a later match takes precedence, +* @dlstadther @Tarrasch @ulzha + +# Specific files, directories, paths, or file types can be +# assigned more specificially. +contrib/redshift*.py @dlstadther + From f0fdca73097b55947983c4ef5a68593df686fbb1 Mon Sep 17 00:00:00 2001 From: Greg Roberts Date: Thu, 26 Jul 2018 14:36:59 +0100 Subject: [PATCH 03/13] Added default port behaviour for Redshift (#2474) * PostgresTarget default port defined as class variable, so Targets which subclass PostgresTarget (e.g. RedshiftTarget) can define their own default port --- luigi/contrib/postgres.py | 7 +++++-- luigi/contrib/redshift.py | 3 +++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/luigi/contrib/postgres.py b/luigi/contrib/postgres.py index d06b2a1f69..ec8b72df16 100644 --- a/luigi/contrib/postgres.py +++ b/luigi/contrib/postgres.py @@ -106,11 +106,14 @@ class PostgresTarget(luigi.Target): """ marker_table = luigi.configuration.get_config().get('postgres', 'marker-table', 'table_updates') + # if not supplied, fall back to default Postgres port + DEFAULT_DB_PORT = 5432 + # Use DB side timestamps or client side timestamps in the marker_table use_db_timestamps = True def __init__( - self, host, database, user, password, table, update_id, port=5432 + self, host, database, user, password, table, update_id, port=None ): """ Args: @@ -126,7 +129,7 @@ def __init__( self.host, self.port = host.split(':') else: self.host = host - self.port = port + self.port = port or self.DEFAULT_DB_PORT self.database = database self.user = user self.password = password diff --git a/luigi/contrib/redshift.py b/luigi/contrib/redshift.py index 5302c3d5f0..a1c5c43172 100644 --- a/luigi/contrib/redshift.py +++ b/luigi/contrib/redshift.py @@ -135,6 +135,9 @@ class RedshiftTarget(postgres.PostgresTarget): 'marker-table', 'table_updates') + # if not supplied, fall back to default Redshift port + DEFAULT_DB_PORT = 5439 + use_db_timestamps = False From 7d2c5574cb53106044a25afdddf151165471a741 Mon Sep 17 00:00:00 2001 From: Gram Date: Sat, 28 Jul 2018 21:55:35 +0300 Subject: [PATCH 04/13] Optional TOML configs support (#2457) See the added docs for usage. --- doc/configuration.rst | 40 +++++++-- luigi/configuration/__init__.py | 27 ++++++ luigi/configuration/base_parser.py | 41 ++++++++++ .../cfg_parser.py} | 34 +------- luigi/configuration/core.py | 79 ++++++++++++++++++ luigi/configuration/toml_parser.py | 82 +++++++++++++++++++ luigi/contrib/s3.py | 2 +- luigi/parameter.py | 2 +- setup.py | 4 + test/config_toml_test.py | 65 +++++++++++++++ test/testconfig/luigi.toml | 7 ++ test/testconfig/luigi_local.toml | 3 + tox.ini | 1 + 13 files changed, 349 insertions(+), 38 deletions(-) create mode 100644 luigi/configuration/__init__.py create mode 100644 luigi/configuration/base_parser.py rename luigi/{configuration.py => configuration/cfg_parser.py} (80%) create mode 100644 luigi/configuration/core.py create mode 100644 luigi/configuration/toml_parser.py create mode 100644 test/config_toml_test.py create mode 100644 test/testconfig/luigi.toml create mode 100644 test/testconfig/luigi_local.toml diff --git a/doc/configuration.rst b/doc/configuration.rst index 5cf649d8bb..dec55c3679 100644 --- a/doc/configuration.rst +++ b/doc/configuration.rst @@ -1,18 +1,35 @@ Configuration ============= -All configuration can be done by adding configuration files. They are looked for in: +All configuration can be done by adding configuration files. - * ``/etc/luigi/client.cfg`` - * ``luigi.cfg`` (or its legacy name ``client.cfg``) in your current working directory - * ``LUIGI_CONFIG_PATH`` environment variable +Supported config parsers: +* ``cfg`` (default) +* ``toml`` -in increasing order of preference. The order only matters in case of key conflicts (see docs for ConfigParser.read_). These files are meant for both the client and ``luigid``. If you decide to specify your own configuration you should make sure that both the client and ``luigid`` load it properly. +You can choose right parser via ``LUIGI_CONFIG_PARSER`` environment variable. For example, ``LUIGI_CONFIG_PARSER=toml``. + +Default (cfg) parser are looked for in: + +* ``/etc/luigi/client.cfg`` (deprecated) +* ``/etc/luigi/luigi.cfg`` +* ``client.cfg`` (deprecated) +* ``luigi.cfg`` +* ``LUIGI_CONFIG_PATH`` environment variable + +`TOML `_ parser are looked for in: + +* ``/etc/luigi/luigi.toml`` +* ``luigi.toml`` +* ``LUIGI_CONFIG_PATH`` environment variable + +Both config lists increase in priority (from low to high). The order only matters in case of key conflicts (see docs for ConfigParser.read_). These files are meant for both the client and ``luigid``. If you decide to specify your own configuration you should make sure that both the client and ``luigid`` load it properly. .. _ConfigParser.read: https://docs.python.org/3.6/library/configparser.html#configparser.ConfigParser.read -The config file is broken into sections, each controlling a different part of the config. Example configuration file: +The config file is broken into sections, each controlling a different part of the config. +Example cfg config: .. code:: ini @@ -23,6 +40,17 @@ The config file is broken into sections, each controlling a different part of th [core] scheduler_host=luigi-host.mycompany.foo +Example toml config: + +.. code:: python + + [hadoop] + version = "cdh4" + streaming-jar = "/usr/lib/hadoop-xyz/hadoop-streaming-xyz-123.jar" + + [core] + scheduler_host = "luigi-host.mycompany.foo" + .. _ParamConfigIngestion: diff --git a/luigi/configuration/__init__.py b/luigi/configuration/__init__.py new file mode 100644 index 0000000000..21ff657fd8 --- /dev/null +++ b/luigi/configuration/__init__.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .cfg_parser import LuigiConfigParser +from .core import get_config, add_config_path +from .toml_parser import LuigiTomlParser + + +__all__ = [ + 'add_config_path', + 'get_config', + 'LuigiConfigParser', + 'LuigiTomlParser', +] diff --git a/luigi/configuration/base_parser.py b/luigi/configuration/base_parser.py new file mode 100644 index 0000000000..9b70a78155 --- /dev/null +++ b/luigi/configuration/base_parser.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging + + +# IMPORTANT: don't inherit from `object`! +# ConfigParser have some troubles in this case. +# More info: https://stackoverflow.com/a/19323238 +class BaseParser: + @classmethod + def instance(cls, *args, **kwargs): + """ Singleton getter """ + if cls._instance is None: + cls._instance = cls(*args, **kwargs) + loaded = cls._instance.reload() + logging.getLogger('luigi-interface').info('Loaded %r', loaded) + + return cls._instance + + @classmethod + def add_config_path(cls, path): + cls._config_paths.append(path) + cls.reload() + + @classmethod + def reload(cls): + return cls.instance().read(cls._config_paths) diff --git a/luigi/configuration.py b/luigi/configuration/cfg_parser.py similarity index 80% rename from luigi/configuration.py rename to luigi/configuration/cfg_parser.py index 6d3ddc2a33..e0df87f10a 100644 --- a/luigi/configuration.py +++ b/luigi/configuration/cfg_parser.py @@ -29,7 +29,6 @@ See :doc:`/configuration` for more info. """ -import logging import os import warnings @@ -38,9 +37,12 @@ except ImportError: from configparser import ConfigParser, NoOptionError, NoSectionError +from .base_parser import BaseParser -class LuigiConfigParser(ConfigParser): + +class LuigiConfigParser(BaseParser, ConfigParser): NO_DEFAULT = object() + enabled = True _instance = None _config_paths = [ '/etc/luigi/client.cfg', # Deprecated old-style global luigi config @@ -48,27 +50,6 @@ class LuigiConfigParser(ConfigParser): 'client.cfg', # Deprecated old-style local luigi config 'luigi.cfg', ] - if 'LUIGI_CONFIG_PATH' in os.environ: - config_file = os.environ['LUIGI_CONFIG_PATH'] - if not os.path.isfile(config_file): - warnings.warn("LUIGI_CONFIG_PATH points to a file which does not exist. Invalid file: {path}".format(path=config_file)) - else: - _config_paths.append(config_file) - - @classmethod - def add_config_path(cls, path): - cls._config_paths.append(path) - cls.reload() - - @classmethod - def instance(cls, *args, **kwargs): - """ Singleton getter """ - if cls._instance is None: - cls._instance = cls(*args, **kwargs) - loaded = cls._instance.reload() - logging.getLogger('luigi-interface').info('Loaded %r', loaded) - - return cls._instance @classmethod def reload(cls): @@ -124,10 +105,3 @@ def set(self, section, option, value=None): ConfigParser.add_section(self, section) return ConfigParser.set(self, section, option, value) - - -def get_config(): - """ - Convenience method (for backwards compatibility) for accessing config singleton. - """ - return LuigiConfigParser.instance() diff --git a/luigi/configuration/core.py b/luigi/configuration/core.py new file mode 100644 index 0000000000..7ca0d6e673 --- /dev/null +++ b/luigi/configuration/core.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import os +import warnings + +from .cfg_parser import LuigiConfigParser +from .toml_parser import LuigiTomlParser + + +logger = logging.getLogger('luigi-interface') + + +PARSERS = { + 'cfg': LuigiConfigParser, + 'conf': LuigiConfigParser, + 'ini': LuigiConfigParser, + 'toml': LuigiTomlParser, +} + +# select parser via env var +DEFAULT_PARSER = 'cfg' +PARSER = os.environ.get('LUIGI_CONFIG_PARSER', DEFAULT_PARSER) +if PARSER not in PARSERS: + warnings.warn("Invalid parser: {parser}".format(parser=PARSER)) + PARSER = DEFAULT_PARSER + + +def get_config(parser=PARSER): + """Get configs singleton for parser + """ + + parser_class = PARSERS[parser] + if not parser_class.enabled: + logger.error(( + "Parser not installed yet. " + "Please, install luigi with required parser:\n" + "pip install luigi[{parser}]" + ).format(parser) + ) + + return parser_class.instance() + + +def add_config_path(path): + """Select config parser by file extension and add path into parser. + """ + if not os.path.isfile(path): + warnings.warn("Config file does not exist: {path}".format(path=path)) + return False + + # select parser by file extension + _base, ext = os.path.splitext(path) + if ext and ext[1:] in PARSERS: + parser_class = PARSERS[ext[1:]] + else: + parser_class = PARSERS[PARSER] + + # add config path to parser + parser_class.add_config_path(path) + return True + + +if 'LUIGI_CONFIG_PATH' in os.environ: + add_config_path(os.environ['LUIGI_CONFIG_PATH']) diff --git a/luigi/configuration/toml_parser.py b/luigi/configuration/toml_parser.py new file mode 100644 index 0000000000..8e6fa3923b --- /dev/null +++ b/luigi/configuration/toml_parser.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2018 Cindicator Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os.path + +try: + import toml +except ImportError: + toml = False + +from .base_parser import BaseParser + + +class LuigiTomlParser(BaseParser): + NO_DEFAULT = object() + enabled = bool(toml) + data = dict() + _instance = None + _config_paths = [ + '/etc/luigi/luigi.toml', + 'luigi.toml', + ] + + @staticmethod + def _update_data(data, new_data): + if not new_data: + return data + if not data: + return new_data + for section, content in new_data.items(): + if section not in data: + data[section] = dict() + data[section].update(content) + return data + + def read(self, config_paths): + self.data = dict() + for path in config_paths: + if os.path.isfile(path): + self.data = self._update_data(self.data, toml.load(path)) + return self.data + + def get(self, section, option, default=NO_DEFAULT, **kwargs): + try: + return self.data[section][option] + except KeyError: + if default is self.NO_DEFAULT: + raise + return default + + def getboolean(self, section, option, default=NO_DEFAULT): + return self.get(section, option, default) + + def getint(self, section, option, default=NO_DEFAULT): + return self.get(section, option, default) + + def getfloat(self, section, option, default=NO_DEFAULT): + return self.get(section, option, default) + + def getintdict(self, section): + return self.data.get(section, {}) + + def set(self, section, option, value=None): + if section not in self.data: + self.data[section] = {} + self.data[section][option] = value + + def __getitem__(self, name): + return self.data[name] diff --git a/luigi/contrib/s3.py b/luigi/contrib/s3.py index fb5fbbb83f..a457808bc3 100644 --- a/luigi/contrib/s3.py +++ b/luigi/contrib/s3.py @@ -486,7 +486,7 @@ def _get_s3_config(self, key=None): defaults = dict(configuration.get_config().defaults()) try: config = dict(configuration.get_config().items('s3')) - except NoSectionError: + except (NoSectionError, KeyError): return {} # So what ports etc can be read without us having to specify all dtypes for k, v in six.iteritems(config): diff --git a/luigi/parameter.py b/luigi/parameter.py index 7485d09f61..f619864090 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -168,7 +168,7 @@ def _get_value_from_config(self, section, name): try: value = conf.get(section, name) - except (NoSectionError, NoOptionError): + except (NoSectionError, NoOptionError, KeyError): return _no_value return self.parse(value) diff --git a/setup.py b/setup.py index 85f7dba8fa..176d671636 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ def get_static_files(path): license='Apache License 2.0', packages=[ 'luigi', + 'luigi.configuration', 'luigi.contrib', 'luigi.contrib.hdfs', 'luigi.tools' @@ -75,6 +76,9 @@ def get_static_files(path): ] }, install_requires=install_requires, + extras_require={ + 'toml': ['toml<2.0.0'], + }, classifiers=[ 'Development Status :: 5 - Production/Stable', 'Environment :: Console', diff --git a/test/config_toml_test.py b/test/config_toml_test.py new file mode 100644 index 0000000000..e0211c60ff --- /dev/null +++ b/test/config_toml_test.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2018 Cindicator Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from luigi.configuration import LuigiTomlParser, get_config, add_config_path + + +from helpers import LuigiTestCase + + +class TomlConfigParserTest(LuigiTestCase): + @classmethod + def setUpClass(cls): + add_config_path('test/testconfig/luigi.toml') + add_config_path('test/testconfig/luigi_local.toml') + + def setUp(self): + LuigiTomlParser._instance = None + super(TomlConfigParserTest, self).setUp() + + def test_get_config(self): + config = get_config('toml') + self.assertIsInstance(config, LuigiTomlParser) + + def test_file_reading(self): + config = get_config('toml') + self.assertIn('hdfs', config.data) + + def test_get(self): + config = get_config('toml') + + # test getting + self.assertEqual(config.get('hdfs', 'client'), 'hadoopcli') + self.assertEqual(config.get('hdfs', 'client', 'test'), 'hadoopcli') + + # test default + self.assertEqual(config.get('hdfs', 'test', 'check'), 'check') + with self.assertRaises(KeyError): + config.get('hdfs', 'test') + + # test override + self.assertEqual(config.get('hdfs', 'namenode_host'), 'localhost') + # test non-string values + self.assertEqual(config.get('hdfs', 'namenode_port'), 50030) + + def test_set(self): + config = get_config('toml') + + self.assertEqual(config.get('hdfs', 'client'), 'hadoopcli') + config.set('hdfs', 'client', 'test') + self.assertEqual(config.get('hdfs', 'client'), 'test') + config.set('hdfs', 'check', 'test me') + self.assertEqual(config.get('hdfs', 'check'), 'test me') diff --git a/test/testconfig/luigi.toml b/test/testconfig/luigi.toml new file mode 100644 index 0000000000..6c8e3409a3 --- /dev/null +++ b/test/testconfig/luigi.toml @@ -0,0 +1,7 @@ +[core] +logging_conf_file = "test/testconfig/logging.cfg" + +[hdfs] +client = "hadoopcli" +snakebite_autoconfig = false +namenode_host = "must be overridden in local config" diff --git a/test/testconfig/luigi_local.toml b/test/testconfig/luigi_local.toml new file mode 100644 index 0000000000..21330c1040 --- /dev/null +++ b/test/testconfig/luigi_local.toml @@ -0,0 +1,3 @@ +[hdfs] +namenode_host = "localhost" +namenode_port = 50030 diff --git a/tox.ini b/tox.ini index ecf04f8d13..4876423fba 100644 --- a/tox.ini +++ b/tox.ini @@ -35,6 +35,7 @@ deps= hypothesis[datetime] selenium==3.0.2 pymongo==3.4.0 + toml<2.0.0 passenv = USER JAVA_HOME POSTGRES_USER DATAPROC_TEST_PROJECT_ID GCS_TEST_PROJECT_ID GCS_TEST_BUCKET GOOGLE_APPLICATION_CREDENTIALS TRAVIS_BUILD_ID TRAVIS TRAVIS_BRANCH TRAVIS_JOB_NUMBER TRAVIS_PULL_REQUEST TRAVIS_JOB_ID TRAVIS_REPO_SLUG TRAVIS_COMMIT CI setenv = From a97a7b7e4f8f2d4fd879afff8da885aeffc44b43 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 31 Jul 2018 00:33:07 +0300 Subject: [PATCH 05/13] tests: Use RunOnceTask where possible (#2476) --- test/scheduler_visualisation_test.py | 50 +++++++++++++--------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/test/scheduler_visualisation_test.py b/test/scheduler_visualisation_test.py index 1d35f69ffc..4edb668e51 100644 --- a/test/scheduler_visualisation_test.py +++ b/test/scheduler_visualisation_test.py @@ -19,7 +19,7 @@ import os import tempfile import time -from helpers import unittest +from helpers import unittest, RunOnceTask import luigi import luigi.notifications @@ -33,7 +33,7 @@ class DummyTask(luigi.Task): - task_id = luigi.Parameter() + task_id = luigi.IntParameter() def run(self): f = self.output().open('w') @@ -44,7 +44,7 @@ def output(self): class FactorTask(luigi.Task): - product = luigi.Parameter() + product = luigi.IntParameter() def requires(self): for factor in range(2, self.product): @@ -77,7 +77,10 @@ def complete(self): class FailingTask(luigi.Task): task_namespace = __name__ - task_id = luigi.Parameter() + task_id = luigi.IntParameter() + + def complete(self): + return False def run(self): raise Exception("Error Message") @@ -100,7 +103,6 @@ def run(self): class SchedulerVisualisationTest(unittest.TestCase): - def setUp(self): self.scheduler = luigi.scheduler.Scheduler() @@ -190,7 +192,7 @@ def complete(self): six.assertCountEqual(self, expected_nodes, graph) def test_truncate_graph_with_full_levels(self): - class BinaryTreeTask(luigi.Task): + class BinaryTreeTask(RunOnceTask): idx = luigi.IntParameter() def requires(self): @@ -226,7 +228,7 @@ def complete(self): graph = self.scheduler.dep_graph(root_task.task_id) self.assertEqual(10, len(graph)) - expected_nodes = [LinearTask(i).task_id for i in range(100, 91, -1)] +\ + expected_nodes = [LinearTask(i).task_id for i in range(100, 91, -1)] + \ [LinearTask(0).task_id] self.maxDiff = None six.assertCountEqual(self, expected_nodes, graph) @@ -387,30 +389,29 @@ def test_task_list_failed(self): def test_task_list_upstream_status(self): class A(luigi.ExternalTask): - pass + def complete(self): + return False class B(luigi.ExternalTask): - def complete(self): return True - class C(luigi.Task): - + class C(RunOnceTask): def requires(self): return [A(), B()] class F(luigi.Task): + def complete(self): + return False def run(self): raise Exception() - class D(luigi.Task): - + class D(RunOnceTask): def requires(self): return [F()] - class E(luigi.Task): - + class E(RunOnceTask): def requires(self): return [C(), D()] @@ -478,22 +479,20 @@ def test_fetch_error(self): self.assertTrue("Traceback" in error["error"]) def test_inverse_deps(self): - class X(luigi.Task): + class X(RunOnceTask): pass - class Y(luigi.Task): - + class Y(RunOnceTask): def requires(self): return [X()] - class Z(luigi.Task): - id = luigi.Parameter() + class Z(RunOnceTask): + id = luigi.IntParameter() def requires(self): return [Y()] - class ZZ(luigi.Task): - + class ZZ(RunOnceTask): def requires(self): return [Z(1), Z(2)] @@ -513,7 +512,6 @@ def assert_has_deps(task_id, deps): def test_simple_worker_list(self): class X(luigi.Task): - def run(self): self._complete = True @@ -536,12 +534,10 @@ def complete(self): def test_worker_list_pending_uniques(self): class X(luigi.Task): - def complete(self): return False class Y(X): - def requires(self): return X() @@ -562,7 +558,7 @@ class Z(Y): self.assertEqual(0, worker['num_running']) def test_worker_list_running(self): - class X(luigi.Task): + class X(RunOnceTask): n = luigi.IntParameter() w = luigi.worker.Worker(worker_id='w', scheduler=self.scheduler, worker_processes=3) @@ -584,7 +580,7 @@ class X(luigi.Task): self.assertEqual(1, worker['num_uniques']) def test_worker_list_disabled_worker(self): - class X(luigi.Task): + class X(RunOnceTask): pass with luigi.worker.Worker(worker_id='w', scheduler=self.scheduler) as w: From 29c97a6ddc505c3611ef3b8bc01e639f91267ce8 Mon Sep 17 00:00:00 2001 From: Stas Glubokiy Date: Tue, 31 Jul 2018 00:37:14 +0300 Subject: [PATCH 06/13] Fix race condition in luigi.lock.acquire_for (#2357) (#2477) --- luigi/lock.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/luigi/lock.py b/luigi/lock.py index 1b31ed0c90..e1a604f540 100644 --- a/luigi/lock.py +++ b/luigi/lock.py @@ -21,6 +21,7 @@ """ from __future__ import print_function +import errno import hashlib import os import sys @@ -102,10 +103,14 @@ def acquire_for(pid_dir, num_available=1, kill_signal=None): my_pid, my_cmd, pid_file = get_info(pid_dir) - # Check if there is a pid file corresponding to this name - if not os.path.exists(pid_dir): + # Create a pid file if it does not exist + try: os.mkdir(pid_dir) os.chmod(pid_dir, 0o777) + except OSError as exc: + if exc.errno != errno.EEXIST: + raise + pass # Let variable "pids" be all pids who exist in the .pid-file who are still # about running the same command. From 0a63cc0a360a4b7dfcb0203d4a41719373a292da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Charles-Andr=C3=A9=20Bouffard?= Date: Thu, 2 Aug 2018 16:16:33 -0400 Subject: [PATCH 07/13] Add metadata columns to the RDBMS contrib (#2440) * Add metadata column feature to Redshift The goal of this feature is to allow metadata column to exists for specific tables created by the Redshift contrib related tasks. Given the scenario where we would always have to have a `created_tz` column at the end of every tables generated by that contrib we could do the following: ```python UpdateTableTask(redshift.S3CopyToTable): def metadata_columns(self): return [('created_tz', 'TIMESTAMP')] def post_queries(self): return query = 'UPDATE {0} ' \ 'SET created_tz = CURRENT_TIMESTAMP ' \ 'WHERE created_tz IS NULL'.format(self.table) ``` Adding layer of abstraction over this feature, you could easily add many default behavior for specific tables for versioning the table and more. This feature is opt-in by default since we don't want this break other people's pipeline after integrating this. * Move the Metadata Columns implementation to the RDBMS As suggested in the code-review, there are multiple other DBs that could benefit from this change. Currently, only PSQL and Redshift implements RDMS but other may implement this class and inherits that new behavior. * Add tests for the Metadata behaviors We've been internally using this feature for Redshift but moving this to the RDBMS contrib and adding this behavior to PSQL could have unexpected side effects, this takes care of testing if the feature works correctly under Redshift and PSQL. * Add additional documentation on how to use the new mixin * Raise ValueError on invalid metadata_columns for RDBMS If the count of metadata_columns is 0 and we're expecting to add them to the table, then we raise an error because that is an invalid flow. The contributor is required to have metadata_columns values if we want to add that column to the table. --- luigi/contrib/postgres.py | 2 + luigi/contrib/rdbms.py | 123 +++++++++++++++- luigi/contrib/redshift.py | 14 ++ test/contrib/postgres_test.py | 52 +++++++ test/contrib/rdbms_test.py | 255 ++++++++++++++++++++++++++++++++++ test/contrib/redshift_test.py | 92 ++++++++++++ 6 files changed, 537 insertions(+), 1 deletion(-) create mode 100644 test/contrib/rdbms_test.py diff --git a/luigi/contrib/postgres.py b/luigi/contrib/postgres.py index ec8b72df16..363cde70b0 100644 --- a/luigi/contrib/postgres.py +++ b/luigi/contrib/postgres.py @@ -326,6 +326,8 @@ def run(self): self.init_copy(connection) self.copy(cursor, tmp_file) self.post_copy(connection) + if self.enable_metadata_columns: + self.post_copy_metacolumns(cursor) except psycopg2.ProgrammingError as e: if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE and attempt == 0: # if first attempt fails with "relation not found", try creating table diff --git a/luigi/contrib/rdbms.py b/luigi/contrib/rdbms.py index d7b82d4c9d..a955c87275 100644 --- a/luigi/contrib/rdbms.py +++ b/luigi/contrib/rdbms.py @@ -27,7 +27,125 @@ logger = logging.getLogger('luigi-interface') -class CopyToTable(luigi.task.MixinNaiveBulkComplete, luigi.Task): +class _MetadataColumnsMixin(object): + """Provide an additional behavior that adds columns and values to tables + + This mixin is used to provide an additional behavior that allow a task to + add generic metadata columns to every table created for both PSQL and + Redshift. + + Example: + + This is a use-case example of how this mixin could come handy and how + to use it. + + .. code:: python + + class CommonMetaColumnsBehavior(object): + def update_report_execution_date_query(self): + query = "UPDATE {0} " \ + "SET date_param = DATE '{1}' " \ + "WHERE date_param IS NULL".format(self.table, self.date) + + return query + + @property + def metadata_columns(self): + if self.date: + cols.append(('date_param', 'VARCHAR')) + + return cols + + @property + def metadata_queries(self): + queries = [self.update_created_tz_query()] + if self.date: + queries.append(self.update_report_execution_date_query()) + + return queries + + + class RedshiftCopyCSVToTableFromS3(CommonMetaColumnsBehavior, redshift.S3CopyToTable): + "We have some business override here that would only add noise to the + example, so let's assume that this is only a shell." + pass + + + class UpdateTableA(RedshiftCopyCSVToTableFromS3): + date = luigi.Parameter() + table = 'tableA' + + def queries(): + return [query_content_for('/queries/deduplicate_dupes.sql')] + + + class UpdateTableB(RedshiftCopyCSVToTableFromS3): + date = luigi.Parameter() + table = 'tableB' + """ + @property + def metadata_columns(self): + """Returns the default metadata columns. + + Those columns are columns that we want each tables to have by default. + """ + return [] + + @property + def metadata_queries(self): + return [] + + @property + def enable_metadata_columns(self): + return False + + def _add_metadata_columns(self, connection): + cursor = connection.cursor() + + for column in self.metadata_columns: + if len(column) == 0: + raise ValueError("_add_metadata_columns is unable to infer column information from column {column} for {table}".format(column=column, + table=self.table)) + + column_name = column[0] + if not self._column_exists(cursor, column_name): + logger.info('Adding missing metadata column {column} to {table}'.format(column=column, table=self.table)) + self._add_column_to_table(cursor, column) + + def _column_exists(self, cursor, column_name): + if '.' in self.table: + schema, table = self.table.split('.') + query = "SELECT 1 AS column_exists " \ + "FROM information_schema.columns " \ + "WHERE table_schema = LOWER('{0}') AND table_name = LOWER('{1}') AND column_name = LOWER('{2}') LIMIT 1;".format(schema, table, column_name) + else: + query = "SELECT 1 AS column_exists " \ + "FROM information_schema.columns " \ + "WHERE table_name = LOWER('{0}') AND column_name = LOWER('{1}') LIMIT 1;".format(self.table, column_name) + + cursor.execute(query) + result = cursor.fetchone() + return bool(result) + + def _add_column_to_table(self, cursor, column): + if len(column) == 1: + raise ValueError("_add_column_to_table() column type not specified for {column}".format(column=column[0])) + elif len(column) == 2: + query = "ALTER TABLE {table} ADD COLUMN {column};".format(table=self.table, column=' '.join(column)) + elif len(column) == 3: + query = "ALTER TABLE {table} ADD COLUMN {column} ENCODE {encoding};".format(table=self.table, column=' '.join(column[0:2]), encoding=column[2]) + else: + raise ValueError("_add_column_to_table() found no matching behavior for {column}".format(column=column)) + + cursor.execute(query) + + def post_copy_metacolumns(self, cursor): + logger.info('Executing post copy metadata queries') + for query in self.metadata_queries: + cursor.execute(query) + + +class CopyToTable(luigi.task.MixinNaiveBulkComplete, _MetadataColumnsMixin, luigi.Task): """ An abstract task for inserting a data set into RDBMS. @@ -120,6 +238,9 @@ def init_copy(self, connection): if hasattr(self, "clear_table"): raise Exception("The clear_table attribute has been removed. Override init_copy instead!") + if self.enable_metadata_columns: + self._add_metadata_columns(connection.cursor()) + def post_copy(self, connection): """ Override to perform custom queries. diff --git a/luigi/contrib/redshift.py b/luigi/contrib/redshift.py index a1c5c43172..6792efff09 100644 --- a/luigi/contrib/redshift.py +++ b/luigi/contrib/redshift.py @@ -373,6 +373,9 @@ def run(self): self.copy(cursor, path) self.post_copy(cursor) + if self.enable_metadata_columns: + self.post_copy_metacolumns(cursor) + # update marker table output.touch(connection) connection.commit() @@ -472,6 +475,9 @@ def init_copy(self, connection): logger.info("Creating table %s", self.table) self.create_table(connection) + if self.enable_metadata_columns: + self._add_metadata_columns(connection) + if self.do_truncate_table: logger.info("Truncating table %s", self.table) self.truncate_table(connection) @@ -488,6 +494,14 @@ def post_copy(self, cursor): for query in self.queries: cursor.execute(query) + def post_copy_metacolums(self, cursor): + """ + Performs post-copy to fill metadata columns. + """ + logger.info('Executing post copy metadata queries') + for query in self.metadata_queries: + cursor.execute(query) + class S3CopyJSONToTable(S3CopyToTable, _CredentialsMixin): """ diff --git a/test/contrib/postgres_test.py b/test/contrib/postgres_test.py index 5df6888343..eadd6d1018 100644 --- a/test/contrib/postgres_test.py +++ b/test/contrib/postgres_test.py @@ -121,3 +121,55 @@ def test_bulk_complete(self, mock_connect): 'DummyPostgresQuery_2015_01_06_f91a47ec40', ]) self.assertFalse(task.complete()) + + +@attr('postgres') +class TestCopyToTableWithMetaColumns(unittest.TestCase): + @mock.patch("luigi.contrib.postgres.CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.postgres.CopyToTable._add_metadata_columns") + @mock.patch("luigi.contrib.postgres.CopyToTable.post_copy_metacolumns") + @mock.patch("luigi.contrib.postgres.CopyToTable.rows", return_value=['row1', 'row2']) + @mock.patch("luigi.contrib.postgres.PostgresTarget") + @mock.patch('psycopg2.connect') + def test_copy_with_metadata_columns_enabled(self, + mock_connect, + mock_redshift_target, + mock_rows, + mock_add_columns, + mock_update_columns, + mock_metadata_columns_enabled): + + task = DummyPostgresImporter(date=datetime.datetime(1991, 3, 24)) + + mock_cursor = MockPostgresCursor([task.task_id]) + mock_connect.return_value.cursor.return_value = mock_cursor + + task = DummyPostgresImporter(date=datetime.datetime(1991, 3, 24)) + task.run() + + self.assertTrue(mock_add_columns.called) + self.assertTrue(mock_update_columns.called) + + @mock.patch("luigi.contrib.postgres.CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=False) + @mock.patch("luigi.contrib.postgres.CopyToTable._add_metadata_columns") + @mock.patch("luigi.contrib.postgres.CopyToTable.post_copy_metacolumns") + @mock.patch("luigi.contrib.postgres.CopyToTable.rows", return_value=['row1', 'row2']) + @mock.patch("luigi.contrib.postgres.PostgresTarget") + @mock.patch('psycopg2.connect') + def test_copy_with_metadata_columns_disabled(self, + mock_connect, + mock_redshift_target, + mock_rows, + mock_add_columns, + mock_update_columns, + mock_metadata_columns_enabled): + + task = DummyPostgresImporter(date=datetime.datetime(1991, 3, 24)) + + mock_cursor = MockPostgresCursor([task.task_id]) + mock_connect.return_value.cursor.return_value = mock_cursor + + task.run() + + self.assertFalse(mock_add_columns.called) + self.assertFalse(mock_update_columns.called) diff --git a/test/contrib/rdbms_test.py b/test/contrib/rdbms_test.py new file mode 100644 index 0000000000..3127cb2e8d --- /dev/null +++ b/test/contrib/rdbms_test.py @@ -0,0 +1,255 @@ +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +We're using Redshift as the test bed since Redshift implements RDBMS. We could +have opted for PSQL but we're less familiar with that contrib and there are +less examples on how to test it. +""" + +import luigi +import luigi.contrib.redshift +import mock + +import unittest + + +# Fake AWS and S3 credentials taken from `../redshift_test.py`. +AWS_ACCESS_KEY = 'key' +AWS_SECRET_KEY = 'secret' + +AWS_ACCOUNT_ID = '0123456789012' +AWS_ROLE_NAME = 'MyRedshiftRole' + +BUCKET = 'bucket' +KEY = 'key' + + +class DummyS3CopyToTableBase(luigi.contrib.redshift.S3CopyToTable): + # Class attributes taken from `DummyPostgresImporter` in + # `../postgres_test.py`. + host = 'dummy_host' + database = 'dummy_database' + user = 'dummy_user' + password = 'dummy_password' + table = luigi.Parameter(default='dummy_table') + columns = luigi.TupleParameter( + default=( + ('some_text', 'varchar(255)'), + ('some_int', 'int'), + ) + ) + + copy_options = '' + prune_table = '' + prune_column = '' + prune_date = '' + + def s3_load_path(self): + return 's3://%s/%s' % (BUCKET, KEY) + + +class DummyS3CopyToTableKey(DummyS3CopyToTableBase): + aws_access_key_id = AWS_ACCESS_KEY + aws_secret_access_key = AWS_SECRET_KEY + + +class TestS3CopyToTableWithMetaColumns(unittest.TestCase): + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[('created_tz', 'TIMESTAMP')]) + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_check_meta_columns_to_table_if_exists(self, + mock_redshift_target, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey(table='my_test_table') + task.run() + + mock_cursor = (mock_redshift_target.return_value + .connect + .return_value + .cursor + .return_value) + + executed_query = mock_cursor.execute.call_args_list[1][0][0] + + expected_output = "SELECT 1 AS column_exists FROM information_schema.columns " \ + "WHERE table_name = LOWER('{table}') " \ + "AND column_name = LOWER('{column}') " \ + "LIMIT 1;".format(table='my_test_table', column='created_tz') + + self.assertEqual(executed_query, expected_output) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[('created_tz', 'TIMESTAMP')]) + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_check_meta_columns_to_schematable_if_exists(self, + mock_redshift_target, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey(table='test.my_test_table') + task.run() + + mock_cursor = (mock_redshift_target.return_value + .connect + .return_value + .cursor + .return_value) + + executed_query = mock_cursor.execute.call_args_list[2][0][0] + + expected_output = "SELECT 1 AS column_exists FROM information_schema.columns " \ + "WHERE table_schema = LOWER('{schema}') " \ + "AND table_name = LOWER('{table}') " \ + "AND column_name = LOWER('{column}') " \ + "LIMIT 1;".format(schema='test', table='my_test_table', column='created_tz') + + self.assertEqual(executed_query, expected_output) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[('created_tz', 'TIMESTAMP')]) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_column_to_table") + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_not_add_if_meta_columns_already_exists(self, + mock_redshift_target, + mock_add_to_table, + mock_columns_exists, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey() + task.run() + + self.assertFalse(mock_add_to_table.called) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[('created_tz', 'TIMESTAMP')]) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=False) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_column_to_table") + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_add_if_meta_columns_not_already_exists(self, + mock_redshift_target, + mock_add_to_table, + mock_columns_exists, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey() + task.run() + + self.assertTrue(mock_add_to_table.called) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[('created_tz', 'TIMESTAMP')]) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=False) + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_add_regular_column(self, + mock_redshift_target, + mock_columns_exists, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey(table='my_test_table') + task.run() + + mock_cursor = (mock_redshift_target.return_value + .connect + .return_value + .cursor + .return_value) + + executed_query = mock_cursor.execute.call_args_list[1][0][0] + + expected_output = "ALTER TABLE {table} " \ + "ADD COLUMN {column} {type};".format(table='my_test_table', column='created_tz', type='TIMESTAMP') + + self.assertEqual(executed_query, expected_output) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[('created_tz', 'TIMESTAMP', 'bytedict')]) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=False) + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_add_encoded_column(self, + mock_redshift_target, + mock_columns_exists, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey(table='my_test_table') + task.run() + + mock_cursor = (mock_redshift_target.return_value + .connect + .return_value + .cursor + .return_value) + + executed_query = mock_cursor.execute.call_args_list[1][0][0] + + expected_output = "ALTER TABLE {table} " \ + "ADD COLUMN {column} {type} ENCODE {encoding};".format(table='my_test_table', column='created_tz', + type='TIMESTAMP', + encoding='bytedict') + + self.assertEqual(executed_query, expected_output) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[('created_tz')]) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=False) + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_raise_error_on_no_column_type(self, + mock_redshift_target, + mock_columns_exists, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey() + + with self.assertRaises(ValueError): + task.run() + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, + return_value=[('created_tz', 'TIMESTAMP', 'bytedict', '42')]) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=False) + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_raise_error_on_invalid_column(self, + mock_redshift_target, + mock_columns_exists, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey() + + with self.assertRaises(ValueError): + task.run() + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_queries", new_callable=mock.PropertyMock, return_value=['SELECT 1 FROM X', 'SELECT 2 FROM Y']) + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_post_copy_metacolumns(self, + mock_redshift_target, + mock_metadata_queries, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey() + task.run() + + mock_cursor = (mock_redshift_target.return_value + .connect + .return_value + .cursor + .return_value) + + executed_query = mock_cursor.execute.call_args_list[2][0][0] + expected_output = "SELECT 1 FROM X" + self.assertEqual(executed_query, expected_output) + + executed_query = mock_cursor.execute.call_args_list[3][0][0] + expected_output = "SELECT 2 FROM Y" + self.assertEqual(executed_query, expected_output) diff --git a/test/contrib/redshift_test.py b/test/contrib/redshift_test.py index c6b23bf2b1..5433c6d186 100644 --- a/test/contrib/redshift_test.py +++ b/test/contrib/redshift_test.py @@ -80,6 +80,36 @@ def s3_load_path(self): return 's3://%s/%s' % (BUCKET, KEY) +class DummyS3CopyJSONToTableBase(luigi.contrib.redshift.S3CopyJSONToTable): + # Class attributes taken from `DummyPostgresImporter` in + # `../postgres_test.py`. + aws_access_key_id = AWS_ACCESS_KEY + aws_secret_access_key = AWS_SECRET_KEY + + host = 'dummy_host' + database = 'dummy_database' + user = 'dummy_user' + password = 'dummy_password' + table = luigi.Parameter(default='dummy_table') + columns = luigi.TupleParameter( + default=( + ('some_text', 'varchar(255)'), + ('some_int', 'int'), + ) + ) + + copy_options = '' + prune_table = '' + prune_column = '' + prune_date = '' + + jsonpath = '' + copy_json_options = '' + + def s3_load_path(self): + return 's3://%s/%s' % (BUCKET, KEY) + + class DummyS3CopyToTableKey(DummyS3CopyToTableBase): aws_access_key_id = AWS_ACCESS_KEY aws_secret_access_key = AWS_SECRET_KEY @@ -130,6 +160,68 @@ def test_from_config(self): self.assertEqual(self.aws_secret_access_key, "config_secret") +class TestS3CopyToTableWithMetaColumns(unittest.TestCase): + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_metadata_columns") + @mock.patch("luigi.contrib.redshift.S3CopyToTable.post_copy_metacolumns") + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_with_metadata_columns_enabled(self, + mock_redshift_target, + mock_add_columns, + mock_update_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey() + task.run() + + self.assertTrue(mock_add_columns.called) + self.assertTrue(mock_update_columns.called) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=False) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_metadata_columns") + @mock.patch("luigi.contrib.redshift.S3CopyToTable.post_copy_metacolumns") + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_with_metadata_columns_disabled(self, + mock_redshift_target, + mock_add_columns, + mock_update_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey() + task.run() + + self.assertFalse(mock_add_columns.called) + self.assertFalse(mock_update_columns.called) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_metadata_columns") + @mock.patch("luigi.contrib.redshift.S3CopyToTable.post_copy_metacolumns") + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_json_copy_with_metadata_columns_enabled(self, + mock_redshift_target, + mock_add_columns, + mock_update_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyJSONToTableBase() + task.run() + + self.assertTrue(mock_add_columns.called) + self.assertTrue(mock_update_columns.called) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=False) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_metadata_columns") + @mock.patch("luigi.contrib.redshift.S3CopyToTable.post_copy_metacolumns") + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_json_copy_with_metadata_columns_disabled(self, + mock_redshift_target, + mock_add_columns, + mock_update_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyJSONToTableBase() + task.run() + + self.assertFalse(mock_add_columns.called) + self.assertFalse(mock_update_columns.called) + + class TestS3CopyToTable(unittest.TestCase): @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_copy_missing_creds(self, mock_redshift_target): From bd55c28d2572ee53018e8276e10913658ae44027 Mon Sep 17 00:00:00 2001 From: Stas Glubokiy Date: Mon, 6 Aug 2018 14:02:36 +0300 Subject: [PATCH 08/13] Add support for multiple requires and inherits arguments (#2475) Add support for multiple task arguments for requires and inherits decorators --- luigi/util.py | 83 ++++++++++++++++++++++++++++------------ test/decorator_test.py | 87 ++++++++++++++++++++++++++++++------------ 2 files changed, 121 insertions(+), 49 deletions(-) diff --git a/luigi/util.py b/luigi/util.py index f6be20021b..784a778602 100644 --- a/luigi/util.py +++ b/luigi/util.py @@ -52,7 +52,7 @@ def requires(self): more burdensome than the last. Refactoring becomes more difficult. There are several ways one might try and avoid the problem. -**Approach 1**: Parameters via command line or config instead of ``requires``. +**Approach 1**: Parameters via command line or config instead of :func:`~luigi.task.Task.requires`. .. code-block:: python @@ -132,13 +132,13 @@ def requires(self): specified in the wrong order. This contrived example is easy to fix (by swapping the ordering of the parents of ``TaskA``), but real world cases can be more difficult to both spot and fix. Inheriting from multiple classes -derived from ``luigi.Task`` should be undertaken with caution and avoided +derived from :class:`~luigi.task.Task` should be undertaken with caution and avoided where possible. -**Approach 3**: Use ``inherits`` and ``requires`` +**Approach 3**: Use :class:`~luigi.util.inherits` and :class:`~luigi.util.requires` -The ``inherits`` class decorator in this module copies parameters (and +The :class:`~luigi.util.inherits` class decorator in this module copies parameters (and nothing else) from one task class to another, and avoids direct pythonic inheritance. @@ -185,11 +185,12 @@ def requires(self): issues, and keeps the task command line interface as simple (as it can be, anyway). Refactoring task parameters is also much easier. -The ``requires`` helper function can reduce this pattern even further. It -does everything ``inherits`` does, and also attaches a ``requires`` method +The :class:`~luigi.util.requires` helper function can reduce this pattern even further. It +does everything :class:`~luigi.util.inherits` does, +and also attaches a :class:`~luigi.util.requires` method to your task (still all without pythonic inheritance). -But how does it know how to invoke the upstream task? It uses ``clone`` +But how does it know how to invoke the upstream task? It uses :func:`~luigi.task.Task.clone` behind the scenes! .. code-block:: python @@ -251,59 +252,91 @@ class inherits(object): """ Task inheritance. + *New after Luigi 2.7.6:* multiple arguments support. + Usage: .. code-block:: python class AnotherTask(luigi.Task): + m = luigi.IntParameter() + + class YetAnotherTask(luigi.Task): n = luigi.IntParameter() - # ... @inherits(AnotherTask): - class MyTask(luigi.Task): + class MyFirstTask(luigi.Task): def requires(self): return self.clone_parent() + def run(self): + print self.m # this will be defined + # ... + + @inherits(AnotherTask, YetAnotherTask): + class MySecondTask(luigi.Task): + def requires(self): + return self.clone_parents() + def run(self): print self.n # this will be defined # ... """ - def __init__(self, task_to_inherit): + def __init__(self, *tasks_to_inherit): super(inherits, self).__init__() - self.task_to_inherit = task_to_inherit + if not tasks_to_inherit: + raise TypeError("tasks_to_inherit cannot be empty") + + self.tasks_to_inherit = tasks_to_inherit def __call__(self, task_that_inherits): - # Get all parameter objects from the underlying task - for param_name, param_obj in self.task_to_inherit.get_params(): - # Check if the parameter exists in the inheriting task - if not hasattr(task_that_inherits, param_name): - # If not, add it to the inheriting task - setattr(task_that_inherits, param_name, param_obj) + # Get all parameter objects from each of the underlying tasks + for task_to_inherit in self.tasks_to_inherit: + for param_name, param_obj in task_to_inherit.get_params(): + # Check if the parameter exists in the inheriting task + if not hasattr(task_that_inherits, param_name): + # If not, add it to the inheriting task + setattr(task_that_inherits, param_name, param_obj) # Modify task_that_inherits by adding methods - def clone_parent(_self, **args): - return _self.clone(cls=self.task_to_inherit, **args) + def clone_parent(_self, **kwargs): + return _self.clone(cls=self.tasks_to_inherit[0], **kwargs) task_that_inherits.clone_parent = clone_parent + def clone_parents(_self, **kwargs): + return [ + _self.clone(cls=task_to_inherit, **kwargs) + for task_to_inherit in self.tasks_to_inherit + ] + task_that_inherits.clone_parents = clone_parents + return task_that_inherits class requires(object): """ - Same as @inherits, but also auto-defines the requires method. + Same as :class:`~luigi.util.inherits`, but also auto-defines the requires method. + + *New after Luigi 2.7.6:* multiple arguments support. + """ - def __init__(self, task_to_require): + def __init__(self, *tasks_to_require): super(requires, self).__init__() - self.inherit_decorator = inherits(task_to_require) + if not tasks_to_require: + raise TypeError("tasks_to_require cannot be empty") + + self.tasks_to_require = tasks_to_require def __call__(self, task_that_requires): - task_that_requires = self.inherit_decorator(task_that_requires) + task_that_requires = inherits(*self.tasks_to_require)(task_that_requires) - # Modify task_that_requres by adding methods + # Modify task_that_requires by adding requires method. + # If only one task is required, this single task is returned. + # Otherwise, list of tasks is returned def requires(_self): - return _self.clone_parent() + return _self.clone_parent() if len(self.tasks_to_require) == 1 else _self.clone_parents() task_that_requires.requires = requires return task_that_requires diff --git a/test/decorator_test.py b/test/decorator_test.py index 0e113caaf6..e9851a269e 100644 --- a/test/decorator_test.py +++ b/test/decorator_test.py @@ -53,9 +53,14 @@ class D_null(luigi.Task): param1 = None +@inherits(A, B) +class E(luigi.Task): + param4 = luigi.Parameter("class E-specific default") + + @inherits(A) @inherits(B) -class E(luigi.Task): +class E_stacked(luigi.Task): param4 = luigi.Parameter("class E-specific default") @@ -69,6 +74,7 @@ def setUp(self): self.d = D() self.d_null = D_null() self.e = E() + self.e_stacked = E_stacked() def test_has_param(self): b_params = dict(self.b.get_params()).keys() @@ -91,11 +97,22 @@ def test_overwriting_defaults(self): self.assertNotEqual(self.d.param1, self.a.param1) self.assertEqual(self.d.param1, "class D overwriting class A's default") - def test_stacked_inheritance(self): + def test_multiple_inheritance(self): self.assertEqual(self.e.param1, self.a.param1) self.assertEqual(self.e.param1, self.b.param1) self.assertEqual(self.e.param2, self.b.param2) + def test_stacked_inheritance(self): + self.assertEqual(self.e_stacked.param1, self.a.param1) + self.assertEqual(self.e_stacked.param1, self.b.param1) + self.assertEqual(self.e_stacked.param2, self.b.param2) + + def test_empty_inheritance(self): + with self.assertRaises(TypeError): + @inherits() + class shouldfail(luigi.Task): + pass + def test_removing_parameter(self): self.assertFalse("param1" in dict(self.d_null.get_params()).keys()) @@ -226,53 +243,75 @@ def test_wrong_common_params_order(self): self.assertRaises(TypeError, self.k_wrongparamsorder.requires) -class X(luigi.Task): +class V(luigi.Task): n = luigi.IntParameter(default=42) -@inherits(X) -class Y(luigi.Task): +@inherits(V) +class W(luigi.Task): def requires(self): return self.clone_parent() -@requires(X) -class Y2(luigi.Task): +@requires(V) +class W2(luigi.Task): pass -@requires(X) -class Y3(luigi.Task): +@requires(V) +class W3(luigi.Task): n = luigi.IntParameter(default=43) +class X(luigi.Task): + m = luigi.IntParameter(default=56) + + +@requires(V, X) +class Y(luigi.Task): + pass + + class CloneParentTest(unittest.TestCase): def test_clone_parent(self): - y = Y() - x = X() - self.assertEqual(y.requires(), x) - self.assertEqual(y.n, 42) + w = W() + v = V() + self.assertEqual(w.requires(), v) + self.assertEqual(w.n, 42) def test_requires(self): - y2 = Y2() - x = X() - self.assertEqual(y2.requires(), x) - self.assertEqual(y2.n, 42) + w2 = W2() + v = V() + self.assertEqual(w2.requires(), v) + self.assertEqual(w2.n, 42) def test_requires_override_default(self): - y3 = Y3() + w3 = W3() + v = V() + self.assertNotEqual(w3.requires(), v) + self.assertEqual(w3.n, 43) + self.assertEqual(w3.requires().n, 43) + + def test_multiple_requires(self): + y = Y() + v = V() x = X() - self.assertNotEqual(y3.requires(), x) - self.assertEqual(y3.n, 43) - self.assertEqual(y3.requires().n, 43) + self.assertEqual(y.requires()[0], v) + self.assertEqual(y.requires()[1], x) + + def test_empty_requires(self): + with self.assertRaises(TypeError): + @requires() + class shouldfail(luigi.Task): + pass def test_names(self): # Just make sure the decorators retain the original class names - x = X() - self.assertEqual(str(x), 'X(n=42)') - self.assertEqual(x.__class__.__name__, 'X') + v = V() + self.assertEqual(str(v), 'V(n=42)') + self.assertEqual(v.__class__.__name__, 'V') class P(luigi.Task): From c9ed761a42aef0c7dfcd526cc70d5bf73bfed0e3 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 8 Aug 2018 08:58:26 +0300 Subject: [PATCH 09/13] Add a visiblity level for luigi.Parameters (#2278) See the docs for usage. --- doc/parameters.rst | 19 +++ luigi/parameter.py | 42 +++++- luigi/scheduler.py | 30 +++-- luigi/task.py | 18 ++- luigi/worker.py | 3 + setup.py | 4 + test/db_task_history_test.py | 4 +- test/scheduler_parameter_visibilities_test.py | 120 ++++++++++++++++++ test/visible_parameters_test.py | 95 ++++++++++++++ tox.ini | 1 + 10 files changed, 319 insertions(+), 17 deletions(-) create mode 100644 test/scheduler_parameter_visibilities_test.py create mode 100644 test/visible_parameters_test.py diff --git a/doc/parameters.rst b/doc/parameters.rst index 1a4a8a721b..6dca716c30 100644 --- a/doc/parameters.rst +++ b/doc/parameters.rst @@ -88,6 +88,25 @@ are not the same instance: >>> hash(c) == hash(d) True +Parameter visibility +^^^^^^^^^^^^^^^^^^^^ + +Using :class:`~luigi.parameter.ParameterVisibility` you can configure parameter visibility. By default, all +parameters are public, but you can also set them hidden or private. + +.. code:: python + + >>> import luigi + >>> from luigi.parameter import ParameterVisibility + + >>> luigi.Parameter(visibility=ParameterVisibility.PRIVATE) + +``ParameterVisibility.PUBLIC`` (default) - visible everywhere + +``ParameterVisibility.HIDDEN`` - ignored in WEB-view, but saved into database if save db_history is true + +``ParameterVisibility.PRIVATE`` - visible only inside task. + Parameter types ^^^^^^^^^^^^^^^ diff --git a/luigi/parameter.py b/luigi/parameter.py index f619864090..4c4c3853a0 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -23,6 +23,7 @@ import abc import datetime import warnings +from enum import IntEnum import json from json import JSONEncoder from collections import OrderedDict, Mapping @@ -40,10 +41,26 @@ from luigi import configuration from luigi.cmdline_parser import CmdlineParser - _no_value = object() +class ParameterVisibility(IntEnum): + """ + Possible values for the parameter visibility option. Public is the default. + See :doc:`/parameters` for more info. + """ + PUBLIC = 0 + HIDDEN = 1 + PRIVATE = 2 + + @classmethod + def has_value(cls, value): + return any(value == item.value for item in cls) + + def serialize(self): + return self.value + + class ParameterException(Exception): """ Base exception. @@ -113,7 +130,8 @@ def run(self): _counter = 0 # non-atomically increasing counter used for ordering parameters. def __init__(self, default=_no_value, is_global=False, significant=True, description=None, - config_path=None, positional=True, always_in_help=False, batch_method=None): + config_path=None, positional=True, always_in_help=False, batch_method=None, + visibility=ParameterVisibility.PUBLIC): """ :param default: the default value for this parameter. This should match the type of the Parameter, i.e. ``datetime.date`` for ``DateParameter`` or ``int`` for @@ -140,6 +158,10 @@ def __init__(self, default=_no_value, is_global=False, significant=True, descrip parameter values into a single value. Used when receiving batched parameter lists from the scheduler. See :ref:`batch_method` + + :param visibility: A Parameter whose value is a :py:class:`~luigi.parameter.ParameterVisibility`. + Default value is ParameterVisibility.PUBLIC + """ self._default = default self._batch_method = batch_method @@ -150,6 +172,7 @@ def __init__(self, default=_no_value, is_global=False, significant=True, descrip positional = False self.significant = significant # Whether different values for this parameter will differentiate otherwise equal tasks self.positional = positional + self.visibility = visibility if ParameterVisibility.has_value(visibility) else ParameterVisibility.PUBLIC self.description = description self.always_in_help = always_in_help @@ -195,11 +218,11 @@ def _value_iterator(self, task_name, param_name): yield (self._get_value_from_config(task_name, param_name), None) yield (self._get_value_from_config(task_name, param_name.replace('_', '-')), 'Configuration [{}] {} (with dashes) should be avoided. Please use underscores.'.format( - task_name, param_name)) + task_name, param_name)) if self._config_path: yield (self._get_value_from_config(self._config_path['section'], self._config_path['name']), 'The use of the configuration [{}] {} is deprecated. Please use [{}] {}'.format( - self._config_path['section'], self._config_path['name'], task_name, param_name)) + self._config_path['section'], self._config_path['name'], task_name, param_name)) yield (self._default, None) def has_task_value(self, task_name, param_name): @@ -689,6 +712,7 @@ class DateIntervalParameter(Parameter): (eg. "2015-W35"). In addition, it also supports arbitrary date intervals provided as two dates separated with a dash (eg. "2015-11-04-2015-12-04"). """ + def parse(self, s): """ Parses a :py:class:`~luigi.date_interval.DateInterval` from the input. @@ -740,8 +764,10 @@ def field(key): def optional_field(key): return "(%s)?" % field(key) + # A little loose: ISO 8601 does not allow weeks in combination with other fields, but this regex does (as does python timedelta) - regex = "P(%s|%s(T%s)?)" % (field("weeks"), optional_field("days"), "".join([optional_field(key) for key in ["hours", "minutes", "seconds"]])) + regex = "P(%s|%s(T%s)?)" % (field("weeks"), optional_field("days"), + "".join([optional_field(key) for key in ["hours", "minutes", "seconds"]])) return self._apply_regex(regex, input) def _parseSimple(self, input): @@ -905,6 +931,7 @@ class _DictParamEncoder(JSONEncoder): """ JSON encoder for :py:class:`~DictParameter`, which makes :py:class:`~_FrozenOrderedDict` JSON serializable. """ + def default(self, obj): if isinstance(obj, _FrozenOrderedDict): return obj.get_wrapped() @@ -943,6 +970,7 @@ def run(self): tags, that are dynamically constructed outside Luigi), or you have a complex parameter containing logically related values (like a database connection config). """ + def normalize(self, value): """ Ensure that dictionary parameter is converted to a _FrozenOrderedDict so it can be hashed. @@ -996,6 +1024,7 @@ def run(self): $ luigi --module my_tasks MyTask --grades '[100,70]' """ + def normalize(self, x): """ Ensure that struct is recursively converted to a tuple so it can be hashed. @@ -1053,6 +1082,7 @@ def run(self): $ luigi --module my_tasks MyTask --book_locations '((12,3),(4,15),(52,1))' """ + def parse(self, x): """ Parse an individual value from the input. @@ -1100,6 +1130,7 @@ class MyTask(luigi.Task): $ luigi --module my_tasks MyTask --my-param-1 -3 --my-param-2 -2 """ + def __init__(self, left_op=operator.le, right_op=operator.lt, *args, **kwargs): """ :param function var_type: The type of the input variable, e.g. int or float. @@ -1178,6 +1209,7 @@ class MyTask(luigi.Task): same type and transparency of parameter value on the command line is desired. """ + def __init__(self, var_type=str, *args, **kwargs): """ :param function var_type: The type of the input variable, e.g. str, int, diff --git a/luigi/scheduler.py b/luigi/scheduler.py index b7993c760b..fbc01a838d 100644 --- a/luigi/scheduler.py +++ b/luigi/scheduler.py @@ -49,6 +49,7 @@ from luigi.task_status import DISABLED, DONE, FAILED, PENDING, RUNNING, SUSPENDED, UNKNOWN, \ BATCH_RUNNING from luigi.task import Config +from luigi.parameter import ParameterVisibility logger = logging.getLogger(__name__) @@ -280,7 +281,7 @@ def __eq__(self, other): class Task(object): def __init__(self, task_id, status, deps, resources=None, priority=0, family='', module=None, - params=None, accepts_messages=False, tracking_url=None, status_message=None, + params=None, param_visibilities=None, accepts_messages=False, tracking_url=None, status_message=None, progress_percentage=None, retry_policy='notoptional'): self.id = task_id self.stakeholders = set() # workers ids that are somehow related to this task (i.e. don't prune while any of these workers are still active) @@ -301,8 +302,11 @@ def __init__(self, task_id, status, deps, resources=None, priority=0, family='', self.resources = _get_default(resources, {}) self.family = family self.module = module - self.params = _get_default(params, {}) - + self.param_visibilities = _get_default(param_visibilities, {}) + self.params = {} + self.public_params = {} + self.hidden_params = {} + self.set_params(params) self.accepts_messages = accepts_messages self.retry_policy = retry_policy self.failures = Failures(self.retry_policy.disable_window) @@ -318,6 +322,13 @@ def __init__(self, task_id, status, deps, resources=None, priority=0, family='', def __repr__(self): return "Task(%r)" % vars(self) + def set_params(self, params): + self.params = _get_default(params, {}) + self.public_params = {key: value for key, value in self.params.items() if + self.param_visibilities.get(key, ParameterVisibility.PUBLIC) == ParameterVisibility.PUBLIC} + self.hidden_params = {key: value for key, value in self.params.items() if + self.param_visibilities.get(key, ParameterVisibility.PUBLIC) == ParameterVisibility.HIDDEN} + # TODO(2017-08-10) replace this function with direct calls to batchable # this only exists for backward compatibility def is_batchable(self): @@ -343,7 +354,7 @@ def has_excessive_failures(self): @property def pretty_id(self): - param_str = ', '.join(u'{}={}'.format(key, value) for key, value in sorted(self.params.items())) + param_str = ', '.join(u'{}={}'.format(key, value) for key, value in sorted(self.public_params.items())) return u'{}({})'.format(self.family, param_str) @@ -778,7 +789,7 @@ def forgive_failures(self, task_id=None): @rpc_method() def add_task(self, task_id=None, status=PENDING, runnable=True, deps=None, new_deps=None, expl=None, resources=None, - priority=0, family='', module=None, params=None, accepts_messages=False, + priority=0, family='', module=None, params=None, param_visibilities=None, accepts_messages=False, assistant=False, tracking_url=None, worker=None, batchable=None, batch_id=None, retry_policy_dict=None, owners=None, **kwargs): """ @@ -802,7 +813,7 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, if worker.enabled: _default_task = self._make_task( task_id=task_id, status=PENDING, deps=deps, resources=resources, - priority=priority, family=family, module=module, params=params, + priority=priority, family=family, module=module, params=params, param_visibilities=param_visibilities, ) else: _default_task = None @@ -817,8 +828,10 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, task.family = family if not getattr(task, 'module', None): task.module = module + if not task.param_visibilities: + task.param_visibilities = _get_default(param_visibilities, {}) if not task.params: - task.params = _get_default(params, {}) + task.set_params(params) if batch_id is not None: task.batch_id = batch_id @@ -1272,6 +1285,7 @@ def _upstream_status(self, task_id, upstream_status_table): def _serialize_task(self, task_id, include_deps=True, deps=None): task = self._state.get_task(task_id) + ret = { 'display_name': task.pretty_id, 'status': task.status, @@ -1280,7 +1294,7 @@ def _serialize_task(self, task_id, include_deps=True, deps=None): 'time_running': getattr(task, "time_running", None), 'start_time': task.time, 'last_updated': getattr(task, "updated", task.time), - 'params': task.params, + 'params': task.public_params, 'name': task.family, 'priority': task.priority, 'resources': task.resources, diff --git a/luigi/task.py b/luigi/task.py index 4340e513dc..08f27b8179 100644 --- a/luigi/task.py +++ b/luigi/task.py @@ -39,6 +39,7 @@ from luigi import parameter from luigi.task_register import Register +from luigi.parameter import ParameterVisibility Parameter = parameter.Parameter logger = logging.getLogger('luigi-interface') @@ -441,7 +442,7 @@ def __init__(self, *args, **kwargs): self.param_kwargs = dict(param_values) self._warn_on_wrong_param_types() - self.task_id = task_id_str(self.get_task_family(), self.to_str_params(only_significant=True)) + self.task_id = task_id_str(self.get_task_family(), self.to_str_params(only_significant=True, only_public=True)) self.__hash = hash(self.task_id) self.set_tracking_url = None @@ -482,18 +483,29 @@ def from_str_params(cls, params_str): return cls(**kwargs) - def to_str_params(self, only_significant=False): + def to_str_params(self, only_significant=False, only_public=False): """ Convert all parameters to a str->str hash. """ params_str = {} params = dict(self.get_params()) for param_name, param_value in six.iteritems(self.param_kwargs): - if (not only_significant) or params[param_name].significant: + if (((not only_significant) or params[param_name].significant) + and ((not only_public) or params[param_name].visibility == ParameterVisibility.PUBLIC) + and params[param_name].visibility != ParameterVisibility.PRIVATE): params_str[param_name] = params[param_name].serialize(param_value) return params_str + def _get_param_visibilities(self): + param_visibilities = {} + params = dict(self.get_params()) + for param_name, param_value in six.iteritems(self.param_kwargs): + if params[param_name].visibility != ParameterVisibility.PRIVATE: + param_visibilities[param_name] = params[param_name].visibility.serialize() + + return param_visibilities + def clone(self, cls=None, **kwargs): """ Creates a new instance from an existing instance where some of the args have changed. diff --git a/luigi/worker.py b/luigi/worker.py index 5c76bbc3de..54954765e2 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -565,6 +565,9 @@ def _add_task(self, *args, **kwargs): for batch_task in self._batch_running_tasks.pop(task_id): self._add_task_history.append((batch_task, status, True)) + if task and kwargs.get('params'): + kwargs['param_visibilities'] = task._get_param_visibilities() + self._scheduler.add_task(*args, **kwargs) logger.info('Informed scheduler that task %s has status %s', task_id, status) diff --git a/setup.py b/setup.py index 176d671636..89cffbcb1d 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ # the License. import os +import sys from setuptools import setup @@ -48,6 +49,9 @@ def get_static_files(path): install_requires.remove('python-daemon<3.0') install_requires.append('sphinx>=1.4.4') # Value mirrored in doc/conf.py +if sys.version_info < (3, 4): + install_requires.append('enum34>1.1.0') + setup( name='luigi', version='2.7.6', diff --git a/test/db_task_history_test.py b/test/db_task_history_test.py index 8b162d282e..d302bed292 100644 --- a/test/db_task_history_test.py +++ b/test/db_task_history_test.py @@ -24,6 +24,7 @@ from luigi.db_task_history import DbTaskHistory from luigi.task_status import DONE, PENDING, RUNNING import luigi.scheduler +from luigi.parameter import ParameterVisibility class DummyTask(luigi.Task): @@ -32,7 +33,8 @@ class DummyTask(luigi.Task): class ParamTask(luigi.Task): param1 = luigi.Parameter() - param2 = luigi.IntParameter() + param2 = luigi.IntParameter(visibility=ParameterVisibility.HIDDEN) + param3 = luigi.Parameter(default="empty", visibility=ParameterVisibility.PRIVATE) class DbTaskHistoryTest(unittest.TestCase): diff --git a/test/scheduler_parameter_visibilities_test.py b/test/scheduler_parameter_visibilities_test.py new file mode 100644 index 0000000000..b3cae1f579 --- /dev/null +++ b/test/scheduler_parameter_visibilities_test.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from helpers import LuigiTestCase, RunOnceTask +import server_test + +import luigi +import luigi.scheduler +import luigi.worker +from luigi.parameter import ParameterVisibility +import json +import time + + +class SchedulerParameterVisibilitiesTest(LuigiTestCase): + def test_task_with_deps(self): + s = luigi.scheduler.Scheduler(send_messages=True) + with luigi.worker.Worker(scheduler=s) as w: + class DynamicTask(RunOnceTask): + dynamic_public = luigi.Parameter(default="dynamic_public") + dynamic_hidden = luigi.Parameter(default="dynamic_hidden", visibility=ParameterVisibility.HIDDEN) + dynamic_private = luigi.Parameter(default="dynamic_private", visibility=ParameterVisibility.PRIVATE) + + class RequiredTask(RunOnceTask): + required_public = luigi.Parameter(default="required_param") + required_hidden = luigi.Parameter(default="required_hidden", visibility=ParameterVisibility.HIDDEN) + required_private = luigi.Parameter(default="required_private", visibility=ParameterVisibility.PRIVATE) + + class Task(RunOnceTask): + a = luigi.Parameter(default="a") + b = luigi.Parameter(default="b", visibility=ParameterVisibility.HIDDEN) + c = luigi.Parameter(default="c", visibility=ParameterVisibility.PRIVATE) + d = luigi.Parameter(default="d", visibility=ParameterVisibility.PUBLIC) + + def requires(self): + return required_task + + def run(self): + yield dynamic_task + + dynamic_task = DynamicTask() + required_task = RequiredTask() + task = Task() + + w.add(task) + w.run() + + time.sleep(1) + task_deps = s.dep_graph(task_id=task.task_id) + required_task_deps = s.dep_graph(task_id=required_task.task_id) + dynamic_task_deps = s.dep_graph(task_id=dynamic_task.task_id) + + self.assertEqual('Task(a=a, d=d)', task_deps[task.task_id]['display_name']) + self.assertEqual('RequiredTask(required_public=required_param)', + required_task_deps[required_task.task_id]['display_name']) + self.assertEqual('DynamicTask(dynamic_public=dynamic_public)', + dynamic_task_deps[dynamic_task.task_id]['display_name']) + + self.assertEqual({'a': 'a', 'd': 'd'}, task_deps[task.task_id]['params']) + self.assertEqual({'required_public': 'required_param'}, + required_task_deps[required_task.task_id]['params']) + self.assertEqual({'dynamic_public': 'dynamic_public'}, + dynamic_task_deps[dynamic_task.task_id]['params']) + + def test_public_and_hidden_params(self): + s = luigi.scheduler.Scheduler(send_messages=True) + with luigi.worker.Worker(scheduler=s) as w: + class Task(RunOnceTask): + a = luigi.Parameter(default="a") + b = luigi.Parameter(default="b", visibility=ParameterVisibility.HIDDEN) + c = luigi.Parameter(default="c", visibility=ParameterVisibility.PRIVATE) + d = luigi.Parameter(default="d", visibility=ParameterVisibility.PUBLIC) + + task = Task() + + w.add(task) + w.run() + + time.sleep(1) + t = s._state.get_task(task.task_id) + self.assertEqual({'b': 'b'}, t.hidden_params) + self.assertEqual({'a': 'a', 'd': 'd'}, t.public_params) + self.assertEqual({'a': 0, 'b': 1, 'd': 0}, t.param_visibilities) + + +class Task(RunOnceTask): + a = luigi.Parameter(default="a") + b = luigi.Parameter(default="b", visibility=ParameterVisibility.HIDDEN) + c = luigi.Parameter(default="c", visibility=ParameterVisibility.PRIVATE) + d = luigi.Parameter(default="d", visibility=ParameterVisibility.PUBLIC) + + +class RemoteSchedulerParameterVisibilitiesTest(server_test.ServerTestBase): + def test_public_params(self): + task = Task() + luigi.build(tasks=[task], workers=2, scheduler_port=self.get_http_port()) + + time.sleep(1) + + response = self.fetch('/api/graph') + + body = response.body + decoded = body.decode('utf8').replace("'", '"') + data = json.loads(decoded) + + self.assertEqual({'a': 'a', 'd': 'd'}, data['response'][task.task_id]['params']) diff --git a/test/visible_parameters_test.py b/test/visible_parameters_test.py new file mode 100644 index 0000000000..e644aa7cb0 --- /dev/null +++ b/test/visible_parameters_test.py @@ -0,0 +1,95 @@ +import luigi +from luigi.parameter import ParameterVisibility +from helpers import unittest +import json + + +class TestTask1(luigi.Task): + param_one = luigi.Parameter(default='1', visibility=ParameterVisibility.HIDDEN, significant=True) + param_two = luigi.Parameter(default='2', significant=True) + param_three = luigi.Parameter(default='3', visibility=ParameterVisibility.PRIVATE, significant=True) + + +class TestTask2(luigi.Task): + param_one = luigi.Parameter(default='1', visibility=ParameterVisibility.PRIVATE) + param_two = luigi.Parameter(default='2', visibility=ParameterVisibility.PRIVATE) + param_three = luigi.Parameter(default='3', visibility=ParameterVisibility.PRIVATE) + + +class TestTask3(luigi.Task): + param_one = luigi.Parameter(default='1', visibility=ParameterVisibility.HIDDEN, significant=True) + param_two = luigi.Parameter(default='2', visibility=ParameterVisibility.HIDDEN, significant=False) + param_three = luigi.Parameter(default='3', visibility=ParameterVisibility.HIDDEN, significant=True) + + +class TestTask4(luigi.Task): + param_one = luigi.Parameter(default='1', visibility=ParameterVisibility.PUBLIC, significant=True) + param_two = luigi.Parameter(default='2', visibility=ParameterVisibility.PUBLIC, significant=False) + param_three = luigi.Parameter(default='3', visibility=ParameterVisibility.PUBLIC, significant=True) + + +class Test(unittest.TestCase): + def test_to_str_params(self): + task = TestTask1() + + self.assertEqual(task.to_str_params(), {'param_one': '1', 'param_two': '2'}) + + task = TestTask2() + + self.assertEqual(task.to_str_params(), {}) + + task = TestTask3() + + self.assertEqual(task.to_str_params(), {'param_one': '1', 'param_two': '2', 'param_three': '3'}) + + def test_all_public_equals_all_hidden(self): + hidden = TestTask3() + public = TestTask4() + + self.assertEqual(public.to_str_params(), hidden.to_str_params()) + + def test_all_public_equals_all_hidden_using_significant(self): + hidden = TestTask3() + public = TestTask4() + + self.assertEqual(public.to_str_params(only_significant=True), hidden.to_str_params(only_significant=True)) + + def test_private_params_and_significant(self): + task = TestTask1() + + self.assertEqual(task.to_str_params(), task.to_str_params(only_significant=True)) + + def test_param_visibilities(self): + task = TestTask1() + + self.assertEqual(task._get_param_visibilities(), {'param_one': 1, 'param_two': 0}) + + def test_incorrect_visibility_value(self): + class Task(luigi.Task): + a = luigi.Parameter(default='val', visibility=5) + + task = Task() + + self.assertEqual(task._get_param_visibilities(), {'a': 0}) + + def test_task_id_exclude_hidden_and_private_params(self): + task = TestTask1() + + self.assertEqual({'param_two': '2'}, task.to_str_params(only_public=True)) + + def test_json_dumps(self): + public = json.dumps(ParameterVisibility.PUBLIC.serialize()) + hidden = json.dumps(ParameterVisibility.HIDDEN.serialize()) + private = json.dumps(ParameterVisibility.PRIVATE.serialize()) + + self.assertEqual('0', public) + self.assertEqual('1', hidden) + self.assertEqual('2', private) + + public = json.loads(public) + hidden = json.loads(hidden) + private = json.loads(private) + + self.assertEqual(0, public) + self.assertEqual(1, hidden) + self.assertEqual(2, private) diff --git a/tox.ini b/tox.ini index 4876423fba..827000fb12 100644 --- a/tox.ini +++ b/tox.ini @@ -110,6 +110,7 @@ deps = boto3 Sphinx>=1.4.4,<1.5 sphinx_rtd_theme + enum34>1.1.0 commands = # build API docs sphinx-apidoc -o doc/api -T luigi --separate From de4629a4c7c562b53fee2045aabd84189f897593 Mon Sep 17 00:00:00 2001 From: Marcel R Date: Wed, 8 Aug 2018 15:50:15 +0200 Subject: [PATCH 10/13] Fix attribute forwarding for tasks with dynamic dependencies (#2478) * Add test case to check fowarded task attributes. * Fix attribute forwarding for tasks yielding deps. * Address review comments on code style and comments. --- luigi/worker.py | 24 +++++--- test/task_forwarded_attributes_test.py | 85 ++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 9 deletions(-) create mode 100644 test/task_forwarded_attributes_test.py diff --git a/luigi/worker.py b/luigi/worker.py index 54954765e2..6cdaff1884 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -37,6 +37,7 @@ import signal import subprocess import sys +import contextlib try: import Queue @@ -135,16 +136,8 @@ def __init__(self, task, worker_id, result_queue, status_reporter, self.check_unfulfilled_deps = check_unfulfilled_deps def _run_get_new_deps(self): - # forward some attributes before running - for reporter_attr, task_attr in six.iteritems(self.forward_reporter_attributes): - setattr(self.task, task_attr, getattr(self.status_reporter, reporter_attr)) - task_gen = self.task.run() - # reset attributes again - for reporter_attr, task_attr in six.iteritems(self.forward_reporter_attributes): - setattr(self.task, task_attr, None) - if not isinstance(task_gen, types.GeneratorType): return None @@ -202,7 +195,8 @@ def run(self): expl = 'Task is an external data dependency ' \ 'and data does not exist (yet?).' else: - new_deps = self._run_get_new_deps() + with self._forward_attributes(): + new_deps = self._run_get_new_deps() status = DONE if not new_deps else PENDING if new_deps: @@ -258,6 +252,18 @@ def terminate(self): except ImportError: return super(TaskProcess, self).terminate() + @contextlib.contextmanager + def _forward_attributes(self): + # forward configured attributes to the task + for reporter_attr, task_attr in six.iteritems(self.forward_reporter_attributes): + setattr(self.task, task_attr, getattr(self.status_reporter, reporter_attr)) + try: + yield self + finally: + # reset attributes again + for reporter_attr, task_attr in six.iteritems(self.forward_reporter_attributes): + setattr(self.task, task_attr, None) + # This code and the task_process_context config key currently feels a bit ad-hoc. # Discussion on generalizing it into a plugin system: https://github.com/spotify/luigi/issues/1897 diff --git a/test/task_forwarded_attributes_test.py b/test/task_forwarded_attributes_test.py new file mode 100644 index 0000000000..48ef319136 --- /dev/null +++ b/test/task_forwarded_attributes_test.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from helpers import LuigiTestCase, RunOnceTask + +import luigi +import luigi.scheduler +import luigi.worker + + +FORWARDED_ATTRIBUTES = set(luigi.worker.TaskProcess.forward_reporter_attributes.values()) + + +class NonYieldingTask(RunOnceTask): + + # need to accept messages in order for the "scheduler_message" attribute to be not None + accepts_messages = True + + def gather_forwarded_attributes(self): + """ + Returns a set of names of attributes that are forwarded by the TaskProcess and that are not + *None*. The tests in this file check if and which attributes are present at different times, + e.g. while running, or before and after a dynamic dependency was yielded. + """ + attrs = set() + for attr in FORWARDED_ATTRIBUTES: + if getattr(self, attr, None) is not None: + attrs.add(attr) + return attrs + + def run(self): + # store names of forwarded attributes which are only available within the run method + self.attributes_while_running = self.gather_forwarded_attributes() + + # invoke the run method of the RunOnceTask which marks this task as complete + RunOnceTask.run(self) + + +class YieldingTask(NonYieldingTask): + + def run(self): + # as TaskProcess._run_get_new_deps handles generators in a specific way, store names of + # forwarded attributes before and after yielding a dynamic dependency, so we can explicitely + # validate the attribute forwarding implementation + self.attributes_before_yield = self.gather_forwarded_attributes() + yield RunOnceTask() + self.attributes_after_yield = self.gather_forwarded_attributes() + + # invoke the run method of the RunOnceTask which marks this task as complete + RunOnceTask.run(self) + + +class TaskForwardedAttributesTest(LuigiTestCase): + + def run_task(self, task): + sch = luigi.scheduler.Scheduler() + with luigi.worker.Worker(scheduler=sch) as w: + w.add(task) + w.run() + return task + + def test_non_yielding_task(self): + task = self.run_task(NonYieldingTask()) + + self.assertEqual(task.attributes_while_running, FORWARDED_ATTRIBUTES) + + def test_yielding_task(self): + task = self.run_task(YieldingTask()) + + self.assertEqual(task.attributes_before_yield, FORWARDED_ATTRIBUTES) + self.assertEqual(task.attributes_after_yield, FORWARDED_ATTRIBUTES) From 14b154c2dd0b529fc52bf4d0e99f970c68f99d08 Mon Sep 17 00:00:00 2001 From: Uldis Barbans Date: Fri, 10 Aug 2018 17:12:37 +0200 Subject: [PATCH 11/13] Factor log_exceptions into a configuration parameter This simplifies away unnecessary propagation, cleans up argument lists. As far as I can tell, log_exceptions is True in all execution paths. It seems that 6dfe9af removed the last occasion of it being (spuriously) set to False. --- luigi/rpc.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/luigi/rpc.py b/luigi/rpc.py index a18bd58ded..9c49e0cd3b 100644 --- a/luigi/rpc.py +++ b/luigi/rpc.py @@ -116,6 +116,7 @@ def __init__(self, url='http://localhost:8082/', connect_timeout=None): self._rpc_retry_attempts = config.getint('core', 'rpc-retry-attempts', 3) self._rpc_retry_wait = config.getint('core', 'rpc-retry-wait', 30) + self._log_exceptions = config.getboolean('core', 'log-exceptions', True) if HAS_REQUESTS: self._fetcher = RequestsFetcher(requests.Session()) @@ -126,7 +127,7 @@ def _wait(self): logger.info("Wait for %d seconds" % self._rpc_retry_wait) time.sleep(self._rpc_retry_wait) - def _fetch(self, url_suffix, body, log_exceptions=True): + def _fetch(self, url_suffix, body): full_url = _urljoin(self._url, url_suffix) last_exception = None attempt = 0 @@ -140,7 +141,7 @@ def _fetch(self, url_suffix, body, log_exceptions=True): break except self._fetcher.raises as e: last_exception = e - if log_exceptions: + if self._log_exceptions: logger.warning("Failed connecting to remote scheduler %r", self._url, exc_info=True) continue @@ -152,11 +153,11 @@ def _fetch(self, url_suffix, body, log_exceptions=True): ) return response - def _request(self, url, data, log_exceptions=True, attempts=3, allow_null=True): + def _request(self, url, data, attempts=3, allow_null=True): body = {'data': json.dumps(data)} for _ in range(attempts): - page = self._fetch(url, body, log_exceptions) + page = self._fetch(url, body) response = json.loads(page)["response"] if allow_null or response is not None: return response From d2d3944c251f267f5fbb9586445433cec53866c3 Mon Sep 17 00:00:00 2001 From: Uldis Barbans Date: Fri, 10 Aug 2018 19:02:14 +0200 Subject: [PATCH 12/13] Rename to rpc_log_retries, and make it apply to all the logging involved --- luigi/rpc.py | 10 ++++++---- test/rpc_test.py | 34 ++++++++++++++++++++++++++++++++-- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/luigi/rpc.py b/luigi/rpc.py index 9c49e0cd3b..1c4580a46e 100644 --- a/luigi/rpc.py +++ b/luigi/rpc.py @@ -116,7 +116,7 @@ def __init__(self, url='http://localhost:8082/', connect_timeout=None): self._rpc_retry_attempts = config.getint('core', 'rpc-retry-attempts', 3) self._rpc_retry_wait = config.getint('core', 'rpc-retry-wait', 30) - self._log_exceptions = config.getboolean('core', 'log-exceptions', True) + self._rpc_log_retries = config.getboolean('core', 'rpc-log-retries', True) if HAS_REQUESTS: self._fetcher = RequestsFetcher(requests.Session()) @@ -124,7 +124,8 @@ def __init__(self, url='http://localhost:8082/', connect_timeout=None): self._fetcher = URLLibFetcher() def _wait(self): - logger.info("Wait for %d seconds" % self._rpc_retry_wait) + if self._rpc_log_retries: + logger.info("Wait for %d seconds" % self._rpc_retry_wait) time.sleep(self._rpc_retry_wait) def _fetch(self, url_suffix, body): @@ -134,14 +135,15 @@ def _fetch(self, url_suffix, body): while attempt < self._rpc_retry_attempts: attempt += 1 if last_exception: - logger.info("Retrying attempt %r of %r (max)" % (attempt, self._rpc_retry_attempts)) + if self._rpc_log_retries: + logger.info("Retrying attempt %r of %r (max)" % (attempt, self._rpc_retry_attempts)) self._wait() # wait for a bit and retry try: response = self._fetcher.fetch(full_url, body, self._connect_timeout) break except self._fetcher.raises as e: last_exception = e - if self._log_exceptions: + if self._rpc_log_retries: logger.warning("Failed connecting to remote scheduler %r", self._url, exc_info=True) continue diff --git a/test/rpc_test.py b/test/rpc_test.py index 044e1c14fe..cfb55a1ca1 100644 --- a/test/rpc_test.py +++ b/test/rpc_test.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from helpers import unittest +from helpers import unittest, with_config try: from unittest import mock except ImportError: @@ -52,7 +52,7 @@ def _wait(self): scheduler = ShorterWaitRemoteScheduler('http://zorg.com', 42) with mock.patch.object(scheduler, '_fetcher') as fetcher: - fetcher.raises = socket.timeout + fetcher.raises = socket.timeout, socket.gaierror fetcher.fetch.side_effect = fetcher_side_effect return scheduler.get_work("fake_worker") @@ -72,6 +72,36 @@ def test_retry_rpc_limited(self): fetch_results = [socket.timeout, socket.timeout, socket.timeout] self.assertRaises(luigi.rpc.RPCError, self.get_work, fetch_results) + @mock.patch('luigi.rpc.logger') + def test_log_rpc_retries_enabled(self, mock_logger): + """ + Tests that each retry of an RPC method is logged + """ + + fetch_results = [socket.timeout, socket.timeout, '{"response":{}}'] + self.get_work(fetch_results) + self.assertEqual([ + mock.call.warning('Failed connecting to remote scheduler %r', 'http://zorg.com', exc_info=True), + mock.call.info('Retrying attempt 2 of 3 (max)'), + mock.call.warning('Failed connecting to remote scheduler %r', 'http://zorg.com', exc_info=True), + mock.call.info('Retrying attempt 3 of 3 (max)'), + ], mock_logger.mock_calls) + + @with_config({'core': {'rpc-log-retries': 'false'}}) + @mock.patch('luigi.rpc.logger') + def test_log_rpc_retries_disabled(self, mock_logger): + """ + Tests that retries of an RPC method are not logged + """ + + fetch_results = [socket.timeout, socket.timeout, socket.gaierror] + try: + self.get_work(fetch_results) + self.fail("get_work should have thrown RPCError") + except luigi.rpc.RPCError as e: + self.assertTrue(isinstance(e.sub_exception, socket.gaierror)) + self.assertEqual([], mock_logger.mock_calls) + def test_get_work_retries_on_null(self): """ Tests that get_work will retry if the response is null From f604ce495a9a0e83780a7e1e04fe5fb63c46072c Mon Sep 17 00:00:00 2001 From: jorge jardines Date: Mon, 13 Aug 2018 12:59:40 +0200 Subject: [PATCH 13/13] S3 client refactor (#2482) * remove unneeded variable s3_key * Switch some methods to static in S3Client * Encapsulate deprecation check in a function * Remove duplicated method `create_bucket` * Remove unused mock * Enhance documentation * Use keyword arguments and docstring change --- luigi/contrib/s3.py | 51 ++++++++++++++++++++------------- test/contrib/s3_test.py | 63 +++++++++++++++++++---------------------- 2 files changed, 60 insertions(+), 54 deletions(-) diff --git a/luigi/contrib/s3.py b/luigi/contrib/s3.py index a457808bc3..56fc655b0d 100644 --- a/luigi/contrib/s3.py +++ b/luigi/contrib/s3.py @@ -226,6 +226,8 @@ def remove(self, path, recursive=True): def move(self, source_path, destination_path, **kwargs): """ Rename/move an object from one S3 location to another. + :param source_path: The `s3://` path of the directory or key to copy from + :param destination_path: The `s3://` path of the directory or key to copy to :param kwargs: Keyword arguments are passed to the boto3 function `copy` """ self.copy(source_path, destination_path, **kwargs) @@ -243,12 +245,11 @@ def get_key(self, path): def put(self, local_path, destination_s3_path, **kwargs): """ Put an object stored locally to an S3 path. - + :param local_path: Path to source local file + :param destination_s3_path: URL for target S3 location :param kwargs: Keyword arguments are passed to the boto function `put_object` """ - if 'encrypt_key' in kwargs: - raise DeprecatedBotoClientException( - 'encrypt_key deprecated in boto3. Please refer to boto3 documentation for encryption details.') + self._check_deprecated_argument(**kwargs) # put the file self.put_multipart(local_path, destination_s3_path, **kwargs) @@ -256,11 +257,11 @@ def put(self, local_path, destination_s3_path, **kwargs): def put_string(self, content, destination_s3_path, **kwargs): """ Put a string to an S3 path. + :param content: Data str + :param destination_s3_path: URL for target S3 location :param kwargs: Keyword arguments are passed to the boto3 function `put_object` """ - if 'encrypt_key' in kwargs: - raise DeprecatedBotoClientException( - 'encrypt_key deprecated in boto3. Please refer to boto3 documentation for encryption details.') + self._check_deprecated_argument(**kwargs) (bucket, key) = self._path_to_bucket_and_key(destination_s3_path) # validate the bucket @@ -279,9 +280,7 @@ def put_multipart(self, local_path, destination_s3_path, part_size=8388608, **kw :param part_size: Part size in bytes. Default: 8388608 (8MB) :param kwargs: Keyword arguments are passed to the boto function `upload_fileobj` as ExtraArgs """ - if 'encrypt_key' in kwargs: - raise DeprecatedBotoClientException( - 'encrypt_key deprecated in boto3. Please refer to boto3 documentation for encryption details.') + self._check_deprecated_argument(**kwargs) from boto3.s3.transfer import TransferConfig # default part size for boto3 is 8Mb, changing it to fit part_size @@ -446,6 +445,7 @@ def listdir(self, path, start_time=None, end_time=None, return_key=False): """ Get an iterable with S3 folder contents. Iterable contains paths relative to queried path. + :param path: URL for target S3 location :param start_time: Optional argument to list files with modified (offset aware) datetime after start_time :param end_time: Optional argument to list files with modified (offset aware) datetime before end_time :param return_key: Optional argument, when set to True will return boto3's ObjectSummary (instead of the filename) @@ -482,7 +482,8 @@ def list(self, path, start_time=None, end_time=None, return_key=False): # backw else: yield item[key_path_len:] - def _get_s3_config(self, key=None): + @staticmethod + def _get_s3_config(key=None): defaults = dict(configuration.get_config().defaults()) try: config = dict(configuration.get_config().items('s3')) @@ -500,17 +501,30 @@ def _get_s3_config(self, key=None): return section_only - def _path_to_bucket_and_key(self, path): + @staticmethod + def _path_to_bucket_and_key(path): (scheme, netloc, path, query, fragment) = urlsplit(path) path_without_initial_slash = path[1:] return netloc, path_without_initial_slash - def _is_root(self, key): + @staticmethod + def _is_root(key): return (len(key) == 0) or (key == '/') - def _add_path_delimiter(self, key): + @staticmethod + def _add_path_delimiter(key): return key if key[-1:] == '/' or key == '' else key + '/' + @staticmethod + def _check_deprecated_argument(**kwargs): + """ + If `encrypt_key` is part of the arguments raise an exception + :return: None + """ + if 'encrypt_key' in kwargs: + raise DeprecatedBotoClientException( + 'encrypt_key deprecated in boto3. Please refer to boto3 documentation for encryption details.') + def _validate_bucket(self, bucket_name): exists = True @@ -525,18 +539,15 @@ def _validate_bucket(self, bucket_name): return exists def _exists(self, bucket, key): - s3_key = False try: self.s3.Object(bucket, key).load() except botocore.exceptions.ClientError as e: if e.response['Error']['Code'] in ['NoSuchKey', '404']: - s3_key = False + return False else: raise - else: - s3_key = True - if s3_key: - return True + + return True class AtomicS3File(AtomicLocalFile): diff --git a/test/contrib/s3_test.py b/test/contrib/s3_test.py index 93b24c4c69..97ea5cfc2e 100644 --- a/test/contrib/s3_test.py +++ b/test/contrib/s3_test.py @@ -41,6 +41,13 @@ AWS_SECRET_KEY = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" +def create_bucket(): + conn = boto3.resource('s3', region_name='us-east-1') + # We need to create the bucket since this is all in Moto's 'virtual' AWS account + conn.create_bucket(Bucket='mybucket') + return conn + + class TestS3Target(unittest.TestCase, FileSystemTargetTestMixin): def setUp(self): @@ -57,20 +64,14 @@ def setUp(self): self.mock_s3.start() self.addCleanup(self.mock_s3.stop) - def create_bucket(self): - conn = boto3.resource('s3', region_name='us-east-1') - # We need to create the bucket since this is all in Moto's 'virtual' AWS account - conn.create_bucket(Bucket='mybucket') - return conn - def create_target(self, format=None, **kwargs): client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) - self.create_bucket() + create_bucket() return S3Target('s3://mybucket/test_file', client=client, format=format, **kwargs) def test_read(self): client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) - self.create_bucket() + create_bucket() client.put(self.tempFilePath, 's3://mybucket/tempfile') t = S3Target('s3://mybucket/tempfile', client=client) read_file = t.open() @@ -99,7 +100,7 @@ def test_read_iterator_long(self): tempf.close() client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) - self.create_bucket() + create_bucket() client.put(temppath, 's3://mybucket/largetempfile') t = S3Target('s3://mybucket/largetempfile', client=client) with t.open() as read_file: @@ -167,33 +168,27 @@ def test_init_with_config_and_roles(self, sts_mock, s3_mock): sts_mock.client.assume_role.called_with( RoleArn='role', RoleSessionName='name') - def create_bucket(self): - conn = boto3.resource('s3', region_name='us-east-1') - # We need to create the bucket since this is all in Moto's 'virtual' AWS account - conn.create_bucket(Bucket='mybucket') - return conn - def test_put(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, 's3://mybucket/putMe') self.assertTrue(s3_client.exists('s3://mybucket/putMe')) def test_put_sse_deprecated(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) with self.assertRaises(DeprecatedBotoClientException): s3_client.put(self.tempFilePath, 's3://mybucket/putMe', encrypt_key=True) def test_put_string(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_string("SOMESTRING", 's3://mybucket/putString') self.assertTrue(s3_client.exists('s3://mybucket/putString')) def test_put_string_sse_deprecated(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) with self.assertRaises(DeprecatedBotoClientException): s3_client.put('SOMESTRING', @@ -243,7 +238,7 @@ def test_put_multipart_less_than_split_size(self): self._run_multipart_test(part_size, file_size) def test_exists(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) self.assertTrue(s3_client.exists('s3://mybucket/')) @@ -266,7 +261,7 @@ def test_exists(self): self.assertFalse(s3_client.exists('s3://mybucket/tempdir')) def test_get(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, 's3://mybucket/putMe') @@ -280,7 +275,7 @@ def test_get(self): tmp_file.close() def test_get_as_string(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, 's3://mybucket/putMe') @@ -289,14 +284,14 @@ def test_get_as_string(self): self.assertEquals(contents, self.tempFileContents.decode("utf-8")) def test_get_key(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, 's3://mybucket/key_to_find') self.assertTrue(s3_client.get_key('s3://mybucket/key_to_find').key) self.assertFalse(s3_client.get_key('s3://mybucket/does_not_exist')) def test_isdir(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) self.assertTrue(s3_client.isdir('s3://mybucket')) @@ -310,7 +305,7 @@ def test_isdir(self): self.assertFalse(s3_client.isdir('s3://mybucket/key')) def test_mkdir(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) self.assertTrue(s3_client.isdir('s3://mybucket')) s3_client.mkdir('s3://mybucket') @@ -324,7 +319,7 @@ def test_mkdir(self): self.assertFalse(s3_client.isdir('s3://mybucket/dir/foo/bar')) def test_listdir(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_string("", 's3://mybucket/hello/frank') @@ -334,7 +329,7 @@ def test_listdir(self): list(s3_client.listdir('s3://mybucket/hello'))) def test_list(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_string("", 's3://mybucket/hello/frank') @@ -344,7 +339,7 @@ def test_list(self): list(s3_client.list('s3://mybucket/hello'))) def test_listdir_key(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_string("", 's3://mybucket/hello/frank') @@ -354,7 +349,7 @@ def test_listdir_key(self): [s3_client.exists('s3://' + x.bucket_name + '/' + x.key) for x in s3_client.listdir('s3://mybucket/hello', return_key=True)]) def test_list_key(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_string("", 's3://mybucket/hello/frank') @@ -364,7 +359,7 @@ def test_list_key(self): [s3_client.exists('s3://' + x.bucket_name + '/' + x.key) for x in s3_client.listdir('s3://mybucket/hello', return_key=True)]) def test_remove(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) self.assertRaises( @@ -436,7 +431,7 @@ def test_copy_dir(self): """ Test copying 20 files from one folder to another """ - self.create_bucket() + create_bucket() n = 20 copy_part_size = (1024 ** 2) * 5 @@ -468,7 +463,7 @@ def test_copy_dir(self): @mock_s3 def _run_multipart_copy_test(self, put_method): - self.create_bucket() + create_bucket() # Run the method to put the file into s3 into the first place put_method() @@ -493,7 +488,7 @@ def _run_multipart_copy_test(self, put_method): @mock_s3 def _run_copy_test(self, put_method): - self.create_bucket() + create_bucket() # Run the method to put the file into s3 into the first place put_method() @@ -514,7 +509,7 @@ def _run_copy_test(self, put_method): @mock_s3 def _run_multipart_test(self, part_size, file_size, **kwargs): - self.create_bucket() + create_bucket() file_contents = b"a" * file_size s3_path = 's3://mybucket/putMe'