Skip to content

Commit

Permalink
Fix metavars for generics, more refactor + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jan 21, 2022
1 parent 7692ab3 commit 37480d1
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 85 deletions.
162 changes: 82 additions & 80 deletions dcargs/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,28 @@
T = TypeVar("T")


def _no_op_action(x: T) -> T:
return x
def instance_from_string(typ: Type, arg: str) -> T:
"""Given a type and and a string from the command-line, reconstruct an object. Not
intended to deal with containers; these are handled in the argument
transformations.
This is intended to replace all calls to `type(string)`, which can cause unexpected
behavior. As an example, note that the following argparse code will always print
`True`, because `bool("True") == bool("False") == bool("0") == True`.
```
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--flag", type=bool)
print(parser.parse_args().flag)
```
"""
assert len(get_args(typ)) == 0, f"Type {typ} cannot be instantiated."
if typ is bool:
return _strings.bool_from_string(arg) # type: ignore
else:
return typ(arg) # type: ignore


@dataclasses.dataclass(frozen=True)
Expand All @@ -42,12 +62,16 @@ class ArgumentDefinition:
# Action that is called on parsed arguments. This handles conversions from strings
# to our desired types.
#
# There are 2 options:
# There are 3 options:
field_action: Union[
# Most standard fields: these are converted from strings from the CLI.
Callable[[str], Any],
# Sequence fields! This should be used whenever argparse's `nargs` field is set.
Callable[[List[str]], Any],
# Special case: the only time that argparse doesn't give us a string is when the
# argument action is set to `store_true` or `store_false`. In this case, we get
# a bool directly, and the field action can be a no-op.
Callable[[bool], bool],
]

