Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support optional subcommands under tyro.conf.ConsolidateSubcommandArgs #224

Merged
merged 6 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions src/tyro/_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,30 @@ def callable_with_args(
kwargs: Dict[str, Any] = {}
consumed_keywords: Set[str] = set()

def get_value_from_arg(prefixed_field_name: str) -> Any:
def get_value_from_arg(
prefixed_field_name: str, field_def: _fields.FieldDefinition
) -> tuple[Any, bool]:
"""Helper for getting values from `value_from_arg` + doing some extra
asserts."""
assert (
prefixed_field_name in value_from_prefixed_field_name
), f"{prefixed_field_name} not in {value_from_prefixed_field_name}"
return value_from_prefixed_field_name[prefixed_field_name]
asserts.

Returns:
- The value from `value_from_prefixed_field_name`.
- If the value was found. If True, we found the value (and it will
be returned as a string or list of strings). If False, we've just
returned the default.
"""

if prefixed_field_name not in value_from_prefixed_field_name:
# When would the value not be found? Only if we have
# `tyro.conf.ConslidateSubcommandArgs` for one of the contained
# subparsers.
assert (
parser_definition.subparsers is not None
and parser_definition.consolidate_subcommand_args
), "Field value is unexpectedly missing. This is likely a bug in tyro."
return field_def.default, False
else:
return value_from_prefixed_field_name[prefixed_field_name], True

arg_from_prefixed_field_name: Dict[str, _arguments.ArgumentDefinition] = {}
for arg in parser_definition.args:
Expand All @@ -79,7 +96,7 @@ def get_value_from_arg(prefixed_field_name: str) -> Any:
name_maybe_prefixed = prefixed_field_name
consumed_keywords.add(name_maybe_prefixed)
if not arg.lowered.is_fixed():
value = get_value_from_arg(name_maybe_prefixed)
value, value_found = get_value_from_arg(name_maybe_prefixed, field)

if value in _fields.MISSING_AND_MISSING_NONPROP:
value = arg.field.default
Expand All @@ -97,7 +114,8 @@ def get_value_from_arg(prefixed_field_name: str) -> Any:
and arg.lowered.nargs in ("?", "*")
):
value = []
else:
elif value_found:
# Value was found from the CLI, so we need to cast it with instance_from_str.
any_arguments_provided = True
if arg.lowered.nargs == "?":
# Special case for optional positional arguments: this is the
Expand Down Expand Up @@ -144,7 +162,10 @@ def get_value_from_arg(prefixed_field_name: str) -> Any:
subparser_dest = _strings.make_subparser_dest(name=prefixed_field_name)
consumed_keywords.add(subparser_dest)
if subparser_dest in value_from_prefixed_field_name:
subparser_name = get_value_from_arg(subparser_dest)
subparser_name, subparser_name_found = get_value_from_arg(
subparser_dest, field
)
assert subparser_name_found
else:
assert (
subparser_def.default_instance
Expand Down
4 changes: 2 additions & 2 deletions src/tyro/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def _cli_impl(
if deprecated_kwargs.get("avoid_subparsers", False):
f = conf.AvoidSubcommands[f] # type: ignore
warnings.warn(
"`avoid_subparsers=` is deprecated! use `tyro.conf.AvoidSubparsers[]`"
"`avoid_subparsers=` is deprecated! use `tyro.conf.AvoidSubcommands[]`"
" instead.",
stacklevel=2,
)
Expand Down Expand Up @@ -398,7 +398,7 @@ def _cli_impl(
parser._parsing_known_args = return_unknown_args
parser._console_outputs = console_outputs
parser._args = args
parser_spec.apply(parser)
parser_spec.apply(parser, force_required_subparsers=False)

# Print help message when no arguments are passed in. (but arguments are
# expected)
Expand Down
103 changes: 64 additions & 39 deletions src/tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,24 @@ def from_callable_or_type(
)

def apply(
self, parser: argparse.ArgumentParser
self, parser: argparse.ArgumentParser, force_required_subparsers: bool
) -> Tuple[argparse.ArgumentParser, ...]:
"""Create defined arguments and subparsers."""

# Generate helptext.
parser.description = self.description

# `force_required_subparsers`: if we have required arguments and we're
# consolidating all arguments into the leaves of the subparser trees, a
# required argument in one node of this tree means that all of its
# descendants are required.
if self.consolidate_subcommand_args and self.has_required_args:
force_required_subparsers = True

# Create subparser tree.
subparser_group = None
if self.subparsers is not None:
leaves = self.subparsers.apply(parser)
leaves = self.subparsers.apply(parser, force_required_subparsers)
subparser_group = parser._action_groups.pop()
else:
leaves = (parser,)
Expand Down Expand Up @@ -408,6 +415,7 @@ class SubparsersSpecification:
name: str
description: str | None
parser_from_name: Dict[str, ParserSpecification]
default_name: str | None
default_parser: ParserSpecification | None
intern_prefix: str
required: bool
Expand Down Expand Up @@ -490,17 +498,24 @@ def from_field(
subcommand_type_from_name[subcommand_name] = cast(type, option)

# If a field default is provided, try to find a matching subcommand name.
if (
field.default is None
or field.default in _singleton.MISSING_AND_MISSING_NONPROP
):
default_name = None
else:
default_name = _subcommand_matching.match_subcommand(
field.default,
subcommand_config_from_name,
subcommand_type_from_name,
)
default_name = None
if field.default not in _singleton.MISSING_AND_MISSING_NONPROP:
# Subcommand matcher won't work with `none_proxy`.
if field.default is None:
default_name = next(
iter(
filter(
lambda pair: pair[1] is none_proxy,
subcommand_type_from_name.items(),
)
)
)[0]
else:
default_name = _subcommand_matching.match_subcommand(
field.default,
subcommand_config_from_name,
subcommand_type_from_name,
)

assert default_name is not None, (
f"`{extern_prefix}` was provided a default value of type"
Expand Down Expand Up @@ -581,13 +596,13 @@ def from_field(
)
parser_from_name[subcommand_name] = subparser

# Required if a default is missing.
required = field.default in _fields.MISSING_AND_MISSING_NONPROP

# Required if a default is passed in, but the default value has missing
# parameters.
default_parser = None
if default_name is not None:
if default_name is None:
required = True
else:
required = False
default_parser = parser_from_name[default_name]
if any(map(lambda arg: arg.lowered.required, default_parser.args)):
required = True
Expand All @@ -597,29 +612,14 @@ def from_field(
):
required = True

# Required if all args are pushed to the final subcommand.
if _markers.ConsolidateSubcommandArgs in field.markers:
required = True

# Make description.
description_parts = []
if field.helptext is not None:
description_parts.append(field.helptext)
if not required and field.default not in _fields.MISSING_AND_MISSING_NONPROP:
description_parts.append(f" (default: {default_name})")
description = (
# We use `None` instead of an empty string to prevent a line break from
# being created where the description would be.
" ".join(description_parts) if len(description_parts) > 0 else None
)

return SubparsersSpecification(
name=field.intern_name,
# If we wanted, we could add information about the default instance
# automatically, as is done for normal fields. But for now we just rely on
# the user to include it in the docstring.
description=description,
description=field.helptext,
parser_from_name=parser_from_name,
default_name=default_name,
default_parser=default_parser,
intern_prefix=intern_prefix,
required=required,
Expand All @@ -628,19 +628,42 @@ def from_field(
)

def apply(
self, parent_parser: argparse.ArgumentParser
self,
parent_parser: argparse.ArgumentParser,
force_required_subparsers: bool,
) -> Tuple[argparse.ArgumentParser, ...]:
title = "subcommands"
metavar = "{" + ",".join(self.parser_from_name.keys()) + "}"
if not self.required:

required = self.required or force_required_subparsers

if not required:
title = "optional " + title
metavar = f"[{metavar}]"

# Make description.
description_parts = []
if self.description is not None:
description_parts.append(self.description)
if not required and self.default_name is not None:
description_parts.append(f"(default: {self.default_name})")

# If this subparser is required because of a required argument in a
# parent (tyro.conf.ConsolidateSubcommandArgs).
if not self.required and force_required_subparsers:
description_parts.append("(required to specify parent argument)")

description = (
# We use `None` instead of an empty string to prevent a line break from
# being created where the description would be.
" ".join(description_parts) if len(description_parts) > 0 else None
)

# Add subparsers to every node in previous level of the tree.
argparse_subparsers = parent_parser.add_subparsers(
dest=_strings.make_subparser_dest(self.intern_prefix),
description=self.description,
required=self.required,
description=description,
required=required,
title=title,
metavar=metavar,
)
Expand All @@ -667,7 +690,9 @@ def apply(
subparser._console_outputs = parent_parser._console_outputs
subparser._args = parent_parser._args

subparser_tree_leaves.extend(subparser_def.apply(subparser))
subparser_tree_leaves.extend(
subparser_def.apply(subparser, force_required_subparsers)
)

return tuple(subparser_tree_leaves)

Expand Down
10 changes: 9 additions & 1 deletion src/tyro/conf/_markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

ConsolidateSubcommandArgs = Annotated[T, None]
"""Consolidate arguments applied to subcommands. Makes CLI less sensitive to argument
ordering, at the cost of support for optional subcommands.
ordering, with some tradeoffs.

By default, :mod:`tyro` will generate a traditional CLI interface where args are applied to
the directly preceding subcommand. When we have two subcommands ``s1`` and ``s2``:
Expand All @@ -87,6 +87,14 @@

This is more robust to reordering of options, ensuring that any new options can simply
be placed at the end of the command.

The tradeoff is in required arguments. In the above example, if any ``--root.*`` options
are required (no default is specified), all subcommands will need to be specified in order to
provide the required argument.

.. code-block:: bash

python x.py s1 s2 {required --root.* arguments}
"""

OmitSubcommandPrefixes = Annotated[T, None]
Expand Down
Loading
Loading