Skip to content

Commit

Permalink
Add a visiblity level for luigi.Parameters (#2278)
Browse files Browse the repository at this point in the history
See the docs for usage.
  • Loading branch information
nryanov authored and Tarrasch committed Aug 8, 2018
1 parent bd55c28 commit c9ed761
Show file tree
Hide file tree
Showing 10 changed files with 319 additions and 17 deletions.
19 changes: 19 additions & 0 deletions doc/parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,25 @@ are not the same instance:
>>> hash(c) == hash(d)
True
Parameter visibility
^^^^^^^^^^^^^^^^^^^^

Using :class:`~luigi.parameter.ParameterVisibility` you can configure parameter visibility. By default, all
parameters are public, but you can also set them hidden or private.

.. code:: python
>>> import luigi
>>> from luigi.parameter import ParameterVisibility
>>> luigi.Parameter(visibility=ParameterVisibility.PRIVATE)
``ParameterVisibility.PUBLIC`` (default) - visible everywhere

``ParameterVisibility.HIDDEN`` - ignored in WEB-view, but saved into database if save db_history is true

``ParameterVisibility.PRIVATE`` - visible only inside task.

Parameter types
^^^^^^^^^^^^^^^

Expand Down
42 changes: 37 additions & 5 deletions luigi/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import abc
import datetime
import warnings
from enum import IntEnum
import json
from json import JSONEncoder
from collections import OrderedDict, Mapping
Expand All @@ -40,10 +41,26 @@
from luigi import configuration
from luigi.cmdline_parser import CmdlineParser


_no_value = object()


class ParameterVisibility(IntEnum):
"""
Possible values for the parameter visibility option. Public is the default.
See :doc:`/parameters` for more info.
"""
PUBLIC = 0
HIDDEN = 1
PRIVATE = 2

@classmethod
def has_value(cls, value):
return any(value == item.value for item in cls)

def serialize(self):
return self.value


class ParameterException(Exception):
"""
Base exception.
Expand Down Expand Up @@ -113,7 +130,8 @@ def run(self):
_counter = 0 # non-atomically increasing counter used for ordering parameters.

def __init__(self, default=_no_value, is_global=False, significant=True, description=None,
config_path=None, positional=True, always_in_help=False, batch_method=None):
config_path=None, positional=True, always_in_help=False, batch_method=None,
visibility=ParameterVisibility.PUBLIC):
"""
:param default: the default value for this parameter. This should match the type of the
Parameter, i.e. ``datetime.date`` for ``DateParameter`` or ``int`` for
Expand All @@ -140,6 +158,10 @@ def __init__(self, default=_no_value, is_global=False, significant=True, descrip
parameter values into a single value. Used
when receiving batched parameter lists from
the scheduler. See :ref:`batch_method`
:param visibility: A Parameter whose value is a :py:class:`~luigi.parameter.ParameterVisibility`.
Default value is ParameterVisibility.PUBLIC
"""
self._default = default
self._batch_method = batch_method
Expand All @@ -150,6 +172,7 @@ def __init__(self, default=_no_value, is_global=False, significant=True, descrip
positional = False
self.significant = significant # Whether different values for this parameter will differentiate otherwise equal tasks
self.positional = positional
self.visibility = visibility if ParameterVisibility.has_value(visibility) else ParameterVisibility.PUBLIC

self.description = description
self.always_in_help = always_in_help
Expand Down Expand Up @@ -195,11 +218,11 @@ def _value_iterator(self, task_name, param_name):
yield (self._get_value_from_config(task_name, param_name), None)
yield (self._get_value_from_config(task_name, param_name.replace('_', '-')),
'Configuration [{}] {} (with dashes) should be avoided. Please use underscores.'.format(
task_name, param_name))
task_name, param_name))
if self._config_path:
yield (self._get_value_from_config(self._config_path['section'], self._config_path['name']),
'The use of the configuration [{}] {} is deprecated. Please use [{}] {}'.format(
self._config_path['section'], self._config_path['name'], task_name, param_name))
self._config_path['section'], self._config_path['name'], task_name, param_name))
yield (self._default, None)

def has_task_value(self, task_name, param_name):
Expand Down Expand Up @@ -689,6 +712,7 @@ class DateIntervalParameter(Parameter):
(eg. "2015-W35"). In addition, it also supports arbitrary date intervals
provided as two dates separated with a dash (eg. "2015-11-04-2015-12-04").
"""

def parse(self, s):
"""
Parses a :py:class:`~luigi.date_interval.DateInterval` from the input.
Expand Down Expand Up @@ -740,8 +764,10 @@ def field(key):

def optional_field(key):
return "(%s)?" % field(key)

# A little loose: ISO 8601 does not allow weeks in combination with other fields, but this regex does (as does python timedelta)
regex = "P(%s|%s(T%s)?)" % (field("weeks"), optional_field("days"), "".join([optional_field(key) for key in ["hours", "minutes", "seconds"]]))
regex = "P(%s|%s(T%s)?)" % (field("weeks"), optional_field("days"),
"".join([optional_field(key) for key in ["hours", "minutes", "seconds"]]))
return self._apply_regex(regex, input)

def _parseSimple(self, input):
Expand Down Expand Up @@ -905,6 +931,7 @@ class _DictParamEncoder(JSONEncoder):
"""
JSON encoder for :py:class:`~DictParameter`, which makes :py:class:`~_FrozenOrderedDict` JSON serializable.
"""

def default(self, obj):
if isinstance(obj, _FrozenOrderedDict):
return obj.get_wrapped()
Expand Down Expand Up @@ -943,6 +970,7 @@ def run(self):
tags, that are dynamically constructed outside Luigi), or you have a complex parameter containing logically related
values (like a database connection config).
"""

def normalize(self, value):
"""
Ensure that dictionary parameter is converted to a _FrozenOrderedDict so it can be hashed.
Expand Down Expand Up @@ -996,6 +1024,7 @@ def run(self):
$ luigi --module my_tasks MyTask --grades '[100,70]'
"""

def normalize(self, x):
"""
Ensure that struct is recursively converted to a tuple so it can be hashed.
Expand Down Expand Up @@ -1053,6 +1082,7 @@ def run(self):
$ luigi --module my_tasks MyTask --book_locations '((12,3),(4,15),(52,1))'
"""

def parse(self, x):
"""
Parse an individual value from the input.
Expand Down Expand Up @@ -1100,6 +1130,7 @@ class MyTask(luigi.Task):
$ luigi --module my_tasks MyTask --my-param-1 -3 --my-param-2 -2
"""

def __init__(self, left_op=operator.le, right_op=operator.lt, *args, **kwargs):
"""
:param function var_type: The type of the input variable, e.g. int or float.
Expand Down Expand Up @@ -1178,6 +1209,7 @@ class MyTask(luigi.Task):
same type and transparency of parameter value on the command line is
desired.
"""

def __init__(self, var_type=str, *args, **kwargs):
"""
:param function var_type: The type of the input variable, e.g. str, int,
Expand Down
30 changes: 22 additions & 8 deletions luigi/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from luigi.task_status import DISABLED, DONE, FAILED, PENDING, RUNNING, SUSPENDED, UNKNOWN, \
BATCH_RUNNING
from luigi.task import Config
from luigi.parameter import ParameterVisibility

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -280,7 +281,7 @@ def __eq__(self, other):

class Task(object):
def __init__(self, task_id, status, deps, resources=None, priority=0, family='', module=None,
params=None, accepts_messages=False, tracking_url=None, status_message=None,
params=None, param_visibilities=None, accepts_messages=False, tracking_url=None, status_message=None,
progress_percentage=None, retry_policy='notoptional'):
self.id = task_id
self.stakeholders = set() # workers ids that are somehow related to this task (i.e. don't prune while any of these workers are still active)
Expand All @@ -301,8 +302,11 @@ def __init__(self, task_id, status, deps, resources=None, priority=0, family='',
self.resources = _get_default(resources, {})
self.family = family
self.module = module
self.params = _get_default(params, {})

self.param_visibilities = _get_default(param_visibilities, {})
self.params = {}
self.public_params = {}
self.hidden_params = {}
self.set_params(params)
self.accepts_messages = accepts_messages
self.retry_policy = retry_policy
self.failures = Failures(self.retry_policy.disable_window)
Expand All @@ -318,6 +322,13 @@ def __init__(self, task_id, status, deps, resources=None, priority=0, family='',
def __repr__(self):
return "Task(%r)" % vars(self)

def set_params(self, params):
self.params = _get_default(params, {})
self.public_params = {key: value for key, value in self.params.items() if
self.param_visibilities.get(key, ParameterVisibility.PUBLIC) == ParameterVisibility.PUBLIC}
self.hidden_params = {key: value for key, value in self.params.items() if
self.param_visibilities.get(key, ParameterVisibility.PUBLIC) == ParameterVisibility.HIDDEN}

# TODO(2017-08-10) replace this function with direct calls to batchable
# this only exists for backward compatibility
def is_batchable(self):
Expand All @@ -343,7 +354,7 @@ def has_excessive_failures(self):

@property
def pretty_id(self):
param_str = ', '.join(u'{}={}'.format(key, value) for key, value in sorted(self.params.items()))
param_str = ', '.join(u'{}={}'.format(key, value) for key, value in sorted(self.public_params.items()))
return u'{}({})'.format(self.family, param_str)


Expand Down Expand Up @@ -778,7 +789,7 @@ def forgive_failures(self, task_id=None):
@rpc_method()
def add_task(self, task_id=None, status=PENDING, runnable=True,
deps=None, new_deps=None, expl=None, resources=None,
priority=0, family='', module=None, params=None, accepts_messages=False,
priority=0, family='', module=None, params=None, param_visibilities=None, accepts_messages=False,
assistant=False, tracking_url=None, worker=None, batchable=None,
batch_id=None, retry_policy_dict=None, owners=None, **kwargs):
"""
Expand All @@ -802,7 +813,7 @@ def add_task(self, task_id=None, status=PENDING, runnable=True,
if worker.enabled:
_default_task = self._make_task(
task_id=task_id, status=PENDING, deps=deps, resources=resources,
priority=priority, family=family, module=module, params=params,
priority=priority, family=family, module=module, params=params, param_visibilities=param_visibilities,
)
else:
_default_task = None
Expand All @@ -817,8 +828,10 @@ def add_task(self, task_id=None, status=PENDING, runnable=True,
task.family = family
if not getattr(task, 'module', None):
task.module = module
if not task.param_visibilities:
task.param_visibilities = _get_default(param_visibilities, {})
if not task.params:
task.params = _get_default(params, {})
task.set_params(params)

if batch_id is not None:
task.batch_id = batch_id
Expand Down Expand Up @@ -1272,6 +1285,7 @@ def _upstream_status(self, task_id, upstream_status_table):

def _serialize_task(self, task_id, include_deps=True, deps=None):
task = self._state.get_task(task_id)

ret = {
'display_name': task.pretty_id,
'status': task.status,
Expand All @@ -1280,7 +1294,7 @@ def _serialize_task(self, task_id, include_deps=True, deps=None):
'time_running': getattr(task, "time_running", None),
'start_time': task.time,
'last_updated': getattr(task, "updated", task.time),
'params': task.params,
'params': task.public_params,
'name': task.family,
'priority': task.priority,
'resources': task.resources,
Expand Down
18 changes: 15 additions & 3 deletions luigi/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

from luigi import parameter
from luigi.task_register import Register
from luigi.parameter import ParameterVisibility

Parameter = parameter.Parameter
logger = logging.getLogger('luigi-interface')
Expand Down Expand Up @@ -441,7 +442,7 @@ def __init__(self, *args, **kwargs):
self.param_kwargs = dict(param_values)

self._warn_on_wrong_param_types()
self.task_id = task_id_str(self.get_task_family(), self.to_str_params(only_significant=True))
self.task_id = task_id_str(self.get_task_family(), self.to_str_params(only_significant=True, only_public=True))
self.__hash = hash(self.task_id)

self.set_tracking_url = None
Expand Down Expand Up @@ -482,18 +483,29 @@ def from_str_params(cls, params_str):

return cls(**kwargs)

def to_str_params(self, only_significant=False):
def to_str_params(self, only_significant=False, only_public=False):
"""
Convert all parameters to a str->str hash.
"""
params_str = {}
params = dict(self.get_params())
for param_name, param_value in six.iteritems(self.param_kwargs):
if (not only_significant) or params[param_name].significant:
if (((not only_significant) or params[param_name].significant)
and ((not only_public) or params[param_name].visibility == ParameterVisibility.PUBLIC)
and params[param_name].visibility != ParameterVisibility.PRIVATE):
params_str[param_name] = params[param_name].serialize(param_value)

return params_str

def _get_param_visibilities(self):
param_visibilities = {}
params = dict(self.get_params())
for param_name, param_value in six.iteritems(self.param_kwargs):
if params[param_name].visibility != ParameterVisibility.PRIVATE:
param_visibilities[param_name] = params[param_name].visibility.serialize()

return param_visibilities

def clone(self, cls=None, **kwargs):
"""
Creates a new instance from an existing instance where some of the args have changed.
Expand Down
3 changes: 3 additions & 0 deletions luigi/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,9 @@ def _add_task(self, *args, **kwargs):
for batch_task in self._batch_running_tasks.pop(task_id):
self._add_task_history.append((batch_task, status, True))

if task and kwargs.get('params'):
kwargs['param_visibilities'] = task._get_param_visibilities()

self._scheduler.add_task(*args, **kwargs)

logger.info('Informed scheduler that task %s has status %s', task_id, status)
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# the License.

import os
import sys

from setuptools import setup

Expand Down Expand Up @@ -48,6 +49,9 @@ def get_static_files(path):
install_requires.remove('python-daemon<3.0')
install_requires.append('sphinx>=1.4.4') # Value mirrored in doc/conf.py

if sys.version_info < (3, 4):
install_requires.append('enum34>1.1.0')

setup(
name='luigi',
version='2.7.6',
Expand Down
4 changes: 3 additions & 1 deletion test/db_task_history_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from luigi.db_task_history import DbTaskHistory
from luigi.task_status import DONE, PENDING, RUNNING
import luigi.scheduler
from luigi.parameter import ParameterVisibility


class DummyTask(luigi.Task):
Expand All @@ -32,7 +33,8 @@ class DummyTask(luigi.Task):

class ParamTask(luigi.Task):
param1 = luigi.Parameter()
param2 = luigi.IntParameter()
param2 = luigi.IntParameter(visibility=ParameterVisibility.HIDDEN)
param3 = luigi.Parameter(default="empty", visibility=ParameterVisibility.PRIVATE)


class DbTaskHistoryTest(unittest.TestCase):
Expand Down
Loading

0 comments on commit c9ed761

Please sign in to comment.