Skip to content

Commit

Permalink
Fix UnionType handling as with Union for Optional values
Browse files Browse the repository at this point in the history
  • Loading branch information
jonaslb committed Oct 3, 2023
1 parent e0b207f commit 0ddd20c
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 13 deletions.
23 changes: 23 additions & 0 deletions tests/test_type_conversion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from enum import Enum
from pathlib import Path
from typing import Any, List, Optional, Tuple
Expand Down Expand Up @@ -29,6 +30,28 @@ def opt(user: Optional[str] = None):
assert "User: Camila" in result.output


@pytest.mark.skipif(
sys.version_info < (3, 10), reason="The | operator for types was new in 3.10"
)
def test_union_type_optional():
app = typer.Typer()

@app.command()
def opt(user: str | None = None):
if user:
print(f"User: {user}")
else:
print("No user")

result = runner.invoke(app)
assert result.exit_code == 0
assert "No user" in result.output

result = runner.invoke(app, ["--user", "Camila"])
assert result.exit_code == 0
assert "User: Camila" in result.output


def test_no_type():
app = typer.Typer()

Expand Down
32 changes: 32 additions & 0 deletions typer/_compat_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,37 @@
import sys
from typing import Union

import click

if sys.version_info >= (3, 8):
from typing import get_args as _get_args
from typing import get_origin as _get_origin
elif sys.version_info >= (3, 7):
from typing_extensions import get_args as _get_args
from typing_extensions import get_origin as _get_origin
else:
# These methods do not handle all the same details as the imported ones.
# However on Python 3.6 they should be sufficient.
# typer <= 0.7.0 used this implementation on all Python versions.

def _get_origin(arg): # pragma: no cover
return getattr(arg, "__origin__", None)

def _get_args(arg): # pragma: no cover
return getattr(arg, "__args__", None)


# Assigning variables to mark them as exported with mypy
get_origin = _get_origin
get_args = _get_args

if sys.version_info >= (3, 10):
from types import UnionType

UNION_TYPES = (UnionType, Union)
else:
UNION_TYPES = (Union,)


def _get_click_major() -> int:
return int(click.__version__.split(".")[0])
28 changes: 15 additions & 13 deletions typer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import click

from ._compat_utils import UNION_TYPES, get_args, get_origin
from .completion import get_completion_inspect_parameters
from .core import MarkupMode, TyperArgument, TyperCommand, TyperGroup, TyperOption
from .models import (
Expand Down Expand Up @@ -816,30 +817,31 @@ def get_click_param(
is_tuple = False
parameter_type: Any = None
is_flag = None
origin = getattr(main_type, "__origin__", None)
origin = get_origin(main_type)

if origin is not None:
# Handle Optional[SomeType]
if origin is Union:
# Handle SomeType | None and Optional[SomeType]
if origin in UNION_TYPES:
types = []
for type_ in main_type.__args__:
for type_ in get_args(main_type):
if type_ is NoneType:
continue
types.append(type_)
assert len(types) == 1, "Typer Currently doesn't support Union types"
main_type = types[0]
origin = getattr(main_type, "__origin__", None)
origin = get_origin(main_type)
# Handle Tuples and Lists
if lenient_issubclass(origin, List):
main_type = main_type.__args__[0]
assert not getattr(
main_type, "__origin__", None
main_type = get_args(main_type)[0]
assert not get_origin(
main_type
), "List types with complex sub-types are not currently supported"
is_list = True
elif lenient_issubclass(origin, Tuple): # type: ignore
types = []
for type_ in main_type.__args__:
assert not getattr(
type_, "__origin__", None
for type_ in get_args(main_type):
assert not get_origin(
type_
), "Tuple types with complex sub-types are not currently supported"
types.append(
get_click_type(annotation=type_, parameter_info=parameter_info)
Expand All @@ -854,7 +856,7 @@ def get_click_param(
if is_list:
convertor = generate_list_convertor(convertor)
if is_tuple:
convertor = generate_tuple_convertor(main_type.__args__)
convertor = generate_tuple_convertor(get_args(main_type))
if isinstance(parameter_info, OptionInfo):
if main_type is bool and not (parameter_info.is_flag is False):
is_flag = True
Expand Down Expand Up @@ -1008,7 +1010,7 @@ def get_param_completion(
incomplete_name = None
unassigned_params = [param for param in parameters.values()]
for param_sig in unassigned_params[:]:
origin = getattr(param_sig.annotation, "__origin__", None)
origin = get_origin(param_sig.annotation)
if lenient_issubclass(param_sig.annotation, click.Context):
ctx_name = param_sig.name
unassigned_params.remove(param_sig)
Expand Down

0 comments on commit 0ddd20c

Please sign in to comment.