Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify and improve the error argument of Rephraser #54

Merged
merged 4 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions cloup/constraints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
AcceptBetween,
And,
Constraint,
ErrorRephraser,
HelpRephraser,
Operator,
Or,
Rephraser,
Expand All @@ -22,11 +24,6 @@
mutually_exclusive,
require_all,
)
from ._support import (
BoundConstraintSpec,
ConstraintMixin,
constraint,
constrained_params,
)
from ._support import (BoundConstraintSpec, ConstraintMixin, constrained_params, constraint)
from .conditions import AllSet, AnySet, Equal, IsSet, Not
from .exceptions import ConstraintViolated, UnsatisfiableConstraint
3 changes: 2 additions & 1 deletion cloup/constraints/_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def check_values(self, params: Sequence[Parameter], ctx: Context):
except ConstraintViolated as err:
desc = (condition.description(ctx) if condition_is_true
else condition.negated_description(ctx))
raise ConstraintViolated(f"when {desc}, {err}", ctx=ctx)
raise ConstraintViolated(
f"when {desc}, {err}", ctx=ctx, constraint=self, params=params)

def __repr__(self) -> str:
if self._else:
Expand Down
54 changes: 39 additions & 15 deletions cloup/constraints/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

Op = TypeVar('Op', bound='Operator')
HelpRephraser = Callable[[Context, 'Constraint'], str]
ErrorRephraser = Callable[[Context, 'Constraint', Sequence[Parameter]], str]
ErrorRephraser = Callable[[ConstraintViolated], str]


class Constraint(abc.ABC):
Expand Down Expand Up @@ -146,6 +146,23 @@ def rephrased(
help: Union[None, str, HelpRephraser] = None,
error: Union[None, str, ErrorRephraser] = None,
) -> 'Rephraser':
"""
Overrides the help string and/or the error message of this constraint
wrapping it with a :class:`Rephraser`.

:param help:
if provided, overrides the help string of this constraint. It can be
a string or a function ``(ctx: Context, constr: Constraint) -> str``.
:param error:
if provided, overrides the error message of this constraint.
It can be:

- a template string for the ``format`` built-in function
- or a function ``(err: ConstraintViolated) -> str``; note that
a :class:`ConstraintViolated` error has fields for ``ctx``,
``constraint`` and ``params``, so it's a complete description
of what happened.
"""
return Rephraser(self, help=help, error=error)

def hidden(self) -> 'Rephraser':
Expand Down Expand Up @@ -224,7 +241,9 @@ def check_values(self, params: Sequence[Parameter], ctx: Context):
return c.check_values(params, ctx)
except ConstraintViolated:
pass
raise ConstraintViolated.default(params, self.help(ctx), ctx=ctx)
raise ConstraintViolated.default(
self.help(ctx), ctx=ctx, constraint=self, params=params
)

def __or__(self, other) -> 'Or':
if isinstance(other, Or):
Expand All @@ -236,8 +255,8 @@ class Rephraser(Constraint):
"""A Constraint decorator that can override the help and/or the error
message of the wrapped constraint.

This is useful also for defining new constraints.
See also :class:`WrapperConstraint`.
.. seealso::
:class:`WrapperConstraint`.
"""

def __init__(
Expand All @@ -259,15 +278,16 @@ def help(self, ctx: Context) -> str:
else:
return self._help(ctx, self._constraint)

def _get_rephrased_error(
self, ctx: Context, params: Sequence[Parameter]
) -> Optional[str]:
def _get_rephrased_error(self, err: ConstraintViolated) -> Optional[str]:
if self._error is None:
return None
elif isinstance(self._error, str):
return self._error.format(param_list=format_param_list(params))
return self._error.format(
error=str(err),
param_list=format_param_list(err.params),
)
else:
return self._error(ctx, self._constraint, params)
return self._error(err)

def check_consistency(self, params: Sequence[Parameter]) -> None:
try:
Expand All @@ -279,10 +299,11 @@ def check_consistency(self, params: Sequence[Parameter]) -> None:
def check_values(self, params: Sequence[Parameter], ctx: Context):
try:
return self._constraint.check_values(params, ctx)
except ConstraintViolated:
rephrased_error = self._get_rephrased_error(ctx, params)
except ConstraintViolated as err:
rephrased_error = self._get_rephrased_error(err)
if rephrased_error:
raise ConstraintViolated(rephrased_error, ctx=ctx)
raise ConstraintViolated(
rephrased_error, ctx=ctx, constraint=self, params=params)
raise

def __repr__(self):
Expand Down Expand Up @@ -341,6 +362,8 @@ def check_values(self, params: Sequence[Parameter], ctx: Context):
many=f"the following parameters are required:\n"
f"{format_param_list(unset_params)}"),
ctx=ctx,
constraint=self,
params=params,
)