# Fields that will be populated initially.
Expand All @@ -72,17 +96,18 @@ def add_argument(
"""Add a defined argument to a parser."""
kwargs = {k: v for k, v in vars(self).items() if v is not None}

# Apply prefix for nested dataclasses.
if "dest" in kwargs:
kwargs["dest"] = self.prefix + kwargs["dest"]

# Important: as far as argparse is concerned, all inputs are strings.
#
# Conversions from strings to our desired types happen in the "field action";
# this is a bit more flexible, and lets us handle more complex types like enums
# and multi-type tuples.
if "type" in kwargs:
kwargs["type"] = str

# Don't pass field action into argparse.
if "dest" in kwargs:
kwargs["dest"] = self.prefix + kwargs["dest"]

kwargs.pop("field")
kwargs.pop("parent_class")
kwargs.pop("prefix")
Expand All @@ -93,6 +118,7 @@ def add_argument(
parser.add_argument(self.get_flag(), **kwargs)

def get_flag(self) -> str:
"""Get --flag representation, with a prefix applied for nested dataclasses."""
return "--" + (self.prefix + self.name).replace("_", "-")

@staticmethod
Expand All @@ -107,47 +133,25 @@ def make_from_field(

assert field.init, "Field must be in class constructor"

# The default field action: this converts a string from argparse to the desired
# type of the argument.
def default_field_action(x: str) -> Any:
return instance_from_string(cast(Type, arg.type), x)

# Create initial argument.
arg = ArgumentDefinition(
prefix="",
field=field,
parent_class=parent_class,
field_action=_no_op_action,
field_action=default_field_action,
name=field.name,
type=field.type,
default=default_override,
)

# Propagate argument through transforms until stable.
prev_arg = arg

def instance_from_string(typ: Type[T], arg: str) -> T:
"""Given a type and and a string from the command-line, reconstruct an object. Not
intended to deal with containers; these are handled in the argument
transformations.
This is intended to replace all calls to `type(string)`, which can cause unexpected
behavior. As an example, note that the following argparse code will always print
`True`, because `bool("True") == bool("False") == bool("0") == True`.
```
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--flag", type=bool)
print(parser.parse_args().flag)
```
"""
if typ in type_from_typevar:
# The Type vs TypeVar annotations could be cleaned up...
typ = type_from_typevar[typ] # type: ignore

if typ is bool:
return _strings.bool_from_string(arg) # type: ignore
else:
return typ(arg) # type: ignore

argument_transforms = _get_argument_transforms(instance_from_string)
argument_transforms = _get_argument_transforms(type_from_typevar)
while True:
for transform in argument_transforms: # type: ignore
# Apply transform.
Expand All @@ -158,27 +162,28 @@ def instance_from_string(typ: Type[T], arg: str) -> T:
break
prev_arg = arg

if arg.field_action is _no_op_action and arg.type is not None:
cast_type = cast(Type, arg.type)
arg = dataclasses.replace(
arg,
field_action=lambda x: instance_from_string(
cast_type,
x,
),
)
elif arg.field_action is _no_op_action:
assert arg.action in ("store_true", "store_false")

return arg


def _get_argument_transforms(
instance_from_string: Callable[[Type[T], str], T]
type_from_typevar: Dict[TypeVar, Type]
) -> List[Callable[[ArgumentDefinition], ArgumentDefinition]]:
"""Get a list of argument transformations."""

def unwrap_final(arg: ArgumentDefinition) -> ArgumentDefinition:
def resolve_typevars(typ: Union[Type, TypeVar]) -> Type:
return type_from_typevar.get(cast(TypeVar, typ), cast(Type, typ))

# All transforms should start with `transform_`.

def transform_resolve_arg_typevars(arg: ArgumentDefinition) -> ArgumentDefinition:
if arg.type is not None:
return dataclasses.replace(
arg,
type=resolve_typevars(arg.type),
)
return arg

def transform_unwrap_final(arg: ArgumentDefinition) -> ArgumentDefinition:
"""Treat Final[T] as just T."""
if get_origin(arg.type) is Final:
(typ,) = get_args(arg.type)
Expand All @@ -189,7 +194,7 @@ def unwrap_final(arg: ArgumentDefinition) -> ArgumentDefinition:
else:
return arg

def unwrap_annotated(arg: ArgumentDefinition) -> ArgumentDefinition:
def transform_unwrap_annotated(arg: ArgumentDefinition) -> ArgumentDefinition:
"""Treat Annotated[T, annotation] as just T."""
if hasattr(arg.type, "__class__") and arg.type.__class__ == _AnnotatedAlias:
typ = get_origin(arg.type)
Expand All @@ -200,7 +205,7 @@ def unwrap_annotated(arg: ArgumentDefinition) -> ArgumentDefinition:
else:
return arg

def handle_optionals(arg: ArgumentDefinition) -> ArgumentDefinition:
def transform_handle_optionals(arg: ArgumentDefinition) -> ArgumentDefinition:
"""Transform for handling Optional[T] types. Sets default to None and marks arg as
not required."""
if get_origin(arg.type) is Union:
Expand All @@ -218,7 +223,7 @@ def handle_optionals(arg: ArgumentDefinition) -> ArgumentDefinition:
else:
return arg

def populate_defaults(arg: ArgumentDefinition) -> ArgumentDefinition:
def transform_populate_defaults(arg: ArgumentDefinition) -> ArgumentDefinition:
"""Populate default values."""
if arg.default is not None:
# Skip if another handler has already populated the default.
Expand All @@ -238,7 +243,7 @@ def populate_defaults(arg: ArgumentDefinition) -> ArgumentDefinition:

return dataclasses.replace(arg, default=default, required=required)

def bool_flags(arg: ArgumentDefinition) -> ArgumentDefinition:
def transform_bool_flags(arg: ArgumentDefinition) -> ArgumentDefinition:
"""For booleans, we use a `store_true` action."""
if arg.type != bool:
return arg
Expand Down Expand Up @@ -272,6 +277,7 @@ def bool_flags(arg: ArgumentDefinition) -> ArgumentDefinition:
arg,
action="store_true",
type=None,
field_action=lambda x: x, # argparse will directly give us a bool!
)
elif arg.default is True:
return dataclasses.replace(
Expand All @@ -280,11 +286,12 @@ def bool_flags(arg: ArgumentDefinition) -> ArgumentDefinition:
name="no_" + arg.name,
action="store_false",
type=None,
field_action=lambda x: x, # argparse will directly give us a bool!
)
else:
assert False, "Invalid default"

def nargs_from_sequences_lists_and_sets(
def transform_nargs_from_sequences_lists_and_sets(
arg: ArgumentDefinition,
) -> ArgumentDefinition:
"""Transform for handling Sequence[T] and list types."""
Expand All @@ -293,7 +300,8 @@ def nargs_from_sequences_lists_and_sets(
list, # different from typing.List!
set, # different from typing.Set!
):
(typ,) = get_args(arg.type)
assert arg.nargs is None, "Sequence types cannot be nested."
(typ,) = map(resolve_typevars, get_args(arg.type))
container_type = get_origin(arg.type)
if container_type is collections.abc.Sequence:
container_type = list
Expand All @@ -312,16 +320,17 @@ def nargs_from_sequences_lists_and_sets(
else:
return arg

def nargs_from_tuples(arg: ArgumentDefinition) -> ArgumentDefinition:
def transform_nargs_from_tuples(arg: ArgumentDefinition) -> ArgumentDefinition:
"""Transform for handling Tuple[T, T, ...] types."""

if arg.nargs is None and get_origin(arg.type) is tuple:
types = get_args(arg.type)
typeset = set(types)
typeset_no_ellipsis = typeset - {Ellipsis}
assert arg.nargs is None, "Sequence types cannot be nested."
types = tuple(map(resolve_typevars, get_args(arg.type)))
typeset = set(types) # Note that sets are unordered.
typeset_no_ellipsis = typeset - {Ellipsis} # type: ignore

if typeset_no_ellipsis != typeset:
# Ellipsis: variable argument counts
# Ellipsis: variable argument counts.
assert (
len(typeset_no_ellipsis) == 1
), "If ellipsis is used, tuples must contain only one type."
Expand All @@ -339,13 +348,13 @@ def nargs_from_tuples(arg: ArgumentDefinition) -> ArgumentDefinition:
),
)
else:
# Tuples with more than one type
# Tuples with more than one type.
assert arg.metavar is None

return dataclasses.replace(
arg,
nargs=len(types),
type=str, # Types will be converted in the dataclass reconstruction step.
type=str, # Types are converted in the field action.
metavar=tuple(
t.__name__.upper() if hasattr(t, "__name__") else "X"
for t in types
Expand All @@ -359,7 +368,7 @@ def nargs_from_tuples(arg: ArgumentDefinition) -> ArgumentDefinition:
else:
return arg

def choices_from_literals(arg: ArgumentDefinition) -> ArgumentDefinition:
def transform_choices_from_literals(arg: ArgumentDefinition) -> ArgumentDefinition:
"""For literal types, set choices."""
if get_origin(arg.type) is Literal:
choices = get_args(arg.type)
Expand All @@ -374,7 +383,7 @@ def choices_from_literals(arg: ArgumentDefinition) -> ArgumentDefinition:
else:
return arg

def enums_as_strings(arg: ArgumentDefinition) -> ArgumentDefinition:
def transform_enums_as_strings(arg: ArgumentDefinition) -> ArgumentDefinition:
"""For enums, use string representations."""
if isinstance(arg.type, type) and issubclass(arg.type, enum.Enum):
if arg.choices is None:
Expand All @@ -392,7 +401,7 @@ def enums_as_strings(arg: ArgumentDefinition) -> ArgumentDefinition:
else:
return arg

def generate_helptext(arg: ArgumentDefinition) -> ArgumentDefinition:
def transform_generate_helptext(arg: ArgumentDefinition) -> ArgumentDefinition:
"""Generate helptext from docstring and argument name."""
if arg.help is None:
help_parts = []
Expand All @@ -416,11 +425,16 @@ def generate_helptext(arg: ArgumentDefinition) -> ArgumentDefinition:
else:
return arg

def use_type_as_metavar(arg: ArgumentDefinition) -> ArgumentDefinition:
def transform_use_type_as_metavar(arg: ArgumentDefinition) -> ArgumentDefinition:
"""Communicate the argument type using the metavar."""
if (
hasattr(arg.type, "__name__")
# Don't generate metavar if target is still wrapping something, eg
# Optional[int] will have 1 arg.
and len(get_args(arg.type)) == 0
# If choices is set, they'll be used by default.
and arg.choices is None
# Don't generate metavar if one already exists.
and arg.metavar is None
):
return dataclasses.replace(
Expand All @@ -429,16 +443,4 @@ def use_type_as_metavar(arg: ArgumentDefinition) -> ArgumentDefinition:
else:
return arg

return [
unwrap_final,
unwrap_annotated,
handle_optionals,
populate_defaults,
bool_flags,
nargs_from_sequences_lists_and_sets,
nargs_from_tuples,
choices_from_literals,
enums_as_strings,
generate_helptext,
use_type_as_metavar,
]
return [v for k, v in locals().items() if k.startswith("transform_")]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="dcargs",
version="0.0.13",
version="0.0.14",
description="Portable, reusable, strongly typed CLIs from dataclass definitions",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
Loading

0 comments on commit 37480d1

Please sign in to comment.