Skip to content

Commit

Permalink
Fix typed arguments issue. #google#300
Browse files Browse the repository at this point in the history
  • Loading branch information
loynoir committed Feb 15, 2021
1 parent 8e9b1d5 commit fb87f7c
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions fire/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def _ParseFn(args):
# Note: _ParseArgs modifies kwargs.
parsed_args, kwargs, remaining_args, capacity = _ParseArgs(
fn_spec.args, fn_spec.defaults, num_required_args, kwargs,
remaining_args, metadata)
remaining_args, metadata, fn_spec.annotations)

if fn_spec.varargs or fn_spec.varkw:
# If we're allowed *varargs or **kwargs, there's always capacity.
Expand All @@ -740,7 +740,7 @@ def _ParseFn(args):
varargs = []

for index, value in enumerate(varargs):
varargs[index] = _ParseValue(value, None, None, metadata)
varargs[index] = _ParseValue(value, None, None, metadata, fn_spec.annotations)

varargs = parsed_args + varargs
remaining_args += remaining_kwargs
Expand All @@ -752,7 +752,7 @@ def _ParseFn(args):


def _ParseArgs(fn_args, fn_defaults, num_required_args, kwargs,
remaining_args, metadata):
remaining_args, metadata, annotations):
"""Parses the positional and named arguments from the available supplied args.
Modifies kwargs, removing args as they are used.
Expand Down Expand Up @@ -786,13 +786,13 @@ def _ParseArgs(fn_args, fn_defaults, num_required_args, kwargs,
for index, arg in enumerate(fn_args):
value = kwargs.pop(arg, None)
if value is not None: # A value is specified at the command line.
value = _ParseValue(value, index, arg, metadata)
value = _ParseValue(value, index, arg, metadata, annotations)
parsed_args.append(value)
else: # No value has been explicitly specified.
if remaining_args and accepts_positional_args:
# Use a positional arg.
value = remaining_args.pop(0)
value = _ParseValue(value, index, arg, metadata)
value = _ParseValue(value, index, arg, metadata, annotations)
parsed_args.append(value)
elif index < num_required_args:
raise FireError(
Expand All @@ -805,7 +805,7 @@ def _ParseArgs(fn_args, fn_defaults, num_required_args, kwargs,
parsed_args.append(fn_defaults[default_index])

for key, value in kwargs.items():
kwargs[key] = _ParseValue(value, None, key, metadata)
kwargs[key] = _ParseValue(value, None, key, metadata, annotations)

return parsed_args, kwargs, remaining_args, capacity

Expand Down Expand Up @@ -951,7 +951,7 @@ def _IsMultiCharFlag(argument):
return argument.startswith('--') or re.match('^-[a-zA-Z]', argument)


def _ParseValue(value, index, arg, metadata):
def _ParseValue(value, index, arg, metadata, annotations):
"""Parses value, a string, into the appropriate type.
The function used to parse value is determined by the remaining arguments.
Expand All @@ -964,7 +964,10 @@ def _ParseValue(value, index, arg, metadata):
Returns:
value, parsed into the appropriate type for calling a function.
"""
parse_fn = parser.DefaultParseValue
if arg in annotations and annotations[arg] in (str, int, float):
parse_fn = annotations[arg]
else:
parse_fn = parser.DefaultParseValue

# We check to see if any parse function from the fn metadata applies here.
parse_fns = metadata.get(decorators.FIRE_PARSE_FNS)
Expand Down

0 comments on commit fb87f7c

Please sign in to comment.