Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Allow configuring pretty errors when creating the Typer instance #416

Merged
merged 3 commits into from
Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/assets/type_error_no_rich_short_disable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import typer
import typer.main

typer.main.rich = None


app = typer.Typer(pretty_errors_short=False)


@app.command()
def main(name: str = "morty"):
print(name + 3)


if __name__ == "__main__":
app()
12 changes: 12 additions & 0 deletions tests/assets/type_error_rich_pretty_disable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import typer

app = typer.Typer(pretty_errors_enable=False)


@app.command()
def main(name: str = "morty"):
print(name + 3)


if __name__ == "__main__":
app()
12 changes: 12 additions & 0 deletions tests/assets/type_error_rich_short_disable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import typer

app = typer.Typer(pretty_errors_short=False)


@app.command()
def main(name: str = "morty"):
print(name + 3)


if __name__ == "__main__":
app()
61 changes: 60 additions & 1 deletion tests/test_tracebacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,28 @@ def test_traceback_rich():
)
assert "return get_command(self)(*args, **kwargs)" not in result.stderr

assert "typer.run(main)" in result.stderr
assert "typer.run(main)" not in result.stderr
assert "print(name + 3)" in result.stderr

# TODO: when deprecating Python 3.6, remove second option
assert (
'TypeError: can only concatenate str (not "int") to str' in result.stderr
or "TypeError: must be str, not int" in result.stderr
)
assert "name = 'morty'" in result.stderr


def test_traceback_rich_pretty_short_disable():
file_path = Path(__file__).parent / "assets/type_error_rich_short_disable.py"
result = subprocess.run(
["coverage", "run", str(file_path)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
)
assert "return get_command(self)(*args, **kwargs)" not in result.stderr

assert "app()" in result.stderr
assert "print(name + 3)" in result.stderr

# TODO: when deprecating Python 3.6, remove second option
Expand Down Expand Up @@ -42,6 +63,25 @@ def test_traceback_no_rich():
)


