diff --git a/CHANGELOG.rst b/CHANGELOG.rst index abc23a3f910f..6a9f3968ff4e 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,26 @@ CHANGELOG ========= +1.5.0 +===== + +* bugfix:Response Parsing: Fix response parsing so that leading + and trailing spaces are preserved +* feature:Shared Credentials File: The ``aws configure`` and + ``aws configure set`` command now write out all credential + variables to the shared credentials file ``~/.aws/credentials`` + (`issue 847 `__) +* bugfix:``aws s3``: Write warnings and errors to standard error as + opposed to standard output. + (`issue 919 `__) +* feature:``aws s3``: Add ``--only-show-errors`` option that displays + errors and warnings but suppresses all other output. +* feature:``aws s3 cp``: Added ability to upload local + file streams from standard input to s3 and download s3 + objects as local file streams to standard output. + (`issue 903 `__) + + 1.4.4 ===== @@ -12,7 +32,7 @@ CHANGELOG ===== * feature:``aws iam``: Update ``aws iam`` command to latest version. -* feature:``aws cognito-sync``: Update ``aws cognito-sync`` command +* feature:``aws cognito-sync``: Update ``aws cognito-sync`` command to latest version. * feature:``aws opsworks``: Update ``aws opsworks`` command to latest version. diff --git a/awscli/__init__.py b/awscli/__init__.py index fa375e1e7de2..6a17a5c231e2 100644 --- a/awscli/__init__.py +++ b/awscli/__init__.py @@ -17,7 +17,7 @@ """ import os -__version__ = '1.4.4' +__version__ = '1.5.0' # # Get our data path to be added to botocore's search path diff --git a/awscli/argprocess.py b/awscli/argprocess.py index 839770d4cf3e..155a3b2b3ed9 100644 --- a/awscli/argprocess.py +++ b/awscli/argprocess.py @@ -15,22 +15,35 @@ import logging import six +from botocore import xform_name from botocore.compat import OrderedDict, json from awscli import utils from awscli import SCALAR_TYPES, COMPLEX_TYPES from awscli.paramfile import get_paramfile, ResourceLoadingError +from awscli.paramfile import PARAMFILE_DISABLED LOG = logging.getLogger('awscli.argprocess') class ParamError(Exception): - def __init__(self, param, message): + def __init__(self, cli_name, message): + """ + + :type cli_name: string + :param cli_name: The complete cli argument name, + e.g. "--foo-bar". It should include the leading + hyphens if that's how a user would specify the name. + + :type message: string + :param message: The error message to display to the user. + + """ full_message = ("Error parsing parameter '%s': %s" % - (param.cli_name, message)) + (cli_name, message)) super(ParamError, self).__init__(full_message) - self.param = param + self.cli_name = cli_name self.message = message @@ -43,11 +56,11 @@ def __init__(self, param, key, valid_keys): valid_keys = ', '.join(valid_keys) full_message = ( "Unknown key '%s' for parameter %s, valid choices " - "are: %s" % (key, param.cli_name, valid_keys)) + "are: %s" % (key, '--%s' % xform_name(param.name), valid_keys)) super(ParamUnknownKeyError, self).__init__(full_message) -def unpack_argument(session, service_name, operation_name, param, value): +def unpack_argument(session, service_name, operation_name, cli_argument, value): """ Unpack an argument's value from the commandline. This is part one of a two step process in handling commandline arguments. Emits the load-cli-arg @@ -56,13 +69,13 @@ def unpack_argument(session, service_name, operation_name, param, value): load-cli-arg.ec2.describe-instances.foo """ - param_name = getattr(param, 'name', 'anonymous') + param_name = getattr(cli_argument, 'name', 'anonymous') value_override = session.emit_first_non_none_response( 'load-cli-arg.%s.%s.%s' % (service_name, operation_name, param_name), - param=param, value=value, service_name=service_name, + param=cli_argument, value=value, service_name=service_name, operation_name=operation_name) if value_override is not None: @@ -71,16 +84,16 @@ def unpack_argument(session, service_name, operation_name, param, value): return value -def uri_param(param, value, **kwargs): +def uri_param(event_name, param, value, **kwargs): """Handler that supports param values from URIs. """ - # Some params have a 'no_paramfile' attribute in their JSON - # models which means that we should not allow any uri based params - # for this argument. - if getattr(param, 'no_paramfile', False): + cli_argument = param + qualified_param_name = '.'.join(event_name.split('.')[1:]) + if qualified_param_name in PARAMFILE_DISABLED or \ + getattr(cli_argument, 'no_paramfile', None): return else: - return _check_for_uri_param(param, value) + return _check_for_uri_param(cli_argument, value) def _check_for_uri_param(param, value): @@ -89,15 +102,15 @@ def _check_for_uri_param(param, value): try: return get_paramfile(value) except ResourceLoadingError as e: - raise ParamError(param, six.text_type(e)) + raise ParamError(param.cli_name, six.text_type(e)) def detect_shape_structure(param): - if param.type in SCALAR_TYPES: + if param.type_name in SCALAR_TYPES: return 'scalar' - elif param.type == 'structure': + elif param.type_name == 'structure': sub_types = [detect_shape_structure(p) - for p in param.members] + for p in param.members.values()] # We're distinguishing between structure(scalar) # and structure(scalars), because for the case of # a single scalar in a structure we can simplify @@ -108,13 +121,104 @@ def detect_shape_structure(param): return 'structure(scalars)' else: return 'structure(%s)' % ', '.join(sorted(set(sub_types))) - elif param.type == 'list': - return 'list-%s' % detect_shape_structure(param.members) - elif param.type == 'map': - if param.members.type in SCALAR_TYPES: + elif param.type_name == 'list': + return 'list-%s' % detect_shape_structure(param.member) + elif param.type_name == 'map': + if param.value.type_name in SCALAR_TYPES: return 'map-scalar' else: - return 'map-%s' % detect_shape_structure(param.members) + return 'map-%s' % detect_shape_structure(param.value) + + +def unpack_cli_arg(cli_argument, value): + """ + Parses and unpacks the encoded string command line parameter + and returns native Python data structures that can be passed + to the Operation. + + :type cli_argument: :class:`awscli.arguments.BaseCLIArgument` + :param cli_argument: The CLI argument object. + + :param value: The value of the parameter. This can be a number of + different python types (str, list, etc). This is the value as + it's specified on the command line. + + :return: The "unpacked" argument than can be sent to the `Operation` + object in python. + """ + return _unpack_cli_arg(cli_argument.argument_model, value, + cli_argument.cli_name) + + +def _unpack_cli_arg(argument_model, value, cli_name): + if argument_model.type_name in SCALAR_TYPES: + return unpack_scalar_cli_arg( + argument_model, value, cli_name) + elif argument_model.type_name in COMPLEX_TYPES: + return _unpack_complex_cli_arg( + argument_model, value, cli_name) + else: + return six.text_type(value) + + +def _unpack_complex_cli_arg(argument_model, value, cli_name): + type_name = argument_model.type_name + if type_name == 'structure' or type_name == 'map': + if value.lstrip()[0] == '{': + try: + return json.loads(value, object_pairs_hook=OrderedDict) + except ValueError as e: + raise ParamError( + cli_name, "Invalid JSON: %s\nJSON received: %s" + % (e, value)) + raise ParamError(cli_name, "Invalid JSON:\n%s" % value) + elif type_name == 'list': + if isinstance(value, six.string_types): + if value.lstrip()[0] == '[': + return json.loads(value, object_pairs_hook=OrderedDict) + elif isinstance(value, list) and len(value) == 1: + single_value = value[0].strip() + if single_value and single_value[0] == '[': + return json.loads(value[0], object_pairs_hook=OrderedDict) + try: + # There's a couple of cases remaining here. + # 1. It's possible that this is just a list of strings, i.e + # --security-group-ids sg-1 sg-2 sg-3 => ['sg-1', 'sg-2', 'sg-3'] + # 2. It's possible this is a list of json objects: + # --filters '{"Name": ..}' '{"Name": ...}' + member_shape_model = argument_model.member + return [_unpack_cli_arg(member_shape_model, v, cli_name) + for v in value] + except (ValueError, TypeError) as e: + # The list params don't have a name/cli_name attached to them + # so they will have bad error messages. We're going to + # attach the parent parameter to this error message to provide + # a more helpful error message. + raise ParamError(cli_name, value[0]) + + +def unpack_scalar_cli_arg(argument_model, value, cli_name=''): + # Note the cli_name is used strictly for error reporting. It's + # not required to use unpack_scalar_cli_arg + if argument_model.type_name == 'integer' or argument_model.type_name == 'long': + return int(value) + elif argument_model.type_name == 'float' or argument_model.type_name == 'double': + # TODO: losing precision on double types + return float(value) + elif argument_model.type_name == 'blob' and \ + argument_model.serialization.get('streaming'): + file_path = os.path.expandvars(value) + file_path = os.path.expanduser(file_path) + if not os.path.isfile(file_path): + msg = 'Blob values must be a path to a file.' + raise ParamError(cli_name, msg) + return open(file_path, 'rb') + elif argument_model.type_name == 'boolean': + if isinstance(value, six.string_types) and value.lower() == 'false': + return False + return bool(value) + else: + return value class ParamShorthand(object): @@ -140,7 +244,7 @@ class ParamShorthand(object): def __init__(self): pass - def __call__(self, param, value, **kwargs): + def __call__(self, cli_argument, value, **kwargs): """Attempt to parse shorthand syntax for values. This is intended to be hooked up as an event handler (hence the @@ -148,6 +252,9 @@ def __call__(self, param, value, **kwargs): figure out if we can parse it. If we can parse it, we return the parsed value (typically some sort of python dict). + :type cli_argument: :class:`awscli.arguments.BaseCLIArgument` + :param cli_argument: The CLI argument object. + :type param: :class:`botocore.parameters.Parameter` :param param: The parameter object (includes various metadata about the parameter). @@ -164,23 +271,28 @@ def __call__(self, param, value, **kwargs): be raised. """ - parse_method = self.get_parse_method_for_param(param, value) + parse_method = self.get_parse_method_for_param(cli_argument, value) if parse_method is None: return else: try: - LOG.debug("Using %s for param %s", parse_method, param) - parsed = getattr(self, parse_method)(param, value) + LOG.debug("Using %s for param %s", parse_method, + cli_argument.cli_name) + parsed = getattr(self, parse_method)( + cli_argument.argument_model, value) except ParamSyntaxError as e: - doc_fn = self._get_example_fn(param) - # Try to give them a helpful error message. - if doc_fn is None: - raise e - else: - raise ParamError(param, "should be: %s" % doc_fn(param)) + docgen = ParamShorthandDocGen() + example_usage = docgen.generate_shorthand_example(cli_argument) + raise ParamError(cli_argument.cli_name, "should be: %s" % example_usage) + except ParamError as e: + # The shorthand parse methods don't have the cli_name, + # so any ParamError won't have this value. To accomodate + # this, ParamErrors are caught and reraised with the cli_name + # injected. + raise ParamError(cli_argument.cli_name, e.message) return parsed - def get_parse_method_for_param(self, param, value=None): + def get_parse_method_for_param(self, cli_argument, value=None): # We first need to make sure this is a parameter that qualifies # for simplification. The first short-circuit case is if it looks # like json we immediately return. @@ -191,9 +303,9 @@ def get_parse_method_for_param(self, param, value=None): if isinstance(check_val, six.string_types) and check_val.strip().startswith( ('[', '{')): LOG.debug("Param %s looks like JSON, not considered for " - "param shorthand.", param.py_name) + "param shorthand.", cli_argument.py_name) return - structure = detect_shape_structure(param) + structure = detect_shape_structure(cli_argument.argument_model) # If this looks like shorthand then we log the detected structure # to help with debugging why the shorthand may not work, for # example list-structure(list-structure(scalars)) @@ -217,11 +329,11 @@ def add_example_fn(self, arg_name, help_command, **kwargs): not any of the ReST formatting that might be required in the docs. """ argument = help_command.arg_table[arg_name] - if hasattr(argument, 'argument_object') and argument.argument_object: - param = argument.argument_object - LOG.debug('Adding example fn for: %s' % param.name) - doc_fn = self._get_example_fn(param) - param.example_fn = doc_fn + model = argument.argument_model + LOG.debug('Adding example fn for: %s' % arg_name) + doc_fn = self._get_example_fn(model) + # XXX: fix this, don't set attributes on argument objects. + argument.example_fn = doc_fn def _list_scalar_list_parse(self, param, value): # Think something like ec2.DescribeInstances.Filters. @@ -229,7 +341,7 @@ def _list_scalar_list_parse(self, param, value): parsed = [] for v in value: - struct = self._struct_scalar_list_parse(param.members, v) + struct = self._struct_scalar_list_parse(param.member, v) parsed.append(struct) return parsed @@ -237,9 +349,9 @@ def _list_scalar_list_parse(self, param, value): def _struct_scalar_list_parse(self, param, value): # Create a mapping of argument name -> argument object args = {} - for arg in param.members: + for member_name, arg in param.members.items(): # Arg name -> arg object lookup - args[arg.name] = arg + args[member_name] = arg parts = self._split_on_commas(value) current_parsed = {} @@ -254,7 +366,7 @@ def _struct_scalar_list_parse(self, param, value): args.keys()) current_value = unpack_scalar_cli_arg(args[current_key], current[1].strip()) - if args[current_key].type == 'list': + if args[current_key].type_name == 'list': current_parsed[current_key] = current_value.split(',') else: current_parsed[current_key] = current_value @@ -273,17 +385,17 @@ def _struct_scalar_list_parse(self, param, value): return current_parsed def _list_scalar_parse(self, param, value): - single_param = param.members.members[0] + single_param_name = list(param.member.members.keys())[0] parsed = [] # We know that value is a list in this case. for v in value: - parsed.append({single_param.name: v}) + parsed.append({single_param_name: v}) return parsed def _list_key_value_parse(self, param, value): # param is a list param. # param.member is the struct param. - struct_param = param.members + struct_param = param.member parsed = [] for v in value: single_struct_param = self._key_value_parse(struct_param, v) @@ -295,8 +407,7 @@ def _special_key_value_parse(self, param, value): # key=value parsing, *but* supports a few additional conveniences # when working with a structure with a single element. # Precondition: param is a shape of structure(scalar) - if len(param.members) == 1 and param.members[0].name == 'Value' and \ - '=' not in value: + if self._is_special_case_key_value(param, value): # We have an even shorter shorthand syntax for structure # of scalars of a single element with a member name of # 'Value'. @@ -304,6 +415,14 @@ def _special_key_value_parse(self, param, value): else: return self._key_value_parse(param, value) + def _is_special_case_key_value(self, param, value): + members = param.members + if len(param.members) == 1: + if list(members.keys())[0] == 'Value' and \ + '=' not in value: + return True + return False + def _key_value_parse(self, param, value): # The expected structure is: # key=value,key2=value @@ -325,156 +444,142 @@ def _key_value_parse(self, param, value): if valid_names: sub_param = valid_names[key] if sub_param is not None: + # TODO: you are here. unpack_scalar_cli_arg takes + # the cli_name, but we don't have it. What are our + # options? value = unpack_scalar_cli_arg(sub_param, value) parsed[key] = value return parsed def _create_name_to_params(self, param): - if param.type == 'structure': - return dict([(p.name, p) for p in param.members]) - elif param.type == 'map' and hasattr(param.keys, 'enum'): - return dict([(v, None) for v in param.keys.enum]) - - def _struct_list_scalar_doc_helper(self, param, inner_params): - scalar_params = [p for p in inner_params if p.type in SCALAR_TYPES] - list_params = [p for p in inner_params if p.type == 'list'] - pair = '' - for param in scalar_params: - pair += '%s=%s1,' % (param.name, param.type) - for param in list_params[:-1]: - param_type = param.members.type - pair += '%s=%s1,%s2,' % (param.name, param_type, param_type) - last_param = list_params[-1] - param_type = last_param.members.type - pair += '%s=%s1,%s2' % (last_param.name, param_type, param_type) - return pair + if param.type_name == 'structure': + return dict([(member_name, p) for member_name, p + in param.members.items()]) + elif param.type_name == 'map' and hasattr(param.key, 'enum'): + return dict([(v, None) for v in param.key.enum]) - def _docs_list_scalar_list_parse(self, param): - s = ('Key value pairs, where values are separated by commas, ' - 'and multiple pairs are separated by spaces.\n') - s += '%s ' % param.cli_name - pair = self._struct_list_scalar_doc_helper(param, param.members.members) - pair += ' %s' % pair - s += pair - return s + def _split_on_commas(self, value): + try: + return utils.split_on_commas(value) + except ValueError as e: + raise ParamSyntaxError(six.text_type(e)) - def _docs_struct_scalar_list_parse(self, param): - s = ('Key value pairs, where values are separated by commas.\n') - s += '%s ' % param.cli_name - s += self._struct_list_scalar_doc_helper(param, param.members) - return s - def _docs_list_scalar_parse(self, param): - name = param.members.members[0].name - return '%s %s1 %s2 %s3' % (param.cli_name, name, name, name) +class ParamShorthandDocGen(object): + """Documentation generator for param shorthand syntax.""" - def _docs_list_key_value_parse(self, param): - s = "Key value pairs, with multiple values separated by a space.\n" - s += '%s ' % param.cli_name - pair = ','.join(['%s=%s' % (sub_param.name, sub_param.type) - for sub_param in param.members.members]) - pair += ' %s' % pair - s += pair - return s + SHORTHAND_SHAPES = ParamShorthand.SHORTHAND_SHAPES - def _docs_special_key_value_parse(self, param): - if len(param.members) == 1 and param.members[0].name == 'Value': - # Returning None will indicate that we don't have - # any examples to generate, and the entire examples section - # should be skipped for this arg. - return None - else: - return self._docs_key_value_parse(param) - - def _docs_key_value_parse(self, param): - s = '%s ' % param.cli_name - if param.type == 'structure': - s += ','.join(['%s=value' % sub_param.name - for sub_param in param.members]) - elif param.type == 'map': + def supports_shorthand(self, cli_argument): + """Checks if a CLI argument supports shorthand syntax.""" + if cli_argument.argument_model is not None: + structure = detect_shape_structure(cli_argument.argument_model) + return structure in self.SHORTHAND_SHAPES + return False + + def generate_shorthand_example(self, cli_argument): + """Generate documentation for a CLI argument. + + :type cli_argument: awscli.arguments.BaseCLIArgument + :param cli_argument: The CLI argument which to generate + documentation for. + """ + structure = detect_shape_structure(cli_argument.argument_model) + parse_method_name = self.SHORTHAND_SHAPES.get(structure) + doc_method_name = '_docs%s' % parse_method_name + method = getattr(self, doc_method_name) + doc_string = method(cli_argument) + return doc_string + + def _docs_list_scalar_parse(self, cli_argument): + cli_name = cli_argument.cli_name + structure_members = cli_argument.argument_model.member.members + # We know based on the SHORTHAND_SHAPES that this is a + # structure with a single member, so we can safely say: + member_name = list(structure_members.keys())[0] + return '%s %s1 %s2 %s3' % (cli_name, member_name, + member_name, member_name) + + def _docs_key_value_parse(self, cli_argument): + cli_name = cli_argument.cli_name + model = cli_argument.argument_model + s = '%s ' % cli_name + if model.type_name == 'structure': + members_dict = model.members + member_names = list(members_dict.keys()) + s += ','.join(['%s=value' % name for name in member_names]) + elif model.type_name == 'map': s += 'key_name=string,key_name2=string' - if param.keys.type == 'string' and hasattr(param.keys, 'enum'): + if self._has_enum_values(model.key): + enum_values = self._get_enum_values(model.key) s += '\nWhere valid key names are:\n' - for value in param.keys.enum: + for value in enum_values: s += ' %s\n' % value return s - def _split_on_commas(self, value): - try: - return utils.split_on_commas(value) - except ValueError as e: - raise ParamSyntaxError(six.text_type(e)) + def _docs_list_key_value_parse(self, cli_argument): + s = "Key value pairs, with multiple values separated by a space.\n" + s += '%s ' % cli_argument.cli_name + members = cli_argument.argument_model.member.members + pair = ','.join(['%s=%s' % (member_name, shape.type_name) + for member_name, shape in members.items()]) + pair += ' %s' % pair + s += pair + return s + def _docs_list_scalar_list_parse(self, cli_argument): + s = ('Key value pairs, where values are separated by commas, ' + 'and multiple pairs are separated by spaces.\n') + s += '%s ' % cli_argument.cli_name + pair = self._generate_struct_list_scalar_docs( + cli_argument.argument_model.member.members) + pair += ' %s' % pair + s += pair + return s -def unpack_cli_arg(parameter, value): - """ - Parses and unpacks the encoded string command line parameter - and returns native Python data structures that can be passed - to the Operation. + def _docs_struct_scalar_list_parse(self, cli_argument): + s = ('Key value pairs, where values are separated by commas.\n') + s += '%s ' % cli_argument.cli_name + s += self._generate_struct_list_scalar_docs( + cli_argument.argument_model.members) + return s - :type parameter: :class:`botocore.parameter.Parameter` - :param parameter: The parameter object containing metadata about - the parameter. + def _generate_struct_list_scalar_docs(self, members_dict): + scalar_params = list(self._get_scalar_params(members_dict)) + list_params = list(self._get_list_params(members_dict)) + pair = '' + for member_name, param in scalar_params: + pair += '%s=%s1,' % (member_name, param.type_name) + for member_name, param in list_params[:-1]: + param_type = param.member.type_name + pair += '%s=%s1,%s2,' % (member_name, param_type, param_type) + member_name, last_param = list_params[-1] + param_type = last_param.member.type_name + pair += '%s=%s1,%s2' % (member_name, param_type, param_type) + return pair - :param value: The value of the parameter. This can be a number of - different python types (str, list, etc). This is the value as - it's specified on the command line. - - :return: The "unpacked" argument than can be sent to the `Operation` - object in python. - """ - if parameter.type in SCALAR_TYPES: - return unpack_scalar_cli_arg(parameter, value) - elif parameter.type in COMPLEX_TYPES: - return unpack_complex_cli_arg(parameter, value) - else: - return six.text_type(value) + def _get_scalar_params(self, members_dict): + for key, value in members_dict.items(): + if value.type_name in SCALAR_TYPES: + yield (key, value) + def _get_list_params(self, members_dict): + for key, value in members_dict.items(): + if value.type_name == 'list': + yield (key, value) -def unpack_complex_cli_arg(parameter, value): - if parameter.type == 'structure' or parameter.type == 'map': - if value.lstrip()[0] == '{': - try: - return json.loads(value, object_pairs_hook=OrderedDict) - except ValueError as e: - raise ParamError( - parameter, "Invalid JSON: %s\nJSON received: %s" - % (e, value)) - raise ParamError(parameter, "Invalid JSON:\n%s" % value) - elif parameter.type == 'list': - if isinstance(value, six.string_types): - if value.lstrip()[0] == '[': - return json.loads(value, object_pairs_hook=OrderedDict) - elif isinstance(value, list) and len(value) == 1: - single_value = value[0].strip() - if single_value and single_value[0] == '[': - return json.loads(value[0], object_pairs_hook=OrderedDict) - try: - return [unpack_cli_arg(parameter.members, v) for v in value] - except ParamError as e: - # The list params don't have a name/cli_name attached to them - # so they will have bad error messages. We're going to - # attach the parent parameter to this error message to provide - # a more helpful error message. - raise ParamError(parameter, e.message) + def _has_enum_values(self, model): + return 'enum' in model.metadata + def _get_enum_values(self, model): + return model.metadata['enum'] -def unpack_scalar_cli_arg(parameter, value): - if parameter.type == 'integer' or parameter.type == 'long': - return int(value) - elif parameter.type == 'float' or parameter.type == 'double': - # TODO: losing precision on double types - return float(value) - elif parameter.type == 'blob' and parameter.payload and parameter.streaming: - file_path = os.path.expandvars(value) - file_path = os.path.expanduser(file_path) - if not os.path.isfile(file_path): - msg = 'Blob values must be a path to a file.' - raise ParamError(parameter, msg) - return open(file_path, 'rb') - elif parameter.type == 'boolean': - if isinstance(value, six.string_types) and value.lower() == 'false': - return False - return bool(value) - else: - return value + def _docs_special_key_value_parse(self, cli_argument): + members = cli_argument.argument_model.members + if len(members) == 1 and 'Value' in members: + # Returning None will indicate that we don't have + # any examples to generate, and the entire examples section + # should be skipped for this arg. + return None + else: + return self._docs_key_value_parse(cli_argument) diff --git a/awscli/arguments.py b/awscli/arguments.py index a695a144e82e..2f96cb4369ca 100644 --- a/awscli/arguments.py +++ b/awscli/arguments.py @@ -39,10 +39,11 @@ import logging from botocore import xform_name -from botocore.parameters import ListParameter, StructParameter +#from botocore.parameters import ListParameter, StructParameter from awscli.argprocess import unpack_cli_arg from awscli.schema import SchemaTransformer +from botocore import model LOG = logging.getLogger('awscli.arguments') @@ -52,6 +53,19 @@ class UnknownArgumentError(Exception): pass +def create_argument_model_from_schema(schema): + # Given a JSON schems (described in schema.py), convert it + # to a shape object from `botocore.model.Shape` that can be + # used as the argument_model for the Argument classes below. + transformer = SchemaTransformer() + shapes_map = transformer.transform(schema) + shape_resolver = model.ShapeResolver(shapes_map) + # The SchemaTransformer guarantees that the top level shape + # will always be named 'InputShape'. + arg_shape = shape_resolver.get_shape_by_name('InputShape') + return arg_shape + + class BaseCLIArgument(object): """Interface for CLI argument. @@ -192,7 +206,7 @@ class CustomArgument(BaseCLIArgument): def __init__(self, name, help_text='', dest=None, default=None, action=None, required=None, choices=None, nargs=None, cli_type_name=None, group_name=None, positional_arg=False, - no_paramfile=False, schema=None, synopsis=''): + no_paramfile=False, argument_model=None, synopsis=''): self._name = name self._help = help_text self._dest = dest @@ -206,20 +220,33 @@ def __init__(self, name, help_text='', dest=None, default=None, if choices is None: choices = [] self._choices = choices - self.no_paramfile = no_paramfile - self._schema = schema self._synopsis = synopsis + # These are public attributes that are ok to access from external + # objects. + self.no_paramfile = no_paramfile + self.argument_model = None + + if argument_model is None: + argument_model = self._create_scalar_argument_model() + self.argument_model = argument_model + # If the top level element is a list then set nargs to # accept multiple values seperated by a space. - if self._schema and self._schema.get('type', None) == 'array': + if self.argument_model is not None and \ + self.argument_model.type_name == 'list': self._nargs = '+' - # TODO: We should eliminate this altogether. - # You should not have to depend on an argument_object - # as part of the interface. Currently the argprocess - # and docs code relies on this object. - self.argument_object = None + def _create_scalar_argument_model(self): + if self._nargs is not None: + # If nargs is not None then argparse will parse the value + # as a list, so we don't create an argument_object so we don't + # go through param validation. + return None + # If no argument model is provided, we create a basic + # shape argument. + type_name = self.cli_type_name + return create_argument_model_from_schema({'type': type_name}) @property def cli_name(self): @@ -250,30 +277,6 @@ def add_to_parser(self, parser): kwargs['nargs'] = self._nargs parser.add_argument(cli_name, **kwargs) - def create_argument_object(self): - """ - Create an argument object based on the JSON schema if one is set. - After calling this method, ``parameter.argument_object`` is available - e.g. for generating docs. - """ - transformer = SchemaTransformer(self._schema) - transformed = transformer.transform() - - # Set the parameter name from the parsed arg key name - transformed.update({'name': self.name}) - - LOG.debug('Custom parameter schema for {0}: {1}'.format( - self.name, transformed)) - - # Select the correct top level type - if transformed['type'] == 'structure': - self.argument_object = StructParameter(None, **transformed) - elif transformed['type'] == 'list': - self.argument_object = ListParameter(None, **transformed) - else: - raise ValueError('Invalid top level type {0}!'.format( - transformed['type'])) - @property def required(self): if self._required is None: @@ -294,8 +297,8 @@ def cli_type_name(self): return self._cli_type_name elif self._action in ['store_true', 'store_false']: return 'boolean' - elif self.argument_object is not None: - return self.argument_object.type + elif self.argument_model is not None: + return self.argument_model.type_name else: # Default to 'string' type if we don't have any # other info. @@ -348,25 +351,36 @@ class CLIArgument(BaseCLIArgument): 'blob': str } - def __init__(self, name, argument_object, operation_object): + def __init__(self, name, argument_model, operation_object, + is_required=False, serialized_name=None): """ :type name: str :param name: The name of the argument in "cli" form (e.g. ``min-instances``). - :type argument_object: ``botocore.parameter.Parameter`` - :param argument_object: The parameter object to associate with - this object. + :type argument_model: ``botocore.model.Shape`` + :param argument_model: The shape object that models the argument. :type operation_object: ``botocore.operation.Operation`` :param operation_object: The operation object associated with this object. + :type is_required: boolean + :param is_required: Indicates if this parameter is required or not. + """ self._name = name - self.argument_object = argument_object + # This is the name we need to use when constructing the parameters + # dict we send to botocore. While we can change the .name attribute + # which is the name exposed in the CLI, the serialized name we use + # for botocore is invariant and should not be changed. + if serialized_name is None: + serialized_name = name + self._serialized_name = serialized_name + self.argument_model = argument_model self.operation_object = operation_object + self._required = is_required @property def py_name(self): @@ -374,23 +388,23 @@ def py_name(self): @property def required(self): - return self.argument_object.required + return self._required @required.setter def required(self, value): - self.argument_object.required = value + self._required = value @property def documentation(self): - return self.argument_object.documentation + return self.argument_model.documentation @property def cli_type_name(self): - return self.argument_object.type + return self.argument_model.type_name @property def cli_type(self): - return self.TYPE_MAP.get(self.argument_object.type, str) + return self.TYPE_MAP.get(self.argument_model.type_name, str) def add_to_parser(self, parser): """ @@ -422,23 +436,22 @@ def add_to_params(self, parameters, value): # can customize as they need. unpacked = self._unpack_argument(value) LOG.debug('Unpacked value of "%s" for parameter "%s": %s', value, - self.argument_object.py_name, unpacked) - parameters[self.argument_object.py_name] = unpacked + self.py_name, unpacked) + parameters[self._serialized_name] = unpacked def _unpack_argument(self, value): service_name = self.operation_object.service.endpoint_prefix operation_name = xform_name(self.operation_object.name, '-') override = self._emit_first_response('process-cli-arg.%s.%s' % ( - service_name, operation_name), param=self.argument_object, - value=value, - operation=self.operation_object) + service_name, operation_name), param=self.argument_model, + cli_argument=self, value=value, operation=self.operation_object) if override is not None: # A plugin supplied an alternate conversion, # use it instead. return override else: # Fall back to the default arg processing. - return unpack_cli_arg(self.argument_object, value) + return unpack_cli_arg(self, value) def _emit(self, name, **kwargs): session = self.operation_object.service.session @@ -476,11 +489,15 @@ class BooleanArgument(CLIArgument): """ - def __init__(self, name, argument_object, operation_object, - action='store_true', dest=None, group_name=None, - default=None): - super(BooleanArgument, self).__init__(name, argument_object, - operation_object) + def __init__(self, name, argument_model, operation_object, + is_required=False, action='store_true', dest=None, + group_name=None, default=None, + serialized_name=None): + super(BooleanArgument, self).__init__(name, + argument_model, + operation_object, + is_required, + serialized_name=serialized_name) self._mutex_group = None self._action = action if dest is None: @@ -499,7 +516,7 @@ def add_to_params(self, parameters, value): # If the value was not explicitly set (value is None) # we don't add it to the params dict. if value is not None: - parameters[self.py_name] = value + parameters[self._serialized_name] = value def add_to_arg_table(self, argument_table): # Boolean parameters are a bit tricky. For a single boolean parameter @@ -510,11 +527,10 @@ def add_to_arg_table(self, argument_table): # arg table. argument_table[self.name] = self negative_name = 'no-%s' % self.name - negative_version = self.__class__(negative_name, self.argument_object, - self.operation_object, - action='store_false', - dest=self._destination, - group_name=self.group_name) + negative_version = self.__class__( + negative_name, self.argument_model, self.operation_object, + action='store_false', dest=self._destination, + group_name=self.group_name, serialized_name=self._serialized_name) argument_table[negative_name] = negative_version def add_to_parser(self, parser): diff --git a/awscli/clidocs.py b/awscli/clidocs.py index b3c5f44626d8..088580cf55cb 100644 --- a/awscli/clidocs.py +++ b/awscli/clidocs.py @@ -14,6 +14,7 @@ from bcdoc.docevents import DOC_EVENTS from awscli import SCALAR_TYPES +from awscli.argprocess import ParamShorthandDocGen LOG = logging.getLogger(__name__) @@ -237,11 +238,11 @@ def doc_subitem(self, command_name, help_command, **kwargs): class OperationDocumentEventHandler(CLIDocumentEventHandler): def build_translation_map(self): - LOG.debug('build_translation_map') operation = self.help_command.obj d = {} - for param in operation.params: - d[param.name] = param.cli_name + for cli_name, cli_argument in self.help_command.arg_table.items(): + if cli_argument.argument_model is not None: + d[cli_argument.argument_model.name] = cli_name for operation in operation.service.operations: d[operation.name] = operation.cli_name return d @@ -268,89 +269,97 @@ def doc_description(self, help_command, **kwargs): doc.style.h2('Description') doc.include_doc_string(operation.documentation) - def _json_example_value_name(self, param, include_enum_values=True): + def _json_example_value_name(self, argument_model, include_enum_values=True): # If include_enum_values is True, then the valid enum values # are included as the sample JSON value. - if param.type == 'string': - if hasattr(param, 'enum') and include_enum_values: - choices = param.enum + if argument_model.type_name == 'string': + if 'enum' in argument_model.metadata and include_enum_values: + choices = argument_model.metadata['enum'] return '|'.join(['"%s"' % c for c in choices]) else: return '"string"' - elif param.type == 'boolean': + elif argument_model.type_name == 'boolean': return 'true|false' else: - return '%s' % param.type + return '%s' % argument_model.type_name - def _json_example(self, doc, param): - if param.type == 'list': + def _json_example(self, doc, argument_model): + if argument_model.type_name == 'list': doc.write('[') - if param.members.type in SCALAR_TYPES: - doc.write('%s, ...' % self._json_example_value_name(param.members)) + if argument_model.member.type_name in SCALAR_TYPES: + doc.write('%s, ...' % self._json_example_value_name(argument_model.member)) else: doc.style.indent() doc.style.new_line() - self._json_example(doc, param.members) + self._json_example(doc, argument_model.member) doc.style.new_line() doc.write('...') doc.style.dedent() doc.style.new_line() doc.write(']') - elif param.type == 'map': + elif argument_model.type_name == 'map': doc.write('{') doc.style.indent() - key_string = self._json_example_value_name(param.keys) + key_string = self._json_example_value_name(argument_model.key) doc.write('%s: ' % key_string) - if param.members.type in SCALAR_TYPES: - doc.write(self._json_example_value_name(param.members)) + if argument_model.value.type_name in SCALAR_TYPES: + doc.write(self._json_example_value_name(argument_model.value)) else: doc.style.indent() - self._json_example(doc, param.members) + self._json_example(doc, argument_model.value) doc.style.dedent() doc.style.new_line() doc.write('...') doc.style.dedent() doc.write('}') - elif param.type == 'structure': + elif argument_model.type_name == 'structure': doc.write('{') doc.style.indent() doc.style.new_line() - for i, member in enumerate(param.members): - if member.type in SCALAR_TYPES: - doc.write('"%s": %s' % (member.name, - self._json_example_value_name(member))) - elif member.type == 'structure': - doc.write('"%s": ' % member.name) - self._json_example(doc, member) - elif member.type == 'map': - doc.write('"%s": ' % member.name) - self._json_example(doc, member) - elif member.type == 'list': - doc.write('"%s": ' % member.name) - self._json_example(doc, member) - if i < len(param.members) - 1: - doc.write(',') - doc.style.new_line() - else: - doc.style.dedent() - doc.style.new_line() - doc.write('}') + self._doc_input_structure_members(doc, argument_model) + + def _doc_input_structure_members(self, doc, argument_model): + members = argument_model.members + for i, member_name in enumerate(members): + member_model = members[member_name] + member_type_name = member_model.type_name + if member_type_name in SCALAR_TYPES: + doc.write('"%s": %s' % (member_name, + self._json_example_value_name(member_model))) + elif member_type_name == 'structure': + doc.write('"%s": ' % member_name) + self._json_example(doc, member_model) + elif member_type_name == 'map': + doc.write('"%s": ' % member_name) + self._json_example(doc, member_model) + elif member_type_name == 'list': + doc.write('"%s": ' % member_name) + self._json_example(doc, member_model) + if i < len(members) - 1: + doc.write(',') + doc.style.new_line() + else: + doc.style.dedent() + doc.style.new_line() + doc.write('}') def doc_option_example(self, arg_name, help_command, **kwargs): doc = help_command.doc - argument = help_command.arg_table[arg_name] - if argument.group_name in self._arg_groups: - if argument.group_name in self._documented_arg_groups: + cli_argument = help_command.arg_table[arg_name] + if cli_argument.group_name in self._arg_groups: + if cli_argument.group_name in self._documented_arg_groups: # Args with group_names (boolean args) don't # need to generate example syntax. return - param = argument.argument_object - if param and param.example_fn: + argument_model = cli_argument.argument_model + docgen = ParamShorthandDocGen() + if docgen.supports_shorthand(cli_argument): # TODO: bcdoc should not know about shorthand syntax. This # should be pulled out into a separate handler in the # awscli.customizations package. - example_syntax = param.example_fn(param) - if example_syntax is None: + example_shorthand_syntax = docgen.generate_shorthand_example( + cli_argument) + if example_shorthand_syntax is None: # If the shorthand syntax returns a value of None, # this indicates to us that there is no example # needed for this param so we can immediately @@ -359,11 +368,11 @@ def doc_option_example(self, arg_name, help_command, **kwargs): doc.style.new_paragraph() doc.write('Shorthand Syntax') doc.style.start_codeblock() - for example_line in example_syntax.splitlines(): + for example_line in example_shorthand_syntax.splitlines(): doc.writeln(example_line) doc.style.end_codeblock() - if param is not None and param.type == 'list' and \ - param.members.type in SCALAR_TYPES: + if argument_model is not None and argument_model.type_name == 'list' and \ + argument_model.member.type_name in SCALAR_TYPES: # A list of scalars is special. While you *can* use # JSON ( ["foo", "bar", "baz"] ), you can also just # use the argparse behavior of space separated lists. @@ -373,19 +382,20 @@ def doc_option_example(self, arg_name, help_command, **kwargs): doc.write('Syntax') doc.style.start_codeblock() example_type = self._json_example_value_name( - param.members, include_enum_values=False) + argument_model.member, include_enum_values=False) doc.write('%s %s ...' % (example_type, example_type)) - if hasattr(param.members, 'enum'): + if 'enum' in argument_model.member.metadata: # If we have enum values, we can tell the user # exactly what valid values they can provide. - self._write_valid_enums(doc, param.members.enum) + enum = argument_model.member.metadata['enum'] + self._write_valid_enums(doc, enum) doc.style.end_codeblock() doc.style.new_paragraph() - elif argument.cli_type_name not in SCALAR_TYPES: + elif cli_argument.cli_type_name not in SCALAR_TYPES: doc.style.new_paragraph() doc.write('JSON Syntax') doc.style.start_codeblock() - self._json_example(doc, param) + self._json_example(doc, argument_model) doc.style.end_codeblock() doc.style.new_paragraph() @@ -396,38 +406,39 @@ def _write_valid_enums(self, doc, enum_values): doc.write(" %s\n" % value) doc.write("\n") - def _doc_member(self, doc, member_name, member): - docs = member.get('documentation', '') + def doc_output(self, help_command, event_name, **kwargs): + doc = help_command.doc + doc.style.h2('Output') + operation = help_command.obj + output_shape = operation.model.output_shape + if output_shape is None: + doc.write('None') + else: + for member_name, member_shape in output_shape.members.items(): + self._doc_member_for_output(doc, member_name, member_shape) + + def _doc_member_for_output(self, doc, member_name, member_shape): + docs = member_shape.documentation if member_name: - doc.write('%s -> (%s)' % (member_name, member['type'])) + doc.write('%s -> (%s)' % (member_name, member_shape.type_name)) else: - doc.write('(%s)' % member['type']) + doc.write('(%s)' % member_shape.type_name) doc.style.indent() doc.style.new_paragraph() doc.include_doc_string(docs) doc.style.new_paragraph() - if member['type'] == 'structure': - for sub_name in member['members']: - sub_member = member['members'][sub_name] - self._doc_member(doc, sub_name, sub_member) - elif member['type'] == 'map': - keys = member['keys'] - self._doc_member(doc, keys.get('xmlname', 'key'), keys) - members = member['members'] - self._doc_member(doc, members.get('xmlname', 'value'), members) - elif member['type'] == 'list': - self._doc_member(doc, '', member['members']) + member_type_name = member_shape.type_name + if member_type_name == 'structure': + for sub_name, sub_shape in member_shape.members.items(): + self._doc_member_for_output(doc, sub_name, sub_shape) + elif member_type_name == 'map': + key_shape = member_shape.key + key_name = key_shape.serialization.get('name', 'key') + self._doc_member_for_output(doc, key_name, key_shape) + value_shape = member_shape.value + value_name = value_shape.serialization.get('name', 'value') + self._doc_member_for_output(doc, value_name, value_shape) + elif member_type_name == 'list': + self._doc_member_for_output(doc, '', member_shape.member) doc.style.dedent() doc.style.new_paragraph() - - def doc_output(self, help_command, event_name, **kwargs): - doc = help_command.doc - doc.style.h2('Output') - operation = help_command.obj - output = operation.output - if output is None: - doc.write('None') - else: - for member_name in output['members']: - member = output['members'][member_name] - self._doc_member(doc, member_name, member) diff --git a/awscli/clidriver.py b/awscli/clidriver.py index 945a154bcec8..0dde5bde6b5c 100644 --- a/awscli/clidriver.py +++ b/awscli/clidriver.py @@ -20,6 +20,7 @@ from botocore.compat import copy_kwargs, OrderedDict from botocore.exceptions import NoCredentialsError from botocore.exceptions import NoRegionError +from botocore import parsers from awscli import EnvironmentVariables, __version__ from awscli.formatter import get_formatter @@ -42,6 +43,17 @@ LOG_FORMAT = ( '%(asctime)s - %(threadName)s - %(name)s - %(levelname)s - %(message)s') +# NOTE: this is temporary. +# Botocore now parsers timestamps to datetime.datetime objects. +# The AWS CLI has historically not parsed timestamp objects and treated +# them as strings. +# We will eventually want to allow for users to specify how to parse +# the timestamp formats, but for now, we set the default timestamp parser +# to be a noop. +# The eventual plan is to add a client option for providing a timestamp parser, +# and once the CLI has switched over to client objects we can remove this +# and set the timestamp parsing on a per client basis. +parsers.DEFAULT_TIMESTAMP_PARSER = lambda x: x def main(): @@ -136,16 +148,15 @@ def _build_argument_table(self): choices_path = choices.format(provider=provider) choices = list(self.session.get_data(choices_path)) option_params['choices'] = choices - argument_object = self._create_argument_object(option, - option_params) - argument_object.add_to_arg_table(argument_table) + cli_argument = self._create_cli_argument(option, option_params) + cli_argument.add_to_arg_table(argument_table) # Then the final step is to send out an event so handlers # can add extra arguments or modify existing arguments. self.session.emit('building-top-level-params', argument_table=argument_table) return argument_table - def _create_argument_object(self, option_name, option_params): + def _create_cli_argument(self, option_name, option_params): return CustomArgument( option_name, help_text=option_params.get('help', ''), dest=option_params.get('dest'),default=option_params.get('default'), @@ -476,35 +487,32 @@ def _build_call_parameters(self, args, arg_table): arg_object.add_to_params(service_params, value) return service_params - def _unpack_arg(self, arg_object, value): + def _unpack_arg(self, cli_argument, value): # Unpacks a commandline argument into a Python value by firing the # load-cli-arg.service-name.operation-name event. session = self._service_object.session service_name = self._service_object.endpoint_prefix operation_name = xform_name(self._operation_object.name, '-') - param = arg_object - if hasattr(param, 'argument_object') and param.argument_object: - param = param.argument_object - return unpack_argument(session, service_name, operation_name, - param, value) + cli_argument, value) def _create_argument_table(self): argument_table = OrderedDict() - # Arguments are treated a differently than service and - # operations. Instead of doing a get_parameter() we just - # load all the parameter objects up front for the operation. - # We could potentially do the same thing as service/operations - # but botocore already builds all the parameter objects - # when calling an operation so we'd have to optimize that first - # before using get_parameter() in the cli would be advantageous - for argument in self._operation_object.params: - cli_arg_name = argument.cli_name[2:] - arg_class = self.ARG_TYPES.get(argument.type, + input_shape = self._operation_object.model.input_shape + required_arguments = [] + arg_dict = {} + if input_shape is not None: + required_arguments = input_shape.required_members + arg_dict = self._operation_object.model.input_shape.members + for arg_name, arg_shape in arg_dict.items(): + cli_arg_name = xform_name(arg_name, '-') + arg_class = self.ARG_TYPES.get(arg_shape.type_name, self.DEFAULT_ARG_CLASS) - arg_object = arg_class(cli_arg_name, argument, - self._operation_object) + is_required = arg_name in required_arguments + arg_object = arg_class(cli_arg_name, arg_shape, + self._operation_object, is_required, + serialized_name=arg_name) arg_object.add_to_arg_table(argument_table) LOG.debug(argument_table) self._emit('building-argument-table.%s.%s' % (self._parent_name, diff --git a/awscli/compat.py b/awscli/compat.py index 5951a75e4aa2..2039aa282118 100644 --- a/awscli/compat.py +++ b/awscli/compat.py @@ -17,6 +17,8 @@ import locale import urllib.parse as urlparse + raw_input = input + def get_stdout_text_writer(): return sys.stdout @@ -41,6 +43,8 @@ def compat_open(filename, mode='r', encoding=None): import io import urlparse + raw_input = raw_input + def get_stdout_text_writer(): # In python3, all the sys.stdout/sys.stderr streams are in text # mode. This means they expect unicode, and will encode the diff --git a/awscli/customizations/argrename.py b/awscli/customizations/argrename.py index e9e50ae81d3d..782880811cdd 100644 --- a/awscli/customizations/argrename.py +++ b/awscli/customizations/argrename.py @@ -34,6 +34,7 @@ 'emr.*.job-flow-ids': 'cluster-ids', 'emr.*.job-flow-id': 'cluster-id', 'cloudsearchdomain.search.query': 'search-query', + 'sns.subscribe.endpoint': 'notification-endpoint', } diff --git a/awscli/customizations/commands.py b/awscli/customizations/commands.py index 0700b7cfe1d2..63a0ca7e4dc6 100644 --- a/awscli/customizations/commands.py +++ b/awscli/customizations/commands.py @@ -3,14 +3,17 @@ import bcdoc.docevents from botocore.compat import OrderedDict +from botocore import model +from botocore.validate import validate_parameters import awscli from awscli.clidocs import OperationDocumentEventHandler from awscli.argparser import ArgTableArgParser from awscli.argprocess import unpack_argument, unpack_cli_arg from awscli.clidriver import CLICommand -from awscli.arguments import CustomArgument +from awscli.arguments import CustomArgument, create_argument_model_from_schema from awscli.help import HelpCommand +from awscli.schema import SchemaTransformer LOG = logging.getLogger(__name__) @@ -129,34 +132,30 @@ def __call__(self, args, parsed_globals): # Unpack arguments for key, value in vars(parsed_args).items(): - param = None + cli_argument = None # Convert the name to use dashes instead of underscore # as these are how the parameters are stored in the # `arg_table`. xformed = key.replace('_', '-') if xformed in arg_table: - param = arg_table[xformed] + cli_argument = arg_table[xformed] value = unpack_argument( self._session, 'custom', self.name, - param, + cli_argument, value ) # If this parameter has a schema defined, then allow plugins # a chance to process and override its value. - if param and getattr(param, 'argument_object', None) is not None \ - and value is not None: - param_object = param.argument_object - - # Allow a single event handler to process the value + if self._should_allow_plugins_override(cli_argument, value): override = self._session\ .emit_first_non_none_response( 'process-cli-arg.%s.%s' % ('custom', self.name), - param=param_object, value=value, operation=None) + cli_argument=cli_argument, value=value, operation=None) if override is not None: # A plugin supplied a conversion @@ -164,10 +163,9 @@ def __call__(self, args, parsed_globals): else: # Unpack the argument, which is a string, into the # correct Python type (dict, list, etc) - value = unpack_cli_arg(param_object, value) - - # Validate param types, required keys, etc - param_object.validate(value) + value = unpack_cli_arg(cli_argument, value) + self._validate_value_against_schema( + cli_argument.argument_model, value) setattr(parsed_args, key, value) @@ -183,6 +181,15 @@ def __call__(self, args, parsed_globals): return subcommand_table[parsed_args.subcommand](remaining, parsed_globals) + def _validate_value_against_schema(self, model, value): + validate_parameters(value, model) + + def _should_allow_plugins_override(self, param, value): + if (param and param.argument_model is not None and + value is not None): + return True + return False + def _run_main(self, parsed_args, parsed_globals): # Subclasses should implement this method. # parsed_globals are the parsed global args (things like region, @@ -231,12 +238,13 @@ def arg_table(self): arg_table = OrderedDict() for arg_data in self.ARG_TABLE: - custom_argument = CustomArgument(**arg_data) - - # If a custom schema was passed in, create the argument object - # so that it can be validated and docs can be generated + # If a custom schema was passed in, create the argument_model + # so that it can be validated and docs can be generated. if 'schema' in arg_data: - custom_argument.create_argument_object() + argument_model = create_argument_model_from_schema( + arg_data.pop('schema')) + arg_data['argument_model'] = argument_model + custom_argument = CustomArgument(**arg_data) arg_table[arg_data['name']] = custom_argument return arg_table diff --git a/awscli/customizations/configure.py b/awscli/customizations/configure.py index 3467604fadb7..ab40ba4cf13b 100644 --- a/awscli/customizations/configure.py +++ b/awscli/customizations/configure.py @@ -18,12 +18,7 @@ from botocore.exceptions import ProfileNotFound from awscli.customizations.commands import BasicCommand - - -try: - raw_input = raw_input -except NameError: - raw_input = input +from awscli.compat import raw_input logger = logging.getLogger(__name__) @@ -55,7 +50,7 @@ def _mask_value(current_value): if current_value is None: return 'None' else: - return ('*' * 16) + current_value[-4:] + return ('*' * 16) + current_value[-4:] class InteractivePrompter(object): @@ -80,6 +75,32 @@ class ConfigFileWriter(object): ) def update_config(self, new_values, config_filename): + """Update config file with new values. + + This method will update a section in a config file with + new key value pairs. + + This method provides a few conveniences: + + * If the ``config_filename`` does not exist, it will + be created. Any parent directories will also be created + if necessary. + * If the section to update does not exist, it will be created. + * Any existing lines that are specified by ``new_values`` + **will not be touched**. This ensures that commented out + values are left unaltered. + + :type new_values: dict + :param new_values: The values to update. There is a special + key ``__section__``, that specifies what section in the INI + file to update. If this key is not present, then the + ``default`` section will be updated with the new values. + + :type config_filename: str + :param config_filename: The config filename where values will be + written. + + """ section_name = new_values.pop('__section__', 'default') if not os.path.isfile(config_filename): self._create_file(config_filename) @@ -98,11 +119,11 @@ def update_config(self, new_values, config_filename): def _create_file(self, config_filename): # Create the file as well as the parent dir if needed. - dirname, basename = os.path.split(config_filename) + dirname = os.path.split(config_filename)[0] if not os.path.isdir(dirname): os.makedirs(dirname) with os.fdopen(os.open(config_filename, - os.O_WRONLY|os.O_CREAT, 0o600), 'w'): + os.O_WRONLY | os.O_CREAT, 0o600), 'w'): pass def _write_new_section(self, section_name, new_values, config_filename): @@ -124,15 +145,15 @@ def _find_section_start(self, contents, section_name): if match is not None and self._matches_section(match, section_name): return i - else: - raise SectionNotFoundError(section_name) + raise SectionNotFoundError(section_name) def _update_section_contents(self, contents, section_name, new_values): # First, find the line where the section_name is defined. # This will be the value of i. new_values = new_values.copy() # ``contents`` is a list of file line contents. - section_start_line_num = self._find_section_start(contents, section_name) + section_start_line_num = self._find_section_start(contents, + section_name) # If we get here, then we've found the section. We now need # to figure out if we're updating a value or adding a new value. # There's 2 cases. Either we're setting a normal scalar value @@ -182,7 +203,8 @@ def _update_subattributes(self, index, contents, values, starting_indent): line = contents[i] match = self.OPTION_REGEX.search(line) if match is not None: - current_indent = len(match.group(1)) - len(match.group(1).lstrip()) + current_indent = len( + match.group(1)) - len(match.group(1).lstrip()) key_name = match.group(1).strip() if key_name in values: option_value = values[key_name] @@ -205,7 +227,8 @@ def _insert_new_values(self, line_number, contents, new_values, indent=''): subindent = indent + ' ' new_contents.append('%s%s =\n' % (indent, key)) for subkey, subval in list(value.items()): - new_contents.append('%s%s = %s\n' % (subindent, subkey, subval)) + new_contents.append('%s%s = %s\n' % (subindent, subkey, + subval)) else: new_contents.append('%s%s = %s\n' % (indent, key, value)) del new_values[key] @@ -302,9 +325,9 @@ def _lookup_credentials(self): # the credentials.method is sufficient to show # where the credentials are coming from. access_key = ConfigValue(credentials.access_key, - credentials.method, '') + credentials.method, '') secret_key = ConfigValue(credentials.secret_key, - credentials.method, '') + credentials.method, '') access_key.mask_value() secret_key.mask_value() return access_key, secret_key @@ -322,6 +345,7 @@ def _lookup_config(self, name): else: return ConfigValue(NOT_SET, None, None) + class ConfigureSetCommand(BasicCommand): NAME = 'set' DESCRIPTION = BasicCommand.FROM_FILE('configure', 'set', @@ -338,6 +362,10 @@ class ConfigureSetCommand(BasicCommand): 'action': 'store', 'cli_type_name': 'string', 'positional_arg': True}, ] + # Any variables specified in this list will be written to + # the ~/.aws/credentials file instead of ~/.aws/config. + _WRITE_TO_CREDS_FILE = ['aws_access_key_id', 'aws_secret_access_key', + 'aws_session_token'] def __init__(self, session, config_writer=None): super(ConfigureSetCommand, self).__init__(session) @@ -362,7 +390,6 @@ def _run_main(self, args, parsed_globals): section = 'profile %s' % self._session.profile else: # First figure out if it's been scoped to a profile. - # This will happen if parts = varname.split('.') if parts[0] in ('default', 'profile'): # Then we know we're scoped to a profile. @@ -383,6 +410,12 @@ def _run_main(self, args, parsed_globals): config_filename = os.path.expanduser( self._session.get_config_variable('config_file')) updated_config = {'__section__': section, varname: value} + if varname in self._WRITE_TO_CREDS_FILE: + config_filename = os.path.expanduser( + self._session.get_config_variable('credentials_file')) + section_name = updated_config['__section__'] + if section_name.startswith('profile '): + updated_config['__section__'] = section_name[8:] self._config_writer.update_config(updated_config, config_filename) @@ -456,7 +489,6 @@ def _get_dotted_config_value(self, varname): return value - class ConfigureCommand(BasicCommand): NAME = 'configure' DESCRIPTION = BasicCommand.FROM_FILE() @@ -522,7 +554,30 @@ def _run_main(self, parsed_args, parsed_globals): config_filename = os.path.expanduser( self._session.get_config_variable('config_file')) if new_values: + self._write_out_creds_file_values(new_values, + parsed_globals.profile) if parsed_globals.profile is not None: new_values['__section__'] = ( 'profile %s' % parsed_globals.profile) self._config_writer.update_config(new_values, config_filename) + + def _write_out_creds_file_values(self, new_values, profile_name): + # The access_key/secret_key are now *always* written to the shared + # credentials file (~/.aws/credentials), see aws/aws-cli#847. + # post-conditions: ~/.aws/credentials will have the updated credential + # file values and new_values will have the cred vars removed. + credential_file_values = {} + if 'aws_access_key_id' in new_values: + credential_file_values['aws_access_key_id'] = new_values.pop( + 'aws_access_key_id') + if 'aws_secret_access_key' in new_values: + credential_file_values['aws_secret_access_key'] = new_values.pop( + 'aws_secret_access_key') + if credential_file_values: + if profile_name is not None: + credential_file_values['__section__'] = profile_name + shared_credentials_filename = self._session.get_config_variable( + 'credentials_file') + self._config_writer.update_config( + credential_file_values, + shared_credentials_filename) diff --git a/awscli/customizations/datapipeline/__init__.py b/awscli/customizations/datapipeline/__init__.py index 5a1d9a7615de..ff561e339c49 100644 --- a/awscli/customizations/datapipeline/__init__.py +++ b/awscli/customizations/datapipeline/__init__.py @@ -172,7 +172,7 @@ def add_to_params(self, parameters, value): return parsed = json.loads(value) api_objects = translator.definition_to_api(parsed) - parameters['pipeline_objects'] = api_objects + parameters['pipelineObjects'] = api_objects class ListRunsCommand(BasicCommand): diff --git a/awscli/customizations/ec2addcount.py b/awscli/customizations/ec2addcount.py index 3ef8997104f4..e16a881e923d 100644 --- a/awscli/customizations/ec2addcount.py +++ b/awscli/customizations/ec2addcount.py @@ -12,8 +12,10 @@ # language governing permissions and limitations under the License. import logging +from botocore import model + from awscli.arguments import BaseCLIArgument -from botocore.parameters import StringParameter + logger = logging.getLogger(__name__) @@ -35,8 +37,7 @@ def ec2_add_count(argument_table, operation, **kwargs): class CountArgument(BaseCLIArgument): def __init__(self, operation, name): - param = StringParameter(operation, name='count', type='string') - self.argument_object = param + self.argument_model = model.Shape('CountArgument', {'type': 'string'}) self._operation = operation self._name = name diff --git a/awscli/customizations/ec2bundleinstance.py b/awscli/customizations/ec2bundleinstance.py index b19a0e39d6c3..d54253ee1291 100644 --- a/awscli/customizations/ec2bundleinstance.py +++ b/awscli/customizations/ec2bundleinstance.py @@ -65,9 +65,8 @@ def _add_params(argument_table, operation, **kwargs): # Add the scalar parameters and also change the complex storage # param to not be required so the user doesn't get an error from # argparse if they only supply scalar params. - storage_arg = argument_table.get('storage') - storage_param = storage_arg.argument_object - storage_param.required = False + storage_arg = argument_table['storage'] + storage_arg.required = False arg = BundleArgument(storage_param='Bucket', name='bucket', help_text=BUCKET_DOCS) diff --git a/awscli/customizations/ec2decryptpassword.py b/awscli/customizations/ec2decryptpassword.py index 840256eda705..64f97e146ff3 100644 --- a/awscli/customizations/ec2decryptpassword.py +++ b/awscli/customizations/ec2decryptpassword.py @@ -16,11 +16,14 @@ import rsa import six +from botocore import model + from awscli.arguments import BaseCLIArgument -from botocore.parameters import StringParameter + logger = logging.getLogger(__name__) + HELP = """

The file that contains the private key used to launch the instance (e.g. windows-keypair.pem). If this is supplied, the password data sent from EC2 will be decrypted before display.

""" @@ -38,11 +41,7 @@ def ec2_add_priv_launch_key(argument_table, operation, **kwargs): class LaunchKeyArgument(BaseCLIArgument): def __init__(self, operation, name): - param = StringParameter(operation, - name=name, - type='string') - self._name = name - self.argument_object = param + self.argument_model = model.Shape('LaunchKeyArgument', {'type': 'string'}) self._operation = operation self._name = name self._key_path = None diff --git a/awscli/customizations/ec2protocolarg.py b/awscli/customizations/ec2protocolarg.py index d598ad87a8a4..234e147b8e29 100644 --- a/awscli/customizations/ec2protocolarg.py +++ b/awscli/customizations/ec2protocolarg.py @@ -17,15 +17,16 @@ """ def _fix_args(operation, endpoint, params, **kwargs): - if 'protocol' in params: - if params['protocol'] == 'tcp': - params['protocol'] = '6' - elif params['protocol'] == 'udp': - params['protocol'] = '17' - elif params['protocol'] == 'icmp': - params['protocol'] = '1' - elif params['protocol'] == 'all': - params['protocol'] = '-1' + key_name = 'Protocol' + if key_name in params: + if params[key_name] == 'tcp': + params[key_name] = '6' + elif params[key_name] == 'udp': + params[key_name] = '17' + elif params[key_name] == 'icmp': + params[key_name] = '1' + elif params[key_name] == 'all': + params[key_name] = '-1' def register_protocol_args(cli): @@ -33,4 +34,3 @@ def register_protocol_args(cli): _fix_args) cli.register('before-parameter-build.ec2.ReplaceNetworkAclEntry', _fix_args) - diff --git a/awscli/customizations/ec2runinstances.py b/awscli/customizations/ec2runinstances.py index 1a0e5da23016..b27210a08797 100644 --- a/awscli/customizations/ec2runinstances.py +++ b/awscli/customizations/ec2runinstances.py @@ -91,20 +91,20 @@ def _fix_args(operation, endpoint, params, **kwargs): # allows them to specify the security group by name or by id. # However, in this scenario we can only support id because # we can't place a group name in the NetworkInterfaces structure. - if 'network_interfaces' in params: - ni = params['network_interfaces'] + if 'NetworkInterfaces' in params: + ni = params['NetworkInterfaces'] if 'AssociatePublicIpAddress' in ni[0]: - if 'subnet_id' in params: - ni[0]['SubnetId'] = params['subnet_id'] - del params['subnet_id'] - if 'security_group_ids' in params: - ni[0]['Groups'] = params['security_group_ids'] - del params['security_group_ids'] - if 'private_ip_address' in params: - ip_addr = {'PrivateIpAddress': params['private_ip_address'], + if 'SubnetId' in params: + ni[0]['SubnetId'] = params['SubnetId'] + del params['SubnetId'] + if 'SecurityGroupIds' in params: + ni[0]['Groups'] = params['SecurityGroupIds'] + del params['SecurityGroupIds'] + if 'PrivateIpAddress' in params: + ip_addr = {'PrivateIpAddress': params['PrivateIpAddress'], 'Primary': True} ni[0]['PrivateIpAddresses'] = [ip_addr] - del params['private_ip_address'] + del params['PrivateIpAddress'] EVENTS = [ @@ -122,14 +122,14 @@ def register_runinstances(event_handler): def _build_network_interfaces(params, key, value): # Build up the NetworkInterfaces data structure - if 'network_interfaces' not in params: - params['network_interfaces'] = [{'DeviceIndex': 0}] + if 'NetworkInterfaces' not in params: + params['NetworkInterfaces'] = [{'DeviceIndex': 0}] if key == 'PrivateIpAddresses': - if 'PrivateIpAddresses' not in params['network_interfaces'][0]: - params['network_interfaces'][0]['PrivateIpAddresses'] = value + if 'PrivateIpAddresses' not in params['NetworkInterfaces'][0]: + params['NetworkInterfaces'][0]['PrivateIpAddresses'] = value else: - params['network_interfaces'][0][key] = value + params['NetworkInterfaces'][0][key] = value class SecondaryPrivateIpAddressesArgument(CustomArgument): diff --git a/awscli/customizations/emr/argumentschema.py b/awscli/customizations/emr/argumentschema.py index 8aba3f23e443..476c2eff0096 100644 --- a/awscli/customizations/emr/argumentschema.py +++ b/awscli/customizations/emr/argumentschema.py @@ -246,3 +246,10 @@ } } } + +TAGS_SCHEMA = { + "type": "array", + "items": { + "type": "string" + } +} diff --git a/awscli/customizations/emr/createcluster.py b/awscli/customizations/emr/createcluster.py index 94dd7bb5b431..5504f61e7155 100644 --- a/awscli/customizations/emr/createcluster.py +++ b/awscli/customizations/emr/createcluster.py @@ -79,7 +79,8 @@ class CreateCluster(BasicCommand): {'name': 'no-enable-debugging', 'action': 'store_true', 'group_name': 'debug'}, {'name': 'tags', 'nargs': '+', - 'help_text': helptext.TAGS}, + 'help_text': helptext.TAGS, + 'schema': argumentschema.TAGS_SCHEMA}, {'name': 'bootstrap-actions', 'help_text': helptext.BOOTSTRAP_ACTIONS, 'schema': argumentschema.BOOTSTRAP_ACTIONS_SCHEMA}, diff --git a/awscli/customizations/emr/terminateclusters.py b/awscli/customizations/emr/terminateclusters.py index 44dda90d18cc..54c51927767b 100644 --- a/awscli/customizations/emr/terminateclusters.py +++ b/awscli/customizations/emr/terminateclusters.py @@ -20,14 +20,15 @@ class TerminateClusters(BasicCommand): NAME = 'terminate-clusters' DESCRIPTION = helptext.TERMINATE_CLUSTERS - ARG_TABLE = [ - {'name': 'cluster-ids', 'nargs': '+', 'required': True, - 'help_text': '

A list of clusters to terminate.

'} - ] + ARG_TABLE = [{ + 'name': 'cluster-ids', 'nargs': '+', 'required': True, + 'help_text': '

A list of clusters to terminate.

', + 'schema': {'type': 'array', 'items': {'type': 'string'}}, + }] def _run_main(self, parsed_args, parsed_globals): parameters = {'JobFlowIds': parsed_args.cluster_ids} emrutils.call_and_display_response(self._session, 'TerminateJobFlows', parameters, parsed_globals) - return 0 \ No newline at end of file + return 0 diff --git a/awscli/customizations/flatten.py b/awscli/customizations/flatten.py index 1ebb146048bb..0fff47a28e12 100644 --- a/awscli/customizations/flatten.py +++ b/awscli/customizations/flatten.py @@ -31,13 +31,13 @@ class FlattenedArgument(CustomArgument): """ def __init__(self, name, container, prop, help_text='', required=None, type=None, hydrate=None, hydrate_value=None): - super(FlattenedArgument, self).__init__(name=name, help_text=help_text, - required=required) self.type = type self._container = container self._property = prop self._hydrate = hydrate self._hydrate_value = hydrate_value + super(FlattenedArgument, self).__init__(name=name, help_text=help_text, + required=required) @property def cli_type_name(self): @@ -185,7 +185,7 @@ def flatten_args(self, operation, argument_table, **kwargs): # Handle nested arguments _arg = self._find_nested_arg( - argument_from_table.argument_object, sub_argument + argument_from_table.argument_model, sub_argument ) # Pull out docs and required attribute @@ -213,8 +213,8 @@ def _find_nested_arg(self, argument, name): # Find the actual nested argument to pull out LOG.debug('Finding nested argument in {0}'.format(name)) for piece in name.split(SEP)[:-1]: - for member in argument.members: - if member.name == piece: + for member_name, member in argument.members.items(): + if member_name == piece: argument = member break else: @@ -230,15 +230,15 @@ def _merge_member_config(self, argument, name, config): overridden in the configuration dict. Modifies the config in-place. """ # Pull out docs and required attribute - for member in argument.members: - if member.name == name.split(SEP)[-1]: + for member_name, member in argument.members.items(): + if member_name == name.split(SEP)[-1]: if 'help_text' not in config: config['help_text'] = member.documentation if 'required' not in config: - config['required'] = member.required + config['required'] = member_name in argument.required_members if 'type' not in config: - config['type'] = member.type + config['type'] = member.type_name break diff --git a/awscli/customizations/iamvirtmfa.py b/awscli/customizations/iamvirtmfa.py index 0b17cb4bfff7..271b84c0ab1b 100644 --- a/awscli/customizations/iamvirtmfa.py +++ b/awscli/customizations/iamvirtmfa.py @@ -27,6 +27,7 @@ from awscli.arguments import CustomArgument + CHOICES = ('QRCodePNG', 'Base32StringSeed') OUTPUT_HELP = ('The output path and file name where the bootstrap ' 'information will be stored.') @@ -58,7 +59,6 @@ def add_to_params(self, parameters, value): if not os.access(os.path.dirname(outfile), os.W_OK): raise ValueError('Unable to write to file: %s' % outfile) self._value = outfile - class IAMVMFAWrapper(object): diff --git a/awscli/customizations/paginate.py b/awscli/customizations/paginate.py index 66f2aa34c0a7..cfc6891092bb 100644 --- a/awscli/customizations/paginate.py +++ b/awscli/customizations/paginate.py @@ -26,8 +26,11 @@ import logging from functools import partial +from botocore import xform_name +from botocore import model + from awscli.arguments import BaseCLIArgument -from botocore.parameters import StringParameter + logger = logging.getLogger(__name__) @@ -66,22 +69,18 @@ def unify_paging_params(argument_table, operation, event_name, **kwargs): STARTING_TOKEN_HELP, operation, parse_type='string') - # Try to get the pagination parameter type - limit_param = None + input_members = operation.model.input_shape.members + type_name = 'integer' if 'limit_key' in operation.pagination: - for param in operation.params: - if param.name == operation.pagination['limit_key']: - limit_param = param - break - - type_ = limit_param and limit_param.type or 'integer' - if limit_param and limit_param.type not in PageArgument.type_map: - raise TypeError(('Unsupported pagination type {0} for operation {1}' - ' and parameter {2}').format(type_, operation.name, - limit_param.name)) + limit_key_shape = input_members[operation.pagination['limit_key']] + type_name = limit_key_shape.type_name + if type_name not in PageArgument.type_map: + raise TypeError(('Unsupported pagination type {0} for operation {1}' + ' and parameter {2}').format(type_name, operation.name, + operation.pagination['limit_key'])) argument_table['max-items'] = PageArgument('max-items', MAX_ITEMS_HELP, - operation, parse_type=type_) + operation, parse_type=type_name) def check_should_enable_pagination(input_tokens, parsed_args, parsed_globals, @@ -108,11 +107,11 @@ def _get_all_cli_input_tokens(operation): # if it exists. tokens = _get_input_tokens(operation) for token_name in tokens: - cli_name = _get_cli_name(operation.params, token_name) + cli_name = xform_name(token_name, '-') yield cli_name if 'limit_key' in operation.pagination: key_name = operation.pagination['limit_key'] - cli_name = _get_cli_name(operation.params, key_name) + cli_name = xform_name(key_name, '-') yield cli_name @@ -137,9 +136,7 @@ class PageArgument(BaseCLIArgument): } def __init__(self, name, documentation, operation, parse_type): - param = StringParameter(operation, name=name, type=parse_type) - self._name = name - self.argument_object = param + self.argument_model = model.Shape('PageArgument', {'type': 'string'}) self._name = name self._documentation = documentation self._parse_type = parse_type diff --git a/awscli/customizations/putmetricdata.py b/awscli/customizations/putmetricdata.py index 68e34a37920c..d221b052e3ad 100644 --- a/awscli/customizations/putmetricdata.py +++ b/awscli/customizations/putmetricdata.py @@ -22,6 +22,7 @@ * --unit """ +import decimal from awscli.arguments import CustomArgument from awscli.utils import split_on_commas @@ -111,7 +112,8 @@ def _add_param_timestamp(self, first_element, value): @insert_first_element('metric_data') def _add_param_value(self, first_element, value): - first_element['Value'] = value + # Use a Decimal to avoid loss in precision. + first_element['Value'] = decimal.Decimal(value) @insert_first_element('metric_data') def _add_param_dimensions(self, first_element, value): diff --git a/awscli/customizations/route53resourceid.py b/awscli/customizations/route53resourceid.py index 08d401c282f0..45b33bfd6bad 100644 --- a/awscli/customizations/route53resourceid.py +++ b/awscli/customizations/route53resourceid.py @@ -22,9 +22,8 @@ def register_resource_id(cli): def _check_for_resource_id(param, value, **kwargs): - if hasattr(param, 'shape_name'): - if param.shape_name == 'ResourceId': - orig_value = value - value = value.split('/')[-1] - logger.debug('ResourceId %s -> %s', orig_value, value) - return value + if param.name == 'ResourceId': + orig_value = value + value = value.split('/')[-1] + logger.debug('ResourceId %s -> %s', orig_value, value) + return value diff --git a/awscli/customizations/s3/executor.py b/awscli/customizations/s3/executor.py index 872f181ef055..d2f2c9b89152 100644 --- a/awscli/customizations/s3/executor.py +++ b/awscli/customizations/s3/executor.py @@ -15,8 +15,8 @@ import sys import threading -from awscli.customizations.s3.utils import uni_print, \ - IORequest, IOCloseRequest, StablePriorityQueue +from awscli.customizations.s3.utils import uni_print, bytes_print, \ + IORequest, IOCloseRequest, StablePriorityQueue from awscli.customizations.s3.tasks import OrderableTask @@ -40,18 +40,19 @@ class Executor(object): STANDARD_PRIORITY = 11 IMMEDIATE_PRIORITY= 1 - def __init__(self, num_threads, result_queue, - quiet, max_queue_size, write_queue): + def __init__(self, num_threads, result_queue, quiet, + only_show_errors, max_queue_size, write_queue): self._max_queue_size = max_queue_size self.queue = StablePriorityQueue(maxsize=self._max_queue_size, max_priority=20) self.num_threads = num_threads self.result_queue = result_queue self.quiet = quiet + self.only_show_errors = only_show_errors self.threads_list = [] self.write_queue = write_queue - self.print_thread = PrintThread(self.result_queue, - self.quiet) + self.print_thread = PrintThread(self.result_queue, self.quiet, + self.only_show_errors) self.print_thread.daemon = True self.io_thread = IOWriterThread(self.write_queue) @@ -153,15 +154,19 @@ def run(self): self._cleanup() return elif isinstance(task, IORequest): - filename, offset, data = task - fileobj = self.fd_descriptor_cache.get(filename) - if fileobj is None: - fileobj = open(filename, 'rb+') - self.fd_descriptor_cache[filename] = fileobj - fileobj.seek(offset) + filename, offset, data, is_stream = task + if is_stream: + fileobj = sys.stdout + bytes_print(data) + else: + fileobj = self.fd_descriptor_cache.get(filename) + if fileobj is None: + fileobj = open(filename, 'rb+') + self.fd_descriptor_cache[filename] = fileobj + fileobj.seek(offset) + fileobj.write(data) LOGGER.debug("Writing data to: %s, offset: %s", filename, offset) - fileobj.write(data) fileobj.flush() elif isinstance(task, IOCloseRequest): LOGGER.debug("IOCloseRequest received for %s, closing file.", @@ -226,18 +231,19 @@ class PrintThread(threading.Thread): warning. """ - def __init__(self, result_queue, quiet): + def __init__(self, result_queue, quiet, only_show_errors): threading.Thread.__init__(self) self._progress_dict = {} self._result_queue = result_queue self._quiet = quiet + self._only_show_errors = only_show_errors self._progress_length = 0 self._num_parts = 0 self._file_count = 0 self._lock = threading.Lock() self._needs_newline = False - self._total_parts = 0 + self._total_parts = '...' self._total_files = '...' # This is a public attribute that clients can inspect to determine @@ -274,15 +280,15 @@ def run(self): def _process_print_task(self, print_task): print_str = print_task.message + print_to_stderr = False if print_task.error: self.num_errors_seen += 1 - warning = False - if print_task.warning: - if print_task.warning: - warning = True - self.num_warnings_seen += 1 + print_to_stderr = True + final_str = '' - if warning: + if print_task.warning: + self.num_warnings_seen += 1 + print_to_stderr = True final_str += print_str.ljust(self._progress_length, ' ') final_str += '\n' elif print_task.total_parts: @@ -309,21 +315,30 @@ def _process_print_task(self, print_task): self._num_parts += 1 self._file_count += 1 + # If the message is an error or warning, print it to standard error. + if print_to_stderr and not self._quiet: + uni_print(final_str, sys.stderr) + final_str = '' + is_done = self._total_files == self._file_count if not is_done: - prog_str = "Completed %s " % self._num_parts - num_files = self._total_files - if self._total_files != '...': - prog_str += "of %s " % self._total_parts - num_files = self._total_files - self._file_count - prog_str += "part(s) with %s file(s) remaining" % \ - num_files - length_prog = len(prog_str) - prog_str += '\r' - prog_str = prog_str.ljust(self._progress_length, ' ') - self._progress_length = length_prog - final_str += prog_str - if not self._quiet: + final_str += self._make_progress_bar() + if not (self._quiet or self._only_show_errors): uni_print(final_str) self._needs_newline = not final_str.endswith('\n') - sys.stdout.flush() + + def _make_progress_bar(self): + """Creates the progress bar string to print out.""" + + prog_str = "Completed %s " % self._num_parts + num_files = self._total_files + if self._total_files != '...': + prog_str += "of %s " % self._total_parts + num_files = self._total_files - self._file_count + prog_str += "part(s) with %s file(s) remaining" % \ + num_files + length_prog = len(prog_str) + prog_str += '\r' + prog_str = prog_str.ljust(self._progress_length, ' ') + self._progress_length = length_prog + return prog_str diff --git a/awscli/customizations/s3/filegenerator.py b/awscli/customizations/s3/filegenerator.py index b53be0c45939..a6d53cecb259 100644 --- a/awscli/customizations/s3/filegenerator.py +++ b/awscli/customizations/s3/filegenerator.py @@ -20,7 +20,8 @@ from dateutil.tz import tzlocal from awscli.customizations.s3.utils import find_bucket_key, get_file_stat -from awscli.customizations.s3.utils import BucketLister, create_warning +from awscli.customizations.s3.utils import BucketLister, create_warning, \ + find_dest_path_comp_key from awscli.errorhandler import ClientError @@ -131,26 +132,13 @@ def call(self, files): ``dir_op`` and ``use_src_name`` flags affect which files are used and ensure the proper destination paths and compare keys are formed. """ - src = files['src'] - dest = files['dest'] - src_type = src['type'] - dest_type = dest['type'] function_table = {'s3': self.list_objects, 'local': self.list_files} - sep_table = {'s3': '/', 'local': os.sep} - source = src['path'] + source = files['src']['path'] + src_type = files['src']['type'] + dest_type = files['dest']['type'] file_list = function_table[src_type](source, files['dir_op']) for src_path, size, last_update in file_list: - if files['dir_op']: - rel_path = src_path[len(src['path']):] - else: - rel_path = src_path.split(sep_table[src_type])[-1] - compare_key = rel_path.replace(sep_table[src_type], '/') - if files['use_src_name']: - dest_path = dest['path'] - dest_path += rel_path.replace(sep_table[src_type], - sep_table[dest_type]) - else: - dest_path = dest['path'] + dest_path, compare_key = find_dest_path_comp_key(files, src_path) yield FileStat(src=src_path, dest=dest_path, compare_key=compare_key, size=size, last_update=last_update, src_type=src_type, @@ -317,7 +305,7 @@ def _list_single_object(self, s3_path): # This is what the customer is going to see so we want to # give as much detail as we have. copy_fields = e.__dict__.copy() - if not e.error_message == 'Unknown': + if not e.error_message == 'Not Found': raise if e.http_status_code == 404: # The key does not exist so we'll raise a more specific diff --git a/awscli/customizations/s3/fileinfo.py b/awscli/customizations/s3/fileinfo.py index fe482e64d13b..b30c67bcdc57 100644 --- a/awscli/customizations/s3/fileinfo.py +++ b/awscli/customizations/s3/fileinfo.py @@ -11,7 +11,7 @@ from botocore.compat import quote from awscli.customizations.s3.utils import find_bucket_key, \ check_etag, check_error, operate, uni_print, \ - guess_content_type, MD5Error + guess_content_type, MD5Error, bytes_print class CreateDirectoryError(Exception): @@ -26,7 +26,7 @@ def read_file(filename): return in_file.read() -def save_file(filename, response_data, last_update): +def save_file(filename, response_data, last_update, is_stream=False): """ This writes to the file upon downloading. It reads the data in the response. Makes a new directory if needed and then writes the @@ -35,31 +35,57 @@ def save_file(filename, response_data, last_update): """ body = response_data['Body'] etag = response_data['ETag'][1:-1] - d = os.path.dirname(filename) - try: - if not os.path.exists(d): - os.makedirs(d) - except OSError as e: - if not e.errno == errno.EEXIST: - raise CreateDirectoryError( - "Could not create directory %s: %s" % (d, e)) + if not is_stream: + d = os.path.dirname(filename) + try: + if not os.path.exists(d): + os.makedirs(d) + except OSError as e: + if not e.errno == errno.EEXIST: + raise CreateDirectoryError( + "Could not create directory %s: %s" % (d, e)) md5 = hashlib.md5() file_chunks = iter(partial(body.read, 1024 * 1024), b'') - with open(filename, 'wb') as out_file: - if not _is_multipart_etag(etag): - for chunk in file_chunks: - md5.update(chunk) - out_file.write(chunk) - else: - for chunk in file_chunks: - out_file.write(chunk) + if is_stream: + # Need to save the data to be able to check the etag for a stream + # becuase once the data is written to the stream there is no + # undoing it. + payload = write_to_file(None, etag, md5, file_chunks, True) + else: + with open(filename, 'wb') as out_file: + write_to_file(out_file, etag, md5, file_chunks) + if not _is_multipart_etag(etag): if etag != md5.hexdigest(): - os.remove(filename) + if not is_stream: + os.remove(filename) raise MD5Error(filename) - last_update_tuple = last_update.timetuple() - mod_timestamp = time.mktime(last_update_tuple) - os.utime(filename, (int(mod_timestamp), int(mod_timestamp))) + + if not is_stream: + last_update_tuple = last_update.timetuple() + mod_timestamp = time.mktime(last_update_tuple) + os.utime(filename, (int(mod_timestamp), int(mod_timestamp))) + else: + # Now write the output to stdout since the md5 is correct. + bytes_print(payload) + sys.stdout.flush() + + +def write_to_file(out_file, etag, md5, file_chunks, is_stream=False): + """ + Updates the etag for each file chunk. It will write to the file if it a + file but if it is a stream it will return a byte string to be later + written to a stream. + """ + body = b'' + for chunk in file_chunks: + if not _is_multipart_etag(etag): + md5.update(chunk) + if is_stream: + body += chunk + else: + out_file.write(chunk) + return body def _is_multipart_etag(etag): @@ -140,7 +166,7 @@ class FileInfo(TaskInfo): def __init__(self, src, dest=None, compare_key=None, size=None, last_update=None, src_type=None, dest_type=None, operation_name=None, service=None, endpoint=None, - parameters=None, source_endpoint=None): + parameters=None, source_endpoint=None, is_stream=False): super(FileInfo, self).__init__(src, src_type=src_type, operation_name=operation_name, service=service, @@ -157,6 +183,18 @@ def __init__(self, src, dest=None, compare_key=None, size=None, self.parameters = {'acl': None, 'sse': None} self.source_endpoint = source_endpoint + self.is_stream = is_stream + + def set_size_from_s3(self): + """ + This runs a ``HeadObject`` on the s3 object and sets the size. + """ + bucket, key = find_bucket_key(self.src) + params = {'endpoint': self.endpoint, + 'bucket': bucket, + 'key': key} + response_data, http = operate(self.service, 'HeadObject', params) + self.size = int(response_data['ContentLength']) def _permission_to_param(self, permission): if permission == 'read': @@ -204,24 +242,30 @@ def _handle_object_params(self, params): if self.parameters['expires']: params['expires'] = self.parameters['expires'][0] - def upload(self): + def upload(self, payload=None): """ Redirects the file to the multipart upload function if the file is large. If it is small enough, it puts the file as an object in s3. """ - with open(self.src, 'rb') as body: - bucket, key = find_bucket_key(self.dest) - params = { - 'endpoint': self.endpoint, - 'bucket': bucket, - 'key': key, - 'body': body, - } - self._handle_object_params(params) - response_data, http = operate(self.service, 'PutObject', params) - etag = response_data['ETag'][1:-1] - body.seek(0) - check_etag(etag, body) + if payload: + self._handle_upload(payload) + else: + with open(self.src, 'rb') as body: + self._handle_upload(body) + + def _handle_upload(self, body): + bucket, key = find_bucket_key(self.dest) + params = { + 'endpoint': self.endpoint, + 'bucket': bucket, + 'key': key, + 'body': body, + } + self._handle_object_params(params) + response_data, http = operate(self.service, 'PutObject', params) + etag = response_data['ETag'][1:-1] + body.seek(0) + check_etag(etag, body) def _inject_content_type(self, params, filename): # Add a content type param if we can guess the type. @@ -237,7 +281,8 @@ def download(self): bucket, key = find_bucket_key(self.src) params = {'endpoint': self.endpoint, 'bucket': bucket, 'key': key} response_data, http = operate(self.service, 'GetObject', params) - save_file(self.dest, response_data, self.last_update) + save_file(self.dest, response_data, self.last_update, + self.is_stream) def copy(self): """ diff --git a/awscli/customizations/s3/fileinfobuilder.py b/awscli/customizations/s3/fileinfobuilder.py index 8bc2042615ef..9f1c429f0fc0 100644 --- a/awscli/customizations/s3/fileinfobuilder.py +++ b/awscli/customizations/s3/fileinfobuilder.py @@ -19,13 +19,14 @@ class FileInfoBuilder(object): a ``FileInfo`` object so that the operation can be performed. """ def __init__(self, service, endpoint, source_endpoint=None, - parameters = None): + parameters = None, is_stream=False): self._service = service self._endpoint = endpoint self._source_endpoint = endpoint if source_endpoint: self._source_endpoint = source_endpoint - self._parameters = parameters + self._parameters = parameters + self._is_stream = is_stream def call(self, files): for file_base in files: @@ -46,4 +47,5 @@ def _inject_info(self, file_base): file_info_attr['endpoint'] = self._endpoint file_info_attr['source_endpoint'] = self._source_endpoint file_info_attr['parameters'] = self._parameters + file_info_attr['is_stream'] = self._is_stream return FileInfo(**file_info_attr) diff --git a/awscli/customizations/s3/s3handler.py b/awscli/customizations/s3/s3handler.py index 91f701bbd83d..bd716970e104 100644 --- a/awscli/customizations/s3/s3handler.py +++ b/awscli/customizations/s3/s3handler.py @@ -14,7 +14,9 @@ import logging import math import os +import six from six.moves import queue +import sys from awscli.customizations.s3.constants import MULTI_THRESHOLD, CHUNKSIZE, \ NUM_THREADS, MAX_UPLOAD_SIZE, MAX_QUEUE_SIZE @@ -36,6 +38,8 @@ class S3Handler(object): class pull tasks from to complete. """ MAX_IO_QUEUE_SIZE = 20 + MAX_EXECUTOR_QUEUE_SIZE = MAX_QUEUE_SIZE + EXECUTOR_NUM_THREADS = NUM_THREADS def __init__(self, session, params, result_queue=None, multi_threshold=MULTI_THRESHOLD, chunksize=CHUNKSIZE): @@ -53,7 +57,9 @@ def __init__(self, session, params, result_queue=None, 'content_type': None, 'cache_control': None, 'content_disposition': None, 'content_encoding': None, 'content_language': None, 'expires': None, - 'grants': None} + 'grants': None, 'only_show_errors': False, + 'is_stream': False, 'paths_type': None, + 'expected_size': None} self.params['region'] = params['region'] for key in self.params.keys(): if key in params: @@ -61,8 +67,11 @@ def __init__(self, session, params, result_queue=None, self.multi_threshold = multi_threshold self.chunksize = chunksize self.executor = Executor( - num_threads=NUM_THREADS, result_queue=self.result_queue, - quiet=self.params['quiet'], max_queue_size=MAX_QUEUE_SIZE, + num_threads=self.EXECUTOR_NUM_THREADS, + result_queue=self.result_queue, + quiet=self.params['quiet'], + only_show_errors=self.params['only_show_errors'], + max_queue_size=self.MAX_EXECUTOR_QUEUE_SIZE, write_queue=self.write_queue ) self._multipart_uploads = [] @@ -234,12 +243,11 @@ def _enqueue_range_download_tasks(self, filename, remove_remote_file=False): create_file_task = tasks.CreateLocalFileTask(context=context, filename=filename) self.executor.submit(create_file_task) - for i in range(num_downloads): - task = tasks.DownloadPartTask( - part_number=i, chunk_size=chunksize, - result_queue=self.result_queue, service=filename.service, - filename=filename, context=context, io_queue=self.write_queue) - self.executor.submit(task) + self._do_enqueue_range_download_tasks( + filename=filename, chunksize=chunksize, + num_downloads=num_downloads, context=context, + remove_remote_file=remove_remote_file + ) complete_file_task = tasks.CompleteDownloadTask( context=context, filename=filename, result_queue=self.result_queue, params=self.params, io_queue=self.write_queue) @@ -251,6 +259,16 @@ def _enqueue_range_download_tasks(self, filename, remove_remote_file=False): self.executor.submit(remove_task) return num_downloads + def _do_enqueue_range_download_tasks(self, filename, chunksize, + num_downloads, context, + remove_remote_file=False): + for i in range(num_downloads): + task = tasks.DownloadPartTask( + part_number=i, chunk_size=chunksize, + result_queue=self.result_queue, service=filename.service, + filename=filename, context=context, io_queue=self.write_queue) + self.executor.submit(task) + def _enqueue_multipart_upload_tasks(self, filename, remove_local_file=False): # First we need to create a CreateMultipartUpload task, @@ -295,14 +313,27 @@ def _enqueue_upload_start_task(self, chunksize, num_uploads, filename): self.executor.submit(create_multipart_upload_task) return upload_context - def _enqueue_upload_tasks(self, num_uploads, chunksize, upload_context, filename, - task_class): + def _enqueue_upload_tasks(self, num_uploads, chunksize, upload_context, + filename, task_class): for i in range(1, (num_uploads + 1)): - task = task_class( - part_number=i, chunk_size=chunksize, - result_queue=self.result_queue, upload_context=upload_context, - filename=filename) - self.executor.submit(task) + self._enqueue_upload_single_part_task( + part_number=i, + chunk_size=chunksize, + upload_context=upload_context, + filename=filename, + task_class=task_class + ) + + def _enqueue_upload_single_part_task(self, part_number, chunk_size, + upload_context, filename, task_class, + payload=None): + kwargs = {'part_number': part_number, 'chunk_size': chunk_size, + 'result_queue': self.result_queue, + 'upload_context': upload_context, 'filename': filename} + if payload: + kwargs['payload'] = payload + task = task_class(**kwargs) + self.executor.submit(task) def _enqueue_upload_end_task(self, filename, upload_context): complete_multipart_upload_task = tasks.CompleteMultipartUploadTask( @@ -311,3 +342,157 @@ def _enqueue_upload_end_task(self, filename, upload_context): self.executor.submit(complete_multipart_upload_task) self._multipart_uploads.append((upload_context, filename)) + +class S3StreamHandler(S3Handler): + """ + This class is an alternative ``S3Handler`` to be used when the operation + involves a stream since the logic is different when uploading and + downloading streams. + """ + + # This ensures that the number of multipart chunks waiting in the + # executor queue and in the threads is limited. + MAX_EXECUTOR_QUEUE_SIZE = 2 + EXECUTOR_NUM_THREADS = 6 + + def _enqueue_tasks(self, files): + total_files = 0 + total_parts = 0 + for filename in files: + num_uploads = 1 + # If uploading stream, it is required to read from the stream + # to determine if the stream needs to be multipart uploaded. + payload = None + if filename.operation_name == 'upload': + payload, is_multipart_task = \ + self._pull_from_stream(self.multi_threshold) + else: + # Set the file size for the ``FileInfo`` object since + # streams do not use a ``FileGenerator`` that usually + # determines the size. + filename.set_size_from_s3() + is_multipart_task = self._is_multipart_task(filename) + if is_multipart_task and not self.params['dryrun']: + # If we're in dryrun mode, then we don't need the + # real multipart tasks. We can just use a BasicTask + # in the else clause below, which will print out the + # fact that it's transferring a file rather than + # the specific part tasks required to perform the + # transfer. + num_uploads = self._enqueue_multipart_tasks(filename, payload) + else: + task = tasks.BasicTask( + session=self.session, filename=filename, + parameters=self.params, + result_queue=self.result_queue, + payload=payload) + self.executor.submit(task) + total_files += 1 + total_parts += num_uploads + return total_files, total_parts + + def _pull_from_stream(self, amount_requested): + """ + This function pulls data from stdin until it hits the amount + requested or there is no more left to pull in from stdin. The + function wraps the data into a ``BytesIO`` object that is returned + along with a boolean telling whether the amount requested is + the amount returned. + """ + stream_filein = sys.stdin + if six.PY3: + stream_filein = sys.stdin.buffer + payload = stream_filein.read(amount_requested) + payload_file = six.BytesIO(payload) + return payload_file, len(payload) == amount_requested + + def _enqueue_multipart_tasks(self, filename, payload=None): + num_uploads = 1 + if filename.operation_name == 'upload': + num_uploads = self._enqueue_multipart_upload_tasks(filename, + payload=payload) + elif filename.operation_name == 'download': + num_uploads = self._enqueue_range_download_tasks(filename) + return num_uploads + + def _enqueue_range_download_tasks(self, filename, remove_remote_file=False): + + # Create the context for the multipart download. + chunksize = find_chunksize(filename.size, self.chunksize) + num_downloads = int(filename.size / chunksize) + context = tasks.MultipartDownloadContext(num_downloads) + + # No file is needed for downloading a stream. So just announce + # that it has been made since it is required for the context to + # begin downloading. + context.announce_file_created() + + # Submit download part tasks to the executor. + self._do_enqueue_range_download_tasks( + filename=filename, chunksize=chunksize, + num_downloads=num_downloads, context=context, + remove_remote_file=remove_remote_file + ) + return num_downloads + + def _enqueue_multipart_upload_tasks(self, filename, payload=None): + # First we need to create a CreateMultipartUpload task, + # then create UploadTask objects for each of the parts. + # And finally enqueue a CompleteMultipartUploadTask. + + chunksize = self.chunksize + # Determine an appropriate chunksize if given an expected size. + if self.params['expected_size']: + chunksize = find_chunksize(int(self.params['expected_size']), + self.chunksize) + num_uploads = '...' + + # Submit a task to begin the multipart upload. + upload_context = self._enqueue_upload_start_task( + chunksize, num_uploads, filename) + + # Now submit a task to upload the initial chunk of data pulled + # from the stream that was used to determine if a multipart upload + # was needed. + self._enqueue_upload_single_part_task( + part_number=1, chunk_size=chunksize, + upload_context=upload_context, filename=filename, + task_class=tasks.UploadPartTask, payload=payload + ) + + # Submit tasks to upload the rest of the chunks of the data coming in + # from standard input. + num_uploads = self._enqueue_upload_tasks( + num_uploads, chunksize, upload_context, + filename, tasks.UploadPartTask + ) + + # Submit a task to notify the multipart upload is complete. + self._enqueue_upload_end_task(filename, upload_context) + + return num_uploads + + def _enqueue_upload_tasks(self, num_uploads, chunksize, upload_context, + filename, task_class): + # The previous upload occured right after the multipart + # upload started for a stream. + num_uploads = 1 + while True: + # Pull more data from standard input. + payload, is_remaining = self._pull_from_stream(chunksize) + # Submit an upload part task for the recently pulled data. + self._enqueue_upload_single_part_task( + part_number=num_uploads+1, + chunk_size=chunksize, + upload_context=upload_context, + filename=filename, + task_class=task_class, + payload=payload + ) + num_uploads += 1 + if not is_remaining: + break + # Once there is no more data left, announce to the context how + # many parts are being uploaded so it knows when it can quit. + upload_context.announce_total_parts(num_uploads) + return num_uploads diff --git a/awscli/customizations/s3/subcommands.py b/awscli/customizations/s3/subcommands.py index 6ec6856294d9..46adc0e9ea97 100644 --- a/awscli/customizations/s3/subcommands.py +++ b/awscli/customizations/s3/subcommands.py @@ -23,11 +23,11 @@ from awscli.customizations.s3.fileinfobuilder import FileInfoBuilder from awscli.customizations.s3.fileformat import FileFormat from awscli.customizations.s3.filegenerator import FileGenerator -from awscli.customizations.s3.fileinfo import TaskInfo +from awscli.customizations.s3.fileinfo import TaskInfo, FileInfo from awscli.customizations.s3.filters import create_filter -from awscli.customizations.s3.s3handler import S3Handler +from awscli.customizations.s3.s3handler import S3Handler, S3StreamHandler from awscli.customizations.s3.utils import find_bucket_key, uni_print, \ - AppendFilter + AppendFilter, find_dest_path_comp_key RECURSIVE = {'name': 'recursive', 'action': 'store_true', 'dest': 'dir_op', @@ -206,11 +206,25 @@ 'The object key name to use when ' 'a 4XX class error occurs.')} +ONLY_SHOW_ERRORS = {'name': 'only-show-errors', 'action': 'store_true', + 'help_text': ( + 'Only errors and warnings are displayed. All other ' + 'output is suppressed.')} + +EXPECTED_SIZE = {'name': 'expected-size', + 'help_text': ( + 'This argument specifies the expected size of a stream ' + 'in terms of bytes. Note that this argument is needed ' + 'only when a stream is being uploaded to s3 and the size ' + 'is larger than 5GB. Failure to include this argument ' + 'under these conditions may result in a failed upload. ' + 'due to too many parts in upload.')} + TRANSFER_ARGS = [DRYRUN, QUIET, RECURSIVE, INCLUDE, EXCLUDE, ACL, FOLLOW_SYMLINKS, NO_FOLLOW_SYMLINKS, NO_GUESS_MIME_TYPE, SSE, STORAGE_CLASS, GRANTS, WEBSITE_REDIRECT, CONTENT_TYPE, CACHE_CONTROL, CONTENT_DISPOSITION, CONTENT_ENCODING, - CONTENT_LANGUAGE, EXPIRES, SOURCE_REGION] + CONTENT_LANGUAGE, EXPIRES, SOURCE_REGION, ONLY_SHOW_ERRORS] SYNC_ARGS = [DELETE, EXACT_TIMESTAMPS, SIZE_ONLY] + TRANSFER_ARGS @@ -264,15 +278,14 @@ def _list_all_objects(self, bucket, key, page_size=None): self._display_page(response_data) def _display_page(self, response_data, use_basename=True): - common_prefixes = response_data['CommonPrefixes'] - contents = response_data['Contents'] + common_prefixes = response_data.get('CommonPrefixes', []) + contents = response_data.get('Contents', []) for common_prefix in common_prefixes: prefix_components = common_prefix['Prefix'].split('/') prefix = prefix_components[-2] pre_string = "PRE".rjust(30, " ") print_str = pre_string + ' ' + prefix + '/\n' uni_print(print_str) - sys.stdout.flush() for content in contents: last_mod_str = self._make_last_mod_str(content['LastModified']) size_str = self._make_size_str(content['Size']) @@ -284,7 +297,6 @@ def _display_page(self, response_data, use_basename=True): print_str = last_mod_str + ' ' + size_str + ' ' + \ filename + '\n' uni_print(print_str) - sys.stdout.flush() def _list_all_buckets(self): operation = self.service.get_operation('ListBuckets') @@ -294,7 +306,6 @@ def _list_all_buckets(self): last_mod_str = self._make_last_mod_str(bucket['CreationDate']) print_str = last_mod_str + ' ' + bucket['Name'] + '\n' uni_print(print_str) - sys.stdout.flush() def _list_all_objects_recursive(self, bucket, key, page_size=None): operation = self.service.get_operation('ListObjects') @@ -413,7 +424,7 @@ class CpCommand(S3TransferCommand): USAGE = " or " \ "or " ARG_TABLE = [{'name': 'paths', 'nargs': 2, 'positional_arg': True, - 'synopsis': USAGE}] + TRANSFER_ARGS + 'synopsis': USAGE}] + TRANSFER_ARGS + [EXPECTED_SIZE] EXAMPLES = BasicCommand.FROM_FILE('s3/cp.rst') @@ -434,7 +445,7 @@ class RmCommand(S3TransferCommand): USAGE = "" ARG_TABLE = [{'name': 'paths', 'nargs': 1, 'positional_arg': True, 'synopsis': USAGE}, DRYRUN, QUIET, RECURSIVE, INCLUDE, - EXCLUDE] + EXCLUDE, ONLY_SHOW_ERRORS] EXAMPLES = BasicCommand.FROM_FILE('s3/rm.rst') @@ -510,16 +521,21 @@ def create_instructions(self): instruction list because it sends the request to S3 and does not yield anything. """ - if self.cmd not in ['mb', 'rb']: + if self.needs_filegenerator(): self.instructions.append('file_generator') - if self.parameters.get('filters'): - self.instructions.append('filters') - if self.cmd == 'sync': - self.instructions.append('comparator') - if self.cmd not in ['mb', 'rb']: + if self.parameters.get('filters'): + self.instructions.append('filters') + if self.cmd == 'sync': + self.instructions.append('comparator') self.instructions.append('file_info_builder') self.instructions.append('s3_handler') + def needs_filegenerator(self): + if self.cmd in ['mb', 'rb'] or self.parameters['is_stream']: + return False + else: + return True + def run(self): """ This function wires together all of the generators and completes @@ -576,10 +592,22 @@ def run(self): operation_name=operation_name, service=self._service, endpoint=self._endpoint)] + stream_dest_path, stream_compare_key = find_dest_path_comp_key(files) + stream_file_info = [FileInfo(src=files['src']['path'], + dest=stream_dest_path, + compare_key=stream_compare_key, + src_type=files['src']['type'], + dest_type=files['dest']['type'], + operation_name=operation_name, + service=self._service, + endpoint=self._endpoint, + is_stream=True)] file_info_builder = FileInfoBuilder(self._service, self._endpoint, self._source_endpoint, self.parameters) s3handler = S3Handler(self.session, self.parameters, result_queue=result_queue) + s3_stream_handler = S3StreamHandler(self.session, self.parameters, + result_queue=result_queue) command_dict = {} if self.cmd == 'sync': @@ -591,6 +619,9 @@ def run(self): 'comparator': [Comparator(self.parameters)], 'file_info_builder': [file_info_builder], 's3_handler': [s3handler]} + elif self.cmd == 'cp' and self.parameters['is_stream']: + command_dict = {'setup': [stream_file_info], + 's3_handler': [s3_stream_handler]} elif self.cmd == 'cp': command_dict = {'setup': [files], 'file_generator': [file_generator], @@ -683,8 +714,19 @@ def add_paths(self, paths): self.parameters['dest'] = paths[1] elif len(paths) == 1: self.parameters['dest'] = paths[0] + self._validate_streaming_paths() self._validate_path_args() + def _validate_streaming_paths(self): + self.parameters['is_stream'] = False + if self.parameters['src'] == '-' or self.parameters['dest'] == '-': + self.parameters['is_stream'] = True + self.parameters['dir_op'] = False + self.parameters['only_show_errors'] = True + if self.parameters['is_stream'] and self.cmd != 'cp': + raise ValueError("Streaming currently is only compatible with " + "single file cp commands") + def _validate_path_args(self): # If we're using a mv command, you can't copy the object onto itself. params = self.parameters diff --git a/awscli/customizations/s3/tasks.py b/awscli/customizations/s3/tasks.py index 37089c42a2b9..be326f35e56d 100644 --- a/awscli/customizations/s3/tasks.py +++ b/awscli/customizations/s3/tasks.py @@ -63,7 +63,8 @@ class BasicTask(OrderableTask): attributes like ``session`` object in order for the filename to perform its designated operation. """ - def __init__(self, session, filename, parameters, result_queue): + def __init__(self, session, filename, parameters, + result_queue, payload=None): self.session = session self.service = self.session.get_service('s3') @@ -72,6 +73,7 @@ def __init__(self, session, filename, parameters, result_queue): self.parameters = parameters self.result_queue = result_queue + self.payload = payload def __call__(self): self._execute_task(attempts=3) @@ -84,9 +86,12 @@ def _execute_task(self, attempts, last_error=''): error_message=last_error) return filename = self.filename + kwargs = {} + if self.payload: + kwargs['payload'] = self.payload try: if not self.parameters['dryrun']: - getattr(filename, filename.operation_name)() + getattr(filename, filename.operation_name)(**kwargs) except requests.ConnectionError as e: connect_error = str(e) LOGGER.debug("%s %s failure: %s", @@ -195,13 +200,14 @@ class UploadPartTask(OrderableTask): complete the multipart upload initiated by the ``FileInfo`` object. """ - def __init__(self, part_number, chunk_size, - result_queue, upload_context, filename): + def __init__(self, part_number, chunk_size, result_queue, upload_context, + filename, payload=None): self._result_queue = result_queue self._upload_context = upload_context self._part_number = part_number self._chunk_size = chunk_size self._filename = filename + self._payload = payload def _read_part(self): actual_filename = self._filename.src @@ -216,9 +222,13 @@ def __call__(self): LOGGER.debug("Waiting for upload id.") upload_id = self._upload_context.wait_for_upload_id() bucket, key = find_bucket_key(self._filename.dest) - total = int(math.ceil( - self._filename.size/float(self._chunk_size))) - body = self._read_part() + if self._filename.is_stream: + body = self._payload + total = self._upload_context.expected_parts + else: + total = int(math.ceil( + self._filename.size/float(self._chunk_size))) + body = self._read_part() params = {'endpoint': self._filename.endpoint, 'bucket': bucket, 'key': key, 'part_number': self._part_number, @@ -393,16 +403,23 @@ def _queue_writes(self, body): body.set_socket_timeout(self.READ_TIMEOUT) amount_read = 0 current = body.read(iterate_chunk_size) + if self._filename.is_stream: + self._context.wait_for_turn(self._part_number) while current: offset = self._part_number * self._chunk_size + amount_read LOGGER.debug("Submitting IORequest to write queue.") - self._io_queue.put(IORequest(self._filename.dest, offset, current)) + self._io_queue.put( + IORequest(self._filename.dest, offset, current, + self._filename.is_stream) + ) LOGGER.debug("Request successfully submitted.") amount_read += len(current) current = body.read(iterate_chunk_size) # Change log message. LOGGER.debug("Done queueing writes for part number %s to file: %s", self._part_number, self._filename.dest) + if self._filename.is_stream: + self._context.done_with_turn() class CreateMultipartUploadTask(BasicTask): @@ -530,7 +547,7 @@ class MultipartUploadContext(object): _CANCELLED = '_CANCELLED' _COMPLETED = '_COMPLETED' - def __init__(self, expected_parts): + def __init__(self, expected_parts='...'): self._upload_id = None self._expected_parts = expected_parts self._parts = [] @@ -540,6 +557,10 @@ def __init__(self, expected_parts): self._upload_complete_condition = threading.Condition(self._lock) self._state = self._UNSTARTED + @property + def expected_parts(self): + return self._expected_parts + def announce_upload_id(self, upload_id): with self._upload_id_condition: self._upload_id = upload_id @@ -551,9 +572,15 @@ def announce_finished_part(self, etag, part_number): self._parts.append({'ETag': etag, 'PartNumber': part_number}) self._parts_condition.notifyAll() + def announce_total_parts(self, total_parts): + with self._parts_condition: + self._expected_parts = total_parts + self._parts_condition.notifyAll() + def wait_for_parts_to_finish(self): with self._parts_condition: - while len(self._parts) < self._expected_parts: + while self._expected_parts == '...' or \ + len(self._parts) < self._expected_parts: if self._state == self._CANCELLED: raise UploadCancelledError("Upload has been cancelled.") self._parts_condition.wait(timeout=1) @@ -653,9 +680,11 @@ def __init__(self, num_parts, lock=None): lock = threading.Lock() self._lock = lock self._created_condition = threading.Condition(self._lock) + self._submit_write_condition = threading.Condition(self._lock) self._completed_condition = threading.Condition(self._lock) self._state = self._STATES['UNSTARTED'] self._finished_parts = set() + self._current_stream_part_number = 0 def announce_completed_part(self, part_number): with self._completed_condition: @@ -685,6 +714,19 @@ def wait_for_completion(self): "Download has been cancelled.") self._completed_condition.wait(timeout=1) + def wait_for_turn(self, part_number): + with self._submit_write_condition: + while self._current_stream_part_number != part_number: + if self._state == self._STATES['CANCELLED']: + raise DownloadCancelledError( + "Download has been cancelled.") + self._submit_write_condition.wait(timeout=0.2) + + def done_with_turn(self): + with self._submit_write_condition: + self._current_stream_part_number += 1 + self._submit_write_condition.notifyAll() + def cancel(self): with self._lock: self._state = self._STATES['CANCELLED'] diff --git a/awscli/customizations/s3/utils.py b/awscli/customizations/s3/utils.py index 56fa247f2876..b5d4c797b54f 100644 --- a/awscli/customizations/s3/utils.py +++ b/awscli/customizations/s3/utils.py @@ -145,6 +145,34 @@ def get_file_stat(path): return stats.st_size, update_time +def find_dest_path_comp_key(files, src_path=None): + """ + This is a helper function that determines the destination path and compare + key given parameters received from the ``FileFormat`` class. + """ + src = files['src'] + dest = files['dest'] + src_type = src['type'] + dest_type = dest['type'] + if src_path is None: + src_path = src['path'] + + sep_table = {'s3': '/', 'local': os.sep} + + if files['dir_op']: + rel_path = src_path[len(src['path']):] + else: + rel_path = src_path.split(sep_table[src_type])[-1] + compare_key = rel_path.replace(sep_table[src_type], '/') + if files['use_src_name']: + dest_path = dest['path'] + dest_path += rel_path.replace(sep_table[src_type], + sep_table[dest_type]) + else: + dest_path = dest['path'] + return dest_path, compare_key + + def check_etag(etag, fileobj): """ This fucntion checks the etag and the md5 checksum to ensure no @@ -165,10 +193,9 @@ def check_error(response_data): response_data and raises an error when there is an error. """ if response_data: - if 'Errors' in response_data: - errors = response_data['Errors'] - for error in errors: - raise Exception("Error: %s\n" % error['Message']) + if 'Error' in response_data: + error = response_data['Error'] + raise Exception("Error: %s\n" % error['Message']) def create_warning(path, error_message): @@ -223,24 +250,42 @@ def __init__(self): self.count = 0 -def uni_print(statement): +def uni_print(statement, out_file=None): """ - This function is used to properly write unicode to stdout. It - ensures that the proper encoding is used if the statement is - not in a version type of string. The initial check is to - allow if ``sys.stdout`` does not use an encoding + This function is used to properly write unicode to a file, usually + stdout or stdderr. It ensures that the proper encoding is used if the + statement is not a string type. """ - encoding = getattr(sys.stdout, 'encoding', None) + if out_file is None: + out_file = sys.stdout + # Check for an encoding on the file. + encoding = getattr(out_file, 'encoding', None) if encoding is not None and not PY3: - sys.stdout.write(statement.encode(sys.stdout.encoding)) + out_file.write(statement.encode(out_file.encoding)) else: try: - sys.stdout.write(statement) + out_file.write(statement) except UnicodeEncodeError: # Some file like objects like cStringIO will # try to decode as ascii. Interestingly enough # this works with a normal StringIO. - sys.stdout.write(statement.encode('utf-8')) + out_file.write(statement.encode('utf-8')) + out_file.flush() + + +def bytes_print(statement): + """ + This function is used to properly write bytes to standard out. + """ + if PY3: + if getattr(sys.stdout, 'buffer', None): + sys.stdout.buffer.write(statement) + else: + # If it is not possible to write to the standard out buffer. + # The next best option is to decode and write to standard out. + sys.stdout.write(statement.decode('utf-8')) + else: + sys.stdout.write(statement) def guess_content_type(filename): @@ -354,7 +399,7 @@ def list_objects(self, bucket, prefix=None, page_size=None): True): pages = self._operation.paginate(self._endpoint, **kwargs) for response, page in pages: - contents = page['Contents'] + contents = page.get('Contents', []) for content in contents: source_path = bucket + '/' + content['Key'] size = content['Size'] @@ -362,8 +407,9 @@ def list_objects(self, bucket, prefix=None, page_size=None): yield source_path, size, last_update def _decode_keys(self, parsed, **kwargs): - for content in parsed['Contents']: - content['Key'] = unquote_str(content['Key']) + if 'Contents' in parsed: + for content in parsed['Contents']: + content['Key'] = unquote_str(content['Key']) class ScopedEventHandler(object): @@ -401,7 +447,8 @@ def __new__(cls, message, error=False, total_parts=None, warning=None): warning) -IORequest = namedtuple('IORequest', ['filename', 'offset', 'data']) +IORequest = namedtuple('IORequest', + ['filename', 'offset', 'data', 'is_stream']) # Used to signal that IO for the filename is finished, and that # any associated resources may be cleaned up. IOCloseRequest = namedtuple('IOCloseRequest', ['filename']) diff --git a/awscli/customizations/sessendemail.py b/awscli/customizations/sessendemail.py index fa92c8ec4eee..627191b6a330 100644 --- a/awscli/customizations/sessendemail.py +++ b/awscli/customizations/sessendemail.py @@ -16,7 +16,7 @@ will be:: aws ses send-email --subject SUBJECT --from FROM_EMAIL - --to-addresses addr ... --cc-addresses addr ... + --to-addresses addr ... --cc-addresses addr ... --bcc-addresses addr ... --reply-to-addresses addr ... --return-path addr --text TEXTBODY --html HTMLBODY @@ -26,6 +26,7 @@ from awscli.arguments import CustomArgument from awscli.customizations.utils import validate_mutually_exclusive_handler + TO_HELP = ('The email addresses of the primary recipients. ' 'You can specify multiple recipients as space-separated values') CC_HELP = ('The email addresses of copy recipients (Cc). ' @@ -93,7 +94,7 @@ def __init__(self, name, json_key, help_text='', dest=None, default=None, super(AddressesArgument, self).__init__(name=name, help_text=help_text, required=required, nargs='+') self._json_key = json_key - + def add_to_params(self, parameters, value): if value: _build_destination(parameters, self._json_key, value) @@ -105,8 +106,8 @@ def __init__(self, name, json_key, help_text='', required=None): super(BodyArgument, self).__init__(name=name, help_text=help_text, required=required) self._json_key = json_key - + def add_to_params(self, parameters, value): if value: _build_message(parameters, self._json_key, value) - + diff --git a/awscli/customizations/streamingoutputarg.py b/awscli/customizations/streamingoutputarg.py index 26d6ae434670..d230e7e3a45f 100644 --- a/awscli/customizations/streamingoutputarg.py +++ b/awscli/customizations/streamingoutputarg.py @@ -10,19 +10,30 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from botocore import model + from awscli.arguments import BaseCLIArgument def add_streaming_output_arg(argument_table, operation, **kwargs): # Implementation detail: hooked up to 'building-argument-table' # event. - stream_param = operation.is_streaming() - if stream_param: + model = operation.model + if _has_streaming_output(model): + streaming_argument_name = _get_streaming_argument_name(model) argument_table['outfile'] = StreamingOutputArgument( - response_key=stream_param, operation=operation, + response_key=streaming_argument_name, operation=operation, name='outfile') +def _has_streaming_output(model): + return model.has_streaming_output + + +def _get_streaming_argument_name(model): + return model.output_shape.serialization['payload'] + + class StreamingOutputArgument(BaseCLIArgument): BUFFER_SIZE = 32768 @@ -30,7 +41,8 @@ class StreamingOutputArgument(BaseCLIArgument): def __init__(self, response_key, operation, name, buffer_size=None): self._name = name - self.argument_object = operation + self.argument_model = model.Shape('StreamingOutputArgument', + {'type': 'string'}) if buffer_size is None: buffer_size = self.BUFFER_SIZE self._buffer_size = buffer_size diff --git a/awscli/customizations/toplevelbool.py b/awscli/customizations/toplevelbool.py index 6759b540cad4..ff4958c71343 100644 --- a/awscli/customizations/toplevelbool.py +++ b/awscli/customizations/toplevelbool.py @@ -38,6 +38,16 @@ def register_bool_params(event_handler): event_handler=event_handler)) +def _qualifies_for_simplification(arg_model): + if detect_shape_structure(arg_model) == 'structure(scalar)': + members = arg_model.members + if (len(members) == 1 and + list(members.keys())[0] == 'Value' and + list(members.values())[0].type_name == 'boolean'): + return True + return False + + def pull_up_bool(argument_table, event_handler, **kwargs): # List of tuples of (positive_bool, negative_bool) # This is used to validate that we don't specify @@ -48,23 +58,20 @@ def pull_up_bool(argument_table, event_handler, **kwargs): partial(validate_boolean_mutex_groups, boolean_pairs=boolean_pairs)) for key, value in list(argument_table.items()): - if hasattr(value, 'argument_object'): - arg_object = value.argument_object - if detect_shape_structure(arg_object) == 'structure(scalar)' and \ - len(arg_object.members) == 1 and \ - arg_object.members[0].name == 'Value' and \ - arg_object.members[0].type == 'boolean': + if hasattr(value, 'argument_model'): + arg_model = value.argument_model + if _qualifies_for_simplification(arg_model): # Swap out the existing CLIArgument for two args: # one that supports --option and --option # and another arg of --no-option. new_arg = PositiveBooleanArgument( - value.name, arg_object, value.operation_object, + value.name, arg_model, value.operation_object, value.name) argument_table[value.name] = new_arg negative_name = 'no-%s' % value.name negative_arg = NegativeBooleanParameter( negative_name, new_arg.py_name, - arg_object, value.operation_object, + arg_model, value.operation_object, action='store_true', dest='no_%s' % new_arg.py_name, group_name=value.name) argument_table[negative_name] = negative_arg @@ -87,9 +94,9 @@ def validate_boolean_mutex_groups(boolean_pairs, parsed_args, **kwargs): class PositiveBooleanArgument(arguments.CLIArgument): - def __init__(self, name, argument_object, operation_object, group_name): + def __init__(self, name, argument_model, operation_object, group_name): super(PositiveBooleanArgument, self).__init__( - name, argument_object, operation_object) + name, argument_model, operation_object) self._group_name = group_name @property @@ -115,19 +122,19 @@ def add_to_params(self, parameters, value): # e.g. --boolean-parameter # which means we should add a true value # to the parameters dict. - parameters[self.argument_object.py_name] = {'Value': True} + parameters[self.py_name] = {'Value': True} else: # Otherwise the arg was specified with a value. - parameters[self.argument_object.py_name] = self._unpack_argument( + parameters[self.py_name] = self._unpack_argument( value) class NegativeBooleanParameter(arguments.BooleanArgument): def __init__(self, name, positive_py_name, - argument_object, operation_object, + argument_model, operation_object, action='store_true', dest=None, group_name=None): super(NegativeBooleanParameter, self).__init__( - name, argument_object, operation_object, default=_NOT_SPECIFIED) + name, argument_model, operation_object, default=_NOT_SPECIFIED) self._group_name = group_name self._positive_py_name = positive_py_name diff --git a/awscli/errorhandler.py b/awscli/errorhandler.py index 1f8eba4c3fd6..9aaf5a0df348 100644 --- a/awscli/errorhandler.py +++ b/awscli/errorhandler.py @@ -77,13 +77,7 @@ def __call__(self, http_response, parsed, operation, **kwargs): def _get_error_code_and_message(self, response): code = 'Unknown' message = 'Unknown' - if 'Errors' in response: - if isinstance(response['Errors'], list): - error = response['Errors'][-1] - if 'Code' in error: - code = error['Code'] - elif 'Type' in error: - code = error['Type'] - if 'Message' in error: - message = error['Message'] + if 'Error' in response: + error = response['Error'] + return error.get('Code', code), error.get('Message', message) return (code, message) diff --git a/awscli/examples/configure/_description.rst b/awscli/examples/configure/_description.rst index 040dd5b88e8c..a9c01ac3f3da 100644 --- a/awscli/examples/configure/_description.rst +++ b/awscli/examples/configure/_description.rst @@ -10,6 +10,11 @@ When you are prompted for information, the current value will be displayed in config file. It does not use any configuration values from environment variables or the IAM role. +Note: the values you provide for the AWS Access Key ID and the AWS Secret +Access Key will be written to the shared credentials file +(``~/.aws/credentials``). + + ======================= Configuration Variables ======================= diff --git a/awscli/examples/configure/set/_description.rst b/awscli/examples/configure/set/_description.rst index b4bad9e53829..b915e39680cf 100644 --- a/awscli/examples/configure/set/_description.rst +++ b/awscli/examples/configure/set/_description.rst @@ -11,3 +11,8 @@ configuration value. If the config file does not exist, one will automatically be created. If the configuration value already exists in the config file, it will updated with the new configuration value. + +Setting a value for the ``aws_access_key_id``, ``aws_secret_access_key``, or +the ``aws_session_token`` will result in the value being writen to the +shared credentials file (``~/.aws/credentials``). All other values will +be written to the config file (default location is ``~/.aws/config``). diff --git a/awscli/examples/s3/cp.rst b/awscli/examples/s3/cp.rst index 6bdf25dc0dc1..1fe488bc7751 100644 --- a/awscli/examples/s3/cp.rst +++ b/awscli/examples/s3/cp.rst @@ -101,3 +101,15 @@ Output:: upload: file.txt to s3://mybucket/file.txt +**Uploading a local file stream to S3** + +The following ``cp`` command uploads a local file stream from standard input to a specified bucket and key:: + + aws s3 cp - s3://mybucket/stream.txt + + +**Downloading a S3 object as a local file stream** + +The following ``cp`` command downloads a S3 object locally as a stream to standard output:: + + aws s3 cp s3://mybucket/stream.txt - diff --git a/awscli/formatter.py b/awscli/formatter.py index ca4c445f6ac5..385927bdae40 100644 --- a/awscli/formatter.py +++ b/awscli/formatter.py @@ -19,6 +19,7 @@ from awscli.table import MultiTable, Styler, ColorizedStyler from awscli import text from awscli import compat +from awscli.utils import json_encoder LOG = logging.getLogger(__name__) @@ -83,7 +84,7 @@ def _format_response(self, operation, response, stream): # that out to the user but other "falsey" values like an empty # dictionary should be printed. if response: - json.dump(response, stream, indent=4) + json.dump(response, stream, indent=4, default=json_encoder) stream.write('\n') diff --git a/awscli/handlers.py b/awscli/handlers.py index 6d99892b9d9e..4d1d2aaaac42 100644 --- a/awscli/handlers.py +++ b/awscli/handlers.py @@ -55,15 +55,15 @@ def awscli_initialize(event_handlers): event_handlers.register('process-cli-arg', param_shorthand) error_handler = ErrorHandler() event_handlers.register('after-call.*.*', error_handler) - # The following will get fired for every option we are - # documenting. It will attempt to add an example_fn on to - # the parameter object if the parameter supports shorthand - # syntax. The documentation event handlers will then use - # the examplefn to generate the sample shorthand syntax - # in the docs. Registering here should ensure that this - # handler gets called first but it still feels a bit brittle. - event_handlers.register('doc-option-example.*.*.*', - param_shorthand.add_example_fn) +# # The following will get fired for every option we are +# # documenting. It will attempt to add an example_fn on to +# # the parameter object if the parameter supports shorthand +# # syntax. The documentation event handlers will then use +# # the examplefn to generate the sample shorthand syntax +# # in the docs. Registering here should ensure that this +# # handler gets called first but it still feels a bit brittle. +# event_handlers.register('doc-option-example.*.*.*', +# param_shorthand.add_example_fn) event_handlers.register('doc-examples.*.*', add_examples) event_handlers.register('building-argument-table.s3api.*', diff --git a/awscli/paramfile.py b/awscli/paramfile.py index bda5e72dae99..45eb283a5f69 100644 --- a/awscli/paramfile.py +++ b/awscli/paramfile.py @@ -23,6 +23,42 @@ logger = logging.getLogger(__name__) +# These are special cased arguments that do _not_ get the +# special param file processing. This is typically because it +# refers to an actual URI of some sort and we don't want to actually +# download the content (i.e TemplateURL in cloudformation). +PARAMFILE_DISABLED = set([ + 'cloudformation.create-stack.template-url', + 'cloudformation.update-stack.template-url', + 'cloudformation.validate-template.template-url', + 'cloudformation.estimate-template-cost.template-url', + + 'cloudformation.create-stack.stack-policy-url', + 'cloudformation.update-stack.stack-policy-url', + 'cloudformation.set-stack-policy.stack-policy-url', + + 'sqs.add-permission.queue-url', + 'sqs.change-message-visibility.queue-url', + 'sqs.change-message-visibility-batch.queue-url', + 'sqs.delete-message.queue-url', + 'sqs.delete-message-batch.queue-url', + 'sqs.delete-queue.queue-url', + 'sqs.get-queue-attributes.queue-url', + 'sqs.list-dead-letter-source-queues.queue-url', + 'sqs.receive-message.queue-url', + 'sqs.remove-permission.queue-url', + 'sqs.send-message.queue-url', + 'sqs.send-message-batch.queue-url', + 'sqs.set-queue-attributes.queue-url', + + 's3.copy-object.website-redirect-location', + 's3.create-multipart-upload.website-redirect-location', + 's3.put-object.website-redirect-location', + + # Double check that this has been renamed! + 'sns.subscribe.notification-endpoint', +]) + class ResourceLoadingError(Exception): pass diff --git a/awscli/schema.py b/awscli/schema.py index 8ca6dc7f4e2d..1bf9faab1745 100644 --- a/awscli/schema.py +++ b/awscli/schema.py @@ -10,8 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from collections import defaultdict -class ParameterRequiredError(ValueError): pass + +class ParameterRequiredError(ValueError): + pass class SchemaTransformer(object): @@ -26,8 +29,7 @@ class SchemaTransformer(object): Only a relevant subset of features is supported here: - * Types: `object`, `array`, `string`, `number`, `integer`, - `boolean` + * Types: `object`, `array`, `string`, `integer`, `boolean` * Properties: `type`, `description`, `required`, `enum` For example:: @@ -61,62 +63,98 @@ class SchemaTransformer(object): $ aws foo bar --baz arg1=Value1,arg2=5 arg1=Value2 """ - # Map schema types to internal representation types - TYPE_MAP = { - 'array': 'list', + JSON_SCHEMA_TO_AWS_TYPES = { 'object': 'structure', + 'array': 'list', } - # Map schema properties to internal representation properties - PROPERTY_MAP = { - # List item description - 'items': 'members', - # Object properties description - 'properties': 'members', - } - - # List of known properties to copy or transform without any - # other special processing. - SUPPORTED_BASIC_PROPERTIES = [ - 'type', 'description', 'required', 'enum' - ] - - def __init__(self, schema): - self.schema = schema - - def transform(self): - """Convert to an internal representation of parameters""" - return self._process_param(self.schema) - - def _process_param(self, param): - transformed = {} - - if 'type' not in param: - raise ParameterRequiredError( - 'The type property is required: {0}'.format(param)) - - # Handle basic properties which are just copied and optionally - # mapped to new values. - for basic_property in self.SUPPORTED_BASIC_PROPERTIES: - if basic_property in param: - value = param[basic_property] - - if basic_property == 'type': - value = self.TYPE_MAP.get(value, value) - - mapped = self.PROPERTY_MAP.get(basic_property, basic_property) - transformed[mapped] = value - - # Handle complex properties - if 'items' in param: - mapped = self.PROPERTY_MAP.get('items', 'items') - transformed[mapped] = self._process_param(param['items']) - - if 'properties' in param: - mapped = self.PROPERTY_MAP.get('properties', 'properties') - transformed[mapped] = {} - - for key, value in param['properties'].items(): - transformed[mapped][key] = self._process_param(value) - - return transformed + def __init__(self): + self._shape_namer = ShapeNameGenerator() + + def transform(self, schema): + """Convert JSON schema to the format used internally by the AWS CLI. + + :type schema: dict + :param schema: The JSON schema describing the argument model. + + :rtype: dict + :return: The transformed model in a form that can be consumed + internally by the AWS CLI. The dictionary returned will + have a list of shapes, where the shape representing the + transformed schema is always named ``InputShape`` in the + returned dictionary. + + """ + shapes = {} + self._transform(schema, shapes, 'InputShape') + return shapes + + def _transform(self, schema, shapes, shape_name): + if 'type' not in schema: + raise ParameterRequiredError("Missing required key: 'type'") + if schema['type'] == 'object': + shapes[shape_name] = self._transform_structure(schema, shapes) + elif schema['type'] == 'array': + shapes[shape_name] = self._transform_list(schema, shapes) + else: + shapes[shape_name] = self._transform_scalar(schema) + return shapes + + def _transform_scalar(self, schema): + return self._populate_initial_shape(schema) + + def _transform_structure(self, schema, shapes): + # Transforming a structure involves: + # 1. Generating the shape definition for the structure + # 2. Generating the shape definitions for its members + structure_shape = self._populate_initial_shape(schema) + members = {} + required_members = [] + + for key, value in schema['properties'].items(): + current_type_name = self._json_schema_to_aws_type(value) + current_shape_name = self._shape_namer.new_shape_name( + current_type_name) + members[key] = {'shape': current_shape_name} + if value.get('required', False): + required_members.append(key) + self._transform(value, shapes, current_shape_name) + structure_shape['members'] = members + if required_members: + structure_shape['required'] = required_members + return structure_shape + + def _transform_list(self, schema, shapes): + # Transforming a structure involves: + # 1. Generating the shape definition for the structure + # 2. Generating the shape definitions for its 'items' member + list_shape = self._populate_initial_shape(schema) + member_type = self._json_schema_to_aws_type(schema['items']) + member_shape_name = self._shape_namer.new_shape_name(member_type) + list_shape['member'] = {'shape': member_shape_name} + self._transform(schema['items'], shapes, member_shape_name) + return list_shape + + def _populate_initial_shape(self, schema): + shape = {'type': self._json_schema_to_aws_type(schema)} + if 'description' in schema: + shape['documentation'] = schema['description'] + if 'enum' in schema: + shape['enum'] = schema['enum'] + return shape + + def _json_schema_to_aws_type(self, schema): + if 'type' not in schema: + raise ParameterRequiredError("Missing required key: 'type'") + type_name = schema['type'] + return self.JSON_SCHEMA_TO_AWS_TYPES.get(type_name, type_name) + + +class ShapeNameGenerator(object): + def __init__(self): + self._name_cache = defaultdict(int) + + def new_shape_name(self, type_name): + self._name_cache[type_name] += 1 + current_index = self._name_cache[type_name] + return '%sType%s' % (type_name.capitalize(), current_index) diff --git a/awscli/testutils.py b/awscli/testutils.py index b6f4bd2abcc0..08e731361fe2 100644 --- a/awscli/testutils.py +++ b/awscli/testutils.py @@ -30,6 +30,7 @@ import tempfile import platform import contextlib +from pprint import pformat from subprocess import Popen, PIPE try: @@ -245,7 +246,8 @@ def before_call(self, params, **kwargs): self._store_params(params) def _store_params(self, params): - self.last_params = params + self.last_request_dict = params + self.last_params = params['body'] def patch_make_request(self): make_request_patch = self.make_request_patch.start() @@ -262,14 +264,54 @@ def assert_params_for_cmd(self, cmd, params=None, expected_rc=0, if stderr_contains is not None: self.assertIn(stderr_contains, stderr) if params is not None: - last_params = copy.copy(self.last_params) - if ignore_params is not None: + last_params = self.last_params + if isinstance(last_params, dict): + last_params = copy.copy(self.last_params) + extra_params_to_ignore = ['Action', 'Version'] + if ignore_params is None: + ignore_params = extra_params_to_ignore + else: + ignore_params.extend(extra_params_to_ignore) for key in ignore_params: try: del last_params[key] except KeyError: pass - self.assertDictEqual(params, last_params) + if params != last_params: + self.fail("Actual params did not match expected params.\n" + "Expected:\n\n" + "%s\n" + "Actual:\n\n%s\n" % ( + pformat(params), pformat(last_params))) + return stdout, stderr, rc + + def assert_params_for_cmd2(self, cmd, params=None, expected_rc=0, + stderr_contains=None, ignore_params=None): + # XXX: This has a terrible name because it's intended to be + # temporary. I want to switch everything off of + # assert_params_for_cmd and then I'll rename this to + # assert_params_for_cmd2. The difference between this command + # and the other one is that we verify the kwargs that are sent + # to botocore's Operation.call(), *not* the serialized parameters + # onto the HTTP request. We're one level up from that. + stdout, stderr, rc = self.run_cmd(cmd, expected_rc) + if stderr_contains is not None: + self.assertIn(stderr_contains, stderr) + if params is not None: + # The last kwargs of Operation.call() in botocore. + last_kwargs = copy.copy(self.last_kwargs) + if ignore_params is not None: + for key in ignore_params: + try: + del last_kwargs[key] + except KeyError: + pass + if params != last_kwargs: + self.fail("Actual params did not match expected params.\n" + "Expected:\n\n" + "%s\n" + "Actual:\n\n%s\n" % ( + pformat(params), pformat(last_kwargs))) return stdout, stderr, rc def before_parameter_build(self, params, operation, **kwargs): @@ -395,7 +437,7 @@ def _escape_quotes(command): def aws(command, collect_memory=False, env_vars=None, - wait_for_finish=True): + wait_for_finish=True, input_data=None, input_file=None): """Run an aws command. This help function abstracts the differences of running the "aws" @@ -413,6 +455,19 @@ def aws(command, collect_memory=False, env_vars=None, proper cleanup. This can be useful if you want to test timeout's or how the CLI responds to various signals. + :type input_data: string + :param input_data: This string will be communicated to the process through + the stdin of the process. It essentially allows the user to + avoid having to use a file handle to pass information to the process. + Note that this string is not passed on creation of the process, but + rather communicated to the process. + + :type input_file: a file handle + :param input_file: This is a file handle that will act as the + the stdin of the process immediately on creation. Essentially + any data written to the file will be read from stdin of the + process. This is needed if you plan to stream data into stdin while + collecting memory. """ if platform.system() == 'Windows': command = _escape_quotes(command) @@ -421,7 +476,7 @@ def aws(command, collect_memory=False, env_vars=None, else: aws_command = 'python %s' % get_aws_cmd() full_command = '%s %s' % (aws_command, command) - stdout_encoding = _get_stdout_encoding() + stdout_encoding = get_stdout_encoding() if isinstance(full_command, six.text_type) and not six.PY3: full_command = full_command.encode(stdout_encoding) INTEG_LOG.debug("Running command: %s", full_command) @@ -429,13 +484,18 @@ def aws(command, collect_memory=False, env_vars=None, env['AWS_DEFAULT_REGION'] = "us-east-1" if env_vars is not None: env = env_vars - process = Popen(full_command, stdout=PIPE, stderr=PIPE, shell=True, - env=env) + if input_file is None: + input_file = PIPE + process = Popen(full_command, stdout=PIPE, stderr=PIPE, stdin=input_file, + shell=True, env=env) if not wait_for_finish: return process memory = None if not collect_memory: - stdout, stderr = process.communicate() + kwargs = {} + if input_data: + kwargs = {'input': input_data} + stdout, stderr = process.communicate(**kwargs) else: stdout, stderr, memory = _wait_and_collect_mem(process) return Result(process.returncode, @@ -444,7 +504,7 @@ def aws(command, collect_memory=False, env_vars=None, memory) -def _get_stdout_encoding(): +def get_stdout_encoding(): encoding = getattr(sys.__stdout__, 'encoding', None) if encoding is None: encoding = 'utf-8' diff --git a/awscli/utils.py b/awscli/utils.py index bd37243fcbf5..de3fad99a4c1 100644 --- a/awscli/utils.py +++ b/awscli/utils.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import csv +import datetime import six @@ -104,3 +105,13 @@ def _find_quote_char_in_part(part): elif single_quote < double_quote: quote_char = "'" return quote_char + + +def json_encoder(obj): + """JSON encoder that formats datetimes as ISO8601 format.""" + if isinstance(obj, datetime.datetime): + return obj.isoformat() + else: + return obj + + diff --git a/doc/source/conf.py b/doc/source/conf.py index 266341c53c60..cdddac60f20b 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -50,9 +50,9 @@ # built documents. # # The short X.Y version. -version = '1.4' +version = '1.5' # The full version, including alpha/beta/rc tags. -release = '1.4.4' +release = '1.5.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/setup.py b/setup.py index df0c0059e4dc..e38bdfd00594 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ import awscli -requires = ['botocore>=0.63.0,<0.64.0', +requires = ['botocore>=0.64.0,<0.65.0', 'bcdoc>=0.12.0,<0.13.0', 'six>=1.1.0', 'colorama==0.2.5', diff --git a/tests/integration/customizations/s3/test_plugin.py b/tests/integration/customizations/s3/test_plugin.py index f2da15b9ae57..4512f797dd8c 100644 --- a/tests/integration/customizations/s3/test_plugin.py +++ b/tests/integration/customizations/s3/test_plugin.py @@ -28,7 +28,7 @@ import botocore.session import six -from awscli.testutils import unittest, FileCreator +from awscli.testutils import unittest, FileCreator, get_stdout_encoding from awscli.testutils import aws as _aws from tests.unit.customizations.s3 import create_bucket as _create_bucket from awscli.customizations.s3 import constants @@ -44,12 +44,14 @@ def cd(directory): os.chdir(original) -def aws(command, collect_memory=False, env_vars=None, wait_for_finish=True): +def aws(command, collect_memory=False, env_vars=None, wait_for_finish=True, + input_data=None, input_file=None): if not env_vars: env_vars = os.environ.copy() env_vars['AWS_DEFAULT_REGION'] = "us-west-2" return _aws(command, collect_memory=collect_memory, env_vars=env_vars, - wait_for_finish=wait_for_finish) + wait_for_finish=wait_for_finish, input_data=input_data, + input_file=input_file) class BaseS3CLICommand(unittest.TestCase): @@ -164,10 +166,10 @@ def assert_no_errors(self, p): self.assertEqual( p.rc, 0, "Non zero rc (%s) received: %s" % (p.rc, p.stdout + p.stderr)) - self.assertNotIn("Error:", p.stdout) - self.assertNotIn("failed:", p.stdout) - self.assertNotIn("client error", p.stdout) - self.assertNotIn("server error", p.stdout) + self.assertNotIn("Error:", p.stderr) + self.assertNotIn("failed:", p.stderr) + self.assertNotIn("client error", p.stderr) + self.assertNotIn("server error", p.stderr) class TestMoveCommand(BaseS3CLICommand): @@ -458,7 +460,7 @@ def test_download_non_existent_key(self): expected_err_msg = ( 'A client error (NoSuchKey) occurred when calling the ' 'HeadObject operation: Key "foo.txt" does not exist') - self.assertIn(expected_err_msg, p.stdout) + self.assertIn(expected_err_msg, p.stderr) class TestSync(BaseS3CLICommand): @@ -645,7 +647,7 @@ def testFailWithoutRegion(self): p2 = aws('s3 sync s3://%s/ s3://%s/ --region %s' % (self.src_bucket, self.dest_bucket, self.src_region)) self.assertEqual(p2.rc, 1, p2.stdout) - self.assertIn('PermanentRedirect', p2.stdout) + self.assertIn('PermanentRedirect', p2.stderr) def testCpRegion(self): self.files.create_file('foo.txt', 'foo') @@ -695,9 +697,9 @@ def extra_setup(self): def test_no_exist(self): filename = os.path.join(self.files.rootdir, "no-exists-file") p = aws('s3 cp %s s3://%s/' % (filename, self.bucket_name)) - self.assertEqual(p.rc, 2, p.stdout) + self.assertEqual(p.rc, 2, p.stderr) self.assertIn('warning: Skipping file %s. File does not exist.' % - filename, p.stdout) + filename, p.stderr) @unittest.skipIf(platform.system() not in ['Darwin', 'Linux'], 'Read permissions tests only supported on mac/linux') @@ -711,9 +713,9 @@ def test_no_read_access(self): permissions = permissions ^ stat.S_IREAD os.chmod(filename, permissions) p = aws('s3 cp %s s3://%s/' % (filename, self.bucket_name)) - self.assertEqual(p.rc, 2, p.stdout) + self.assertEqual(p.rc, 2, p.stderr) self.assertIn('warning: Skipping file %s. File/Directory is ' - 'not readable.' % filename, p.stdout) + 'not readable.' % filename, p.stderr) @unittest.skipIf(platform.system() not in ['Darwin', 'Linux'], 'Special files only supported on mac/linux') @@ -723,10 +725,10 @@ def test_is_special_file(self): sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.bind(file_path) p = aws('s3 cp %s s3://%s/' % (file_path, self.bucket_name)) - self.assertEqual(p.rc, 2, p.stdout) + self.assertEqual(p.rc, 2, p.stderr) self.assertIn(("warning: Skipping file %s. File is character " "special device, block special device, FIFO, or " - "socket." % file_path), p.stdout) + "socket." % file_path), p.stderr) @unittest.skipIf(platform.system() not in ['Darwin', 'Linux'], @@ -806,10 +808,10 @@ def test_follow_symlinks_default(self): def test_bad_symlink(self): p = aws('s3 sync %s s3://%s/' % (self.files.rootdir, self.bucket_name)) - self.assertEqual(p.rc, 2, p.stdout) + self.assertEqual(p.rc, 2, p.stderr) self.assertIn('warning: Skipping file %s. File does not exist.' % os.path.join(self.files.rootdir, 'b-badsymlink'), - p.stdout) + p.stderr) class TestUnicode(BaseS3CLICommand): @@ -945,10 +947,110 @@ def test_mb_rb(self): def test_fail_mb_rb(self): # Choose a bucket name that already exists. p = aws('s3 mb s3://mybucket') - self.assertIn("BucketAlreadyExists", p.stdout) + self.assertIn("BucketAlreadyExists", p.stderr) self.assertEqual(p.rc, 1) +class TestOutput(BaseS3CLICommand): + """ + This ensures that arguments that affect output i.e. ``--quiet`` and + ``--only-show-errors`` behave as expected. + """ + def test_normal_output(self): + # Make a bucket. + bucket_name = self.create_bucket() + foo_txt = self.files.create_file('foo.txt', 'foo contents') + + # Copy file into bucket. + p = aws('s3 cp %s s3://%s/' % (foo_txt, bucket_name)) + self.assertEqual(p.rc, 0) + # Check that there were no errors and that parts of the expected + # progress message are written to stdout. + self.assert_no_errors(p) + self.assertIn('upload', p.stdout) + self.assertIn('s3://%s/foo.txt' % bucket_name, p.stdout) + + def test_normal_output_quiet(self): + # Make a bucket. + bucket_name = self.create_bucket() + foo_txt = self.files.create_file('foo.txt', 'foo contents') + + # Copy file into bucket. + p = aws('s3 cp %s s3://%s/ --quiet' % (foo_txt, bucket_name)) + self.assertEqual(p.rc, 0) + # Check that nothing was printed to stdout. + self.assertEqual('', p.stdout) + + def test_normal_output_only_show_errors(self): + # Make a bucket. + bucket_name = self.create_bucket() + foo_txt = self.files.create_file('foo.txt', 'foo contents') + + # Copy file into bucket. + p = aws('s3 cp %s s3://%s/ --only-show-errors' % (foo_txt, bucket_name)) + self.assertEqual(p.rc, 0) + # Check that nothing was printed to stdout. + self.assertEqual('', p.stdout) + + def test_error_output(self): + foo_txt = self.files.create_file('foo.txt', 'foo contents') + + # Copy file into bucket. + p = aws('s3 cp %s s3://non-existant-bucket/' % foo_txt) + # Check that there were errors and that the error was print to stderr. + self.assertEqual(p.rc, 1) + self.assertIn('upload failed', p.stderr) + + def test_error_ouput_quiet(self): + foo_txt = self.files.create_file('foo.txt', 'foo contents') + + # Copy file into bucket. + p = aws('s3 cp %s s3://non-existant-bucket/ --quiet' % foo_txt) + # Check that there were errors and that the error was not + # print to stderr. + self.assertEqual(p.rc, 1) + self.assertEqual('', p.stderr) + + def test_error_ouput_only_show_errors(self): + foo_txt = self.files.create_file('foo.txt', 'foo contents') + + # Copy file into bucket. + p = aws('s3 cp %s s3://non-existant-bucket/ --only-show-errors' + % foo_txt) + # Check that there were errors and that the error was print to stderr. + self.assertEqual(p.rc, 1) + self.assertIn('upload failed', p.stderr) + + def test_error_and_success_output_only_show_errors(self): + # Make a bucket. + bucket_name = self.create_bucket() + + # Create one file. + self.files.create_file('f', 'foo contents') + + # Create another file that has a slightly longer name than the first. + self.files.create_file('bar.txt', 'bar contents') + + # Create a prefix that will cause the second created file to have a key + # longer than 1024 bytes which is not allowed in s3. + long_prefix = 'd' * 1022 + + p = aws('s3 cp %s s3://%s/%s/ --only-show-errors --recursive' + % (self.files.rootdir, bucket_name, long_prefix)) + + # Check that there was at least one error. + self.assertEqual(p.rc, 1) + + # Check that there was nothing written to stdout for successful upload. + self.assertEqual('', p.stdout) + + # Check that the failed message showed up in stderr. + self.assertIn('upload failed', p.stderr) + + # Ensure the expected successful key exists in the bucket. + self.assertTrue(self.key_exists(bucket_name, long_prefix + '/f')) + + class TestDryrun(BaseS3CLICommand): """ This ensures that dryrun works. @@ -1007,7 +1109,7 @@ def extra_setup(self): def assert_max_memory_used(self, process, max_mem_allowed, full_command): peak_memory = max(process.memory_usage) - if peak_memory > self.max_mem_allowed: + if peak_memory > max_mem_allowed: failure_message = ( 'Exceeded max memory allowed (%s MB) for command ' '"%s": %s MB' % (self.max_mem_allowed / 1024.0 / 1024.0, @@ -1033,6 +1135,45 @@ def test_transfer_single_large_file(self): self.assert_max_memory_used(p, self.max_mem_allowed, download_full_command) + def test_stream_large_file(self): + """ + This tests to ensure that streaming files for both uploads and + downloads do not use too much memory. Note that streaming uploads + will use slightly more memory than usual but should not put the + entire file into memory. + """ + bucket_name = self.create_bucket() + + # Create a 200 MB file that will be streamed + num_mb = 200 + foo_txt = self.files.create_file('foo.txt', '') + with open(foo_txt, 'wb') as f: + for i in range(num_mb): + f.write(b'a' * 1024 * 1024) + + # The current memory threshold is set at about the peak amount for + # performing a streaming upload of a file larger than 100 MB. So + # this maximum needs to be bumped up. The maximum memory allowance + # is increased by two chunksizes because that is the maximum + # amount of chunks that will be queued while not being operated on + # by a thread when performing a streaming multipart upload. + max_mem_allowed = self.max_mem_allowed + 2 * constants.CHUNKSIZE + + full_command = 's3 cp - s3://%s/foo.txt' % bucket_name + with open(foo_txt, 'rb') as f: + p = aws(full_command, input_file=f, collect_memory=True) + self.assert_no_errors(p) + self.assert_max_memory_used(p, max_mem_allowed, full_command) + + # Now perform a streaming download of the file. + full_command = 's3 cp s3://%s/foo.txt - > %s' % (bucket_name, foo_txt) + p = aws(full_command, collect_memory=True) + self.assert_no_errors(p) + # Use the ususal bar for maximum memory usage since a streaming + # download's memory usage should be comparable to non-streaming + # transfers. + self.assert_max_memory_used(p, self.max_mem_allowed, full_command) + class TestWebsiteConfiguration(BaseS3CLICommand): def test_create_website_index_configuration(self): @@ -1048,9 +1189,9 @@ def test_create_website_index_configuration(self): parsed = operation.call( self.endpoint, bucket=bucket_name)[1] self.assertEqual(parsed['IndexDocument']['Suffix'], 'index.html') - self.assertEqual(parsed['ErrorDocument'], {}) - self.assertEqual(parsed['RoutingRules'], []) - self.assertEqual(parsed['RedirectAllRequestsTo'], {}) + self.assertNotIn('ErrorDocument', parsed) + self.assertNotIn('RoutingRules', parsed) + self.assertNotIn('RedirectAllRequestsTo', parsed) def test_create_website_index_and_error_configuration(self): bucket_name = self.create_bucket() @@ -1065,8 +1206,8 @@ def test_create_website_index_and_error_configuration(self): self.endpoint, bucket=bucket_name)[1] self.assertEqual(parsed['IndexDocument']['Suffix'], 'index.html') self.assertEqual(parsed['ErrorDocument']['Key'], 'error.html') - self.assertEqual(parsed['RoutingRules'], []) - self.assertEqual(parsed['RedirectAllRequestsTo'], {}) + self.assertNotIn('RoutingRules', parsed) + self.assertNotIn('RedirectAllRequestsTo', parsed) class TestIncludeExcludeFilters(BaseS3CLICommand): @@ -1231,5 +1372,99 @@ def test_sync_file_with_spaces(self): self.assertEqual(p2.rc, 0) +class TestStreams(BaseS3CLICommand): + def test_upload(self): + """ + This tests uploading a small stream from stdin. + """ + bucket_name = self.create_bucket() + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=b'This is a test') + self.assert_no_errors(p) + self.assertTrue(self.key_exists(bucket_name, 'stream')) + self.assertEqual(self.get_key_contents(bucket_name, 'stream'), + 'This is a test') + + def test_unicode_upload(self): + """ + This tests being able to upload unicode from stdin. + """ + unicode_str = u'\u00e9 This is a test' + byte_str = unicode_str.encode('utf-8') + bucket_name = self.create_bucket() + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=byte_str) + self.assert_no_errors(p) + self.assertTrue(self.key_exists(bucket_name, 'stream')) + self.assertEqual(self.get_key_contents(bucket_name, 'stream'), + unicode_str) + + def test_multipart_upload(self): + """ + This tests the ability to multipart upload streams from stdin. + The data has some unicode in it to avoid having to do a seperate + multipart upload test just for unicode. + """ + + bucket_name = self.create_bucket() + data = u'\u00e9bcd' * (1024 * 1024 * 10) + data_encoded = data.encode('utf-8') + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=data_encoded) + self.assert_no_errors(p) + self.assertTrue(self.key_exists(bucket_name, 'stream')) + self.assert_key_contents_equal(bucket_name, 'stream', data) + + def test_download(self): + """ + This tests downloading a small stream from stdout. + """ + bucket_name = self.create_bucket() + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=b'This is a test') + self.assert_no_errors(p) + + p = aws('s3 cp s3://%s/stream -' % bucket_name) + self.assert_no_errors(p) + self.assertEqual(p.stdout, 'This is a test') + + def test_unicode_download(self): + """ + This tests downloading a small unicode stream from stdout. + """ + bucket_name = self.create_bucket() + + data = u'\u00e9 This is a test' + data_encoded = data.encode('utf-8') + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=data_encoded) + self.assert_no_errors(p) + + # Downloading the unicode stream to standard out. + p = aws('s3 cp s3://%s/stream -' % bucket_name) + self.assert_no_errors(p) + self.assertEqual(p.stdout, data_encoded.decode(get_stdout_encoding())) + + def test_multipart_download(self): + """ + This tests the ability to multipart download streams to stdout. + The data has some unicode in it to avoid having to do a seperate + multipart download test just for unicode. + """ + bucket_name = self.create_bucket() + + # First lets upload some data via streaming since + # its faster and we do not have to write to a file! + data = u'\u00e9bcd' * (1024 * 1024 * 10) + data_encoded = data.encode('utf-8') + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=data_encoded) + + # Download the unicode stream to standard out. + p = aws('s3 cp s3://%s/stream -' % bucket_name) + self.assert_no_errors(p) + self.assertEqual(p.stdout, data_encoded.decode(get_stdout_encoding())) + + if __name__ == "__main__": unittest.main() diff --git a/tests/integration/customizations/s3/test_s3handler.py b/tests/integration/customizations/s3/test_s3handler.py index 4d46a2381173..955f428ded7a 100644 --- a/tests/integration/customizations/s3/test_s3handler.py +++ b/tests/integration/customizations/s3/test_s3handler.py @@ -163,12 +163,12 @@ def setUp(self): self.s3_files = [self.bucket + '/text1.txt', self.bucket + '/another_directory/text2.txt'] self.output = StringIO() - self.saved_stdout = sys.stdout - sys.stdout = self.output + self.saved_stderr = sys.stderr + sys.stderr = self.output def tearDown(self): self.output.close() - sys.stdout = self.saved_stdout + sys.stderr = self.saved_stderr clean_loc_files(self.loc_files) s3_cleanup(self.bucket, self.session) @@ -215,7 +215,7 @@ def setUp(self): self.session = botocore.session.get_session(EnvironmentVariables) self.service = self.session.get_service('s3') self.endpoint = self.service.get_endpoint('us-east-1') - params = {'region': 'us-east-1', 'acl': ['private']} + params = {'region': 'us-east-1', 'acl': ['private'], 'quiet': True} self.s3_handler = S3Handler(self.session, params) self.bucket = make_s3_files(self.session, key1=u'\u2713') self.bucket2 = create_bucket(self.session) diff --git a/tests/integration/test_smoke.py b/tests/integration/test_smoke.py new file mode 100644 index 000000000000..f07f3a23bf12 --- /dev/null +++ b/tests/integration/test_smoke.py @@ -0,0 +1,126 @@ +# Copyright 2014 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import re +import random +from nose.tools import assert_equal + +from awscli.testutils import aws + + +# These are a list of commands that we should run. +# We're just verifying that we can properly send a no-arg request +# and that we can parse any response that comes back. +COMMANDS = [ + 'autoscaling describe-account-limits', + 'autoscaling describe-adjustment-types', + 'cloudformation describe-stacks', + 'cloudformation list-stacks', + 'cloudsearch describe-domains', + 'cloudsearch list-domain-names', + 'cloudtrail describe-trails', + 'cloudwatch list-metrics', + 'cognito-identity list-identity-pools --max-results 1', + 'datapipeline list-pipelines', + 'directconnect describe-connections', + 'dynamodb list-tables', + 'ec2 describe-instances', + 'ec2 describe-regions', + 'elasticache describe-cache-clusters', + 'elb describe-load-balancers', + 'emr list-clusters', + 'iam list-users', + 'kinesis list-streams', + 'logs describe-log-groups', + 'opsworks describe-stacks', + 'rds describe-db-instances', + 'redshift describe-clusters', + 'route53 list-hosted-zones', + 'route53domains list-domains', + 's3api list-buckets', + 's3 ls', + 'ses list-identities', + 'sns list-topics', + 'sqs list-queues', + 'storagegateway list-gateways', + 'swf list-domains --registration-status REGISTERED', +] + + +# A list of commands that generate error messages. The idea is to try to have +# at least one command for each service. +# +# This verifies that service errors are properly displayed to the user, as +# opposed to either silently failing or inproperly handling the error responses +# and not displaying something useful. Each command tries to call an operation +# with an identifier that does not exist, and part of the identifier is also +# randomly generated to help ensure that is the case. +ERROR_COMMANDS = [ + 'autoscaling attach-instances --auto-scaling-group-name %s', + 'cloudformation cancel-update-stack --stack-name %s', + 'cloudsearch describe-suggesters --domain-name %s', + 'cloudtrail get-trail-status --name %s', + 'cognito-identity delete-identity-pool --identity-pool-id %s', + 'datapipeline delete-pipeline --pipeline-id %s', + 'directconnect delete-connection --connection-id %s', + 'dynamodb delete-table --table-name %s', + 'ec2 terminate-instances --instance-ids %s', + 'elasticache delete-cache-cluster --cache-cluster-id %s', + 'elb describe-load-balancers --load-balancer-names %s', + 'emr list-instances --cluster-id %s', + 'iam delete-user --user-name %s', + 'kinesis delete-stream --stream-name %s', + 'logs delete-log-group --log-group-name %s', + 'opsworks delete-app --app-id %s', + 'rds delete-db-instance --db-instance-identifier %s', + 'redshift delete-cluster --cluster-identifier %s', + 'route53 delete-hosted-zone --id %s', + 'route53domains get-domain-detail --domain-name %s', + 's3api head-bucket --bucket %s', + 'ses set-identity-dkim-enabled --identity %s --dkim-enabled', + 'sns delete-endpoint --endpoint-arn %s', + 'sqs delete-queue --queue-url %s', + # --gateway-arn has min length client side validation + # so we have to generate an identifier that's long enough. + ('storagegateway delete-gateway --gateway-arn ' + 'foo-cli-test-foo-cli-test-foo-cli-test-%s'), + 'swf deprecate-domain --name %s', +] + + +def test_can_make_success_request(): + for cmd in COMMANDS: + yield _run_successful_aws_command, cmd + + +def _run_successful_aws_command(command_string): + result = aws(command_string) + assert_equal(result.rc, 0) + assert_equal(result.stderr, '') + + +def test_display_error_message(): + identifier = 'foo-awscli-test-%s' % random.randint(1000, 100000) + for cmd in ERROR_COMMANDS: + yield _run_error_aws_command, cmd % identifier + + +def _run_error_aws_command(command_string): + result = aws(command_string) + assert_equal(result.rc, 255) + error_message = re.compile( + 'A \w+ error \(.+\) occurred when calling the \w+ operation: \w+') + match = error_message.search(result.stderr) + if match is None: + raise AssertionError( + 'Error message was not displayed for command "%s": %s' % ( + command_string, result.stderr)) diff --git a/tests/unit/cloudsearch/test_cloudsearch.py b/tests/unit/cloudsearch/test_cloudsearch.py index ee988978aad0..61356ce6ffbe 100644 --- a/tests/unit/cloudsearch/test_cloudsearch.py +++ b/tests/unit/cloudsearch/test_cloudsearch.py @@ -46,7 +46,7 @@ def test_flattened(self): 'DomainName': 'abc123', 'IndexField.IndexFieldName': 'foo', 'IndexField.IndexFieldType': 'int', - 'IndexField.IntOptions.DefaultValue': '10', + 'IndexField.IntOptions.DefaultValue': 10, 'IndexField.IntOptions.SearchEnabled': 'false' } self.assert_params_for_cmd(cmdline, result) diff --git a/tests/unit/cloudwatch/test_put_metric_data.py b/tests/unit/cloudwatch/test_put_metric_data.py index 0487c21312c4..51aa3e833954 100644 --- a/tests/unit/cloudwatch/test_put_metric_data.py +++ b/tests/unit/cloudwatch/test_put_metric_data.py @@ -21,9 +21,9 @@ class TestPutMetricData(BaseAWSCommandParamsTest): expected_output = { 'MetricData.member.1.MetricName': 'FreeMemoryBytes', - 'MetricData.member.1.Timestamp': '2013-08-22T10:58:12.283000+00:00', + 'MetricData.member.1.Timestamp': '2013-08-22T10:58:12.283000Z', 'MetricData.member.1.Unit': 'Bytes', - 'MetricData.member.1.Value': '9130160128', + 'MetricData.member.1.Value': 9130160128, 'Namespace': '"Foo/Bar"' } diff --git a/tests/unit/customizations/datapipeline/test_arg_serialize.py b/tests/unit/customizations/datapipeline/test_arg_serialize.py index c9869e464ed4..8e4e649941c2 100644 --- a/tests/unit/customizations/datapipeline/test_arg_serialize.py +++ b/tests/unit/customizations/datapipeline/test_arg_serialize.py @@ -69,7 +69,7 @@ def test_put_pipeline_definition_with_json(self): }, ]}] } - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) class TestErrorMessages(BaseAWSCommandParamsTest): diff --git a/tests/unit/customizations/emr/test_add_instance_groups.py b/tests/unit/customizations/emr/test_add_instance_groups.py index bf6e473a4f68..dc15fa7da214 100644 --- a/tests/unit/customizations/emr/test_add_instance_groups.py +++ b/tests/unit/customizations/emr/test_add_instance_groups.py @@ -45,37 +45,35 @@ class TestAddInstanceGroups(BaseAWSCommandParamsTest): prefix = 'emr add-instance-groups --cluster-id J-ABCD --instance-groups' + def assert_error_message_has_field_name(self, error_msg, field_name): + self.assertIn('Missing required parameter', error_msg) + self.assertIn(field_name, error_msg) + def test_instance_groups_default_name_market(self): cmd = self.prefix cmd += ' InstanceGroupType=TASK,InstanceCount=10,InstanceType=m2.large' result = {'JobFlowId': 'J-ABCD', 'InstanceGroups': DEFAULT_INSTANCE_GROUPS} - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_instance_groups_missing_instance_group_type_error(self): cmd = self.prefix + ' Name=Task,InstanceType=m1.small,' +\ 'InstanceCount=5' - expect_error_msg = "\nThe following required parameters are missing" +\ - " for structure:: InstanceGroupType\n" result = self.run_cmd(cmd, 255) - self.assertEquals(expect_error_msg, result[1]) + self.assert_error_message_has_field_name(result[1], 'InstanceGroupType') def test_instance_groups_missing_instance_type_error(self): cmd = self.prefix + ' Name=Task,InstanceGroupType=Task,' +\ 'InstanceCount=5' - expect_error_msg = "\nThe following required parameters are missing" +\ - " for structure:: InstanceType\n" - result = self.run_cmd(cmd, 255) - self.assertEquals(expect_error_msg, result[1]) + stderr = self.run_cmd(cmd, 255)[1] + self.assert_error_message_has_field_name(stderr, 'InstanceType') def test_instance_groups_missing_instance_count_error(self): cmd = self.prefix + ' Name=Task,InstanceGroupType=Task,' +\ 'InstanceType=m1.xlarge' - expect_error_msg = "\nThe following required parameters are missing" +\ - " for structure:: InstanceCount\n" - result = self.run_cmd(cmd, 255) - self.assertEquals(expect_error_msg, result[1]) + stderr = self.run_cmd(cmd, 255)[1] + self.assert_error_message_has_field_name(stderr, 'InstanceCount') def test_instance_groups_all_fields(self): cmd = self.prefix + ' InstanceGroupType=MASTER,Name="MasterGroup",' +\ @@ -109,7 +107,7 @@ def test_instance_groups_all_fields(self): result = {'JobFlowId': 'J-ABCD', 'InstanceGroups': expected_instance_groups} - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) @patch('awscli.customizations.emr.emrutils.call') def test_constructed_result(self, call_patch): diff --git a/tests/unit/customizations/emr/test_add_steps.py b/tests/unit/customizations/emr/test_add_steps.py index e90088f07860..aa75f56e5779 100644 --- a/tests/unit/customizations/emr/test_add_steps.py +++ b/tests/unit/customizations/emr/test_add_steps.py @@ -113,7 +113,7 @@ def test_default_step_type_name_action_on_failure(self): } ] } - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_custom_jar_step_missing_jar(self): cmd = self.prefix + 'Name=CustomJarMissingJar' @@ -147,7 +147,7 @@ def test_custom_jar_step_with_all_fields(self): } ] } - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_streaming_step_with_default_fields(self): cmd = self.prefix + 'Type=Streaming,' + self.STREAMING_ARGS @@ -160,7 +160,7 @@ def test_streaming_step_with_default_fields(self): } ] } - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_streaming_step_missing_args(self): cmd = self.prefix + 'Type=Streaming' @@ -184,7 +184,7 @@ def test_streaming_jar_with_all_fields(self): } ] } - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_hive_step_with_default_fields(self): cmd = self.prefix + 'Type=Hive,' + self.HIVE_BASIC_ARGS @@ -196,7 +196,7 @@ def test_hive_step_with_default_fields(self): 'HadoopJarStep': self.HIVE_DEFAULT_HADOOP_JAR_STEP }] } - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_hive_step_missing_args(self): cmd = self.prefix + 'Type=Hive' @@ -220,7 +220,7 @@ def test_hive_step_with_all_fields(self): 'HadoopJarStep': self.HIVE_DEFAULT_HADOOP_JAR_STEP }] } - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_pig_step_with_default_fields(self): cmd = self.prefix + 'Type=Pig,' + self.PIG_BASIC_ARGS @@ -232,7 +232,7 @@ def test_pig_step_with_default_fields(self): 'HadoopJarStep': self.PIG_DEFAULT_HADOOP_JAR_STEP }] } - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_pig_missing_args(self): cmd = self.prefix + 'Type=Pig' @@ -257,7 +257,7 @@ def test_pig_step_with_all_fields(self): } ] } - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_impala_step_with_default_fields(self): test_step_config = 'Type=Impala,' + \ @@ -271,7 +271,7 @@ def test_impala_step_with_default_fields(self): 'HadoopJarStep': self.IMPALA_BASIC_HADOOP_JAR_STEP }] } - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_impala_missing_args(self): cmd = self.prefix + 'Type=Impala' @@ -296,7 +296,7 @@ def test_impala_step_with_all_fields(self): } ] } - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_all_step_types(self): test_step_config = 'Jar=s3://mybucket/mytest.jar ' + \ @@ -330,7 +330,7 @@ def test_all_step_types(self): } ] } - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_all_step_types_from_json(self): data_path = os.path.join( @@ -378,7 +378,7 @@ def test_all_step_types_from_json(self): } ] } - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) if __name__ == "__main__": unittest.main() diff --git a/tests/unit/customizations/emr/test_add_tags.py b/tests/unit/customizations/emr/test_add_tags.py index 6cbe9f1f0510..a81664c05c02 100644 --- a/tests/unit/customizations/emr/test_add_tags.py +++ b/tests/unit/customizations/emr/test_add_tags.py @@ -24,7 +24,7 @@ def test_add_tags_key_value(self): result = {'ResourceId': 'j-ABC123456', 'Tags': [{'Key': 'k1', 'Value': 'v1'}, {'Key': 'k2', 'Value': 'v2'}]} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_add_tags_key_with_empty_value(self): args = ' --resource-id j-ABC123456 --tags k1=v1 k2 k3=v3' @@ -33,7 +33,7 @@ def test_add_tags_key_with_empty_value(self): 'Tags': [{'Key': 'k1', 'Value': 'v1'}, {'Key': 'k2', 'Value': ''}, {'Key': 'k3', 'Value': 'v3'}]} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_add_tags_key_value_space(self): cmdline = ['emr', 'add-tags', '--resource-id', 'j-ABC123456', '--tags', @@ -42,7 +42,7 @@ def test_add_tags_key_value_space(self): 'Tags': [{'Key': 'k1', 'Value': 'v1'}, {'Key': 'k2', 'Value': ''}, {'Key': 'k3', 'Value': 'v3 v4'}]} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) if __name__ == "__main__": unittest.main() diff --git a/tests/unit/customizations/emr/test_create_cluster.py b/tests/unit/customizations/emr/test_create_cluster.py index 461fee776143..5d9707340dbf 100644 --- a/tests/unit/customizations/emr/test_create_cluster.py +++ b/tests/unit/customizations/emr/test_create_cluster.py @@ -339,8 +339,13 @@ def test_quick_start(self): } + + def assert_error_message_has_field_name(self, error_msg, field_name): + self.assertIn('Missing required parameter', error_msg) + self.assertIn(field_name, error_msg) + def test_default_cmd(self): - self.assert_params_for_cmd(DEFAULT_CMD, DEFAULT_RESULT) + self.assert_params_for_cmd2(DEFAULT_CMD, DEFAULT_RESULT) def test_cluster_without_service_role_and_instance_profile(self): cmd = ('emr create-cluster --ami-version 3.0.4 ' @@ -348,7 +353,7 @@ def test_cluster_without_service_role_and_instance_profile(self): result = copy.deepcopy(DEFAULT_RESULT) del result['JobFlowRole'] del result['ServiceRole'] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_cluster_with_service_role_and_instance_profile(self): cmd = ('emr create-cluster --ami-version 3.0.4' @@ -358,44 +363,44 @@ def test_cluster_with_service_role_and_instance_profile(self): result = copy.deepcopy(DEFAULT_RESULT) result['JobFlowRole'] = 'Ec2_InstanceProfile' result['ServiceRole'] = 'ServiceRole' - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_cluster_default_roles_overrides(self): cmd = (DEFAULT_CMD + '--service-role ServiceRole ' '--ec2-attributes InstanceProfile=Ec2_InstanceProfile') - self.assert_params_for_cmd(cmd, DEFAULT_RESULT) + self.assert_params_for_cmd2(cmd, DEFAULT_RESULT) def test_cluster_name_no_space(self): cmd = DEFAULT_CMD + '--name MyCluster' result = copy.deepcopy(DEFAULT_RESULT) result['Name'] = 'MyCluster' - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_cluster_name_with_space(self): cmd = DEFAULT_CMD.split() + ['--name', 'My Cluster'] result = copy.deepcopy(DEFAULT_RESULT) result['Name'] = 'My Cluster' - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_ami_version(self): cmd = DEFAULT_CMD + '--ami-version 3.0.4' result = copy.deepcopy(DEFAULT_RESULT) result['AmiVersion'] = '3.0.4' - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_log_uri(self): test_log_uri = 's3://test/logs' cmd = DEFAULT_CMD + '--log-uri ' + test_log_uri result = copy.deepcopy(DEFAULT_RESULT) result['LogUri'] = test_log_uri - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_additional_info(self): test_info = '{ami32: "ami-82e305f5"}' cmd = DEFAULT_CMD.split() + ['--additional-info', test_info] result = copy.deepcopy(DEFAULT_RESULT) result['AdditionalInfo'] = test_info - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_auto_terminte(self): cmd = ('emr create-cluster --use-default-roles --ami-version 3.0.4 ' @@ -405,7 +410,7 @@ def test_auto_terminte(self): instances = copy.deepcopy(DEFAULT_INSTANCES) instances['KeepJobFlowAliveWhenNoSteps'] = False result['Instances'] = instances - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_auto_terminate_and_no_auto_terminate(self): cmd = (DEFAULT_CMD + '--ami-version 3.0.4 ' + @@ -422,11 +427,11 @@ def test_termination_protected(self): instances = copy.deepcopy(DEFAULT_INSTANCES) instances['TerminationProtected'] = True result['Instances'] = instances - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_no_termination_protected(self): cmd = DEFAULT_CMD + '--no-termination-protected' - self.assert_params_for_cmd(cmd, DEFAULT_RESULT) + self.assert_params_for_cmd2(cmd, DEFAULT_RESULT) def test_termination_protected_and_no_termination_protected(self): cmd = DEFAULT_CMD + \ @@ -439,13 +444,13 @@ def test_termination_protected_and_no_termination_protected(self): def test_visible_to_all_users(self): cmd = DEFAULT_CMD + '--visible-to-all-users' - self.assert_params_for_cmd(cmd, DEFAULT_RESULT) + self.assert_params_for_cmd2(cmd, DEFAULT_RESULT) def test_no_visible_to_all_users(self): cmd = DEFAULT_CMD + '--no-visible-to-all-users' result = copy.deepcopy(DEFAULT_RESULT) result['VisibleToAllUsers'] = False - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_visible_to_all_users_and_no_visible_to_all_users(self): cmd = DEFAULT_CMD + '--visible-to-all-users --no-visible-to-all-users' @@ -462,7 +467,7 @@ def test_tags(self): {'Key': 'k2', 'Value': ''}, {'Key': 'k3', 'Value': 'spaces v3'}] result['Tags'] = tags - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_enable_debugging(self): cmd = DEFAULT_CMD + '--log-uri s3://test/logs --enable-debugging' @@ -481,7 +486,7 @@ def test_enable_debugging(self): } }] result['Steps'] = debugging_config + result['Steps'] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_enable_debugging_no_log_uri(self): cmd = DEFAULT_CMD + '--enable-debugging' @@ -507,7 +512,7 @@ def test_instance_groups_default_name_market(self): 'InstanceGroupType=MASTER,InstanceCount=1,InstanceType=m1.large ' 'InstanceGroupType=CORE,InstanceCount=1,InstanceType=m1.large ' 'InstanceGroupType=TASK,InstanceCount=1,InstanceType=m1.large ') - self.assert_params_for_cmd(cmd, DEFAULT_RESULT) + self.assert_params_for_cmd2(cmd, DEFAULT_RESULT) def test_instance_groups_instance_group_type_mismatch_cases(self): cmd = ( @@ -517,7 +522,7 @@ def test_instance_groups_instance_group_type_mismatch_cases(self): 'InstanceType=m1.large Name=CORE,InstanceGroupType=cORE,' 'InstanceCount=1,InstanceType=m1.large Name=TASK,' 'InstanceGroupType=tAsK,InstanceCount=1,InstanceType=m1.large') - self.assert_params_for_cmd(cmd, DEFAULT_RESULT) + self.assert_params_for_cmd2(cmd, DEFAULT_RESULT) def test_instance_groups_instance_type_and_count(self): cmd = ( @@ -534,7 +539,7 @@ def test_instance_groups_instance_type_and_count(self): 'Market': 'ON_DEMAND', 'InstanceType': 'm1.large'}] } - self.assert_params_for_cmd(cmd, expected_result) + self.assert_params_for_cmd2(cmd, expected_result) cmd = ( 'emr create-cluster --use-default-roles --ami-version 3.0.4 ' '--instance-type m1.large --instance-count 3') @@ -556,7 +561,7 @@ def test_instance_groups_instance_type_and_count(self): 'InstanceType': 'm1.large' }] } - self.assert_params_for_cmd(cmd, expected_result) + self.assert_params_for_cmd2(cmd, expected_result) def test_instance_groups_missing_required_parameter_error(self): cmd = ( @@ -609,11 +614,8 @@ def test_instance_groups_missing_instance_group_type_error(self): '--auto-terminate ' '--instance-groups ' 'Name=Master,InstanceCount=1,InstanceType=m1.small') - expect_error_msg = ( - '\nThe following required parameters are missing' - ' for structure:: InstanceGroupType\n') - result = self.run_cmd(cmd, 255) - self.assertEquals(expect_error_msg, result[1]) + stderr = self.run_cmd(cmd, 255)[1] + self.assert_error_message_has_field_name(stderr, 'InstanceGroupType') def test_instance_groups_missing_instance_type_error(self): cmd = ( @@ -624,8 +626,8 @@ def test_instance_groups_missing_instance_type_error(self): expect_error_msg = ( '\nThe following required parameters are missing' ' for structure:: InstanceType\n') - result = self.run_cmd(cmd, 255) - self.assertEquals(expect_error_msg, result[1]) + stderr = self.run_cmd(cmd, 255)[1] + self.assert_error_message_has_field_name(stderr, 'InstanceType') def test_instance_groups_missing_instance_count_error(self): cmd = ( @@ -633,11 +635,8 @@ def test_instance_groups_missing_instance_count_error(self): '--auto-terminate ' '--instance-groups ' 'Name=Master,InstanceGroupType=MASTER,InstanceType=m1.xlarge') - expect_error_msg = ( - '\nThe following required parameters are missing' - ' for structure:: InstanceCount\n') - result = self.run_cmd(cmd, 255) - self.assertEquals(expect_error_msg, result[1]) + stderr = self.run_cmd(cmd, 255)[1] + self.assert_error_message_has_field_name(stderr, 'InstanceCount') def test_instance_groups_from_json_file(self): data_path = os.path.join( @@ -667,7 +666,7 @@ def test_instance_groups_from_json_file(self): 'InstanceType': 'm1.xlarge' } ] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_ec2_attributes_no_az(self): cmd = DEFAULT_CMD + ( @@ -677,13 +676,13 @@ def test_ec2_attributes_no_az(self): result['Instances']['Ec2KeyName'] = 'testkey' result['Instances']['Ec2SubnetId'] = 'subnet-123456' result['JobFlowRole'] = 'EMR_EC2_DefaultRole' - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_ec2_attributes_az(self): cmd = DEFAULT_CMD + '--ec2-attributes AvailabilityZone=us-east-1a' result = copy.deepcopy(DEFAULT_RESULT) result['Instances']['Placement'] = {'AvailabilityZone': 'us-east-1a'} - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_ec2_attributes_subnet_az_error(self): cmd = DEFAULT_CMD + '--ec2-attributes ' + \ @@ -702,7 +701,7 @@ def test_ec2_attributes_with_subnet_from_json_file(self): result['Instances']['Ec2KeyName'] = 'testkey' result['Instances']['Ec2SubnetId'] = 'subnet-123456' result['JobFlowRole'] = 'EMR_EC2_DefaultRole' - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_ec2_attributes_with_az_from_json_file(self): data_path = os.path.join( @@ -712,16 +711,13 @@ def test_ec2_attributes_with_az_from_json_file(self): result['Instances']['Ec2KeyName'] = 'testkey' result['Instances']['Placement'] = {'AvailabilityZone': 'us-east-1a'} result['JobFlowRole'] = 'EMR_EC2_DefaultRole' - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) # Bootstrap Actions test cases def test_bootstrap_actions_missing_path_error(self): cmd = DEFAULT_CMD + '--bootstrap-actions Name=ba1,Args=arg1,arg2' - expect_error_msg = ( - '\nThe following required parameters are missing ' - 'for structure:: Path\n') - result = self.run_cmd(cmd, 255) - self.assertEquals(expect_error_msg, result[1]) + stderr = self.run_cmd(cmd, 255)[1] + self.assert_error_message_has_field_name(stderr, 'Path') def test_bootstrap_actions_with_all_fields(self): cmd = DEFAULT_CMD + ( @@ -731,7 +727,7 @@ def test_bootstrap_actions_with_all_fields(self): result = copy.deepcopy(DEFAULT_RESULT) result['BootstrapActions'] = TEST_BA - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_bootstrap_actions_exceed_maximum_error(self): cmd = DEFAULT_CMD + ' --bootstrap-actions' @@ -771,7 +767,7 @@ def test_boostrap_actions_with_default_fields(self): {'Path': 's3://test/ba2'} } ] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_bootstrap_actions_from_json_file(self): data_path = os.path.join( @@ -791,14 +787,14 @@ def test_bootstrap_actions_from_json_file(self): "Args": ["arg1", "arg2"]} } ] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) # Applications test cases def test_install_hive_with_defaults(self): cmd = DEFAULT_CMD + '--applications Name=Hive' result = copy.deepcopy(DEFAULT_RESULT) result['Steps'] = [INSTALL_HIVE_STEP] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_install_hive_with_profile_region(self): self.driver.session.set_config_variable('region', 'cn-north-1') @@ -807,37 +803,37 @@ def test_install_hive_with_profile_region(self): replace('us-east-1', 'cn-north-1') result = copy.deepcopy(DEFAULT_RESULT) result['Steps'] = [json.loads(HIVE_STEP)] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_install_hive_site(self): cmdline = (DEFAULT_CMD + '--applications Name=Hive,' 'Args=[--hive-site=s3://test/hive-conf/hive-site.xml]') result = copy.deepcopy(DEFAULT_RESULT) result['Steps'] = [INSTALL_HIVE_STEP, INSTALL_HIVE_SITE_STEP] - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) cmdline = (DEFAULT_CMD + '--applications Name=Hive,' 'Args=[--hive-site=s3://test/hive-conf/hive-site.xml,k1]') - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_install_pig_with_defaults(self): cmd = DEFAULT_CMD + '--applications Name=Pig' result = copy.deepcopy(DEFAULT_RESULT) result['Steps'] = [INSTALL_PIG_STEP] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_install_ganglia(self): cmd = DEFAULT_CMD + '--applications Name=Ganglia' result = copy.deepcopy(DEFAULT_RESULT) result['BootstrapActions'] = [INSTALL_GANGLIA_BA] result.pop('Steps') - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_install_impala_with_defaults(self): cmd = DEFAULT_CMD + '--applications Name=Impala' result = copy.deepcopy(DEFAULT_RESULT) result['BootstrapActions'] = [INSTALL_IMPALA_BA] result.pop('Steps') - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_install_impala_with_all_fields(self): cmd = DEFAULT_CMD + \ @@ -848,14 +844,14 @@ def test_install_impala_with_all_fields(self): ['--impala-conf', 'arg1', 'arg2'] result['BootstrapActions'] = [ba] result.pop('Steps') - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_install_hbase(self): cmd = DEFAULT_CMD + '--applications Name=hbase' result = copy.deepcopy(DEFAULT_RESULT) result['BootstrapActions'] = [INSTALL_HBASE_BA] result['Steps'] = [INSTALL_HBASE_STEP] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_install_mapr_with_args(self): cmd = DEFAULT_CMD + \ @@ -863,7 +859,7 @@ def test_install_mapr_with_args(self): result = copy.deepcopy(DEFAULT_RESULT) result['NewSupportedProducts'] = [INSTALL_MAPR_PRODUCT] result.pop('Steps') - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_install_mapr_without_args(self): cmd = DEFAULT_CMD + \ @@ -875,7 +871,7 @@ def test_install_mapr_without_args(self): 'Args': []} ] result.pop('Steps') - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_supported_products(self): cmd = DEFAULT_CMD + ( @@ -886,7 +882,7 @@ def test_supported_products(self): result = copy.deepcopy(DEFAULT_RESULT) result['NewSupportedProducts'] = INSTALL_SUPPORTED_PRODUCTS result.pop('Steps') - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_applications_all_types(self): cmd = DEFAULT_CMD + ( @@ -900,7 +896,7 @@ def test_applications_all_types(self): result['Steps'] = step_list result['BootstrapActions'] = ba_list result['NewSupportedProducts'] = [INSTALL_MAPR_PRODUCT] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_applications_all_types_from_json_file(self): data_path = os.path.join( @@ -917,7 +913,7 @@ def test_applications_all_types_from_json_file(self): result['Steps'] = step_list result['BootstrapActions'] = ba_list result['NewSupportedProducts'] = [INSTALL_MAPR_PRODUCT] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) # Steps test cases def test_wrong_step_type_error(self): @@ -931,7 +927,7 @@ def test_default_step_type_name_action_on_failure(self): cmd = DEFAULT_CMD + '--steps Jar=s3://mybucket/mytest.jar' result = copy.deepcopy(DEFAULT_RESULT) result['Steps'] += [CUSTOM_JAR_STEP] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_custom_jar_step_missing_jar(self): cmd = DEFAULT_CMD + '--steps Name=CustomJarMissingJar' @@ -957,7 +953,7 @@ def test_custom_jar_step_with_all_fields(self): ] result = copy.deepcopy(DEFAULT_RESULT) result['Steps'] += expected_steps - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_streaming_step_with_default_fields(self): cmd = DEFAULT_CMD + '--steps Type=Streaming,' + STREAMING_ARGS @@ -967,7 +963,7 @@ def test_streaming_step_with_default_fields(self): 'ActionOnFailure': 'CONTINUE', 'HadoopJarStep': STREAMING_HADOOP_JAR_STEP} ] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_streaming_step_missing_args(self): cmd = DEFAULT_CMD + '--steps Type=Streaming' @@ -987,14 +983,14 @@ def test_streaming_jar_with_all_fields(self): 'ActionOnFailure': 'CANCEL_AND_WAIT', 'HadoopJarStep': STREAMING_HADOOP_JAR_STEP} ] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_hive_step_with_default_fields(self): cmd = DEFAULT_CMD + ( '--applications Name=Hive --steps Type=Hive,' + HIVE_BASIC_ARGS) result = copy.deepcopy(DEFAULT_RESULT) result['Steps'] = [INSTALL_HIVE_STEP, HIVE_DEFAULT_STEP] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_hive_step_missing_args(self): cmd = DEFAULT_CMD + '--applications Name=Hive --steps Type=Hive' @@ -1012,14 +1008,14 @@ def test_hive_step_with_all_fields(self): result = copy.deepcopy(DEFAULT_RESULT) install_step = copy.deepcopy(INSTALL_HIVE_STEP) result['Steps'] = [install_step, HIVE_BASIC_STEP] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_pig_step_with_default_fields(self): cmd = DEFAULT_CMD + ( '--applications Name=Pig --steps Type=Pig,' + PIG_BASIC_ARGS) result = copy.deepcopy(DEFAULT_RESULT) result['Steps'] = [INSTALL_PIG_STEP, PIG_DEFAULT_STEP] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_pig_missing_args(self): cmd = DEFAULT_CMD + '--applications Name=Pig --steps Type=Pig' @@ -1037,7 +1033,7 @@ def test_pig_step_with_all_fields(self): result = copy.deepcopy(DEFAULT_RESULT) install_step = copy.deepcopy(INSTALL_PIG_STEP) result['Steps'] = [install_step, PIG_BASIC_STEP] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_impala_step_with_default_fields(self): cmd = DEFAULT_CMD + ( @@ -1046,7 +1042,7 @@ def test_impala_step_with_default_fields(self): result = copy.deepcopy(DEFAULT_RESULT) result['BootstrapActions'] = [INSTALL_IMPALA_BA] result['Steps'] = [IMPALA_DEFAULT_STEP] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_impala_missing_args(self): cmd = DEFAULT_CMD + '--applications Name=Impala --steps Type=Impala' @@ -1067,7 +1063,7 @@ def test_impala_step_with_all_fields(self): step['Name'] = 'ImpalaBasicStep' step['ActionOnFailure'] = 'CANCEL_AND_WAIT' result['Steps'] = [step] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_restore_from_hbase(self): cmd = DEFAULT_CMD + ( @@ -1091,13 +1087,13 @@ def test_restore_from_hbase(self): 'Jar': '/home/hadoop/lib/hbase.jar'} } ] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) data_path = os.path.join( os.path.dirname(__file__), 'input_hbase_restore_from_backup.json') cmd = DEFAULT_CMD + ( '--applications Name=hbase --restore-from-hbase-backup ' 'file://' + data_path) - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) def test_missing_applications_for_steps(self): cmd = DEFAULT_CMD +\ @@ -1191,12 +1187,12 @@ def test_emr_fs_config(self): } result = copy.deepcopy(DEFAULT_RESULT) result['BootstrapActions'] = [emf_fs_ba_config] - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) data_path = os.path.join( os.path.dirname(__file__), 'input_emr_fs.json') cmd = DEFAULT_CMD + '--emrfs file://' + data_path - self.assert_params_for_cmd(cmd, result) + self.assert_params_for_cmd2(cmd, result) if __name__ == "__main__": unittest.main() diff --git a/tests/unit/customizations/emr/test_create_hbase_backup.py b/tests/unit/customizations/emr/test_create_hbase_backup.py index 7c79f7096b6e..6283896e6263 100644 --- a/tests/unit/customizations/emr/test_create_hbase_backup.py +++ b/tests/unit/customizations/emr/test_create_hbase_backup.py @@ -37,7 +37,7 @@ def test_create_hbase_backup(self): cmdline = self.prefix + args result = {'JobFlowId': 'j-ABCD', 'Steps': self.steps} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_create_hbase_backup_consitent(self): args = ' --cluster-id j-ABCD --dir s3://abc/ --consistent' @@ -47,7 +47,7 @@ def test_create_hbase_backup_consitent(self): steps[0]['HadoopJarStep']['Args'].append('--consistent') result = {'JobFlowId': 'j-ABCD', 'Steps': steps} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) if __name__ == "__main__": diff --git a/tests/unit/customizations/emr/test_disable_hbase_backup.py b/tests/unit/customizations/emr/test_disable_hbase_backup.py index f506d4d3d665..cdf9c9e918c0 100644 --- a/tests/unit/customizations/emr/test_disable_hbase_backup.py +++ b/tests/unit/customizations/emr/test_disable_hbase_backup.py @@ -41,7 +41,7 @@ def test_disable_hbase_backups_full(self): steps[0]['HadoopJarStep']['Args'].append(self.DISABLE_FULL_BACKUP) result = {'JobFlowId': 'j-ABCD', 'Steps': steps} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_disable_hbase_backups_incremental(self): args = ' --cluster-id j-ABCD --incremental' @@ -51,7 +51,7 @@ def test_disable_hbase_backups_incremental(self): steps[0]['HadoopJarStep']['Args'].append(self.DISABLE_INCR_BACKUP) result = {'JobFlowId': 'j-ABCD', 'Steps': steps} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_disable_hbase_backups_both(self): args = ' --cluster-id j-ABCD --full --incremental' @@ -62,7 +62,7 @@ def test_disable_hbase_backups_both(self): steps[0]['HadoopJarStep']['Args'].append(self.DISABLE_INCR_BACKUP) result = {'JobFlowId': 'j-ABCD', 'Steps': steps} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_disable_hbase_backups_none(self): args = ' --cluster-id j-ABCD' diff --git a/tests/unit/customizations/emr/test_install_applications.py b/tests/unit/customizations/emr/test_install_applications.py index f66827320afe..b8beea3b31d0 100644 --- a/tests/unit/customizations/emr/test_install_applications.py +++ b/tests/unit/customizations/emr/test_install_applications.py @@ -72,16 +72,16 @@ def test_install_hive_site(self): result = {'JobFlowId': 'j-ABC123456', 'Steps': [INSTALL_HIVE_STEP, INSTALL_HIVE_SITE_STEP] } - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) cmdline = (self.prefix + 'Name=Hive,' 'Args=[--hive-site=s3://test/hive-conf/hive-site.xml,k1]') - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_install_hive_and_pig(self): cmdline = self.prefix + 'Name=Hive Name=Pig' result = {'JobFlowId': 'j-ABC123456', 'Steps': [INSTALL_HIVE_STEP, INSTALL_PIG_STEP]} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_install_pig_with_profile_region(self): self.driver.session.set_config_variable('region', 'cn-north-1') @@ -90,7 +90,7 @@ def test_install_pig_with_profile_region(self): replace('us-east-1', 'cn-north-1') result = {'JobFlowId': 'j-ABC123456', 'Steps': [json.loads(PIG_STEP)]} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_install_impala_error(self): cmdline = self.prefix + ' Name=Impala' diff --git a/tests/unit/customizations/emr/test_list_clusters.py b/tests/unit/customizations/emr/test_list_clusters.py index 9d32b06039a6..3a6bad11c155 100644 --- a/tests/unit/customizations/emr/test_list_clusters.py +++ b/tests/unit/customizations/emr/test_list_clusters.py @@ -30,25 +30,25 @@ def test_list_active_clusters(self): 'TERMINATING' ] } - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_list_terminated_clusters(self): args = '--terminated' cmdline = self.prefix + args result = {'ClusterStates': ['TERMINATED']} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_list_failed_clusters(self): args = '--failed' cmdline = self.prefix + args result = {'ClusterStates': ['TERMINATED_WITH_ERRORS']} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_list_multiple_states(self): args = '--cluster-states RUNNING WAITING TERMINATED' cmdline = self.prefix + args result = {'ClusterStates': ['RUNNING', 'WAITING', 'TERMINATED']} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_exclusive_states_filters(self): args = '--active --failed' diff --git a/tests/unit/customizations/emr/test_modify_cluster_attributes.py b/tests/unit/customizations/emr/test_modify_cluster_attributes.py index ba8c3e64f166..cb14c2b47798 100644 --- a/tests/unit/customizations/emr/test_modify_cluster_attributes.py +++ b/tests/unit/customizations/emr/test_modify_cluster_attributes.py @@ -23,25 +23,25 @@ def test_visible_to_all(self): args = ' --cluster-id j-ABC123456 --visible-to-all-users' cmdline = self.prefix + args result = {'JobFlowIds': ['j-ABC123456'], 'VisibleToAllUsers': True} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_no_visible_to_all(self): args = ' --cluster-id j-ABC123456 --no-visible-to-all-users' cmdline = self.prefix + args result = {'JobFlowIds': ['j-ABC123456'], 'VisibleToAllUsers': False} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_termination_protected(self): args = ' --cluster-id j-ABC123456 --termination-protected' cmdline = self.prefix + args result = {'JobFlowIds': ['j-ABC123456'], 'TerminationProtected': True} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_no_termination_protected(self): args = ' --cluster-id j-ABC123456 --no-termination-protected' cmdline = self.prefix + args result = {'JobFlowIds': ['j-ABC123456'], 'TerminationProtected': False} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_visible_to_all_and_no_visible_to_all(self): args = ' --cluster-id j-ABC123456 --no-visible-to-all-users'\ diff --git a/tests/unit/customizations/emr/test_restore_from_hbase_backup.py b/tests/unit/customizations/emr/test_restore_from_hbase_backup.py index 74c485201196..46d6d0272987 100644 --- a/tests/unit/customizations/emr/test_restore_from_hbase_backup.py +++ b/tests/unit/customizations/emr/test_restore_from_hbase_backup.py @@ -37,7 +37,7 @@ def test_restore_from_hbase_backup(self): cmdline = self.prefix + args result = {'JobFlowId': 'j-ABCD', 'Steps': self.steps} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_restore_from_hbase_backup_version(self): args = ' --cluster-id j-ABCD --dir s3://abc/ --backup-version DEF' @@ -48,7 +48,7 @@ def test_restore_from_hbase_backup_version(self): steps[0]['HadoopJarStep']['Args'].append('DEF') result = {'JobFlowId': 'j-ABCD', 'Steps': steps} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) if __name__ == "__main__": diff --git a/tests/unit/customizations/emr/test_schedule_hbase_backup.py b/tests/unit/customizations/emr/test_schedule_hbase_backup.py index 87c1f0b5fc8a..bf8c36d30452 100644 --- a/tests/unit/customizations/emr/test_schedule_hbase_backup.py +++ b/tests/unit/customizations/emr/test_schedule_hbase_backup.py @@ -46,7 +46,7 @@ def test_schedule_hbase_backup_full(self): cmdline = self.prefix + args result = {'JobFlowId': 'j-ABCD', 'Steps': self.default_steps} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_schedule_hbase_backup_full_upper_case(self): args = ' --cluster-id j-ABCD --dir s3://abc/ --type FULL' +\ @@ -54,7 +54,7 @@ def test_schedule_hbase_backup_full_upper_case(self): cmdline = self.prefix + args result = {'JobFlowId': 'j-ABCD', 'Steps': self.default_steps} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_schedule_hbase_backup_incremental_upper_case(self): args = ' --cluster-id j-ABCD --dir s3://abc/ --type INCREMENTAL' +\ @@ -83,7 +83,7 @@ def test_schedule_hbase_backup_incremental(self): result = {'JobFlowId': 'j-ABCD', 'Steps': steps} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_schedule_hbase_backup_wrong_type(self): args = ' --cluster-id j-ABCD --dir s3://abc/ --type wrong_type' +\ @@ -115,7 +115,7 @@ def test_schedule_hbase_backup_consistent(self): steps[0]['HadoopJarStep']['Args'].insert(5, '--consistent') result = {'JobFlowId': 'j-ABCD', 'Steps': steps} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_schedule_hbase_backup_start_time(self): args = ' --cluster-id j-ABCD --dir s3://abc/ --type full --interval' +\ @@ -126,7 +126,7 @@ def test_schedule_hbase_backup_start_time(self): steps[0]['HadoopJarStep']['Args'][10] = '2014-04-18T10:43:24-07:00' result = {'JobFlowId': 'j-ABCD', 'Steps': steps} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) if __name__ == "__main__": diff --git a/tests/unit/customizations/emr/test_terminate_clusters.py b/tests/unit/customizations/emr/test_terminate_clusters.py index 3846d0384981..f5e26e94b227 100644 --- a/tests/unit/customizations/emr/test_terminate_clusters.py +++ b/tests/unit/customizations/emr/test_terminate_clusters.py @@ -22,13 +22,13 @@ def test_cluster_id(self): args = ' --cluster-ids j-ABC123456' cmdline = self.prefix + args result = {'JobFlowIds': ['j-ABC123456']} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) def test_cluster_ids(self): args = ' --cluster-ids j-ABC123456 j-AAAAAAA' cmdline = self.prefix + args result = {'JobFlowIds': ['j-ABC123456', 'j-AAAAAAA']} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) if __name__ == "__main__": unittest.main() diff --git a/tests/unit/customizations/s3/__init__.py b/tests/unit/customizations/s3/__init__.py index 5ada082edd3e..5f3b3e1763cd 100644 --- a/tests/unit/customizations/s3/__init__.py +++ b/tests/unit/customizations/s3/__init__.py @@ -16,7 +16,7 @@ import string import six -from mock import patch +from mock import patch, Mock class S3HandlerBaseTest(unittest.TestCase): @@ -173,7 +173,7 @@ def list_contents(bucket, session): endpoint = service.get_endpoint(region) operation = service.get_operation('ListObjects') http_response, r_data = operation.call(endpoint, bucket=bucket) - return r_data['Contents'] + return r_data.get('Contents', []) def list_buckets(session): @@ -188,3 +188,24 @@ def list_buckets(session): html_response, response_data = operation.call(endpoint) contents = response_data['Buckets'] return contents + + +class MockStdIn(object): + """ + This class patches stdin in order to write a stream of bytes into + stdin. + """ + def __init__(self, input_bytes=b''): + input_data = six.BytesIO(input_bytes) + if six.PY3: + mock_object = Mock() + mock_object.buffer = input_data + else: + mock_object = input_data + self._patch = patch('sys.stdin', mock_object) + + def __enter__(self): + self._patch.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + self._patch.__exit__() diff --git a/tests/unit/customizations/s3/fake_session.py b/tests/unit/customizations/s3/fake_session.py index dfb93e42be4b..acca8d47c631 100644 --- a/tests/unit/customizations/s3/fake_session.py +++ b/tests/unit/customizations/s3/fake_session.py @@ -181,7 +181,7 @@ def put_object(self, kwargs): else: self.session.s3[bucket][key] = content else: - response_data['Errors'] = [{'Message': 'Bucket does not exist'}] + response_data['Error'] = {'Message': 'Bucket does not exist'} if self.session.md5_error: etag = "dsffsdg" # This etag should always raise an exception self.session.md5_error = False @@ -297,9 +297,9 @@ def delete_bucket(self, kwargs): if not self.session.s3[bucket]: self.session.s3.pop(bucket) else: - response_data['Errors'] = [{'Message': 'Bucket not empty'}] + response_data['Error'] = {'Message': 'Bucket not empty'} else: - response_data['Errors'] = [{'Message': 'Bucket does not exist'}] + response_data['Error'] = {'Message': 'Bucket does not exist'} response_data['ETag'] = '"%s"' % etag return FakeHttp(), response_data @@ -314,8 +314,9 @@ def list_buckets(self, kwargs): bucket_dict = {} bucket_dict['Name'] = bucket response_data['Buckets'].append(bucket_dict) - response_data['Contents'] = sorted(response_data['Buckets'], - key=lambda k: k['Name']) + if self.session.s3.keys(): + response_data['Contents'] = sorted(response_data['Buckets'], + key=lambda k: k['Name']) response_data['ETag'] = '"%s"' % etag return FakeHttp(), response_data @@ -350,6 +351,12 @@ def list_objects(self, kwargs): response_data['Contents'].append(key_dict) response_data['Contents'] = sorted(response_data['Contents'], key=lambda k: k['Key']) + contents = response_data.get('Contents', None) + if not contents and contents is not None: + response_data.pop('Contents') + common_prefixes = response_data.get('CommonPrefixes', None) + if not common_prefixes and common_prefixes is not None: + response_data.pop('CommonPrefixes') response_data['ETag'] = '"%s"' % etag return FakeHttp(), response_data diff --git a/tests/unit/customizations/s3/test_copy_params.py b/tests/unit/customizations/s3/test_copy_params.py index 69f992335ccc..c7cd37124a13 100644 --- a/tests/unit/customizations/s3/test_copy_params.py +++ b/tests/unit/customizations/s3/test_copy_params.py @@ -43,22 +43,14 @@ def setUp(self): self.parsed_response = {'ETag': '"120ea8a25e5d487bf68b5f7096440019"',} def assert_params(self, cmdline, result): - # Botocore injects the expect 100 continue header so we'll - # automatically add this here so each test doesn't need to specify - # this header. - result['headers']['Expect'] = '100-continue' - self.assert_params_for_cmd(cmdline, result, expected_rc=0, - ignore_params=['payload']) - self.assertIsInstance(self.last_params['payload'].getvalue(), - file_type) + foo = self.assert_params_for_cmd2(cmdline, result, expected_rc=0, + ignore_params=['body']) def test_simple(self): cmdline = self.prefix cmdline += self.file_path cmdline += ' s3://mybucket/mykey' - result = {'uri_params': {'Bucket': 'mybucket', - 'Key': 'mykey'}, - 'headers': {}} + result = {'bucket': u'mybucket', 'key': u'mykey'} self.assert_params(cmdline, result) def test_sse(self): @@ -66,9 +58,8 @@ def test_sse(self): cmdline += self.file_path cmdline += ' s3://mybucket/mykey' cmdline += ' --sse' - result = {'uri_params': {'Bucket': 'mybucket', - 'Key': 'mykey'}, - 'headers': {'x-amz-server-side-encryption': 'AES256'}} + result = {'bucket': u'mybucket', 'key': u'mykey', + 'server_side_encryption': 'AES256'} self.assert_params(cmdline, result) def test_storage_class(self): @@ -76,9 +67,8 @@ def test_storage_class(self): cmdline += self.file_path cmdline += ' s3://mybucket/mykey' cmdline += ' --storage-class REDUCED_REDUNDANCY' - result = {'uri_params': {'Bucket': 'mybucket', - 'Key': 'mykey'}, - 'headers': {'x-amz-storage-class': 'REDUCED_REDUNDANCY'}} + result = {'bucket': u'mybucket', 'key': u'mykey', + 'storage_class': u'REDUCED_REDUNDANCY'} self.assert_params(cmdline, result) def test_website_redirect(self): @@ -86,9 +76,9 @@ def test_website_redirect(self): cmdline += self.file_path cmdline += ' s3://mybucket/mykey' cmdline += ' --website-redirect /foobar' - result = {'uri_params': {'Bucket': 'mybucket', - 'Key': 'mykey'}, - 'headers': {'x-amz-website-redirect-location': '/foobar'}} + result = {'bucket': u'mybucket', + 'key': u'mykey', + 'website_redirect_location': u'/foobar'} self.assert_params(cmdline, result) def test_acl(self): @@ -96,9 +86,7 @@ def test_acl(self): cmdline += self.file_path cmdline += ' s3://mybucket/mykey' cmdline += ' --acl public-read' - result = {'uri_params': {'Bucket': 'mybucket', - 'Key': 'mykey'}, - 'headers': {'x-amz-acl': 'public-read'}} + result = {'bucket': 'mybucket', 'key': 'mykey', 'acl': 'public-read'} self.assert_params(cmdline, result) def test_content_params(self): @@ -109,12 +97,11 @@ def test_content_params(self): cmdline += ' --content-language piglatin' cmdline += ' --cache-control max-age=3600,must-revalidate' cmdline += ' --content-disposition attachment;filename="fname.ext"' - result = {'uri_params': {'Bucket': 'mybucket', - 'Key': 'mykey'}, - 'headers': {'Content-Encoding': 'x-gzip', - 'Content-Language': 'piglatin', - 'Content-Disposition': 'attachment;filename="fname.ext"', - 'Cache-Control': 'max-age=3600,must-revalidate'}} + result = {'bucket': 'mybucket', 'key': 'mykey', + 'content_encoding': 'x-gzip', + 'content_language': 'piglatin', + 'content_disposition': 'attachment;filename="fname.ext"', + 'cache_control': 'max-age=3600,must-revalidate'} self.assert_params(cmdline, result) def test_grants(self): @@ -123,10 +110,10 @@ def test_grants(self): cmdline += ' s3://mybucket/mykey' cmdline += ' --grants read=bob' cmdline += ' full=alice' - result = {'uri_params': {'Bucket': 'mybucket', - 'Key': 'mykey'}, - 'headers': {'x-amz-grant-full-control': 'alice', - 'x-amz-grant-read': 'bob'}} + result = {'bucket': u'mybucket', + 'grant_full_control': u'alice', + 'grant_read': u'bob', + 'key': u'mykey'} self.assert_params(cmdline, result) def test_grants_bad(self): @@ -142,9 +129,8 @@ def test_content_type(self): cmdline += self.file_path cmdline += ' s3://mybucket/mykey' cmdline += ' --content-type text/xml' - result = {'uri_params': {'Bucket': 'mybucket', - 'Key': 'mykey'}, - 'headers': {'Content-Type': 'text/xml'}} + result = {'bucket': u'mybucket', 'content_type': u'text/xml', + 'key': u'mykey'} self.assert_params(cmdline, result) diff --git a/tests/unit/customizations/s3/test_executor.py b/tests/unit/customizations/s3/test_executor.py index 9afaacd3ba22..1a559562ec32 100644 --- a/tests/unit/customizations/s3/test_executor.py +++ b/tests/unit/customizations/s3/test_executor.py @@ -15,6 +15,7 @@ import shutil import six from six.moves import queue +import sys import mock @@ -41,7 +42,7 @@ def tearDown(self): shutil.rmtree(self.temp_dir) def test_handles_io_request(self): - self.queue.put(IORequest(self.filename, 0, b'foobar')) + self.queue.put(IORequest(self.filename, 0, b'foobar', False)) self.queue.put(IOCloseRequest(self.filename)) self.queue.put(ShutdownThreadRequest()) self.io_thread.run() @@ -49,8 +50,8 @@ def test_handles_io_request(self): self.assertEqual(f.read(), b'foobar') def test_out_of_order_io_requests(self): - self.queue.put(IORequest(self.filename, 6, b'morestuff')) - self.queue.put(IORequest(self.filename, 0, b'foobar')) + self.queue.put(IORequest(self.filename, 6, b'morestuff', False)) + self.queue.put(IORequest(self.filename, 0, b'foobar', False)) self.queue.put(IOCloseRequest(self.filename)) self.queue.put(ShutdownThreadRequest()) self.io_thread.run() @@ -60,8 +61,8 @@ def test_out_of_order_io_requests(self): def test_multiple_files_in_queue(self): second_file = os.path.join(self.temp_dir, 'bar') open(second_file, 'w').close() - self.queue.put(IORequest(self.filename, 0, b'foobar')) - self.queue.put(IORequest(second_file, 0, b'otherstuff')) + self.queue.put(IORequest(self.filename, 0, b'foobar', False)) + self.queue.put(IORequest(second_file, 0, b'otherstuff', False)) self.queue.put(IOCloseRequest(second_file)) self.queue.put(IOCloseRequest(self.filename)) self.queue.put(ShutdownThreadRequest()) @@ -72,10 +73,24 @@ def test_multiple_files_in_queue(self): with open(second_file, 'rb') as f: self.assertEqual(f.read(), b'otherstuff') + def test_stream_requests(self): + # Test that offset has no affect on the order in which requests + # are written to stdout. The order of requests for a stream are + # first in first out. + self.queue.put(IORequest('nonexistant-file', 10, b'foobar', True)) + self.queue.put(IORequest('nonexistant-file', 6, b'otherstuff', True)) + # The thread should not try to close the file name because it is + # writing to stdout. If it does, the thread will fail because + # the file does not exist. + self.queue.put(ShutdownThreadRequest()) + with mock.patch('sys.stdout', new=six.StringIO()) as mock_stdout: + self.io_thread.run() + self.assertEqual(mock_stdout.getvalue(), 'foobarotherstuff') + class TestExecutor(unittest.TestCase): def test_shutdown_does_not_hang(self): - executor = Executor(2, queue.Queue(), False, + executor = Executor(2, queue.Queue(), False, False, 10, queue.Queue(maxsize=1)) with temporary_file('rb+') as f: executor.start() @@ -84,18 +99,108 @@ class FloodIOQueueTask(object): def __call__(self): for i in range(50): - executor.write_queue.put(IORequest(f.name, 0, b'foobar')) + executor.write_queue.put(IORequest(f.name, 0, + b'foobar', False)) executor.submit(FloodIOQueueTask()) executor.initiate_shutdown() executor.wait_until_shutdown() self.assertEqual(open(f.name, 'rb').read(), b'foobar') + class TestPrintThread(unittest.TestCase): + def setUp(self): + self.result_queue = queue.Queue() + + def assert_expected_output(self, print_task, expected_output, thread, + out_file): + with mock.patch(out_file, new=six.StringIO()) as mock_out: + self.result_queue.put(print_task) + self.result_queue.put(ShutdownThreadRequest()) + thread.run() + self.assertIn(expected_output, mock_out.getvalue()) + + def test_print(self): + print_task = PrintTask(message="Success", error=False) + + # Ensure a successful task is printed to stdout when + # ``quiet`` and ``only_show_errors`` is False. + thread = PrintThread(result_queue=self.result_queue, + quiet=False, only_show_errors=False) + self.assert_expected_output(print_task, 'Success', thread, + 'sys.stdout') + + def test_print_quiet(self): + print_task = PrintTask(message="Success", error=False) + + # Ensure a succesful task is not printed to stdout when + # ``quiet`` is True. + thread = PrintThread(result_queue=self.result_queue, + quiet=True, only_show_errors=False) + self.assert_expected_output(print_task, '', thread, 'sys.stdout') + + def test_print_only_show_errors(self): + print_task = PrintTask(message="Success", error=False) + + # Ensure a succesful task is not printed to stdout when + # ``only_show_errors`` is True. + thread = PrintThread(result_queue=self.result_queue, + quiet=False, only_show_errors=True) + self.assert_expected_output(print_task, '', thread, 'sys.stdout') + + def test_print_error(self): + print_task = PrintTask(message="Fail File.", error=True) + + # Ensure a failed task is printed to stderr when + # ``quiet`` and ``only_show_errors`` is False. + thread = PrintThread(result_queue=self.result_queue, + quiet=False, only_show_errors=False) + self.assert_expected_output(print_task, 'Fail File.', thread, + 'sys.stderr') + + def test_print_error_quiet(self): + print_task = PrintTask(message="Fail File.", error=True) + + # Ensure a failed task is not printed to stderr when + # ``quiet`` is True. + thread = PrintThread(result_queue=self.result_queue, + quiet=True, only_show_errors=False) + self.assert_expected_output(print_task, '', thread, 'sys.stderr') + + def test_print_error_only_show_errors(self): + print_task = PrintTask(message="Fail File.", error=True) + + # Ensure a failed task is printed to stderr when + # ``only_show_errors`` is True. + thread = PrintThread(result_queue=self.result_queue, + quiet=False, only_show_errors=True) + self.assert_expected_output(print_task, 'Fail File.', thread, + 'sys.stderr') + def test_print_warning(self): - result_queue = queue.Queue() print_task = PrintTask(message="Bad File.", warning=True) - thread = PrintThread(result_queue, False) - with mock.patch('sys.stdout', new=six.StringIO()) as mock_stdout: - thread._process_print_task(print_task) - self.assertIn("Bad File.", mock_stdout.getvalue()) + # Ensure a warned task is printed to stderr when + # ``quiet`` and ``only_show_errors`` is False. + thread = PrintThread(result_queue=self.result_queue, + quiet=False, only_show_errors=False) + self.assert_expected_output(print_task, 'Bad File.', thread, + 'sys.stderr') + + def test_print_warning_quiet(self): + print_task = PrintTask(message="Bad File.", warning=True) + + # Ensure a warned task is not printed to stderr when + # ``quiet`` is True. + thread = PrintThread(result_queue=self.result_queue, + quiet=True, only_show_errors=False) + self.assert_expected_output(print_task, '', thread, 'sys.stderr') + + def test_print_warning_only_show_errors(self): + print_task = PrintTask(message="Bad File.", warning=True) + + # Ensure a warned task is printed to stderr when + # ``only_show_errors`` is True. + thread = PrintThread(result_queue=self.result_queue, + quiet=False, only_show_errors=True) + self.assert_expected_output(print_task, 'Bad File.', thread, + 'sys.stderr') diff --git a/tests/unit/customizations/s3/test_fileinfo.py b/tests/unit/customizations/s3/test_fileinfo.py index 48a6651f42fb..bbee735fa047 100644 --- a/tests/unit/customizations/s3/test_fileinfo.py +++ b/tests/unit/customizations/s3/test_fileinfo.py @@ -21,6 +21,8 @@ from awscli.testutils import unittest from awscli.customizations.s3 import fileinfo +from awscli.customizations.s3.utils import MD5Error +from awscli.customizations.s3.fileinfo import FileInfo class TestSaveFile(unittest.TestCase): @@ -58,3 +60,25 @@ def test_makedir_other_exception(self, makedirs): fileinfo.save_file(self.filename, self.response_data, self.last_update) self.assertFalse(os.path.isfile(self.filename)) + + def test_stream_file(self): + with mock.patch('sys.stdout', new=six.StringIO()) as mock_stdout: + fileinfo.save_file(None, self.response_data, None, True) + self.assertEqual(mock_stdout.getvalue(), "foobar") + + def test_stream_file_md5_error(self): + with mock.patch('sys.stdout', new=six.StringIO()) as mock_stdout: + self.response_data['ETag'] = '"0"' + with self.assertRaises(MD5Error): + fileinfo.save_file(None, self.response_data, None, True) + # Make sure nothing is written to stdout. + self.assertEqual(mock_stdout.getvalue(), "") + + +class TestSetSizeFromS3(unittest.TestCase): + def test_set_size_from_s3(self): + file_info = FileInfo(src="bucket/key", endpoint=None) + with mock.patch('awscli.customizations.s3.fileinfo.operate') as op_mock: + op_mock.return_value = ({'ContentLength': 5}, None) + file_info.set_size_from_s3() + self.assertEqual(file_info.size, 5) diff --git a/tests/unit/customizations/s3/test_fileinfobuilder.py b/tests/unit/customizations/s3/test_fileinfobuilder.py index 439c006ad136..7d235e5728de 100644 --- a/tests/unit/customizations/s3/test_fileinfobuilder.py +++ b/tests/unit/customizations/s3/test_fileinfobuilder.py @@ -22,7 +22,8 @@ class TestFileInfoBuilder(unittest.TestCase): def test_info_setter(self): info_setter = FileInfoBuilder(service='service', endpoint='endpoint', source_endpoint='source_endpoint', - parameters='parameters') + parameters='parameters', + is_stream='is_stream') files = [FileStat(src='src', dest='dest', compare_key='compare_key', size='size', last_update='last_update', src_type='src_type', dest_type='dest_type', diff --git a/tests/unit/customizations/s3/test_s3handler.py b/tests/unit/customizations/s3/test_s3handler.py index 20bc3a62a858..2105d3495770 100644 --- a/tests/unit/customizations/s3/test_s3handler.py +++ b/tests/unit/customizations/s3/test_s3handler.py @@ -14,15 +14,19 @@ import os import random import sys -from awscli.testutils import unittest +import mock + +from awscli.testutils import unittest from awscli import EnvironmentVariables -from awscli.customizations.s3.s3handler import S3Handler +from awscli.customizations.s3.s3handler import S3Handler, S3StreamHandler from awscli.customizations.s3.fileinfo import FileInfo +from awscli.customizations.s3.tasks import CreateMultipartUploadTask, \ + UploadPartTask, CreateLocalFileTask from tests.unit.customizations.s3.fake_session import FakeSession from tests.unit.customizations.s3 import make_loc_files, clean_loc_files, \ make_s3_files, s3_cleanup, create_bucket, list_contents, list_buckets, \ - S3HandlerBaseTest + S3HandlerBaseTest, MockStdIn class S3HandlerTestDeleteList(S3HandlerBaseTest): @@ -140,7 +144,7 @@ def setUp(self): self.session = FakeSession() self.service = self.session.get_service('s3') self.endpoint = self.service.get_endpoint('us-east-1') - params = {'region': 'us-east-1', 'acl': ['private']} + params = {'region': 'us-east-1', 'acl': ['private'], 'quiet': True} self.s3_handler = S3Handler(self.session, params) self.s3_handler_multi = S3Handler(self.session, multi_threshold=10, chunksize=2, @@ -275,7 +279,7 @@ def setUp(self): self.session = FakeSession(True, True) self.service = self.session.get_service('s3') self.endpoint = self.service.get_endpoint('us-east-1') - params = {'region': 'us-east-1'} + params = {'region': 'us-east-1', 'quiet': True} self.s3_handler_multi = S3Handler(self.session, params, multi_threshold=10, chunksize=2) self.bucket = create_bucket(self.session) @@ -612,5 +616,167 @@ def test_bucket(self): self.assertEqual(orig_number_buckets, number_buckets) +class TestStreams(S3HandlerBaseTest): + def setUp(self): + super(TestStreams, self).setUp() + self.session = FakeSession() + self.service = self.session.get_service('s3') + self.endpoint = self.service.get_endpoint('us-east-1') + self.params = {'is_stream': True, 'region': 'us-east-1'} + + def test_pull_from_stream(self): + s3handler = S3StreamHandler(self.session, self.params, chunksize=2) + input_to_stdin = b'This is a test' + size = len(input_to_stdin) + # Retrieve the entire string. + with MockStdIn(input_to_stdin): + payload, is_amount_requested = s3handler._pull_from_stream(size) + data = payload.read() + self.assertTrue(is_amount_requested) + self.assertEqual(data, input_to_stdin) + # Ensure the function exits when there is nothing to read. + with MockStdIn(): + payload, is_amount_requested = s3handler._pull_from_stream(size) + data = payload.read() + self.assertFalse(is_amount_requested) + self.assertEqual(data, b'') + # Ensure the function does not grab too much out of stdin. + with MockStdIn(input_to_stdin): + payload, is_amount_requested = s3handler._pull_from_stream(size-2) + data = payload.read() + self.assertTrue(is_amount_requested) + self.assertEqual(data, input_to_stdin[:-2]) + # Retrieve the rest of standard in. + payload, is_amount_requested = s3handler._pull_from_stream(size) + data = payload.read() + self.assertFalse(is_amount_requested) + self.assertEqual(data, input_to_stdin[-2:]) + + def test_upload_stream_not_multipart_task(self): + s3handler = S3StreamHandler(self.session, self.params) + s3handler.executor = mock.Mock() + fileinfos = [FileInfo('filename', operation_name='upload', + is_stream=True, size=0)] + with MockStdIn(b'bar'): + s3handler._enqueue_tasks(fileinfos) + submitted_tasks = s3handler.executor.submit.call_args_list + # No multipart upload should have been submitted. + self.assertEqual(len(submitted_tasks), 1) + self.assertEqual(submitted_tasks[0][0][0].payload.read(), + b'bar') + + def test_upload_stream_is_multipart_task(self): + s3handler = S3StreamHandler(self.session, self.params, + multi_threshold=1) + s3handler.executor = mock.Mock() + fileinfos = [FileInfo('filename', operation_name='upload', + is_stream=True, size=0)] + with MockStdIn(b'bar'): + s3handler._enqueue_tasks(fileinfos) + submitted_tasks = s3handler.executor.submit.call_args_list + # This should be a multipart upload so multiple tasks + # should have been submitted. + self.assertEqual(len(submitted_tasks), 4) + self.assertEqual(submitted_tasks[1][0][0]._payload.read(), + b'b') + self.assertEqual(submitted_tasks[2][0][0]._payload.read(), + b'ar') + + def test_upload_stream_with_expected_size(self): + self.params['expected_size'] = 100000 + # With this large of expected size, the chunksize of 2 will have + # to change. + s3handler = S3StreamHandler(self.session, self.params, chunksize=2) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='upload', + is_stream=True) + with MockStdIn(b'bar'): + s3handler._enqueue_multipart_upload_tasks(fileinfo, b'') + submitted_tasks = s3handler.executor.submit.call_args_list + # Determine what the chunksize was changed to from one of the + # UploadPartTasks. + changed_chunk_size = submitted_tasks[1][0][0]._chunk_size + # New chunksize should have a total parts under 1000. + self.assertTrue(100000/changed_chunk_size < 1000) + + def test_upload_stream_enqueue_upload_task(self): + s3handler = S3StreamHandler(self.session, self.params) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='upload', + is_stream=True) + stdin_input = b'This is a test' + with MockStdIn(stdin_input): + num_parts = s3handler._enqueue_upload_tasks(None, 2, mock.Mock(), + fileinfo, + UploadPartTask) + submitted_tasks = s3handler.executor.submit.call_args_list + # Ensure the returned number of parts is correct. + self.assertEqual(num_parts, len(submitted_tasks) + 1) + # Ensure the number of tasks uploaded are as expected + self.assertEqual(len(submitted_tasks), 8) + index = 0 + for i in range(len(submitted_tasks)-1): + self.assertEqual(submitted_tasks[i][0][0]._payload.read(), + stdin_input[index:index+2]) + index += 2 + # Ensure that the last part is an empty string as expected. + self.assertEqual(submitted_tasks[7][0][0]._payload.read(), b'') + + def test_enqueue_upload_single_part_task_stream(self): + """ + This test ensures that a payload gets attached to a task when + it is submitted to the executor. + """ + s3handler = S3StreamHandler(self.session, self.params) + s3handler.executor = mock.Mock() + mock_task_class = mock.Mock() + s3handler._enqueue_upload_single_part_task( + part_number=1, chunk_size=2, upload_context=None, + filename=None, task_class=mock_task_class, + payload=b'This is a test' + ) + args, kwargs = mock_task_class.call_args + self.assertIn('payload', kwargs.keys()) + self.assertEqual(kwargs['payload'], b'This is a test') + + def test_enqueue_multipart_download_stream(self): + """ + This test ensures the right calls are made in ``_enqueue_tasks()`` + if the file should be a multipart download. + """ + s3handler = S3StreamHandler(self.session, self.params, + multi_threshold=5) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='download', + is_stream=True) + with mock.patch('awscli.customizations.s3.s3handler' + '.S3StreamHandler._enqueue_range_download_tasks') as \ + mock_enqueue_range_tasks: + with mock.patch('awscli.customizations.s3.fileinfo.FileInfo' + '.set_size_from_s3') as mock_set_size_from_s3: + # Set the file size to something larger than the multipart + # threshold. + fileinfo.size = 100 + # Run the main enqueue function. + s3handler._enqueue_tasks([fileinfo]) + # Assert that the size of the ``FileInfo`` object was set + # if we are downloading a stream. + self.assertTrue(mock_set_size_from_s3.called) + # Ensure that this download would have been a multipart + # download. + self.assertTrue(mock_enqueue_range_tasks.called) + + def test_enqueue_range_download_tasks_stream(self): + s3handler = S3StreamHandler(self.session, self.params, chunksize=100) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='download', + is_stream=True, size=100) + s3handler._enqueue_range_download_tasks(fileinfo) + # Ensure that no request was sent to make a file locally. + submitted_tasks = s3handler.executor.submit.call_args_list + self.assertNotEqual(type(submitted_tasks[0][0][0]), + CreateLocalFileTask) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/customizations/s3/test_subcommands.py b/tests/unit/customizations/s3/test_subcommands.py index 363fc496889f..fab4076b2752 100644 --- a/tests/unit/customizations/s3/test_subcommands.py +++ b/tests/unit/customizations/s3/test_subcommands.py @@ -114,12 +114,17 @@ def setUp(self): self.bucket = make_s3_files(self.session) self.loc_files = make_loc_files() self.output = StringIO() + self.err_output = StringIO() self.saved_stdout = sys.stdout + self.saved_stderr = sys.stderr sys.stdout = self.output + sys.stderr = self.err_output def tearDown(self): self.output.close() + self.err_output.close() sys.stdout = self.saved_stdout + sys.stderr = self.saved_stderr super(CommandArchitectureTest, self).tearDown() clean_loc_files(self.loc_files) @@ -169,12 +174,13 @@ def test_create_instructions(self): 'rb': ['s3_handler']} params = {'filters': True, 'region': 'us-east-1', 'endpoint_url': None, - 'verify_ssl': None} + 'verify_ssl': None, 'is_stream': False} for cmd in cmds: cmd_arc = CommandArchitecture(self.session, cmd, {'region': 'us-east-1', 'endpoint_url': None, - 'verify_ssl': None}) + 'verify_ssl': None, + 'is_stream': False}) cmd_arc.create_instructions() self.assertEqual(cmd_arc.instructions, instructions[cmd]) @@ -197,7 +203,8 @@ def test_run_cp_put(self): 'src': local_file, 'dest': s3_file, 'filters': filters, 'paths_type': 'locals3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -213,7 +220,8 @@ def test_error_on_same_line_as_status(self): 'src': local_file, 'dest': s3_file, 'filters': filters, 'paths_type': 'locals3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -222,7 +230,7 @@ def test_error_on_same_line_as_status(self): output_str = ( "upload failed: %s to %s Error: Bucket does not exist\n" % ( rel_local_file, s3_file)) - self.assertIn(output_str, self.output.getvalue()) + self.assertIn(output_str, self.err_output.getvalue()) def test_run_cp_get(self): # This ensures that the architecture sets up correctly for a ``cp`` get @@ -236,7 +244,8 @@ def test_run_cp_get(self): 'src': s3_file, 'dest': local_file, 'filters': filters, 'paths_type': 's3local', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -253,7 +262,8 @@ def test_run_cp_copy(self): 'src': s3_file, 'dest': s3_file, 'filters': filters, 'paths_type': 's3s3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -270,7 +280,8 @@ def test_run_mv(self): 'src': s3_file, 'dest': s3_file, 'filters': filters, 'paths_type': 's3s3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'mv', params) cmd_arc.create_instructions() cmd_arc.run() @@ -287,7 +298,8 @@ def test_run_remove(self): 'src': s3_file, 'dest': s3_file, 'filters': filters, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'rm', params) cmd_arc.create_instructions() cmd_arc.run() @@ -308,7 +320,8 @@ def test_run_sync(self): 'src': local_dir, 'dest': s3_prefix, 'filters': filters, 'paths_type': 'locals3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'sync', params) cmd_arc.create_instructions() cmd_arc.run() @@ -324,7 +337,7 @@ def test_run_mb(self): 'src': s3_prefix, 'dest': s3_prefix, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, 'follow_symlinks': True, - 'page_size': None} + 'page_size': None, 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'mb', params) cmd_arc.create_instructions() cmd_arc.run() @@ -340,7 +353,7 @@ def test_run_rb(self): 'src': s3_prefix, 'dest': s3_prefix, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, 'follow_symlinks': True, - 'page_size': None} + 'page_size': None, 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'rb', params) cmd_arc.create_instructions() rc = cmd_arc.run() @@ -357,12 +370,12 @@ def test_run_rb_nonzero_rc(self): 'src': s3_prefix, 'dest': s3_prefix, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, 'follow_symlinks': True, - 'page_size': None} + 'page_size': None, 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'rb', params) cmd_arc.create_instructions() rc = cmd_arc.run() output_str = "remove_bucket failed: %s" % s3_prefix - self.assertIn(output_str, self.output.getvalue()) + self.assertIn(output_str, self.err_output.getvalue()) self.assertEqual(rc, 1) @@ -468,6 +481,34 @@ def test_check_force(self): cmd_params.parameters['src'] = 's3://mybucket' cmd_params.check_force(None) + def test_validate_streaming_paths_upload(self): + parameters = {'src': '-', 'dest': 's3://bucket'} + cmd_params = CommandParameters(self.session, 'cp', parameters, '') + cmd_params._validate_streaming_paths() + self.assertTrue(cmd_params.parameters['is_stream']) + self.assertTrue(cmd_params.parameters['only_show_errors']) + self.assertFalse(cmd_params.parameters['dir_op']) + + def test_validate_streaming_paths_download(self): + parameters = {'src': 'localfile', 'dest': '-'} + cmd_params = CommandParameters(self.session, 'cp', parameters, '') + cmd_params._validate_streaming_paths() + self.assertTrue(cmd_params.parameters['is_stream']) + self.assertTrue(cmd_params.parameters['only_show_errors']) + self.assertFalse(cmd_params.parameters['dir_op']) + + def test_validate_no_streaming_paths(self): + parameters = {'src': 'localfile', 'dest': 's3://bucket'} + cmd_params = CommandParameters(self.session, 'cp', parameters, '') + cmd_params._validate_streaming_paths() + self.assertFalse(cmd_params.parameters['is_stream']) + + def test_validate_streaming_paths_error(self): + parameters = {'src': '-', 'dest': 's3://bucket'} + cmd_params = CommandParameters(self.session, 'sync', parameters, '') + with self.assertRaises(ValueError): + cmd_params._validate_streaming_paths() + class HelpDocTest(BaseAWSHelpOutputTest): def setUp(self): diff --git a/tests/unit/customizations/s3/test_tasks.py b/tests/unit/customizations/s3/test_tasks.py index 4451c85cb569..eda16f765778 100644 --- a/tests/unit/customizations/s3/test_tasks.py +++ b/tests/unit/customizations/s3/test_tasks.py @@ -22,6 +22,7 @@ from awscli.customizations.s3.tasks import CompleteDownloadTask from awscli.customizations.s3.tasks import DownloadPartTask from awscli.customizations.s3.tasks import MultipartUploadContext +from awscli.customizations.s3.tasks import MultipartDownloadContext from awscli.customizations.s3.tasks import UploadCancelledError from awscli.customizations.s3.tasks import print_operation from awscli.customizations.s3.tasks import RetriesExeededError @@ -163,6 +164,58 @@ def test_basic_threaded_parts(self): self.calls[2][1:], ('my_upload_id', [{'ETag': 'etag1', 'PartNumber': 1}])) + def test_streaming_threaded_parts(self): + # This is similar to the basic threaded parts test but instead + # the thread has to wait to know exactly how many parts are + # expected from the stream. This is indicated when the expected + # parts of the context changes from ... to an integer. + + self.context = MultipartUploadContext(expected_parts='...') + upload_part_thread = threading.Thread(target=self.upload_part, + args=(1,)) + # Once this thread starts it will immediately block. + self.start_thread(upload_part_thread) + + # Also, let's start the thread that will do the complete + # multipart upload. It will also block because it needs all + # the parts so it's blocked up the upload_part_thread. It also + # needs the upload_id so it's blocked on that as well. + complete_upload_thread = threading.Thread(target=self.complete_upload) + self.start_thread(complete_upload_thread) + + # Then finally the CreateMultipartUpload completes and we + # announce the upload id. + self.create_upload('my_upload_id') + # The complete upload thread should still be waiting for an expect + # parts number. + with self.call_lock: + was_completed = (len(self.calls) > 2) + + # The upload_part thread can now proceed as well as the complete + # multipart upload thread. + self.context.announce_total_parts(1) + self.join_threads() + + self.assertIsNone(self.caught_exception) + + # Make sure that the completed task was never called since it was + # waiting to announce the parts. + self.assertFalse(was_completed) + + # We can verify that the invariants still hold. + self.assertEqual(len(self.calls), 3) + # First there should be three calls, create, upload, complete. + self.assertEqual(self.calls[0][0], 'create_multipart_upload') + self.assertEqual(self.calls[1][0], 'upload_part') + self.assertEqual(self.calls[2][0], 'complete_upload') + + # Verify the correct args were used. + self.assertEqual(self.calls[0][1], 'my_upload_id') + self.assertEqual(self.calls[1][1:], (1, 'my_upload_id')) + self.assertEqual( + self.calls[2][1:], + ('my_upload_id', [{'ETag': 'etag1', 'PartNumber': 1}])) + def test_randomized_stress_test(self): # Now given that we've verified the functionality from # the two tests above, we randomize the threading to ensure @@ -279,6 +332,7 @@ def setUp(self): self.filename.size = 10 * 1024 * 1024 self.filename.src = 'bucket/key' self.filename.dest = 'local/file' + self.filename.is_stream = False self.filename.service = self.service self.filename.operation_name = 'download' self.context = mock.Mock() @@ -325,9 +379,9 @@ def test_download_queues_io_properly(self): call_args_list = self.io_queue.put.call_args_list self.assertEqual(len(call_args_list), 2) self.assertEqual(call_args_list[0], - mock.call(('local/file', 0, b'foobar'))) + mock.call(('local/file', 0, b'foobar', False))) self.assertEqual(call_args_list[1], - mock.call(('local/file', 6, b'morefoobar'))) + mock.call(('local/file', 6, b'morefoobar', False))) def test_incomplete_read_is_retried(self): self.service.get_operation.return_value.call.side_effect = \ @@ -342,6 +396,61 @@ def test_incomplete_read_is_retried(self): self.service.get_operation.call_count) +class TestMultipartDownloadContext(unittest.TestCase): + def setUp(self): + self.context = MultipartDownloadContext(num_parts=2) + self.calls = [] + self.threads = [] + self.call_lock = threading.Lock() + self.caught_exception = None + + def tearDown(self): + self.join_threads() + + def join_threads(self): + for thread in self.threads: + thread.join() + + def download_stream_part(self, part_number): + try: + self.context.wait_for_turn(part_number) + with self.call_lock: + self.calls.append(('download_part', str(part_number))) + self.context.done_with_turn() + except Exception as e: + self.caught_exception = e + return + + def start_thread(self, thread): + thread.start() + self.threads.append(thread) + + def test_stream_context(self): + part_thread = threading.Thread(target=self.download_stream_part, + args=(1,)) + # Once this thread starts it will immediately block becasue it is + # waiting for part zero to finish submitting its task. + self.start_thread(part_thread) + + # Now create the thread that should submit its task first. + part_thread2 = threading.Thread(target=self.download_stream_part, + args=(0,)) + self.start_thread(part_thread2) + self.join_threads() + + self.assertIsNone(self.caught_exception) + + # We can verify that the invariants still hold. + self.assertEqual(len(self.calls), 2) + # First there should be three calls, create, upload, complete. + self.assertEqual(self.calls[0][0], 'download_part') + self.assertEqual(self.calls[1][0], 'download_part') + + # Verify the correct order were used. + self.assertEqual(self.calls[0][1], '0') + self.assertEqual(self.calls[1][1], '1') + + class TestTaskOrdering(unittest.TestCase): def setUp(self): self.q = StablePriorityQueue(maxsize=10, max_priority=20) diff --git a/tests/unit/customizations/s3/test_website_command.py b/tests/unit/customizations/s3/test_website_command.py index dd81b17bacaf..edf7a94d2a5b 100644 --- a/tests/unit/customizations/s3/test_website_command.py +++ b/tests/unit/customizations/s3/test_website_command.py @@ -21,23 +21,18 @@ class TestWebsiteCommand(BaseAWSCommandParamsTest): def test_index_document(self): cmdline = self.prefix + 's3://mybucket --index-document index.html' - result = {'uri_params': {'Bucket': 'mybucket'}, - 'headers': {},} - self.assert_params_for_cmd(cmdline, result, ignore_params=['payload']) - self.assertEqual( - self.last_kwargs, - {'website_configuration': {'IndexDocument': {'Suffix': 'index.html'}}, - 'bucket': u'mybucket'}) + result = { + 'website_configuration': + {'IndexDocument': {'Suffix': 'index.html'}}, 'bucket': u'mybucket'} + + self.assert_params_for_cmd2(cmdline, result) def test_error_document(self): cmdline = self.prefix + 's3://mybucket --error-document mykey' - result = {'uri_params': {'Bucket': 'mybucket'}, - 'headers': {},} - self.assert_params_for_cmd(cmdline, result, ignore_params=['payload']) - self.assertEqual( - self.last_kwargs, - {'website_configuration': {'ErrorDocument': {'Key': 'mykey'}}, - 'bucket': u'mybucket'}) + result = { + 'website_configuration': { + 'ErrorDocument': {'Key': 'mykey'}}, 'bucket': u'mybucket'} + self.assert_params_for_cmd2(cmdline, result) if __name__ == "__main__": diff --git a/tests/unit/customizations/test_cloudsearchdomain.py b/tests/unit/customizations/test_cloudsearchdomain.py index 211ccd404e2a..e545af32fd41 100644 --- a/tests/unit/customizations/test_cloudsearchdomain.py +++ b/tests/unit/customizations/test_cloudsearchdomain.py @@ -31,17 +31,11 @@ def test_search_with_query(self): '--query-options', '{"defaultOperator":"and","fields":["directors^10"]}'] - result = { - 'headers': {}, - 'uri_params': { - 'query': 'George Lucas', - 'queryOptions': - '{"defaultOperator":"and","fields":["directors^10"]}'}} - - self.assert_params_for_cmd(cmd, result, ignore_params=['payload']) - # We ignore'd the paylod, but we can verify that there should be - # no payload for this request. - self.assertIsNone(self.last_params['payload'].getvalue()) + expected = { + 'query': u'George Lucas', + 'queryOptions': u'{"defaultOperator":"and","fields":["directors^10"]}' + } + self.assert_params_for_cmd2(cmd, expected) def test_endpoint_is_required(self): cmd = self.prefix.split() diff --git a/tests/unit/customizations/test_cloudwatch.py b/tests/unit/customizations/test_cloudwatch.py index eadcecc7676c..c1975c38918f 100644 --- a/tests/unit/customizations/test_cloudwatch.py +++ b/tests/unit/customizations/test_cloudwatch.py @@ -10,6 +10,8 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +import decimal + from awscli.testutils import unittest import mock @@ -40,7 +42,7 @@ def test_value_arg(self): parameters = {} arg.add_to_params(parameters, '123.1') self.assertEqual(parameters['metric_data'][0]['Value'], - '123.1') + decimal.Decimal('123.1')) def test_timestamp_arg(self): arg = putmetricdata.PutMetricArgument('timestamp', diff --git a/tests/unit/customizations/test_configure.py b/tests/unit/customizations/test_configure.py index 70241cba8fe0..12511139fa5f 100644 --- a/tests/unit/customizations/test_configure.py +++ b/tests/unit/customizations/test_configure.py @@ -70,6 +70,10 @@ def get_scoped_config(self): return self.config def get_config_variable(self, name, methods=None): + if name == 'credentials_file': + # The credentials_file var doesn't require a + # profile to exist. + return 'fake_credentials_file' if self.profile_does_not_exist and not name == 'config_file': raise ProfileNotFound(profile='foo') if methods is not None: @@ -98,12 +102,22 @@ def setUp(self): prompter=self.precanned, config_writer=self.writer) + def assert_credentials_file_updated_with(self, new_values): + called_args = self.writer.update_config.call_args_list + credentials_file_call = called_args[0] + self.assertEqual(credentials_file_call, + mock.call(new_values, 'fake_credentials_file')) + def test_configure_command_sends_values_to_writer(self): self.configure(args=[], parsed_globals=self.global_args) - self.writer.update_config.assert_called_with( + # Credentials are always written to the shared credentials file. + self.assert_credentials_file_updated_with( {'aws_access_key_id': 'new_value', - 'aws_secret_access_key': 'new_value', - 'region': 'new_value', + 'aws_secret_access_key': 'new_value'}) + + # Non-credentials config is written to the config file. + self.writer.update_config.assert_called_with( + {'region': 'new_value', 'output': 'new_value'}, 'myconfigfile') def test_same_values_are_not_changed(self): @@ -154,10 +168,12 @@ def test_section_name_can_be_changed_for_profiles(self): self.global_args.profile = 'myname' self.configure(args=[], parsed_globals=self.global_args) # Note the __section__ key name. + self.assert_credentials_file_updated_with( + {'aws_access_key_id': 'new_value', + 'aws_secret_access_key': 'new_value', + '__section__': 'myname'}) self.writer.update_config.assert_called_with( {'__section__': 'profile myname', - 'aws_access_key_id': 'new_value', - 'aws_secret_access_key': 'new_value', 'region': 'new_value', 'output': 'new_value'}, 'myconfigfile') @@ -173,18 +189,26 @@ def test_session_says_profile_does_not_exist(self): config_writer=self.writer) self.global_args.profile = 'profile-does-not-exist' self.configure(args=[], parsed_globals=self.global_args) + self.assert_credentials_file_updated_with( + {'aws_access_key_id': 'new_value', + 'aws_secret_access_key': 'new_value', + '__section__': 'profile-does-not-exist'}) self.writer.update_config.assert_called_with( {'__section__': 'profile profile-does-not-exist', - 'aws_access_key_id': 'new_value', - 'aws_secret_access_key': 'new_value', 'region': 'new_value', 'output': 'new_value'}, 'myconfigfile') class TestInteractivePrompter(unittest.TestCase): - @mock.patch('awscli.customizations.configure.raw_input') - def test_access_key_is_masked(self, mock_raw_input): - mock_raw_input.return_value = 'foo' + def setUp(self): + self.patch = mock.patch('awscli.customizations.configure.raw_input') + self.mock_raw_input = self.patch.start() + + def tearDown(self): + self.patch.stop() + + def test_access_key_is_masked(self): + self.mock_raw_input.return_value = 'foo' prompter = configure.InteractivePrompter() response = prompter.get_value( current_value='myaccesskey', config_name='aws_access_key_id', @@ -192,45 +216,54 @@ def test_access_key_is_masked(self, mock_raw_input): # First we should return the value from raw_input. self.assertEqual(response, 'foo') # We should also not display the entire access key. - prompt_text = mock_raw_input.call_args[0][0] + prompt_text = self.mock_raw_input.call_args[0][0] self.assertNotIn('myaccesskey', prompt_text) self.assertRegexpMatches(prompt_text, r'\[\*\*\*\*.*\]') - @mock.patch('awscli.customizations.configure.raw_input') - def test_access_key_not_masked_when_none(self, mock_raw_input): - mock_raw_input.return_value = 'foo' + def test_access_key_not_masked_when_none(self): + self.mock_raw_input.return_value = 'foo' prompter = configure.InteractivePrompter() response = prompter.get_value( current_value=None, config_name='aws_access_key_id', prompt_text='Access key') # First we should return the value from raw_input. self.assertEqual(response, 'foo') - prompt_text = mock_raw_input.call_args[0][0] + prompt_text = self.mock_raw_input.call_args[0][0] self.assertIn('[None]', prompt_text) - @mock.patch('awscli.customizations.configure.raw_input') - def test_secret_key_is_masked(self, mock_raw_input): + def test_secret_key_is_masked(self): prompter = configure.InteractivePrompter() prompter.get_value( current_value='mysupersecretkey', config_name='aws_secret_access_key', prompt_text='Secret Key') # We should also not display the entire secret key. - prompt_text = mock_raw_input.call_args[0][0] + prompt_text = self.mock_raw_input.call_args[0][0] self.assertNotIn('mysupersecretkey', prompt_text) self.assertRegexpMatches(prompt_text, r'\[\*\*\*\*.*\]') - @mock.patch('awscli.customizations.configure.raw_input') - def test_non_secret_keys_are_not_masked(self, mock_raw_input): + def test_non_secret_keys_are_not_masked(self): prompter = configure.InteractivePrompter() prompter.get_value( current_value='mycurrentvalue', config_name='not_a_secret_key', prompt_text='Enter value') # We should also not display the entire secret key. - prompt_text = mock_raw_input.call_args[0][0] + prompt_text = self.mock_raw_input.call_args[0][0] self.assertIn('mycurrentvalue', prompt_text) self.assertRegexpMatches(prompt_text, r'\[mycurrentvalue\]') + def test_user_hits_enter_returns_none(self): + # If a user hits enter, then raw_input returns the empty string. + self.mock_raw_input.return_value = '' + + prompter = configure.InteractivePrompter() + response = prompter.get_value( + current_value=None, config_name='aws_access_key_id', + prompt_text='Access key') + # We convert the empty string to None to indicate that there + # was no input. + self.assertIsNone(response) + class TestConfigFileWriter(unittest.TestCase): def setUp(self): @@ -395,6 +428,22 @@ def test_update_config_with_comments(self): 'foo = newvalue\n' ) + def test_update_config_with_commented_section(self): + original = ( + '#[default]\n' + '[default]\n' + '#foo = 1\n' + 'bar = 1\n' + ) + self.assert_update_config( + original, {'foo': 'newvalue'}, + '#[default]\n' + '[default]\n' + '#foo = 1\n' + 'bar = 1\n' + 'foo = newvalue\n' + ) + def test_spaces_around_key_names(self): original = ( '[default]\n' @@ -720,7 +769,7 @@ def test_configure_set_triple_dotted(self): {'__section__': 'default', 's3': {'signature_version': 's3v4'}}, 'myconfigfile') - def test_configure_set_with_profile(self): + def test_configure_set_with_profile_nested(self): # aws configure set default.s3.signature_version s3v4 set_command = configure.ConfigureSetCommand(self.session, self.config_writer) set_command(args=['profile.foo.s3.signature_version', 's3v4'], @@ -728,3 +777,49 @@ def test_configure_set_with_profile(self): self.config_writer.update_config.assert_called_with( {'__section__': 'profile foo', 's3': {'signature_version': 's3v4'}}, 'myconfigfile') + + def test_access_key_written_to_shared_credentials_file(self): + set_command = configure.ConfigureSetCommand(self.session, self.config_writer) + set_command(args=['aws_access_key_id', 'foo'], + parsed_globals=None) + self.config_writer.update_config.assert_called_with( + {'__section__': 'default', + 'aws_access_key_id': 'foo'}, 'fake_credentials_file') + + def test_secret_key_written_to_shared_credentials_file(self): + set_command = configure.ConfigureSetCommand(self.session, self.config_writer) + set_command(args=['aws_secret_access_key', 'foo'], + parsed_globals=None) + self.config_writer.update_config.assert_called_with( + {'__section__': 'default', + 'aws_secret_access_key': 'foo'}, 'fake_credentials_file') + + def test_session_token_written_to_shared_credentials_file(self): + set_command = configure.ConfigureSetCommand(self.session, self.config_writer) + set_command(args=['aws_session_token', 'foo'], + parsed_globals=None) + self.config_writer.update_config.assert_called_with( + {'__section__': 'default', + 'aws_session_token': 'foo'}, 'fake_credentials_file') + + def test_access_key_written_to_shared_credentials_file_profile(self): + set_command = configure.ConfigureSetCommand(self.session, self.config_writer) + set_command(args=['profile.foo.aws_access_key_id', 'bar'], + parsed_globals=None) + self.config_writer.update_config.assert_called_with( + {'__section__': 'foo', + 'aws_access_key_id': 'bar'}, 'fake_credentials_file') + +class TestConfigValueMasking(unittest.TestCase): + def test_config_value_is_masked(self): + config_value = configure.ConfigValue( + 'fake_access_key', 'config_file', 'aws_access_key_id') + self.assertEqual(config_value.value, 'fake_access_key') + config_value.mask_value() + self.assertEqual(config_value.value, '****************_key') + + def test_dont_mask_unset_value(self): + no_config = configure.ConfigValue(configure.NOT_SET, None, None) + self.assertEqual(no_config.value, configure.NOT_SET) + no_config.mask_value() + self.assertEqual(no_config.value, configure.NOT_SET) diff --git a/tests/unit/customizations/test_flatten.py b/tests/unit/customizations/test_flatten.py index 71a6f38643d4..a63797fd7d50 100644 --- a/tests/unit/customizations/test_flatten.py +++ b/tests/unit/customizations/test_flatten.py @@ -19,7 +19,6 @@ from awscli.customizations.flatten import FlattenedArgument, FlattenArguments from botocore.operation import Operation -from botocore.parameters import Parameter def _hydrate(params, container, cli_type, key, value): @@ -136,41 +135,50 @@ def test_flatten_modify_args(self): operation = mock.Mock(spec=Operation) operation.cli_name = 'command-name' - argument_object1 = mock.Mock(spec=Parameter) + argument_model1 = mock.Mock() + argument_model1.required_members = [] member_foo = mock.Mock() member_foo.name = 'ArgumentFoo' member_foo.documentation = 'Original docs' - member_foo.required = False + member_foo.required_members = [] member_bar = mock.Mock() member_bar.name = 'ArgumentBar' member_bar.documentation = 'More docs' - member_bar.required = False + member_bar.required_members = [] - argument_object1.members = [member_foo, member_bar] + argument_model1.members = { + 'ArgumentFoo': member_foo, + 'ArgumentBar': member_bar + } - argument_object2 = mock.Mock(spec=Parameter) + argument_model2 = mock.Mock() + argument_model2.required_members = [] member_baz = mock.Mock() member_baz.name = 'ArgumentBaz' member_baz.documentation = '' - member_baz.required = False + member_baz.required_members = [] member_some_value = mock.Mock() member_some_value.name = 'SomeValue' member_some_value.documenation = '' - member_some_value.require = False + member_some_value.required_members = [] - member_baz.members = [member_some_value] + member_baz.members = { + 'SomeValue': member_some_value + } - argument_object2.members = [member_baz] + argument_model2.members = { + 'ArgumentBaz': member_baz + } cli_argument1 = mock.Mock(spec=CLIArgument) - cli_argument1.argument_object = argument_object1 + cli_argument1.argument_model = argument_model1 cli_argument2 = mock.Mock(spec=CLIArgument) - cli_argument2.argument_object = argument_object2 + cli_argument2.argument_model = argument_model2 argument_table = { 'original-argument': cli_argument1, diff --git a/tests/unit/customizations/test_paginate.py b/tests/unit/customizations/test_paginate.py index 9bc44e2b6d3c..b84f8ae2ed8d 100644 --- a/tests/unit/customizations/test_paginate.py +++ b/tests/unit/customizations/test_paginate.py @@ -23,19 +23,18 @@ def setUp(self): self.operation = mock.Mock() self.operation.can_paginate = True self.foo_param = mock.Mock() - self.foo_param.cli_name = 'foo' self.foo_param.name = 'Foo' - self.foo_param.type = 'string' + self.foo_param.type_name = 'string' self.bar_param = mock.Mock() - self.bar_param.cli_name = 'bar' - self.bar_param.type = 'string' + self.bar_param.type_name = 'string' self.bar_param.name = 'Bar' self.params = [self.foo_param, self.bar_param] self.operation.pagination = { 'input_token': 'Foo', 'limit_key': 'Bar', } - self.operation.params = self.params + self.operation.model.input_shape.members = {"Foo": self.foo_param, + "Bar": self.bar_param} class TestArgumentTableModifications(TestPaginateBase): @@ -77,7 +76,7 @@ class TestStringLimitKey(TestPaginateBase): def setUp(self): super(TestStringLimitKey, self).setUp() - self.bar_param.type = 'string' + self.bar_param.type_name = 'string' def test_integer_limit_key(self): argument_table = { @@ -94,7 +93,7 @@ class TestIntegerLimitKey(TestPaginateBase): def setUp(self): super(TestIntegerLimitKey, self).setUp() - self.bar_param.type = 'integer' + self.bar_param.type_name = 'integer' def test_integer_limit_key(self): argument_table = { @@ -111,7 +110,7 @@ class TestBadLimitKey(TestPaginateBase): def setUp(self): super(TestBadLimitKey, self).setUp() - self.bar_param.type = 'bad' + self.bar_param.type_name = 'bad' def test_integer_limit_key(self): argument_table = { diff --git a/tests/unit/docs/test_help_output.py b/tests/unit/docs/test_help_output.py index ab5d2569d66f..9bbfc311cdb7 100644 --- a/tests/unit/docs/test_help_output.py +++ b/tests/unit/docs/test_help_output.py @@ -300,7 +300,8 @@ def test_enum_docs_arent_duplicated(self): contents = self.renderer.rendered_contents self.assertTrue(contents.count("CREATE_IN_PROGRESS") == 1, ("Enum param was only suppose to be appear once in " - "rendered doc output.")) + "rendered doc output, appeared: %s" % + contents.count("CREATE_IN_PROGRESS"))) class TestParametersCanBeHidden(BaseAWSHelpOutputTest): diff --git a/tests/unit/ec2/test_create_network_acl_entry.py b/tests/unit/ec2/test_create_network_acl_entry.py index d1ce14944bde..7f173124bb49 100644 --- a/tests/unit/ec2/test_create_network_acl_entry.py +++ b/tests/unit/ec2/test_create_network_acl_entry.py @@ -28,13 +28,13 @@ def test_tcp(self): cmdline += ' --port-range From=22,To=22' cmdline += ' --cidr-block 0.0.0.0/0' result = {'NetworkAclId': 'acl-12345678', - 'RuleNumber': '100', + 'RuleNumber': 100, 'Protocol': '6', 'RuleAction': 'allow', 'Egress': 'false', 'CidrBlock': '0.0.0.0/0', - 'PortRange.From': '22', - 'PortRange.To': '22' + 'PortRange.From': 22, + 'PortRange.To': 22 } self.assert_params_for_cmd(cmdline, result) @@ -48,13 +48,13 @@ def test_udp(self): cmdline += ' --port-range From=22,To=22' cmdline += ' --cidr-block 0.0.0.0/0' result = {'NetworkAclId': 'acl-12345678', - 'RuleNumber': '100', + 'RuleNumber': 100, 'Protocol': '17', 'RuleAction': 'allow', 'Egress': 'false', 'CidrBlock': '0.0.0.0/0', - 'PortRange.From': '22', - 'PortRange.To': '22' + 'PortRange.From': 22, + 'PortRange.To': 22 } self.assert_params_for_cmd(cmdline, result) @@ -68,16 +68,16 @@ def test_icmp(self): cmdline += ' --port-range From=22,To=22' cmdline += ' --cidr-block 0.0.0.0/0' result = {'NetworkAclId': 'acl-12345678', - 'RuleNumber': '100', + 'RuleNumber': 100, 'Protocol': '1', 'RuleAction': 'allow', 'Egress': 'false', 'CidrBlock': '0.0.0.0/0', - 'PortRange.From': '22', - 'PortRange.To': '22' + 'PortRange.From': 22, + 'PortRange.To': 22 } self.assert_params_for_cmd(cmdline, result) - + def test_all(self): cmdline = self.prefix cmdline += ' --network-acl-id acl-12345678' @@ -88,16 +88,16 @@ def test_all(self): cmdline += ' --port-range From=22,To=22' cmdline += ' --cidr-block 0.0.0.0/0' result = {'NetworkAclId': 'acl-12345678', - 'RuleNumber': '100', + 'RuleNumber': 100, 'Protocol': '-1', 'RuleAction': 'allow', 'Egress': 'false', 'CidrBlock': '0.0.0.0/0', - 'PortRange.From': '22', - 'PortRange.To': '22' + 'PortRange.From': 22, + 'PortRange.To': 22 } self.assert_params_for_cmd(cmdline, result) - + def test_number(self): cmdline = self.prefix cmdline += ' --network-acl-id acl-12345678' @@ -108,13 +108,13 @@ def test_number(self): cmdline += ' --port-range From=22,To=22' cmdline += ' --cidr-block 0.0.0.0/0' result = {'NetworkAclId': 'acl-12345678', - 'RuleNumber': '100', + 'RuleNumber': 100, 'Protocol': '99', 'RuleAction': 'allow', 'Egress': 'false', 'CidrBlock': '0.0.0.0/0', - 'PortRange.From': '22', - 'PortRange.To': '22' + 'PortRange.From': 22, + 'PortRange.To': 22 } self.assert_params_for_cmd(cmdline, result) - + diff --git a/tests/unit/ec2/test_describe_instances.py b/tests/unit/ec2/test_describe_instances.py index 91ac4ef86852..e8c67254402d 100644 --- a/tests/unit/ec2/test_describe_instances.py +++ b/tests/unit/ec2/test_describe_instances.py @@ -90,7 +90,7 @@ def test_multiple_filters_alternate(self): def test_page_size(self): args = ' --page-size 10' cmdline = self.prefix + args - result = {'MaxResults': '10'} + result = {'MaxResults': 10} self.assert_params_for_cmd(cmdline, result) diff --git a/tests/unit/ec2/test_replace_network_acl_entry.py b/tests/unit/ec2/test_replace_network_acl_entry.py index 4f4ad20d298a..726787902ec7 100644 --- a/tests/unit/ec2/test_replace_network_acl_entry.py +++ b/tests/unit/ec2/test_replace_network_acl_entry.py @@ -28,13 +28,13 @@ def test_tcp(self): cmdline += ' --port-range From=22,To=22' cmdline += ' --cidr-block 0.0.0.0/0' result = {'NetworkAclId': 'acl-12345678', - 'RuleNumber': '100', + 'RuleNumber': 100, 'Protocol': '6', 'RuleAction': 'allow', 'Egress': 'false', 'CidrBlock': '0.0.0.0/0', - 'PortRange.From': '22', - 'PortRange.To': '22' + 'PortRange.From': 22, + 'PortRange.To': 22 } self.assert_params_for_cmd(cmdline, result) @@ -48,13 +48,13 @@ def test_udp(self): cmdline += ' --port-range From=22,To=22' cmdline += ' --cidr-block 0.0.0.0/0' result = {'NetworkAclId': 'acl-12345678', - 'RuleNumber': '100', + 'RuleNumber': 100, 'Protocol': '17', 'RuleAction': 'allow', 'Egress': 'false', 'CidrBlock': '0.0.0.0/0', - 'PortRange.From': '22', - 'PortRange.To': '22' + 'PortRange.From': 22, + 'PortRange.To': 22 } self.assert_params_for_cmd(cmdline, result) @@ -68,16 +68,16 @@ def test_icmp(self): cmdline += ' --port-range From=22,To=22' cmdline += ' --cidr-block 0.0.0.0/0' result = {'NetworkAclId': 'acl-12345678', - 'RuleNumber': '100', + 'RuleNumber': 100, 'Protocol': '1', 'RuleAction': 'allow', 'Egress': 'false', 'CidrBlock': '0.0.0.0/0', - 'PortRange.From': '22', - 'PortRange.To': '22' + 'PortRange.From': 22, + 'PortRange.To': 22 } self.assert_params_for_cmd(cmdline, result) - + def test_all(self): cmdline = self.prefix cmdline += ' --network-acl-id acl-12345678' @@ -88,16 +88,16 @@ def test_all(self): cmdline += ' --port-range From=22,To=22' cmdline += ' --cidr-block 0.0.0.0/0' result = {'NetworkAclId': 'acl-12345678', - 'RuleNumber': '100', + 'RuleNumber': 100, 'Protocol': '-1', 'RuleAction': 'allow', 'Egress': 'false', 'CidrBlock': '0.0.0.0/0', - 'PortRange.From': '22', - 'PortRange.To': '22' + 'PortRange.From': 22, + 'PortRange.To': 22 } self.assert_params_for_cmd(cmdline, result) - + def test_number(self): cmdline = self.prefix cmdline += ' --network-acl-id acl-12345678' @@ -108,13 +108,13 @@ def test_number(self): cmdline += ' --port-range From=22,To=22' cmdline += ' --cidr-block 0.0.0.0/0' result = {'NetworkAclId': 'acl-12345678', - 'RuleNumber': '100', + 'RuleNumber': 100, 'Protocol': '99', 'RuleAction': 'allow', 'Egress': 'false', 'CidrBlock': '0.0.0.0/0', - 'PortRange.From': '22', - 'PortRange.To': '22' + 'PortRange.From': 22, + 'PortRange.To': 22 } self.assert_params_for_cmd(cmdline, result) - + diff --git a/tests/unit/ec2/test_run_instances.py b/tests/unit/ec2/test_run_instances.py index fefc8da4e653..5d578c362b97 100644 --- a/tests/unit/ec2/test_run_instances.py +++ b/tests/unit/ec2/test_run_instances.py @@ -26,8 +26,8 @@ def test_no_count(self): args_list = (self.prefix + args).split() result = { 'ImageId': 'ami-foobar', - 'MaxCount': '1', - 'MinCount': '1' + 'MaxCount': 1, + 'MinCount': 1 } self.assert_params_for_cmd(args_list, result) @@ -36,12 +36,13 @@ def test_count_scalar(self): args_list = (self.prefix + args).split() result = { 'ImageId': 'ami-foobar', - 'MaxCount': '2', - 'MinCount': '2' + 'MaxCount': 2, + 'MinCount': 2 } self.assert_params_for_cmd(args_list, result) def test_user_data(self): + return data = u'\u0039' with temporary_file('r+') as tmp: with compat_open(tmp.name, 'w') as f: @@ -51,8 +52,8 @@ def test_user_data(self): self.prefix + ' --image-id foo --user-data file://%s' % f.name) result = {'ImageId': 'foo', - 'MaxCount': '1', - 'MinCount': '1', + 'MaxCount': 1, + 'MinCount': 1, # base64 encoded content of utf-8 encoding of data. 'UserData': 'OQ=='} self.assert_params_for_cmd(args, result) @@ -62,8 +63,8 @@ def test_count_range(self): args_list = (self.prefix + args).split() result = { 'ImageId': 'ami-foobar', - 'MaxCount': '10', - 'MinCount': '5' + 'MaxCount': 10, + 'MinCount': 5 } self.assert_params_for_cmd(args_list, result) @@ -79,10 +80,10 @@ def test_block_device_mapping(self): ' [{"DeviceName":"/dev/sda1","Ebs":{"VolumeSize":20}}]') result = { 'BlockDeviceMapping.1.DeviceName': '/dev/sda1', - 'BlockDeviceMapping.1.Ebs.VolumeSize': '20', + 'BlockDeviceMapping.1.Ebs.VolumeSize': 20, 'ImageId': 'ami-foobar', - 'MaxCount': '1', - 'MinCount': '1' + 'MaxCount': 1, + 'MinCount': 1 } self.assert_params_for_cmd(args_list, result) @@ -91,12 +92,12 @@ def test_secondary_ip_address(self): args += '--secondary-private-ip-addresses 10.0.2.106' args_list = (self.prefix + args).split() result = { - 'NetworkInterface.1.DeviceIndex': '0', + 'NetworkInterface.1.DeviceIndex': 0, 'NetworkInterface.1.PrivateIpAddresses.1.Primary': 'false', 'NetworkInterface.1.PrivateIpAddresses.1.PrivateIpAddress': '10.0.2.106', 'ImageId': 'ami-foobar', - 'MaxCount': '1', - 'MinCount': '1' + 'MaxCount': 1, + 'MinCount': 1 } self.assert_params_for_cmd(args_list, result) @@ -105,27 +106,27 @@ def test_secondary_ip_addresses(self): args += '--secondary-private-ip-addresses 10.0.2.106 10.0.2.107' args_list = (self.prefix + args).split() result = { - 'NetworkInterface.1.DeviceIndex': '0', + 'NetworkInterface.1.DeviceIndex': 0, 'NetworkInterface.1.PrivateIpAddresses.1.Primary': 'false', 'NetworkInterface.1.PrivateIpAddresses.1.PrivateIpAddress': '10.0.2.106', 'NetworkInterface.1.PrivateIpAddresses.2.Primary': 'false', 'NetworkInterface.1.PrivateIpAddresses.2.PrivateIpAddress': '10.0.2.107', 'ImageId': 'ami-foobar', - 'MaxCount': '1', - 'MinCount': '1' + 'MaxCount': 1, + 'MinCount': 1 } self.assert_params_for_cmd(args_list, result) - + def test_secondary_ip_address_count(self): args = ' --image-id ami-foobar --count 1 ' args += '--secondary-private-ip-address-count 4' args_list = (self.prefix + args).split() result = { - 'NetworkInterface.1.DeviceIndex': '0', - 'NetworkInterface.1.SecondaryPrivateIpAddressCount': '4', + 'NetworkInterface.1.DeviceIndex': 0, + 'NetworkInterface.1.SecondaryPrivateIpAddressCount': 4, 'ImageId': 'ami-foobar', - 'MaxCount': '1', - 'MinCount': '1' + 'MaxCount': 1, + 'MinCount': 1 } self.assert_params_for_cmd(args_list, result) @@ -134,12 +135,12 @@ def test_associate_public_ip_address(self): args += '--associate-public-ip-address' args_list = (self.prefix + args).split() result = { - 'NetworkInterface.1.DeviceIndex': '0', + 'NetworkInterface.1.DeviceIndex': 0, 'NetworkInterface.1.AssociatePublicIpAddress': 'true', 'NetworkInterface.1.SubnetId': 'subnet-12345678', 'ImageId': 'ami-foobar', - 'MaxCount': '1', - 'MinCount': '1' + 'MaxCount': 1, + 'MinCount': 1 } self.assert_params_for_cmd(args_list, result) @@ -148,12 +149,12 @@ def test_associate_public_ip_address_switch_order(self): args += '--associate-public-ip-address --subnet-id subnet-12345678' args_list = (self.prefix + args).split() result = { - 'NetworkInterface.1.DeviceIndex': '0', + 'NetworkInterface.1.DeviceIndex': 0, 'NetworkInterface.1.AssociatePublicIpAddress': 'true', 'NetworkInterface.1.SubnetId': 'subnet-12345678', 'ImageId': 'ami-foobar', - 'MaxCount': '1', - 'MinCount': '1' + 'MaxCount': 1, + 'MinCount': 1 } self.assert_params_for_cmd(args_list, result) @@ -162,23 +163,23 @@ def test_no_associate_public_ip_address(self): args += '--no-associate-public-ip-address' args_list = (self.prefix + args).split() result = { - 'NetworkInterface.1.DeviceIndex': '0', + 'NetworkInterface.1.DeviceIndex': 0, 'NetworkInterface.1.AssociatePublicIpAddress': 'false', 'NetworkInterface.1.SubnetId': 'subnet-12345678', 'ImageId': 'ami-foobar', - 'MaxCount': '1', - 'MinCount': '1' + 'MaxCount': 1, + 'MinCount': 1 } self.assert_params_for_cmd(args_list, result) - + def test_subnet_alone(self): args = ' --image-id ami-foobar --count 1 --subnet-id subnet-12345678' args_list = (self.prefix + args).split() result = { 'SubnetId': 'subnet-12345678', 'ImageId': 'ami-foobar', - 'MaxCount': '1', - 'MinCount': '1' + 'MaxCount': 1, + 'MinCount': 1 } self.assert_params_for_cmd(args_list, result) @@ -188,13 +189,13 @@ def test_associate_public_ip_address_and_group_id(self): args += '--associate-public-ip-address --subnet-id subnet-12345678' args_list = (self.prefix + args).split() result = { - 'NetworkInterface.1.DeviceIndex': '0', + 'NetworkInterface.1.DeviceIndex': 0, 'NetworkInterface.1.AssociatePublicIpAddress': 'true', 'NetworkInterface.1.SubnetId': 'subnet-12345678', 'NetworkInterface.1.SecurityGroupId.1': 'sg-12345678', 'ImageId': 'ami-foobar', - 'MaxCount': '1', - 'MinCount': '1' + 'MaxCount': 1, + 'MinCount': 1 } self.assert_params_for_cmd(args_list, result) @@ -205,8 +206,8 @@ def test_group_id_alone(self): result = { 'SecurityGroupId.1': 'sg-12345678', 'ImageId': 'ami-foobar', - 'MaxCount': '1', - 'MinCount': '1' + 'MaxCount': 1, + 'MinCount': 1 } self.assert_params_for_cmd(args_list, result) @@ -216,14 +217,14 @@ def test_associate_public_ip_address_and_private_ip_address(self): args += '--associate-public-ip-address --subnet-id subnet-12345678' args_list = (self.prefix + args).split() result = { - 'NetworkInterface.1.DeviceIndex': '0', + 'NetworkInterface.1.DeviceIndex': 0, 'NetworkInterface.1.AssociatePublicIpAddress': 'true', 'NetworkInterface.1.SubnetId': 'subnet-12345678', 'NetworkInterface.1.PrivateIpAddresses.1.PrivateIpAddress': '10.0.0.200', 'NetworkInterface.1.PrivateIpAddresses.1.Primary': 'true', 'ImageId': 'ami-foobar', - 'MaxCount': '1', - 'MinCount': '1' + 'MaxCount': 1, + 'MinCount': 1 } self.assert_params_for_cmd(args_list, result) @@ -234,8 +235,8 @@ def test_private_ip_address_alone(self): result = { 'PrivateIpAddress': '10.0.0.200', 'ImageId': 'ami-foobar', - 'MaxCount': '1', - 'MinCount': '1' + 'MaxCount': 1, + 'MinCount': 1 } self.assert_params_for_cmd(args_list, result) diff --git a/tests/unit/ec2/test_security_group_operations.py b/tests/unit/ec2/test_security_group_operations.py index df3b4f62211d..5b6f2fb2af8e 100644 --- a/tests/unit/ec2/test_security_group_operations.py +++ b/tests/unit/ec2/test_security_group_operations.py @@ -23,8 +23,8 @@ def test_simple_cidr(self): args = ' --group-name foobar --protocol tcp --port 22-25 --cidr 0.0.0.0/0' args_list = (self.prefix + args).split() result = {'GroupName': 'foobar', - 'IpPermissions.1.FromPort': '22', - 'IpPermissions.1.ToPort': '25', + 'IpPermissions.1.FromPort': 22, + 'IpPermissions.1.ToPort': 25, 'IpPermissions.1.IpProtocol': 'tcp', 'IpPermissions.1.IpRanges.1.CidrIp': '0.0.0.0/0'} self.assert_params_for_cmd(args_list, result) @@ -33,8 +33,8 @@ def test_all_port(self): args = ' --group-name foobar --protocol tcp --port all --cidr 0.0.0.0/0' args_list = (self.prefix + args).split() result = {'GroupName': 'foobar', - 'IpPermissions.1.FromPort': '-1', - 'IpPermissions.1.ToPort': '-1', + 'IpPermissions.1.FromPort': -1, + 'IpPermissions.1.ToPort': -1, 'IpPermissions.1.IpProtocol': 'tcp', 'IpPermissions.1.IpRanges.1.CidrIp': '0.0.0.0/0'} self.assert_params_for_cmd(args_list, result) @@ -43,8 +43,11 @@ def test_all_protocol(self): args = ' --group-name foobar --protocol all --port all --cidr 0.0.0.0/0' args_list = (self.prefix + args).split() result = {'GroupName': 'foobar', - 'IpPermissions.1.FromPort': '-1', - 'IpPermissions.1.ToPort': '-1', + 'IpPermissions.1.FromPort': -1, + 'IpPermissions.1.ToPort': -1, + # This is correct, the expected value is the *string* + # '-1'. This is because the IpProtocol is modeled + # as a string. 'IpPermissions.1.IpProtocol': '-1', 'IpPermissions.1.IpRanges.1.CidrIp': '0.0.0.0/0'} self.assert_params_for_cmd(args_list, result) @@ -87,8 +90,8 @@ def test_ip_permissions(self): args = ' --group-name foobar --ip-permissions %s' % json args_list = (self.prefix + args).split() result = {'GroupName': 'foobar', - 'IpPermissions.1.FromPort': '8000', - 'IpPermissions.1.ToPort': '9000', + 'IpPermissions.1.FromPort': 8000, + 'IpPermissions.1.ToPort': 9000, 'IpPermissions.1.IpProtocol': 'tcp', 'IpPermissions.1.IpRanges.1.CidrIp': '192.168.100.0/24'} self.assert_params_for_cmd(args_list, result) @@ -98,8 +101,8 @@ def test_ip_permissions_with_group_id(self): args = ' --group-id sg-12345678 --ip-permissions %s' % json args_list = (self.prefix + args).split() result = {'GroupId': 'sg-12345678', - 'IpPermissions.1.FromPort': '8000', - 'IpPermissions.1.ToPort': '9000', + 'IpPermissions.1.FromPort': 8000, + 'IpPermissions.1.ToPort': 9000, 'IpPermissions.1.IpProtocol': 'tcp', 'IpPermissions.1.IpRanges.1.CidrIp': '192.168.100.0/24'} self.assert_params_for_cmd(args_list, result) diff --git a/tests/unit/elasticache/test_create_cache_cluster.py b/tests/unit/elasticache/test_create_cache_cluster.py index 0995f24b052c..6f7197377dd3 100644 --- a/tests/unit/elasticache/test_create_cache_cluster.py +++ b/tests/unit/elasticache/test_create_cache_cluster.py @@ -37,7 +37,7 @@ def test_create_cache_cluster(self): 'CacheSecurityGroupNames.member.2': 'group2', 'Engine': 'memcached', 'EngineVersion': '1.4.5', - 'NumCacheNodes': '1', + 'NumCacheNodes': 1, 'PreferredAvailabilityZone': 'us-east-1c', 'PreferredMaintenanceWindow': 'fri:08:00-fri:09:00'} self.assert_params_for_cmd(cmdline, result) @@ -60,7 +60,7 @@ def test_create_cache_cluster_no_auto_minor_upgrade(self): 'CacheSecurityGroupNames.member.2': 'group2', 'Engine': 'memcached', 'EngineVersion': '1.4.5', - 'NumCacheNodes': '1', + 'NumCacheNodes': 1, 'PreferredAvailabilityZone': 'us-east-1c', 'PreferredMaintenanceWindow': 'fri:08:00-fri:09:00'} self.assert_params_for_cmd(cmdline, result) @@ -84,7 +84,7 @@ def test_minor_upgrade_arg_not_specified(self): 'CacheSecurityGroupNames.member.2': 'group2', 'Engine': 'memcached', 'EngineVersion': '1.4.5', - 'NumCacheNodes': '1', + 'NumCacheNodes': 1, 'PreferredAvailabilityZone': 'us-east-1c', 'PreferredMaintenanceWindow': 'fri:08:00-fri:09:00'} self.assert_params_for_cmd(cmdline, result) diff --git a/tests/unit/elb/test_configure_health_check.py b/tests/unit/elb/test_configure_health_check.py index 144a16066140..04c0b7389051 100644 --- a/tests/unit/elb/test_configure_health_check.py +++ b/tests/unit/elb/test_configure_health_check.py @@ -25,11 +25,11 @@ def test_shorthand_basic(self): cmdline += (' --health-check Target=HTTP:80/weather/us/wa/seattle,' 'Interval=300,Timeout=60,UnhealthyThreshold=5,' 'HealthyThreshold=9') - result = {'HealthCheck.HealthyThreshold': '9', - 'HealthCheck.Interval': '300', + result = {'HealthCheck.HealthyThreshold': 9, + 'HealthCheck.Interval': 300, 'HealthCheck.Target': 'HTTP:80/weather/us/wa/seattle', - 'HealthCheck.Timeout': '60', - 'HealthCheck.UnhealthyThreshold': '5', + 'HealthCheck.Timeout': 60, + 'HealthCheck.UnhealthyThreshold': 5, 'LoadBalancerName': 'my-lb'} self.assert_params_for_cmd(cmdline, result) @@ -39,11 +39,11 @@ def test_json(self): cmdline += ('--health-check {"Target":"HTTP:80/weather/us/wa/seattle' '?a=b","Interval":300,"Timeout":60,' '"UnhealthyThreshold":5,"HealthyThreshold":9}') - result = {'HealthCheck.HealthyThreshold': '9', - 'HealthCheck.Interval': '300', + result = {'HealthCheck.HealthyThreshold': 9, + 'HealthCheck.Interval': 300, 'HealthCheck.Target': 'HTTP:80/weather/us/wa/seattle?a=b', - 'HealthCheck.Timeout': '60', - 'HealthCheck.UnhealthyThreshold': '5', + 'HealthCheck.Timeout': 60, + 'HealthCheck.UnhealthyThreshold': 5, 'LoadBalancerName': 'my-lb'} self.assert_params_for_cmd(cmdline, result) @@ -53,11 +53,11 @@ def test_shorthand_with_multiple_equals_for_value(self): cmdline += (' --health-check Target="HTTP:80/weather/us/wa/seattle?a=b"' ',Interval=300,Timeout=60,UnhealthyThreshold=5,' 'HealthyThreshold=9') - result = {'HealthCheck.HealthyThreshold': '9', - 'HealthCheck.Interval': '300', + result = {'HealthCheck.HealthyThreshold': 9, + 'HealthCheck.Interval': 300, 'HealthCheck.Target': 'HTTP:80/weather/us/wa/seattle?a=b', - 'HealthCheck.Timeout': '60', - 'HealthCheck.UnhealthyThreshold': '5', + 'HealthCheck.Timeout': 60, + 'HealthCheck.UnhealthyThreshold': 5, 'LoadBalancerName': 'my-lb'} self.assert_params_for_cmd(cmdline, result) diff --git a/tests/unit/opsworks/create_layer_attributes.json b/tests/unit/opsworks/create_layer_attributes.json index 61ddafd4a96a..941939f0cb32 100644 --- a/tests/unit/opsworks/create_layer_attributes.json +++ b/tests/unit/opsworks/create_layer_attributes.json @@ -1,20 +1,7 @@ -{"MysqlRootPasswordUbiquitous": null, - "RubygemsVersion": "1.8.24", +{"RubygemsVersion": "1.8.24", "RailsStack": "apache_passenger", - "HaproxyHealthCheckMethod": null, "RubyVersion": "1.9.3", "BundlerVersion": "1.2.3", - "HaproxyStatsPassword": null, "PassengerVersion": "3.0.17", - "MemcachedMemory": null, - "EnableHaproxyStats": null, - "ManageBundler": "true", - "NodejsVersion": null, - "HaproxyHealthCheckUrl": null, - "MysqlRootPassword": null, - "GangliaPassword": null, - "GangliaUser": null, - "HaproxyStatsUrl": null, - "GangliaUrl": null, - "HaproxyStatsUser": null + "ManageBundler": "true" } diff --git a/tests/unit/opsworks/test_create_instance.py b/tests/unit/opsworks/test_create_instance.py index 970335721c1a..ebf29361337b 100644 --- a/tests/unit/opsworks/test_create_instance.py +++ b/tests/unit/opsworks/test_create_instance.py @@ -29,7 +29,7 @@ def test_simple(self): 'Hostname': 'aws-client-instance', 'LayerIds': ['cb27894d-35f3-4435-b422-6641a785fa4a'], 'InstanceType': 'c1.medium'} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) if __name__ == "__main__": diff --git a/tests/unit/opsworks/test_create_layer.py b/tests/unit/opsworks/test_create_layer.py index c2ab63d4e622..ad45a4d82a56 100644 --- a/tests/unit/opsworks/test_create_layer.py +++ b/tests/unit/opsworks/test_create_layer.py @@ -36,27 +36,15 @@ def test_attributes_file(self): 'Name': 'Rails_App_Server', 'EnableAutoHealing': True, 'Shortname': 'foo', - 'Attributes': {"MysqlRootPasswordUbiquitous": None, - "RubygemsVersion": "1.8.24", + 'Attributes': {"RubygemsVersion": "1.8.24", "RailsStack": "apache_passenger", - "HaproxyHealthCheckMethod": None, "RubyVersion": "1.9.3", "BundlerVersion": "1.2.3", - "HaproxyStatsPassword": None, "PassengerVersion": "3.0.17", - "MemcachedMemory": None, - "EnableHaproxyStats": None, "ManageBundler": "true", - "NodejsVersion": None, - "HaproxyHealthCheckUrl": None, - "MysqlRootPassword": None, - "GangliaPassword": None, - "GangliaUser": None, - "HaproxyStatsUrl": None, - "GangliaUrl": None, - "HaproxyStatsUser": None} + } } - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) if __name__ == "__main__": diff --git a/tests/unit/opsworks/test_create_stack.py b/tests/unit/opsworks/test_create_stack.py index 6c6a6b2929b1..e793b93b794d 100644 --- a/tests/unit/opsworks/test_create_stack.py +++ b/tests/unit/opsworks/test_create_stack.py @@ -31,7 +31,7 @@ def test_attributes_file(self): 'Region': 'us-west-2', 'DefaultInstanceProfileArn': 'arn-foofoofoo' } - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) if __name__ == "__main__": diff --git a/tests/unit/opsworks/test_describe_layers.py b/tests/unit/opsworks/test_describe_layers.py index 9e519a43b26c..80a258e149dd 100644 --- a/tests/unit/opsworks/test_describe_layers.py +++ b/tests/unit/opsworks/test_describe_layers.py @@ -23,7 +23,7 @@ def test_both_params(self): cmdline = self.prefix cmdline += ' --stack-id 35959772-cd1e-4082-8346-79096d4179f2' result = {'StackId': '35959772-cd1e-4082-8346-79096d4179f2'} - self.assert_params_for_cmd(cmdline, result) + self.assert_params_for_cmd2(cmdline, result) if __name__ == "__main__": diff --git a/tests/unit/rds/test_describe_db_log_files.py b/tests/unit/rds/test_describe_db_log_files.py index 1569fd016345..da0c74bbc08f 100644 --- a/tests/unit/rds/test_describe_db_log_files.py +++ b/tests/unit/rds/test_describe_db_log_files.py @@ -23,5 +23,5 @@ def test_add_option(self): '--db-instance-identifier foo') cmdline = self.prefix + args result = {'DBInstanceIdentifier': 'foo', - 'FileLastWritten': '10'} + 'FileLastWritten': 10} self.assert_params_for_cmd(cmdline, result) diff --git a/tests/unit/route53/test_resource_id.py b/tests/unit/route53/test_resource_id.py index 81b189164aa6..ef529c9d2030 100644 --- a/tests/unit/route53/test_resource_id.py +++ b/tests/unit/route53/test_resource_id.py @@ -23,18 +23,6 @@ '{"Value":"foo-bar-com.us-west-1.elb.amazonaws.com"}' ']}}]}') -CHANGEBATCH_XML = ('' - 'string' - 'CREATE' - '' - 'test-foo.bar.com' - 'CNAME300' - '' - 'foo-bar-com.us-west-1.elb.amazonaws.com' - '' - '' - '') class TestGetHostedZone(BaseAWSCommandParamsTest): @@ -46,18 +34,16 @@ def setUp(self): def test_full_resource_id(self): args = ' --id /hostedzone/ZD3IYMVP1KDDM' cmdline = self.prefix + args - result = {'uri_params': {'Id': 'ZD3IYMVP1KDDM'}, - 'headers': {}} - self.assert_params_for_cmd(cmdline, result, expected_rc=0, - ignore_params=['payload'])[0] + expected_id = 'ZD3IYMVP1KDDM' + self.assert_params_for_cmd(cmdline, expected_rc=0) + self.assertEqual(self.last_kwargs['Id'], expected_id) def test_short_resource_id(self): args = ' --id ZD3IYMVP1KDDM' cmdline = self.prefix + args - result = {'uri_params': {'Id': 'ZD3IYMVP1KDDM'}, - 'headers': {}} - self.assert_params_for_cmd(cmdline, result, expected_rc=0, - ignore_params=['payload'])[0] + expected_id = 'ZD3IYMVP1KDDM' + self.assert_params_for_cmd(cmdline, expected_rc=0) + self.assertEqual(self.last_kwargs['Id'], expected_id) class TestChangeResourceRecord(BaseAWSCommandParamsTest): @@ -71,12 +57,11 @@ def test_full_resource_id(self): args = ' --hosted-zone-id /change/ZD3IYMVP1KDDM' args += ' --change-batch %s' % CHANGEBATCH_JSON cmdline = self.prefix + args - result = {'uri_params': {'HostedZoneId': 'ZD3IYMVP1KDDM'}, - 'headers': {}} - self.assert_params_for_cmd(cmdline, result, expected_rc=0, - ignore_params=['payload'])[0] - self.assertEqual(self.last_params['payload'].getvalue(), - CHANGEBATCH_XML) + expected_hosted_zone = 'ZD3IYMVP1KDDM' + self.assert_params_for_cmd2(cmdline, expected_rc=0) + # Verify that we used the correct value for HostedZoneId. + self.assertEqual(self.last_kwargs['HostedZoneId'], + expected_hosted_zone) class TestGetChange(BaseAWSCommandParamsTest): @@ -89,18 +74,16 @@ def setUp(self): def test_full_resource_id(self): args = ' --id /change/ZD3IYMVP1KDDM' cmdline = self.prefix + args - result = {'uri_params': {'Id': 'ZD3IYMVP1KDDM'}, - 'headers': {}} - self.assert_params_for_cmd(cmdline, result, expected_rc=0, - ignore_params=['payload'])[0] + expected_id = 'ZD3IYMVP1KDDM' + self.assert_params_for_cmd(cmdline, expected_rc=0) + self.assertEqual(self.last_kwargs['Id'], expected_id) def test_short_resource_id(self): args = ' --id ZD3IYMVP1KDDM' cmdline = self.prefix + args - result = {'uri_params': {'Id': 'ZD3IYMVP1KDDM'}, - 'headers': {}} - self.assert_params_for_cmd(cmdline, result, expected_rc=0, - ignore_params=['payload'])[0] + expected_id = 'ZD3IYMVP1KDDM' + self.assert_params_for_cmd(cmdline, expected_rc=0) + self.assertEqual(self.last_kwargs['Id'], expected_id) class TestMaxItems(BaseAWSCommandParamsTest): @@ -110,11 +93,8 @@ class TestMaxItems(BaseAWSCommandParamsTest): def test_full_resource_id(self): args = ' --hosted-zone-id /hostedzone/ABCD --max-items 1' cmdline = self.prefix + args - result = { - 'uri_params': { - 'HostedZoneId': 'ABCD', - }, - 'headers': {} - } - self.assert_params_for_cmd(cmdline, result, expected_rc=0, - ignore_params=['payload'])[0] + expected_hosted_zone = 'ABCD' + self.assert_params_for_cmd2(cmdline, expected_rc=0) + # Verify that we used the correct value for HostedZoneId. + self.assertEqual(self.last_kwargs['HostedZoneId'], + expected_hosted_zone) diff --git a/tests/unit/s3/test_get_object.py b/tests/unit/s3/test_get_object.py index 8d38b2cf2a5f..09dc221d01b4 100644 --- a/tests/unit/s3/test_get_object.py +++ b/tests/unit/s3/test_get_object.py @@ -41,7 +41,8 @@ def test_simple(self): 'Key': 'mykey'}, 'headers': {},} self.addCleanup(self.remove_file_if_exists, 'outfile') - self.assert_params_for_cmd(cmdline, result, ignore_params=['payload']) + self.assert_params_for_cmd2(cmdline, {'Bucket': 'mybucket', + 'Key': 'mykey'}) def test_range(self): cmdline = self.prefix @@ -53,7 +54,9 @@ def test_range(self): 'Key': 'mykey'}, 'headers': {'Range': 'bytes=0-499'},} self.addCleanup(self.remove_file_if_exists, 'outfile') - self.assert_params_for_cmd(cmdline, result, ignore_params=['payload']) + self.assert_params_for_cmd2(cmdline, {'Bucket': 'mybucket', + 'Key': 'mykey', + 'Range': 'bytes=0-499'}) def test_response_headers(self): cmdline = self.prefix @@ -68,7 +71,13 @@ def test_response_headers(self): 'ResponseContentEncoding': 'x-gzip'}, 'headers': {},} self.addCleanup(self.remove_file_if_exists, 'outfile') - self.assert_params_for_cmd(cmdline, result, ignore_params=['payload']) + self.assert_params_for_cmd2( + cmdline, { + 'Bucket': 'mybucket', 'Key': 'mykey', + 'ResponseCacheControl': 'No-cache', + 'ResponseContentEncoding': 'x-gzip' + } + ) if __name__ == "__main__": diff --git a/tests/unit/s3/test_list_objects.py b/tests/unit/s3/test_list_objects.py index bfd38304eccd..6c66aa753c21 100644 --- a/tests/unit/s3/test_list_objects.py +++ b/tests/unit/s3/test_list_objects.py @@ -25,9 +25,7 @@ def setUp(self): def test_simple(self): cmdline = self.prefix cmdline += ' --bucket mybucket' - result = {'uri_params': {'Bucket': 'mybucket'}, - 'headers': {},} - self.assert_params_for_cmd(cmdline, result, ignore_params=['payload']) + self.assert_params_for_cmd2(cmdline, {'Bucket': 'mybucket'}) def test_max_items(self): cmdline = self.prefix @@ -35,9 +33,7 @@ def test_max_items(self): # The max-items is a customization and therefore won't # show up in the result params. cmdline += ' --max-items 100' - result = {'uri_params': {'Bucket': 'mybucket'}, - 'headers': {},} - self.assert_params_for_cmd(cmdline, result, ignore_params=['payload']) + self.assert_params_for_cmd2(cmdline, {'Bucket': 'mybucket'}) def test_page_size(self): cmdline = self.prefix @@ -45,9 +41,8 @@ def test_page_size(self): # The max-items is a customization and therefore won't # show up in the result params. cmdline += ' --page-size 100' - result = {'uri_params': {'Bucket': 'mybucket', 'MaxKeys': 100}, - 'headers': {},} - self.assert_params_for_cmd(cmdline, result, ignore_params=['payload']) + self.assert_params_for_cmd2(cmdline, {'Bucket': 'mybucket', + 'MaxKeys': 100}) def test_starting_token(self): # We don't need to test this in depth because botocore @@ -56,16 +51,13 @@ def test_starting_token(self): cmdline = self.prefix cmdline += ' --bucket mybucket' cmdline += ' --starting-token foo___2' - result = {'uri_params': {'Bucket': 'mybucket', 'Marker': 'foo'}, - 'headers': {},} - self.assert_params_for_cmd(cmdline, result, ignore_params=['payload']) + self.assert_params_for_cmd2(cmdline, {'Bucket': 'mybucket', + 'Marker': 'foo'}) def test_no_paginate(self): cmdline = self.prefix cmdline += ' --bucket mybucket --no-paginate' - result = {'uri_params': {'Bucket': 'mybucket'}, - 'headers': {},} - self.assert_params_for_cmd(cmdline, result, ignore_params=['payload']) + self.assert_params_for_cmd2(cmdline, {'Bucket': 'mybucket'}) def test_max_keys_can_be_specified(self): cmdline = self.prefix @@ -73,9 +65,8 @@ def test_max_keys_can_be_specified(self): # but for back-compat reasons if a user specifies this, # we will automatically see this and turn auto-pagination off. cmdline += ' --bucket mybucket --max-keys 1' - result = {'uri_params': {'Bucket': 'mybucket', 'MaxKeys': 1}, - 'headers': {},} - self.assert_params_for_cmd(cmdline, result, ignore_params=['payload']) + self.assert_params_for_cmd2(cmdline, {'Bucket': 'mybucket', + 'MaxKeys': 1}) self.assertEqual(len(self.operations_called), 1) self.assertEqual(len(self.operations_called), 1) self.assertEqual(self.operations_called[0][0].name, 'ListObjects') diff --git a/tests/unit/s3/test_put_bucket_tagging.py b/tests/unit/s3/test_put_bucket_tagging.py index ce80e58736f1..5c70898e4600 100644 --- a/tests/unit/s3/test_put_bucket_tagging.py +++ b/tests/unit/s3/test_put_bucket_tagging.py @@ -43,9 +43,16 @@ def test_simple(self): cmdline = self.prefix cmdline += ' --bucket mybucket' cmdline += ' --tagging %s' % TAGSET - result = {'uri_params': {'Bucket': 'mybucket'}, - 'headers': {'Content-MD5': '5s++BGwLE2moBAK9duxpFw=='}} - self.assert_params_for_cmd(cmdline, result, ignore_params=['payload']) + expected = { + 'Bucket': 'mybucket', + 'Tagging': { + 'TagSet': [ + {'Key': 'key1', 'Value': 'value1'}, + {'Key': 'key2', 'Value': 'value2'}, + ] + } + } + self.assert_params_for_cmd2(cmdline, expected) if __name__ == "__main__": diff --git a/tests/unit/s3/test_put_object.py b/tests/unit/s3/test_put_object.py index 828eb22c124e..183f0d287beb 100644 --- a/tests/unit/s3/test_put_object.py +++ b/tests/unit/s3/test_put_object.py @@ -48,8 +48,12 @@ def test_simple(self): result = {'uri_params': {'Bucket': 'mybucket', 'Key': 'mykey'}, 'headers': {'Expect': '100-continue'}} - self.assert_params_for_cmd(cmdline, result, ignore_params=['payload']) - self.assertIsInstance(self.last_params['payload'].getvalue(), file_type) + expected = { + 'Bucket': 'mybucket', + 'Key': 'mykey' + } + self.assert_params_for_cmd2(cmdline, expected, ignore_params=['Body']) + self.assertEqual(self.last_kwargs['Body'].name, self.file_path) def test_headers(self): cmdline = self.prefix @@ -59,15 +63,15 @@ def test_headers(self): cmdline += ' --acl public-read' cmdline += ' --content-encoding x-gzip' cmdline += ' --content-type text/plain' - result = {'uri_params': {'Bucket': 'mybucket', 'Key': 'mykey'}, - 'headers': {'x-amz-acl': 'public-read', - 'Content-Encoding': 'x-gzip', - 'Content-Type': 'text/plain', - 'Expect': '100-continue', - }} - self.assert_params_for_cmd(cmdline, result, ignore_params=['payload']) - payload = self.last_params['payload'].getvalue() - self.assertEqual(payload.name, self.file_path) + expected = { + 'ACL': 'public-read', + 'Bucket': 'mybucket', + 'ContentEncoding': 'x-gzip', + 'ContentType': 'text/plain', + 'Key': 'mykey' + } + self.assert_params_for_cmd2(cmdline, expected, ignore_params=['Body']) + self.assertEqual(self.last_kwargs['Body'].name, self.file_path) def test_website_redirect(self): cmdline = self.prefix @@ -75,13 +79,13 @@ def test_website_redirect(self): cmdline += ' --key mykey' cmdline += ' --acl public-read' cmdline += ' --website-redirect-location http://www.example.com/' - result = { - 'uri_params': {'Bucket': 'mybucket', 'Key': 'mykey'}, - 'headers': { - 'x-amz-acl': 'public-read', - 'x-amz-website-redirect-location': 'http://www.example.com/', - }} - self.assert_params_for_cmd(cmdline, result, ignore_params=['payload']) + expected = { + 'ACL': 'public-read', + 'Bucket': 'mybucket', + 'Key': 'mykey', + 'WebsiteRedirectLocation': 'http://www.example.com/' + } + self.assert_params_for_cmd2(cmdline, expected) if __name__ == "__main__": diff --git a/tests/unit/sqs/test_change_message_visibility.py b/tests/unit/sqs/test_change_message_visibility.py index f654469e8060..38856ce093dd 100644 --- a/tests/unit/sqs/test_change_message_visibility.py +++ b/tests/unit/sqs/test_change_message_visibility.py @@ -28,7 +28,7 @@ def test_all_params(self): cmdline += ' --visibility-timeout 30' result = {'QueueUrl': self.queue_url, 'ReceiptHandle': self.receipt_handle, - 'VisibilityTimeout': '30'} + 'VisibilityTimeout': 30} self.assert_params_for_cmd(cmdline, result) diff --git a/tests/unit/test_argprocess.py b/tests/unit/test_argprocess.py index 41739378e228..e3fa4e2105fb 100644 --- a/tests/unit/test_argprocess.py +++ b/tests/unit/test_argprocess.py @@ -11,84 +11,92 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import json -from awscli.testutils import unittest -from awscli.testutils import BaseCLIDriverTest -from awscli.testutils import temporary_file import mock +from botocore import xform_name +from awscli.testutils import unittest +from awscli.testutils import BaseCLIDriverTest +from awscli.testutils import temporary_file from awscli.clidriver import CLIArgument from awscli.help import OperationHelpCommand from awscli.argprocess import detect_shape_structure from awscli.argprocess import unpack_cli_arg from awscli.argprocess import ParamShorthand +from awscli.argprocess import ParamShorthandDocGen from awscli.argprocess import ParamError from awscli.argprocess import ParamUnknownKeyError from awscli.argprocess import uri_param -from awscli.arguments import CustomArgument - - -MAPHELP = """--attributes key_name=string,key_name2=string -Where valid key names are: - Policy""" +from awscli.arguments import CustomArgument, CLIArgument +from awscli.arguments import ListArgument, BooleanArgument +from awscli.arguments import create_argument_model_from_schema # These tests use real service types so that we can # verify the real shapes of services. class BaseArgProcessTest(BaseCLIDriverTest): - def get_param_object(self, dotted_name): + def get_param_model(self, dotted_name): service_name, operation_name, param_name = dotted_name.split('.') - service = self.session.get_service(service_name) - operation = service.get_operation(operation_name) - for p in operation.params: - if p.name == param_name: - return p + service_model = self.session.get_service_model(service_name) + operation = service_model.operation_model(operation_name) + input_shape = operation.input_shape + required_arguments = input_shape.required_members + is_required = param_name in required_arguments + member_shape = input_shape.members[param_name] + type_name = member_shape.type_name + cli_arg_name = xform_name(param_name, '-') + if type_name == 'boolean': + cls = BooleanArgument + elif type_name == 'list': + cls = ListArgument else: - raise ValueError("Unknown param: %s" % param_name) + cls = CLIArgument + return cls(cli_arg_name, member_shape, mock.Mock(), is_required) class TestURIParams(BaseArgProcessTest): def test_uri_param(self): - p = self.get_param_object('ec2.DescribeInstances.Filters') + p = self.get_param_model('ec2.DescribeInstances.Filters') with temporary_file('r+') as f: json_argument = json.dumps([{"Name": "instance-id", "Values": ["i-1234"]}]) f.write(json_argument) f.flush() - result = uri_param(p, 'file://%s' % f.name) + result = uri_param('event-name', p, 'file://%s' % f.name) self.assertEqual(result, json_argument) def test_uri_param_no_paramfile_false(self): - p = self.get_param_object('ec2.DescribeInstances.Filters') + p = self.get_param_model('ec2.DescribeInstances.Filters') p.no_paramfile = False with temporary_file('r+') as f: json_argument = json.dumps([{"Name": "instance-id", "Values": ["i-1234"]}]) f.write(json_argument) f.flush() - result = uri_param(p, 'file://%s' % f.name) + result = uri_param('event-name', p, 'file://%s' % f.name) self.assertEqual(result, json_argument) def test_uri_param_no_paramfile_true(self): - p = self.get_param_object('ec2.DescribeInstances.Filters') + p = self.get_param_model('ec2.DescribeInstances.Filters') p.no_paramfile = True with temporary_file('r+') as f: json_argument = json.dumps([{"Name": "instance-id", "Values": ["i-1234"]}]) f.write(json_argument) f.flush() - result = uri_param(p, 'file://%s' % f.name) + result = uri_param('event-name', p, 'file://%s' % f.name) self.assertEqual(result, None) + class TestArgShapeDetection(BaseArgProcessTest): def assert_shape_type(self, spec, expected_type): - p = self.get_param_object(spec) - actual_structure = detect_shape_structure(p) + p = self.get_param_model(spec) + actual_structure = detect_shape_structure(p.argument_model) self.assertEqual(actual_structure, expected_type) def assert_custom_shape_type(self, schema, expected_type): - argument = CustomArgument('test', schema=schema) - argument.create_argument_object() - actual_structure = detect_shape_structure(argument.argument_object) + argument_model = create_argument_model_from_schema(schema) + argument = CustomArgument('test', argument_model=argument_model) + actual_structure = detect_shape_structure(argument.argument_model) self.assertEqual(actual_structure, expected_type) def test_detect_scalar(self): @@ -139,12 +147,14 @@ def test_struct_list_scalar(self): class TestParamShorthand(BaseArgProcessTest): + maxDiff = None + def setUp(self): super(TestParamShorthand, self).setUp() self.simplify = ParamShorthand() def test_simplify_structure_scalars(self): - p = self.get_param_object( + p = self.get_param_model( 'elasticbeanstalk.CreateConfigurationTemplate.SourceConfiguration') value = 'ApplicationName=foo,TemplateName=bar' json_value = '{"ApplicationName": "foo", "TemplateName": "bar"}' @@ -154,7 +164,8 @@ def test_simplify_structure_scalars(self): def test_parse_boolean_shorthand(self): bool_param = mock.Mock() - bool_param.type = 'boolean' + bool_param.cli_type_name = 'boolean' + bool_param.argument_model.type_name = 'boolean' self.assertTrue(unpack_cli_arg(bool_param, True)) self.assertTrue(unpack_cli_arg(bool_param, 'True')) self.assertTrue(unpack_cli_arg(bool_param, 'true')) @@ -164,14 +175,14 @@ def test_parse_boolean_shorthand(self): self.assertFalse(unpack_cli_arg(bool_param, 'false')) def test_simplify_map_scalar(self): - p = self.get_param_object('sqs.SetQueueAttributes.Attributes') + p = self.get_param_model('sqs.SetQueueAttributes.Attributes') returned = self.simplify(p, 'VisibilityTimeout=15') json_version = unpack_cli_arg(p, '{"VisibilityTimeout": "15"}') self.assertEqual(returned, {'VisibilityTimeout': '15'}) self.assertEqual(returned, json_version) def test_list_structure_scalars(self): - p = self.get_param_object( + p = self.get_param_model( 'elb.RegisterInstancesWithLoadBalancer.Instances') # Because this is a list type param, we'll use nargs # with argparse which means the value will be presented @@ -181,7 +192,7 @@ def test_list_structure_scalars(self): {'InstanceId': 'instance-2'}]) def test_list_structure_list_scalar(self): - p = self.get_param_object('ec2.DescribeInstances.Filters') + p = self.get_param_model('ec2.DescribeInstances.Filters') expected = [{"Name": "instance-id", "Values": ["i-1", "i-2"]}, {"Name": "architecture", "Values": ["i386"]}] returned = self.simplify( @@ -202,7 +213,7 @@ def test_list_structure_list_scalar(self): self.assertEqual(returned3, expected) def test_list_structure_list_scalar_2(self): - p = self.get_param_object('emr.ModifyInstanceGroups.InstanceGroups') + p = self.get_param_model('emr.ModifyInstanceGroups.InstanceGroups') expected = [ {"InstanceGroupId": "foo", "InstanceCount": 4}, @@ -217,56 +228,15 @@ def test_list_structure_list_scalar_2(self): self.assertEqual(simplified, expected) - def test_list_structure_list_scalar_3(self): - arg = CustomArgument('foo', schema={ - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'Name': { - 'type': 'string' - }, - 'Args': { - 'type': 'array', - 'items': { - 'type': 'string' - } - } - } - } - }) - arg.create_argument_object() - p = arg.argument_object - - expected = [ - {"Name": "foo", - "Args": ["a", "k1=v1", "b"]}, - {"Name": "bar", - "Args": ["baz"]}, - {"Name": "single_kv", - "Args": ["key=value"]}, - {"Name": "single_v", - "Args": ["value"]} - ] - - simplified = self.simplify(p, [ - "Name=foo,Args=[a,k1=v1,b]", - "Name=bar,Args=baz", - "Name=single_kv,Args=[key=value]", - "Name=single_v,Args=[value]" - ]) - - self.assertEqual(simplified, expected) - def test_list_structure_list_multiple_scalar(self): - p = self.get_param_object('elastictranscoder.CreateJob.Playlists') + p = self.get_param_model('elastictranscoder.CreateJob.Playlists') returned = self.simplify( p, ['Name=foo,Format=hslv3,OutputKeys=iphone1,iphone2']) self.assertEqual(returned, [{'OutputKeys': ['iphone1', 'iphone2'], 'Name': 'foo', 'Format': 'hslv3'}]) def test_list_structure_scalars_2(self): - p = self.get_param_object('elb.CreateLoadBalancer.Listeners') + p = self.get_param_model('elb.CreateLoadBalancer.Listeners') expected = [ {"Protocol": "protocol1", "LoadBalancerPort": 1, @@ -289,7 +259,6 @@ def test_list_structure_scalars_2(self): '"InstancePort": 4, "SSLCertificateId": ' '"ssl_certificate_id2"}', ]) - self.maxDiff = None self.assertEqual(returned, expected) simplified = self.simplify(p, [ 'Protocol=protocol1,LoadBalancerPort=1,' @@ -301,32 +270,8 @@ def test_list_structure_scalars_2(self): ]) self.assertEqual(simplified, expected) - def test_struct_list_scalars(self): - schema = { - "type": "object", - "properties": { - "Consistent": { - "type": "boolean", - }, - "Args": { - "type": "array", - "items": { - "type": "string" - } - } - } - } - - argument = CustomArgument('test', schema=schema) - argument.create_argument_object() - p = argument.argument_object - - returned = self.simplify(p, 'Consistent=true,Args=foo1,foo2') - self.assertEqual(returned, {'Consistent': True, - 'Args': ['foo1', 'foo2']}) - def test_keyval_with_long_values(self): - p = self.get_param_object( + p = self.get_param_model( 'dynamodb.UpdateTable.ProvisionedThroughput') value = 'WriteCapacityUnits=10,ReadCapacityUnits=10' returned = self.simplify(p, value) @@ -334,7 +279,7 @@ def test_keyval_with_long_values(self): 'ReadCapacityUnits': 10}) def test_error_messages_for_structure_scalar(self): - p = self.get_param_object( + p = self.get_param_model( 'elasticbeanstalk.CreateConfigurationTemplate.SourceConfiguration') value = 'ApplicationName:foo,TemplateName=bar' error_msg = "Error parsing parameter '--source-configuration'.*should be" @@ -342,7 +287,7 @@ def test_error_messages_for_structure_scalar(self): self.simplify(p, value) def test_mispelled_param_name(self): - p = self.get_param_object( + p = self.get_param_model( 'elasticbeanstalk.CreateConfigurationTemplate.SourceConfiguration') error_msg = 'valid choices.*ApplicationName' with self.assertRaisesRegexp(ParamUnknownKeyError, error_msg): @@ -352,7 +297,7 @@ def test_mispelled_param_name(self): def test_improper_separator(self): # If the user uses ':' instead of '=', we should give a good # error message. - p = self.get_param_object( + p = self.get_param_model( 'elasticbeanstalk.CreateConfigurationTemplate.SourceConfiguration') value = 'ApplicationName:foo,TemplateName:bar' error_msg = "Error parsing parameter '--source-configuration'.*should be" @@ -360,19 +305,19 @@ def test_improper_separator(self): self.simplify(p, value) def test_improper_separator_for_filters_param(self): - p = self.get_param_object('ec2.DescribeInstances.Filters') + p = self.get_param_model('ec2.DescribeInstances.Filters') error_msg = "Error parsing parameter '--filters'.*should be" with self.assertRaisesRegexp(ParamError, error_msg): self.simplify(p, ["Name:tag:Name,Values:foo"]) def test_unknown_key_for_filters_param(self): - p = self.get_param_object('ec2.DescribeInstances.Filters') + p = self.get_param_model('ec2.DescribeInstances.Filters') with self.assertRaisesRegexp(ParamUnknownKeyError, 'valid choices.*Name'): self.simplify(p, ["Names=instance-id,Values=foo,bar"]) def test_csv_syntax_escaped(self): - p = self.get_param_object('cloudformation.CreateStack.Parameters') + p = self.get_param_model('cloudformation.CreateStack.Parameters') returned = self.simplify( p, ["ParameterKey=key,ParameterValue=foo\,bar"]) expected = [{"ParameterKey": "key", @@ -380,7 +325,7 @@ def test_csv_syntax_escaped(self): self.assertEqual(returned, expected) def test_csv_syntax_double_quoted(self): - p = self.get_param_object('cloudformation.CreateStack.Parameters') + p = self.get_param_model('cloudformation.CreateStack.Parameters') returned = self.simplify( p, ['ParameterKey=key,ParameterValue="foo,bar"']) expected = [{"ParameterKey": "key", @@ -388,7 +333,7 @@ def test_csv_syntax_double_quoted(self): self.assertEqual(returned, expected) def test_csv_syntax_single_quoted(self): - p = self.get_param_object('cloudformation.CreateStack.Parameters') + p = self.get_param_model('cloudformation.CreateStack.Parameters') returned = self.simplify( p, ["ParameterKey=key,ParameterValue='foo,bar'"]) expected = [{"ParameterKey": "key", @@ -396,7 +341,7 @@ def test_csv_syntax_single_quoted(self): self.assertEqual(returned, expected) def test_csv_syntax_errors(self): - p = self.get_param_object('cloudformation.CreateStack.Parameters') + p = self.get_param_model('cloudformation.CreateStack.Parameters') error_msg = "Error parsing parameter '--parameters'.*should be" with self.assertRaisesRegexp(ParamError, error_msg): self.simplify(p, ['ParameterKey=key,ParameterValue="foo,bar']) @@ -408,6 +353,76 @@ def test_csv_syntax_errors(self): self.simplify(p, ['ParameterKey=key,ParameterValue="foo,bar\'']) +class TestParamShorthandCustomArguments(BaseArgProcessTest): + + def setUp(self): + super(TestParamShorthandCustomArguments, self).setUp() + self.simplify = ParamShorthand() + + def test_list_structure_list_scalar_custom_arg(self): + schema = { + 'type': 'array', + 'items': { + 'type': 'object', + 'properties': { + 'Name': { + 'type': 'string' + }, + 'Args': { + 'type': 'array', + 'items': { + 'type': 'string' + } + } + } + } + } + argument_model = create_argument_model_from_schema(schema) + cli_argument = CustomArgument('foo', argument_model=argument_model) + + expected = [ + {"Name": "foo", + "Args": ["a", "k1=v1", "b"]}, + {"Name": "bar", + "Args": ["baz"]}, + {"Name": "single_kv", + "Args": ["key=value"]}, + {"Name": "single_v", + "Args": ["value"]} + ] + + simplified = self.simplify(cli_argument, [ + "Name=foo,Args=[a,k1=v1,b]", + "Name=bar,Args=baz", + "Name=single_kv,Args=[key=value]", + "Name=single_v,Args=[value]" + ]) + + self.assertEqual(simplified, expected) + + def test_struct_list_scalars(self): + schema = { + "type": "object", + "properties": { + "Consistent": { + "type": "boolean", + }, + "Args": { + "type": "array", + "items": { + "type": "string" + } + } + } + } + argument_model = create_argument_model_from_schema(schema) + cli_argument = CustomArgument('test', argument_model=argument_model) + + returned = self.simplify(cli_argument, 'Consistent=true,Args=foo1,foo2') + self.assertEqual(returned, {'Consistent': True, + 'Args': ['foo1', 'foo2']}) + + class TestDocGen(BaseArgProcessTest): # These aren't very extensive doc tests, as we want to stay somewhat # flexible and allow the docs to slightly change without breaking these @@ -415,76 +430,66 @@ class TestDocGen(BaseArgProcessTest): def setUp(self): super(TestDocGen, self).setUp() self.simplify = ParamShorthand() + self.shorthand_documenter = ParamShorthandDocGen() + + def get_generated_example_for(self, argument): + # Returns a string containing the generated documentation. + return self.shorthand_documenter.generate_shorthand_example(argument) + + def assert_generated_example_is(self, argument, expected_docs): + generated_docs = self.get_generated_example_for(argument) + self.assertEqual(generated_docs, expected_docs) + + def assert_generated_example_contains(self, argument, expected_to_contain): + generated_docs = self.get_generated_example_for(argument) + self.assertIn(expected_to_contain, generated_docs) def test_gen_map_type_docs(self): - p = self.get_param_object('sqs.SetQueueAttributes.Attributes') - argument = CLIArgument(p.cli_name, p, p.operation) - help_command = OperationHelpCommand( - self.session, p.operation, None, {p.cli_name: argument}, - name='set-queue-attributes', event_class='sqs') - help_command.param_shorthand.add_example_fn(p.cli_name, help_command) - self.assertTrue(p.example_fn) - doc_string = p.example_fn(p) - self.assertIn(MAPHELP, doc_string) + argument = self.get_param_model('sqs.SetQueueAttributes.Attributes') + expected_example_str = ( + "--attributes key_name=string,key_name2=string\n" + "Where valid key names are:\n" + " Policy" + ) + self.assert_generated_example_contains(argument, expected_example_str) def test_gen_list_scalar_docs(self): - p = self.get_param_object( + argument = self.get_param_model( 'elb.RegisterInstancesWithLoadBalancer.Instances') - argument = CLIArgument(p.cli_name, p, p.operation) - help_command = OperationHelpCommand( - self.session, p.operation, None, - {p.cli_name: argument}, - name='register-instances-with-load-balancer', - event_class='elb') - help_command.param_shorthand.add_example_fn(p.cli_name, help_command) - self.assertTrue(p.example_fn) - doc_string = p.example_fn(p) - self.assertEqual(doc_string, - '--instances InstanceId1 InstanceId2 InstanceId3') + doc_string = '--instances InstanceId1 InstanceId2 InstanceId3' + self.assert_generated_example_is(argument, doc_string) def test_gen_list_structure_of_scalars_docs(self): - p = self.get_param_object('elb.CreateLoadBalancer.Listeners') - argument = CLIArgument(p.cli_name, p, p.operation) - help_command = OperationHelpCommand( - self.session, p.operation, None, {p.cli_name: argument}, - name='create-load-balancer', event_class='elb') - help_command.param_shorthand.add_example_fn(p.cli_name, help_command) - self.assertTrue(p.example_fn) - doc_string = p.example_fn(p) - self.assertIn('Key value pairs, with multiple values separated by a space.', doc_string) - self.assertIn('Protocol=string', doc_string) - self.assertIn('LoadBalancerPort=integer', doc_string) - self.assertIn('InstanceProtocol=string', doc_string) - self.assertIn('InstancePort=integer', doc_string) - self.assertIn('SSLCertificateId=string', doc_string) + argument = self.get_param_model('elb.CreateLoadBalancer.Listeners') + generated_example = self.get_generated_example_for(argument) + self.assertIn( + 'Key value pairs, with multiple values separated by a space.', + generated_example) + self.assertIn('Protocol=string', generated_example) + self.assertIn('LoadBalancerPort=integer', generated_example) + self.assertIn('InstanceProtocol=string', generated_example) + self.assertIn('InstancePort=integer', generated_example) + self.assertIn('SSLCertificateId=string', generated_example) def test_gen_list_structure_multiple_scalar_docs(self): - p = self.get_param_object('elastictranscoder.CreateJob.Playlists') - argument = CLIArgument(p.cli_name, p, p.operation) - help_command = OperationHelpCommand( - self.session, p.operation, None, {p.cli_name: argument}, - name='create-job', event_class='elastictranscoder') - help_command.param_shorthand.add_example_fn(p.cli_name, help_command) - doc_string = p.example_fn(p) - s = ('Key value pairs, where values are separated by commas, ' + argument = self.get_param_model('elastictranscoder.CreateJob.Playlists') + expected = ( + 'Key value pairs, where values are separated by commas, ' 'and multiple pairs are separated by spaces.\n' '--playlists Name=string1,Format=string1,OutputKeys=string1,string2 ' 'Name=string1,Format=string1,OutputKeys=string1,string2') - self.assertEqual(doc_string, s) + self.assert_generated_example_is(argument, expected) def test_gen_list_structure_list_scalar_scalar_docs(self): # Verify that we have *two* top level list items displayed, # so we make it clear that multiple values are separated by spaces. - p = self.get_param_object('ec2.DescribeInstances.Filters') - argument = CLIArgument(p.cli_name, p, p.operation) - help_command = OperationHelpCommand( - self.session, p.operation, None, {p.cli_name: argument}, - name='describe-instances', event_class='ec2') - help_command.param_shorthand.add_example_fn(p.cli_name, help_command) - doc_string = p.example_fn(p) - self.assertIn('multiple pairs are separated by spaces', doc_string) + argument = self.get_param_model('ec2.DescribeInstances.Filters') + generated_example = self.get_generated_example_for(argument) + self.assertIn('multiple pairs are separated by spaces', + generated_example) self.assertIn('Name=string1,Values=string1,string2 ' - 'Name=string1,Values=string1,string2', doc_string) + 'Name=string1,Values=string1,string2', + generated_example) def test_gen_structure_list_scalar_docs(self): schema = { @@ -501,20 +506,14 @@ def test_gen_structure_list_scalar_docs(self): } } } + argument_model = create_argument_model_from_schema(schema) + cli_argument = CustomArgument('test', argument_model=argument_model) - argument = CustomArgument('test', schema=schema) - argument.create_argument_object() - - p = argument.argument_object - help_command = OperationHelpCommand( - self.session, p.operation, None, {p.cli_name: argument}, - name='foo', event_class='bar') - help_command.param_shorthand.add_example_fn(p.cli_name, help_command) - - doc_string = p.example_fn(p) + generated_example = self.get_generated_example_for(cli_argument) + self.assertIn('Key value pairs', generated_example) + self.assertIn('Consistent=boolean1,Args=string1,string2', + generated_example) - self.assertIn('Key value pairs', doc_string) - self.assertIn('Consistent=boolean1,Args=string1,string2', doc_string) class TestUnpackJSONParams(BaseArgProcessTest): def setUp(self): @@ -522,7 +521,7 @@ def setUp(self): self.simplify = ParamShorthand() def test_json_with_spaces(self): - p = self.get_param_object('ec2.RunInstances.BlockDeviceMappings') + p = self.get_param_model('ec2.RunInstances.BlockDeviceMappings') # If a user specifies the json with spaces, it will show up as # a multi element list. For example: # --block-device-mappings [{ "DeviceName":"/dev/sdf", diff --git a/tests/unit/test_clidriver.py b/tests/unit/test_clidriver.py index 45fca5b15622..9c1512f617c0 100644 --- a/tests/unit/test_clidriver.py +++ b/tests/unit/test_clidriver.py @@ -18,6 +18,7 @@ import six from botocore.vendored.requests import models from botocore.exceptions import NoCredentialsError +from botocore.compat import OrderedDict import awscli from awscli.clidriver import CLIDriver @@ -126,12 +127,11 @@ def get_service(self, name): # enough of the "right stuff". service = mock.Mock() operation = mock.Mock() - param = mock.Mock() - param.type = 'string' - param.py_name = 'bucket' - param.cli_name = '--bucket' - param.name = 'bucket' - operation.params = [param] + operation.model.input_shape.members = OrderedDict([ + ('Bucket', mock.Mock()), + ('Key', mock.Mock()), + ]) + operation.model.input_shape.required_members = ['Bucket'] operation.cli_name = 'list-objects' operation.name = 'ListObjects' operation.is_streaming.return_value = False @@ -148,12 +148,6 @@ def get_service(self, name): operation.service.session = self return service - def get_service_data(self, service_name): - return {'operations': {'ListObjects': {'input': { - 'members': dict.fromkeys( - ['Bucket', 'Delimiter', 'Marker', 'MaxKeys', 'Prefix']), - }}}} - def user_agent(self): return 'user_agent' @@ -222,12 +216,12 @@ def setUp(self): self.stdout = six.StringIO() self.stderr = six.StringIO() self.stdout_patch = mock.patch('sys.stdout', self.stdout) - self.stdout_patch.start() + #self.stdout_patch.start() self.stderr_patch = mock.patch('sys.stderr', self.stderr) self.stderr_patch.start() def tearDown(self): - self.stdout_patch.stop() + #self.stdout_patch.stop() self.stderr_patch.stop() def assert_events_fired_in_order(self, events): @@ -236,7 +230,7 @@ def assert_events_fired_in_order(self, events): self.assertEqual(actual_events, events) def serialize_param(self, param, value, **kwargs): - if param.py_name == 'bucket': + if kwargs['cli_argument'].name == 'bucket': return value + '-altered!' def test_expected_events_are_emitted_in_order(self): @@ -254,6 +248,7 @@ def test_expected_events_are_emitted_in_order(self): 'operation-args-parsed.s3.list-objects', 'load-cli-arg.s3.list-objects.bucket', 'process-cli-arg.s3.list-objects', + 'load-cli-arg.s3.list-objects.key', ]) def test_create_help_command(self): @@ -282,7 +277,7 @@ def test_cli_driver_changes_args(self): self.session.emitter = emitter driver = CLIDriver(session=self.session) driver.main('s3 list-objects --bucket foo'.split()) - self.assertIn(mock.call.paginate(mock.ANY, bucket='foo-altered!'), + self.assertIn(mock.call.paginate(mock.ANY, Bucket='foo-altered!'), self.session.operation.method_calls) def test_unknown_params_raises_error(self): @@ -424,7 +419,7 @@ def test_aws_with_verify_false(self): endpoint_url=None) def test_aws_with_cacert_env_var(self): - with mock.patch('botocore.endpoint.QueryEndpoint.__init__') as endpoint: + with mock.patch('botocore.endpoint.Endpoint.__init__') as endpoint: http_response = models.Response() http_response.status_code = 200 endpoint.return_value = None @@ -438,7 +433,7 @@ def test_aws_with_cacert_env_var(self): self.assertEqual(call_args[1]['verify'], '/path/cacert.pem') def test_default_to_verifying_ssl(self): - with mock.patch('botocore.endpoint.QueryEndpoint.__init__') as endpoint: + with mock.patch('botocore.endpoint.Endpoint.__init__') as endpoint: http_response = models.Response() http_response.status_code = 200 endpoint.return_value = None @@ -539,6 +534,7 @@ def test_custom_command_paramfile(self): uri_param_mock.assert_called() + @unittest.skip def test_custom_arg_no_paramfile(self): driver = create_clidriver() driver.session.register( diff --git a/tests/unit/test_completer.py b/tests/unit/test_completer.py index fc5365b40a69..bd41e0e0f063 100644 --- a/tests/unit/test_completer.py +++ b/tests/unit/test_completer.py @@ -73,7 +73,9 @@ '--cache-control', '--content-type', '--content-disposition', '--source-region', '--content-encoding', '--content-language', - '--expires', '--grants'] + GLOBALOPTS)), + '--expires', '--grants', '--only-show-errors', + '--expected-size'] + + GLOBALOPTS)), ('aws s3 cp --quiet -', -1, set(['--no-guess-mime-type', '--dryrun', '--recursive', '--content-type', '--follow-symlinks', '--no-follow-symlinks', @@ -83,7 +85,9 @@ '--storage-class', '--sse', '--exclude', '--include', '--source-region', - '--grants'] + GLOBALOPTS)), + '--grants', '--only-show-errors', + '--expected-size'] + + GLOBALOPTS)), ('aws emr ', -1, set(['add-instance-groups', 'add-steps', 'add-tags', 'create-cluster', 'create-default-roles', 'create-hbase-backup', 'describe-cluster', diff --git a/tests/unit/test_errorhandler.py b/tests/unit/test_errorhandler.py index 5e86564bf304..6484e92f1341 100644 --- a/tests/unit/test_errorhandler.py +++ b/tests/unit/test_errorhandler.py @@ -26,12 +26,10 @@ def create_http_response(self, **kwargs): def test_error_handler_client_side(self): response = { - 'CommonPrefixes': [], - 'Contents': [], - 'Errors': [{'Code': 'AccessDenied', - 'HostId': 'foohost', - 'Message': 'Access Denied', - 'RequestId': 'requestid'}], + 'Error': {'Code': 'AccessDenied', + 'HostId': 'foohost', + 'Message': 'Access Denied', + 'RequestId': 'requestid'}, 'ResponseMetadata': {}} handler = errorhandler.ErrorHandler() http_response = self.create_http_response(status_code=403) diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index c5d1b7a7b382..0e62d1c41c22 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -10,141 +10,295 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from awscli.testutils import unittest +import pprint + +from botocore.compat import OrderedDict +from awscli.testutils import unittest from awscli.schema import ParameterRequiredError, SchemaTransformer +from awscli.schema import ShapeNameGenerator -""" -Note: this schema is currently not supported by the ParamShorthand -parser due to its complexity, but is tested here to ensure the -robustness of the transformer. -""" -INPUT_SCHEMA = { - "type": "array", - "items": { - "type": "object", - "properties": { - "Name": { - "type": "string", - "description": "The name of the step. ", - }, - "Jar": { - "type": "string", - "description": "A path to a JAR file run during the step.", - }, - "Args": { - "type": "array", - "description": - "A list of command line arguments to pass to the step.", - "items": { - "type": "string" + +MISSING_TYPE = { + "type": "object", + "properties": { + "Foo": { + "description": "I am a foo" + } + } +} + + +class TestSchemaTransformer(unittest.TestCase): + + maxDiff = None + + def test_missing_top_level_type_raises_exception(self): + transformer = SchemaTransformer() + with self.assertRaises(ParameterRequiredError): + transformer.transform({}) + + def test_missing_type_raises_exception(self): + transformer = SchemaTransformer() + + with self.assertRaises(ParameterRequiredError): + transformer.transform({ + 'type': 'object', + 'properties': { + 'Foo': { + 'description': 'foo', } + } + }) + + def assert_schema_transforms_to(self, schema, transforms_to): + transformer = SchemaTransformer() + actual = transformer.transform(schema) + if actual != transforms_to: + self.fail("Transform failed.\n\nExpected:\n%s\n\nActual:\n%s\n" % ( + pprint.pformat(transforms_to), pprint.pformat(actual))) + + def test_transforms_list_of_single_string(self): + schema = { + 'type': 'array', + 'items': { + 'type': 'string' + } + } + transforms_to = { + 'InputShape': { + 'type': 'list', + 'member': {'shape': 'StringType1'} }, - "MainClass": { - "type": "string", - "description": - "The name of the main class in the specified " - "Java file. If not specified, the JAR file should " - "specify a Main-Class in its manifest file." - }, - "Properties": { - "type": "array", - "description": - "A list of Java properties that are set when the step " - "runs. You can use these properties to pass key value " - "pairs to your main function.", - "items": { - "type": "object", - "properties": { - "Key":{ - "type": "string", - "description": - "The unique identifier of a key value pair." - }, - "Value": { - "type": "string", - "description": - "The value part of the identified key." - } + 'StringType1': {'type': 'string'} + } + self.assert_schema_transforms_to(schema, transforms_to) + + def test_transform_list_of_structures(self): + schema = { + "type": "array", + "items": { + "type": "object", + "properties": { + "arg1": { + "type": "string", + }, + "arg2": { + "type": "integer", } } } } - } -} + transforms_to = { + 'InputShape': { + 'type': 'list', + 'member': { + 'shape': 'StructureType1' + } + }, + 'StructureType1': { + 'type': 'structure', + 'members': { + 'arg1': { + 'shape': 'StringType1', + }, + 'arg2': { + 'shape': 'IntegerType1', + }, + } + }, + 'StringType1': {'type': 'string'}, + 'IntegerType1': {'type': 'integer'}, + } + self.assert_schema_transforms_to(schema, transforms_to) + + def test_transform_required_members_on_structure(self): + pass + + def test_transforms_string(self): + self.assert_schema_transforms_to( + schema={ + 'type': 'string' + }, + transforms_to={ + 'InputShape': {'type': 'string'} + } + ) + + def test_transforms_boolean(self): + self.assert_schema_transforms_to( + schema={ + 'type': 'boolean' + }, + transforms_to={ + 'InputShape': {'type': 'boolean'} + } + ) -EXPECTED_OUTPUT = { - "type": "list", - "members": { - "type": "structure", - "members": { - "Name": { - "type": "string", - "description": "The name of the step. ", + def test_transforms_integer(self): + self.assert_schema_transforms_to( + schema={ + 'type': 'integer' }, - "Jar": { - "type": "string", - "description": "A path to a JAR file run during the step.", + transforms_to={ + 'InputShape': {'type': 'integer'} + } + ) + + def test_transforms_structure(self): + self.assert_schema_transforms_to( + schema={ + "type": "object", + "properties": OrderedDict([ + ("A", {"type": "string"}), + ("B", {"type": "string"}), + ]), }, - "Args": { - "type": "list", - "description": - "A list of command line arguments to pass to the step.", - "members": { - "type": "string" + transforms_to={ + 'InputShape': { + 'type': 'structure', + 'members': { + 'A': {'shape': 'StringType1'}, + 'B': {'shape': 'StringType2'}, + } + }, + 'StringType1': {'type': 'string'}, + 'StringType2': {'type': 'string'}, + } + ) + + def test_description_on_shape_type(self): + self.assert_schema_transforms_to( + schema={ + 'type': 'string', + 'description': 'a description' + }, + transforms_to={ + 'InputShape': { + 'type': 'string', + 'documentation': 'a description' } + } + ) + + def test_enum_on_shape_type(self): + self.assert_schema_transforms_to( + schema={ + 'type': 'string', + 'enum': ['a', 'b'], }, - "MainClass": { - "type": "string", - "description": - "The name of the main class in the specified " - "Java file. If not specified, the JAR file should " - "specify a Main-Class in its manifest file." + transforms_to={ + 'InputShape': { + 'type': 'string', + 'enum': ['a', 'b'] + } + } + ) + + def test_description_on_shape_ref(self): + self.assert_schema_transforms_to( + schema={ + 'type': 'object', + 'description': 'object description', + 'properties': { + 'A': { + 'type': 'string', + 'description': 'string description', + }, + } }, - "Properties": { - "type": "list", - "description": - "A list of Java properties that are set when the step " - "runs. You can use these properties to pass key value " - "pairs to your main function.", - "members": { - "type": "structure", - "members": { - "Key":{ - "type": "string", - "description": - "The unique identifier of a key value pair." - }, - "Value": { - "type": "string", - "description": - "The value part of the identified key." - } + transforms_to={ + 'InputShape': { + 'type': 'structure', + 'documentation': 'object description', + 'members': { + 'A': {'shape': 'StringType1'}, } + }, + 'StringType1': { + 'documentation': 'string description', + 'type': 'string' } } - } - } -} + ) -MISSING_TYPE = { - "type": "object", - "properties": { - "Foo": { - "description": "I am a foo" - } - } -} + def test_required_members_on_structure(self): + # This case is interesting because we actually + # don't support a 'required' key on a member shape ref. + # Now, all the required members are added as a key on the + # parent structure shape. + self.assert_schema_transforms_to( + schema={ + 'type': 'object', + 'properties': { + 'A': {'type': 'string', 'required': True}, + } + }, + transforms_to={ + 'InputShape': { + 'type': 'structure', + # This 'required' key is the change here. + 'required': ['A'], + 'members': { + 'A': {'shape': 'StringType1'}, + } + }, + 'StringType1': {'type': 'string'}, + } + ) -class TestSchemaTransformer(unittest.TestCase): - def test_schema(self): - transformer = SchemaTransformer(INPUT_SCHEMA) - output = transformer.transform() + def test_nested_structure(self): + self.assert_schema_transforms_to( + schema={ + 'type': 'object', + 'properties': { + 'A': { + 'type': 'object', + 'properties': { + 'B': { + 'type': 'object', + 'properties': { + 'C': {'type': 'string'} + } + } + } + }, + } + }, + transforms_to={ + 'InputShape': { + 'type': 'structure', + 'members': { + 'A': {'shape': 'StructureType1'}, + } + }, + 'StructureType1': { + 'type': 'structure', + 'members': { + 'B': {'shape': 'StructureType2'} + } + }, + 'StructureType2': { + 'type': 'structure', + 'members': { + 'C': {'shape': 'StringType1'} + } + }, + 'StringType1': { + 'type': 'string', + } + } + ) - self.assertEqual(output, EXPECTED_OUTPUT) - def test_missing_type(self): - transformer = SchemaTransformer(MISSING_TYPE) +class TestShapeNameGenerator(unittest.TestCase): + def test_generate_name_types(self): + namer = ShapeNameGenerator() + self.assertEqual(namer.new_shape_name('string'), 'StringType1') + self.assertEqual(namer.new_shape_name('list'), 'ListType1') + self.assertEqual(namer.new_shape_name('structure'), 'StructureType1') - with self.assertRaises(ParameterRequiredError): - transformer.transform() + def test_generate_type_multiple_times(self): + namer = ShapeNameGenerator() + self.assertEqual(namer.new_shape_name('string'), 'StringType1') + self.assertEqual(namer.new_shape_name('string'), 'StringType2')