Expand Down Expand Up @@ -370,7 +393,7 @@ def check_values(self, params: Sequence[Parameter], ctx: Context):
raise ConstraintViolated(
f"at least {n} of the following parameters must be set:\n"
f"{format_param_list(params)}",
ctx=ctx
ctx=ctx, constraint=self, params=params,
)

def __repr__(self):
Expand Down Expand Up @@ -400,7 +423,7 @@ def check_values(self, params: Sequence[Parameter], ctx: Context):
raise ConstraintViolated(
f"no more than {n} of the following parameters can be set:\n"
f"{format_param_list(params)}",
ctx=ctx,
ctx=ctx, constraint=self, params=params,
)

def __repr__(self):
Expand All @@ -427,7 +450,8 @@ def check_values(self, params: Sequence[Parameter], ctx: Context):
zero='none of the following parameters must be set:\n',
many=f'exactly {n} of the following parameters must be set:\n'
) + format_param_list(params)
raise ConstraintViolated(reason, ctx=ctx)
raise ConstraintViolated(
reason, ctx=ctx, constraint=self, params=params)


class AcceptBetween(WrapperConstraint):
Expand Down
20 changes: 16 additions & 4 deletions cloup/constraints/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Optional, TYPE_CHECKING
from typing import Iterable, Sequence, TYPE_CHECKING

import click
from click import Context, Parameter
Expand All @@ -18,16 +18,28 @@ def default_constraint_error(params: Iterable[Parameter], desc: str) -> str:

class ConstraintViolated(click.UsageError):
def __init__(
self, message: str, ctx: Optional[Context] = None
self, message: str,
ctx: Context,
constraint: 'Constraint',
params: Sequence[click.Parameter]
):
super().__init__(message, ctx=ctx)
self.ctx = ctx
self.constraint = constraint
self.params = params

@classmethod
def default(
cls, params: Iterable[Parameter], desc: str, ctx: Optional[Context] = None
cls,
desc: str,
ctx: Context,
constraint: 'Constraint',
params: Sequence[Parameter],
) -> 'ConstraintViolated':
return ConstraintViolated(
default_constraint_error(params, desc), ctx=ctx)
default_constraint_error(params, desc),
ctx=ctx, constraint=constraint, params=params,
)


class UnsatisfiableConstraint(Exception):
Expand Down
28 changes: 23 additions & 5 deletions tests/constraints/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def check_consistency(self, params: Sequence[Parameter]) -> None:
def check_values(self, params: Sequence[Parameter], ctx: Context):
self.check_values_calls.append(dict(params=params, ctx=ctx))
if not self.satisfied:
raise ConstraintViolated(self.error, ctx=ctx)
raise ConstraintViolated(self.error, ctx=ctx, constraint=self, params=params)


class TestBaseConstraint:
Expand Down Expand Up @@ -297,21 +297,39 @@ def test_help_override_with_function(self, dummy_ctx):

def test_error_is_overridden_passing_string(self):
fake_ctx = make_fake_context(make_options('abcd'))
wrapped = FakeConstraint(satisfied=False)
wrapped = FakeConstraint(satisfied=False, error='__error__')
rephrased = Rephraser(wrapped, error='error:\n{param_list}')
with pytest.raises(ConstraintViolated) as exc_info:
rephrased.check(['a', 'b'], ctx=fake_ctx)
assert exc_info.value.message == 'error:\n --a\n --b\n'

def test_error_template_key(self):
fake_ctx = make_fake_context(make_options('abcd'))
wrapped = FakeConstraint(satisfied=False, error='__error__')
rephrased = Rephraser(wrapped, error='{error}\nExtra info here.')
with pytest.raises(ConstraintViolated) as exc_info:
rephrased.check(['a', 'b'], ctx=fake_ctx)
assert str(exc_info.value) == '__error__\nExtra info here.'

def test_error_is_overridden_passing_function(self):
params = make_options('abc')
fake_ctx = make_fake_context(params)
wrapped = FakeConstraint(satisfied=False)
get_error = Mock(return_value='rephrased error')
rephrased = Rephraser(wrapped, error=get_error)
error_rephraser_mock = Mock(return_value='rephrased error')
rephrased = Rephraser(wrapped, error=error_rephraser_mock)
with pytest.raises(ConstraintViolated, match='rephrased error'):
rephrased.check(params, ctx=fake_ctx)
get_error.assert_called_once_with(fake_ctx, wrapped, params)

# Check the function is called with a single argument of type ConstraintViolated
error_rephraser_mock.assert_called_once()
args = error_rephraser_mock.call_args[0]
assert len(args) == 1
error = args[0]
assert isinstance(error, ConstraintViolated)
# Check the error has all fields set
assert isinstance(error.ctx, Context)
assert isinstance(error.constraint, Constraint)
assert len(error.params) == 3

def test_check_consistency_raises_if_wrapped_constraint_raises(self):
constraint = FakeConstraint(consistent=True)
Expand Down