def test_traceback_no_rich_short_disable():
file_path = Path(__file__).parent / "assets/type_error_no_rich_short_disable.py"
result = subprocess.run(
["coverage", "run", str(file_path)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
)
assert "return get_command(self)(*args, **kwargs)" not in result.stderr

assert "app()" in result.stderr
assert "print(name + 3)" in result.stderr
# TODO: when deprecating Python 3.6, remove second option
assert (
'TypeError: can only concatenate str (not "int") to str' in result.stderr
or "TypeError: must be str, not int" in result.stderr
)


def test_unmodified_traceback():
file_path = Path(__file__).parent / "assets/type_error_normal_traceback.py"
result = subprocess.run(
Expand All @@ -62,3 +102,22 @@ def test_unmodified_traceback():
'TypeError: can only concatenate str (not "int") to str' in result.stderr
or "TypeError: must be str, not int" in result.stderr
)


def test_rich_pretty_errors_disable():
file_path = Path(__file__).parent / "assets/type_error_rich_pretty_disable.py"
result = subprocess.run(
["coverage", "run", str(file_path)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
)
assert "return get_command(self)(*args, **kwargs)" in result.stderr

assert "app()" in result.stderr
assert "print(name + 3)" in result.stderr
# TODO: when deprecating Python 3.6, remove second option
assert (
'TypeError: can only concatenate str (not "int") to str' in result.stderr
or "TypeError: must be str, not int" in result.stderr
)
70 changes: 53 additions & 17 deletions typer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
CommandInfo,
Default,
DefaultPlaceholder,
DeveloperExceptionConfig,
FileBinaryRead,
FileBinaryWrite,
FileText,
Expand Down Expand Up @@ -52,7 +53,10 @@
def except_hook(
exc_type: Type[BaseException], exc_value: BaseException, tb: TracebackType
) -> None:
if not getattr(exc_value, _typer_developer_exception_attr_name, None):
exception_config: Union[DeveloperExceptionConfig, None] = getattr(
exc_value, _typer_developer_exception_attr_name, None
)
if not exception_config or not exception_config.pretty_errors_enable:
_original_except_hook(exc_type, exc_value, tb)
return
typer_path = os.path.dirname(__file__)
Expand All @@ -64,7 +68,7 @@ def except_hook(
type(exc),
exc,
exc.__traceback__,
show_locals=True,
show_locals=exception_config.pretty_errors_show_locals,
suppress=supress_internal_dir_names,
)
console_stderr.print(rich_tb)
Expand All @@ -75,15 +79,16 @@ def except_hook(
if any(
[frame.filename.startswith(path) for path in supress_internal_dir_names]
):
# Hide the line for internal libraries, Typer and Click
stack.append(
traceback.FrameSummary(
filename=frame.filename,
lineno=frame.lineno,
name=frame.name,
line="",
if not exception_config.pretty_errors_short:
# Hide the line for internal libraries, Typer and Click
stack.append(
traceback.FrameSummary(
filename=frame.filename,
lineno=frame.lineno,
name=frame.name,
line="",
)
)
)
else:
stack.append(frame)
# Type ignore ref: https://github.com/python/typeshed/pull/8244
Expand Down Expand Up @@ -123,8 +128,14 @@ def __init__(
hidden: bool = Default(False),
deprecated: bool = Default(False),
add_completion: bool = True,
pretty_errors_enable: bool = True,
pretty_errors_show_locals: bool = True,
pretty_errors_short: bool = True,
):
self._add_completion = add_completion
self.pretty_errors_enable = pretty_errors_enable
self.pretty_errors_show_locals = pretty_errors_show_locals
self.pretty_errors_short = pretty_errors_short
self.info = TyperInfo(
name=name,
cls=cls,
Expand Down Expand Up @@ -285,12 +296,23 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
# but that means the last error shown is the custom exception, not the
# actual error. This trick improves developer experience by showing the
# actual error last.
setattr(e, _typer_developer_exception_attr_name, True)
setattr(
e,
_typer_developer_exception_attr_name,
DeveloperExceptionConfig(
pretty_errors_enable=self.pretty_errors_enable,
pretty_errors_show_locals=self.pretty_errors_show_locals,
pretty_errors_short=self.pretty_errors_short,
),
)
raise e


def get_group(typer_instance: Typer) -> click.Command:
group = get_group_from_info(TyperInfo(typer_instance))
group = get_group_from_info(
TyperInfo(typer_instance),
pretty_errors_short=typer_instance.pretty_errors_short,
)
return group


Expand Down Expand Up @@ -318,7 +340,9 @@ def get_command(typer_instance: Typer) -> click.Command:
):
single_command.context_settings = typer_instance.info.context_settings

click_command = get_command_from_info(single_command)
click_command = get_command_from_info(
single_command, pretty_errors_short=typer_instance.pretty_errors_short
)
if typer_instance._add_completion:
click_command.params.append(click_install_param)
click_command.params.append(click_show_param)
Expand Down Expand Up @@ -422,17 +446,23 @@ def solve_typer_info_defaults(typer_info: TyperInfo) -> TyperInfo:
return TyperInfo(**values)


def get_group_from_info(group_info: TyperInfo) -> click.Command:
def get_group_from_info(
group_info: TyperInfo, *, pretty_errors_short: bool
) -> click.Command:
assert (
group_info.typer_instance
), "A Typer instance is needed to generate a Click Group"
commands: Dict[str, click.Command] = {}
for command_info in group_info.typer_instance.registered_commands:
command = get_command_from_info(command_info=command_info)
command = get_command_from_info(
command_info=command_info, pretty_errors_short=pretty_errors_short
)
if command.name:
commands[command.name] = command
for sub_group_info in group_info.typer_instance.registered_groups:
sub_group = get_group_from_info(sub_group_info)
sub_group = get_group_from_info(
sub_group_info, pretty_errors_short=pretty_errors_short
)
if sub_group.name:
commands[sub_group.name] = sub_group
solved_info = solve_typer_info_defaults(group_info)
Expand All @@ -456,6 +486,7 @@ def get_group_from_info(group_info: TyperInfo) -> click.Command:
params=params,
convertors=convertors,
context_param_name=context_param_name,
pretty_errors_short=pretty_errors_short,
),
params=params, # type: ignore
help=solved_info.help,
Expand Down Expand Up @@ -492,7 +523,9 @@ def get_params_convertors_ctx_param_name_from_function(
return params, convertors, context_param_name


def get_command_from_info(command_info: CommandInfo) -> click.Command:
def get_command_from_info(
command_info: CommandInfo, *, pretty_errors_short: bool
) -> click.Command:
assert command_info.callback, "A command must have a callback function"
name = command_info.name or get_command_name(command_info.callback.__name__)
use_help = command_info.help
Expand All @@ -514,6 +547,7 @@ def get_command_from_info(command_info: CommandInfo) -> click.Command:
params=params,
convertors=convertors,
context_param_name=context_param_name,
pretty_errors_short=pretty_errors_short,
),
params=params, # type: ignore
help=use_help,
Expand Down Expand Up @@ -585,6 +619,7 @@ def get_callback(
params: Sequence[click.Parameter] = [],
convertors: Dict[str, Callable[[str], Any]] = {},
context_param_name: Optional[str] = None,
pretty_errors_short: bool,
) -> Optional[Callable[..., Any]]:
if not callback:
return None
Expand All @@ -597,6 +632,7 @@ def get_callback(
use_params[param.name] = param.default

def wrapper(**kwargs: Any) -> Any:
_rich_traceback_guard = pretty_errors_short # noqa: F841
for k, v in kwargs.items():
if k in convertors:
use_params[k] = convertors[k](v)
Expand Down
13 changes: 13 additions & 0 deletions typer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,3 +454,16 @@ def __init__(
self.name = name
self.default = default
self.annotation = annotation


class DeveloperExceptionConfig:
def __init__(
self,
*,
pretty_errors_enable: bool = True,
pretty_errors_show_locals: bool = True,
pretty_errors_short: bool = True,
) -> None:
self.pretty_errors_enable = pretty_errors_enable
self.pretty_errors_show_locals = pretty_errors_show_locals
self.pretty_errors_short = pretty_errors_short