diff --git a/awscli/clidriver.py b/awscli/clidriver.py index 9ba80784114a..3139bfbc7e73 100644 --- a/awscli/clidriver.py +++ b/awscli/clidriver.py @@ -22,6 +22,15 @@ from .help import get_provider_help, get_service_help, get_operation_help from .formatter import get_formatter from .paramfile import get_paramfile +from .plugin import load_plugins, first_non_none_response +from .hooks import BaseEventHooks + + +def main(): + session = botocore.session.get_session(EnvironmentVariables) + emitter = load_plugins(session.full_config.get('plugins', {})) + driver = CLIDriver(session=session, emitter=emitter) + return driver.main() class CLIDriver(object): @@ -41,13 +50,17 @@ class CLIDriver(object): 'double': float, 'blob': str} - def __init__(self, session=None): + def __init__(self, session=None, emitter=None): if session is None: self.session = botocore.session.get_session(EnvironmentVariables) self.session.user_agent_name = 'aws-cli' self.session.user_agent_version = __version__ else: self.session = session + if emitter is None: + self._emitter = BaseEventHooks() + else: + self._emitter = emitter self.args = None self.service = None self.region = None @@ -89,6 +102,7 @@ def create_main_parser(self): parser.add_argument(option_name, **option_data) parser.add_argument('--version', action="version", version=self.session.user_agent()) + self._emitter.emit('parser-created.main', parser=parser) return parser def create_service_parser(self, remaining, main_parser): @@ -115,6 +129,8 @@ def create_service_parser(self, remaining, main_parser): parser.add_argument('operation', help='The operation', metavar='operation', choices=operations) + self._emitter.emit('parser-created.%s' % self.service.cli_name, + parser=parser) return parser def _create_operation_parser(self, remaining, main_parser): @@ -165,7 +181,6 @@ def _create_operation_parser(self, remaining, main_parser): else: parser.add_argument(param.cli_name, help=param.documentation, - nargs=1, type=self.type_map[param.type], required=param.required, dest=param.py_name) @@ -175,6 +190,8 @@ def _create_operation_parser(self, remaining, main_parser): if 'help' in remaining: get_operation_help(self.operation) return 0 + self._emitter.emit('parser-created.%s-%s' % (self.service.cli_name, + self.operation.cli_name)) return parser def _unpack_cli_arg(self, param, s): @@ -184,17 +201,11 @@ def _unpack_cli_arg(self, param, s): to the Operation. """ if param.type == 'integer': - if isinstance(s, list): - s = s[0] return int(s) elif param.type == 'float' or param.type == 'double': # TODO: losing precision on double types - if isinstance(s, list): - s = s[0] return float(s) elif param.type == 'structure' or param.type == 'map': - if isinstance(s, list) and len(s) == 1: - s = s[0] if s[0] == '{': d = json.loads(s) else: @@ -210,38 +221,53 @@ def _unpack_cli_arg(self, param, s): return json.loads(s[0]) return [self._unpack_cli_arg(param.members, v) for v in s] elif param.type == 'blob' and param.payload and param.streaming: - if isinstance(s, list) and len(s) == 1: - file_path = s[0] - file_path = os.path.expandvars(file_path) + file_path = os.path.expandvars(s) file_path = os.path.expanduser(file_path) if not os.path.isfile(file_path): msg = 'Blob values must be a path to a file.' raise ValueError(msg) return open(file_path, 'rb') else: - if isinstance(s, list): - s = s[0] return str(s) def _build_call_parameters(self, args, param_dict): + service_name = self.service.cli_name + operation_name = self.operation.cli_name for param in self.operation.params: value = getattr(args, param.py_name) if value is not None: - # Don't include non-required boolean params whose - # values are False + # Plugins can override the cli -> python conversion + # process for CLI args. + responses = self._emitter.emit('process-cli-arg.%s.%s' % ( + service_name, operation_name), param=param, value=value, + service=self.service, operation=self.operation) + override = first_non_none_response(responses) + if override is not None: + # A plugin supplied an alternate conversion, + # use it instead. + param_dict[param.py_name] = override + continue + # Otherwise fall back to our normal built in cli -> python + # conversion process. if param.type == 'boolean' and not param.required and \ value is False: + # Don't include non-required boolean params whose + # values are False continue if not hasattr(param, 'no_paramfile'): - if isinstance(value, list) and len(value) == 1: - temp = value[0] - else: - temp = value - temp = get_paramfile(self.session, temp) - if temp: - value = temp + value = self._handle_param_file(value) param_dict[param.py_name] = self._unpack_cli_arg(param, value) + def _handle_param_file(self, value): + if isinstance(value, list) and len(value) == 1: + temp = value[0] + else: + temp = value + temp = get_paramfile(self.session, temp) + if temp: + value = temp + return value + def display_error_and_exit(self, ex): if self.args.debug: traceback.print_exc() @@ -273,20 +299,31 @@ def save_output(self, body_name, response_data, path): data = response_data[body_name].read(buffsize) del response_data[body_name] - def call(self, args): + def _call(self, args): try: params = {} self._build_call_parameters(args, params) self.endpoint = self.service.get_endpoint( self.args.region, endpoint_url=self.args.endpoint_url) self.endpoint.verify = not self.args.no_verify_ssl + self._emitter.emit( + 'before-operation.%s.%s' % (self.service.cli_name, + self.operation.cli_name), + service=self.service, operation=self.operation, + endpoint=self.endpoint, params=params) if self.operation.can_paginate: pages = self.operation.paginate(self.endpoint, **params) + self._emitter.emit( + 'after-operation.%s.%s' % (self.service.cli_name, + self.operation.cli_name), + service=self.service, operation=self.operation, + endpoint=self.endpoint, params=params) self._display_response(self.operation, pages) # TODO: need to handle http error responses. I believe # this will be addressed with the plugin refactoring, # but the other alternative is going to be that we'll need # to cache the fully buffered response. + return 0 else: http_response, response_data = self.operation.call( self.endpoint, **params) @@ -378,7 +415,16 @@ def _parse_args(self, main_parser, args): return -1 return args - def main(self): + def main(self, args=None): + """ + + :param args: List of arguments, with the 'aws' removed. For example, + the command "aws s3 list-objects --bucket foo" will have an + args list of ``['s3', 'list-objects', '--bucket', 'foo']``. + + """ + if args is None: + args = sys.argv[1:] main_parser = self.create_main_parser() - args = self._parse_args(main_parser, sys.argv[1:]) - return self.call(args) + remaining_args = self._parse_args(main_parser, args) + return self._call(remaining_args) diff --git a/awscli/hooks.py b/awscli/hooks.py index 95c2d6fdce8e..6ae1d5408cc4 100644 --- a/awscli/hooks.py +++ b/awscli/hooks.py @@ -15,7 +15,18 @@ from collections import defaultdict -class EventHooks(object): +class BaseEventHooks(object): + def emit(self, event_name, **kwargs): + return [] + + def register(self, event_name, handler): + pass + + def unregister(self, event_name, handler): + pass + + +class EventHooks(BaseEventHooks): def __init__(self): # event_name -> [handler, ...] self._handlers = defaultdict(list) @@ -75,3 +86,26 @@ def _verify_accept_kwargs(self, func): raise ValueError("Event handler %s must accept keyword " "arguments (**kwargs)" % func) + +class HierarchicalEmitter(BaseEventHooks): + def __init__(self, event_hooks): + self._event_hooks = event_hooks + + def emit(self, event, **kwargs): + responses = [] + # Invoke the event handlers from most specific + # to least specific, each time stripping off a dot. + while event: + responses.extend(self._event_hooks.emit(event, **kwargs)) + next_event = event.rsplit('.', 1) + if len(next_event) == 2: + event = next_event[0] + else: + event = None + return responses + + def register(self, event_name, handler): + return self._event_hooks.register(event_name, handler) + + def unregister(self, event_name, handler): + return self._event_hooks.unregister(event_name, handler) diff --git a/awscli/plugin.py b/awscli/plugin.py index 104b0a0fda30..f4d5e6fe869f 100644 --- a/awscli/plugin.py +++ b/awscli/plugin.py @@ -10,70 +10,77 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +import logging -from awscli.hooks import EventHooks +from awscli.hooks import EventHooks, HierarchicalEmitter +log = logging.getLogger('awscli.plugin') -def load_plugins(plugin_names, event_hooks=None): - modules = _import_plugins(plugin_names) + +def load_plugins(plugin_mapping, event_hooks=None): + """ + + :type plugin_mapping: dict + :param plugin_mapping: A dict of plugin name to import path, + e.g. ``{"plugingName": "package.modulefoo"}``. + + :type event_hooks: ``EventHooks`` + :param event_hooks: Event hook emitter. + + :rtype: HierarchicalEmitter + :return: An event emitter object. + + """ + modules = _import_plugins(plugin_mapping) if event_hooks is None: event_hooks = EventHooks() cli = CLI(event_hooks) - for plugin in modules: + for name, plugin in zip(plugin_mapping.keys(), modules): + log.debug("Initializing plugin %s: %s", name, plugin) plugin.awscli_initialize(cli) return HierarchicalEmitter(event_hooks) def _import_plugins(plugin_names): plugins = [] - for name in plugin_names: + for name, path in plugin_names.items(): + log.debug("Importing plugin %s: %s", name, path) if '.' not in name: - plugins.append(__import__(name)) + plugins.append(__import__(path)) return plugins -class HierarchicalEmitter(object): - def __init__(self, event_hooks): - self._event_hooks = event_hooks +def first_non_none_response(responses, default=None): + """Find first non None response in a list of tuples. + + This function can be used to find the first non None response from + handlers connected to an event. This is useful if you are interested + in the returned responses from event handlers. Example usage:: - def emit(self, event): - responses = [] - # Invoke the event handlers from most specific - # to least specific, each time stripping off a dot. - while event: - responses.extend(self._event_hooks.emit(event)) - next_event = event.rsplit('.', 1) - if len(next_event) == 2: - event = next_event[0] - else: - event = None - return responses + print(first_non_none_response([(func1, None), (func2, 'foo'), + (func3, 'bar')])) + # This will print 'foo' + + :type responses: list of tuples + :param responses: The responses from the ``EventHooks.emit`` method. + This is a list of tuples, and each tuple is + (handler, handler_response). + + :param default: If no non-None responses are found, then this default + value will be returned. + + :return: The first non-None response in the list of tuples. + + """ + for response in responses: + if response[1] is not None: + return response[1] + return default class CLI(object): def __init__(self, event_hooks): self._event_hooks = event_hooks - def before_call(self, handler, service_name=None, operation_name=None): - op_event_name = self._get_event_name(service_name, operation_name) - if op_event_name: - event_name = 'before_call.%s' % op_event_name - else: - event_name = 'before_call' + def register(self, event_name, handler): self._event_hooks.register(event_name, handler) - - def after_call(self, handler, service_name=None, operation_name=None): - op_event_name = self._get_event_name(service_name, operation_name) - if op_event_name: - event_name = 'after_call.%s' % op_event_name - else: - event_name = 'after_call' - self._event_hooks.register(event_name, handler) - - def _get_event_name(self, service_name, operation_name): - if service_name is None: - return '' - if service_name is not None and operation_name is None: - return service_name - elif service_name is not None and operation_name is not None: - return '%s.%s' % (service_name, operation_name) diff --git a/bin/aws b/bin/aws index a6883c8c5057..11550cfa6576 100755 --- a/bin/aws +++ b/bin/aws @@ -16,8 +16,7 @@ import awscli.clidriver def main(): - driver = awscli.clidriver.CLIDriver() - return driver.main() + return awscli.clidriver.main() if __name__ == '__main__': diff --git a/bin/aws.cmd b/bin/aws.cmd index 06fe1607e7a8..af78fbb71403 100644 --- a/bin/aws.cmd +++ b/bin/aws.cmd @@ -50,9 +50,8 @@ import awscli.clidriver def main(): - driver = awscli.clidriver.CLIDriver() - driver.main() + return awscli.clidriver.main() if __name__ == '__main__': - main() + sys.exit(main()) diff --git a/tests/unit/ec2/test_describe_instances.py b/tests/unit/ec2/test_describe_instances.py index 5f582247dfea..e3f7a459afc9 100644 --- a/tests/unit/ec2/test_describe_instances.py +++ b/tests/unit/ec2/test_describe_instances.py @@ -57,6 +57,19 @@ def test_filter_values(self): params = self.driver.test(cmdline) self.assertEqual(params, result) + def test_multiple_filters(self): + args = (' --filters {"name":"group-name","values":["foobar"]} ' + '{"name":"instance-id","values":["i-12345"]}') + cmdline = self.prefix + args + result = { + 'Filter.1.Name': 'group-name', + 'Filter.1.Value.1': 'foobar', + 'Filter.2.Name': 'instance-id', + 'Filter.2.Value.1': 'i-12345', + } + params = self.driver.test(cmdline) + self.assertEqual(params, result) + if __name__ == "__main__": diff --git a/tests/unit/test_clidriver.py b/tests/unit/test_clidriver.py index 3d666441ba8c..76e50e35fd97 100644 --- a/tests/unit/test_clidriver.py +++ b/tests/unit/test_clidriver.py @@ -10,20 +10,151 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -import unittest +from tests import unittest import mock from awscli.clidriver import CLIDriver +from awscli.hooks import HierarchicalEmitter, EventHooks + + +GET_DATA = { + 'cli': { + 'description': 'description', + 'options': { + "service_name": { + "choices": "{provider}/_services", + "metavar": "service_name" + }, + "--debug": { + "action": "store_true", + "help": "Turn on debug logging" + }, + "--output": { + "choices": [ + "json", + "text", + "table" + ], + "metavar": "output_format" + }, + "--profile": { + "help": "Use a specific profile from your credential file", + "metavar": "profile_name" + }, + "--region": { + "choices": "{provider}/_regions", + "metavar": "region_name" + }, + "--endpoint-url": { + "help": "Override service's default URL with the given URL", + "metavar": "endpoint_url" + }, + "--no-verify-ssl": { + "action": "store_true", + "help": "Override default behavior of verifying SSL certificates" + }, + } + }, + 'aws/_services': {'s3':{}}, + 'aws/_regions': {}, +} + +GET_VARIABLE = { + 'provider': 'aws', + 'output': 'json', +} + + +class FakeSession(object): + def __init__(self): + pass + + def get_data(self, name): + return GET_DATA[name] + + def get_variable(self, name): + return GET_VARIABLE[name] + + def get_service(self, name): + # Get service returns a service object, + # so we'll just return a Mock object with + # enough of the "right stuff". + service = mock.Mock() + list_objects = mock.Mock(name='operation') + list_objects.cli_name = 'list-objects' + list_objects.params = [] + operation = mock.Mock() + param = mock.Mock() + param.type = 'string' + param.py_name = 'bucket' + param.cli_name = '--bucket' + operation.params = [param] + operation.cli_name = 'list-objects' + operation.is_streaming.return_value = False + operation.paginate.return_value.build_full_result.return_value = { + 'foo': 'bar'} + service.operations = [list_objects] + service.cli_name = 's3' + service.get_operation.return_value = operation + return service + + def user_agent(self): + return 'user_agent' class TestCliDriver(unittest.TestCase): def setUp(self): - pass + self.session = FakeSession() def test_session_can_be_passed_in(self): - session = mock.Mock() - driver = CLIDriver(session=session) - self.assertEqual(driver.session, session) + driver = CLIDriver(session=self.session) + self.assertEqual(driver.session, self.session) + + def test_call(self): + driver = CLIDriver(session=self.session) + rc = driver.main('s3 list-objects --bucket foo'.split()) + self.assertEqual(rc, 0) + + +class TestCliDriverHooks(unittest.TestCase): + # These tests verify the proper hooks are emitted in clidriver. + def setUp(self): + self.session = FakeSession() + self.emitter = mock.Mock() + self.emitter.emit.return_value = [] + + def assert_events_fired_in_order(self, events): + args = self.emitter.emit.call_args_list + actual_events = [arg[0][0] for arg in args] + self.assertEqual(actual_events, events) + + def serialize_param(self, param, value, **kwargs): + if param.py_name == 'bucket': + return value + '-altered!' + + def test_expected_events_are_emitted_in_order(self): + driver = CLIDriver(session=self.session, emitter=self.emitter) + driver.main('s3 list-objects --bucket foo'.split()) + self.assert_events_fired_in_order([ + # Events fired while parser is being created. + 'parser-created.main', + 'parser-created.s3', + 'parser-created.s3-list-objects', + 'process-cli-arg.s3.list-objects', + # Events fired when operation is being invoked. + 'before-operation.s3.list-objects', + 'after-operation.s3.list-objects', + ]) + + def test_cli_driver_changes_args(self): + actual_params = [] + emitter = HierarchicalEmitter(EventHooks()) + emitter.register('process-cli-arg.s3.list-objects', self.serialize_param) + emitter.register('before-operation', + lambda params, **kwargs: actual_params.append(params)) + driver = CLIDriver(session=self.session, emitter=emitter) + driver.main('s3 list-objects --bucket foo'.split()) + self.assertEqual(actual_params, [{'bucket': 'foo-altered!'}]) if __name__ == '__main__': diff --git a/tests/unit/test_hooks.py b/tests/unit/test_hooks.py index 144f45166d6a..3cc851882409 100644 --- a/tests/unit/test_hooks.py +++ b/tests/unit/test_hooks.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from tests import unittest -from awscli.hooks import EventHooks +from awscli.hooks import EventHooks, HierarchicalEmitter class TestEventHooks(unittest.TestCase): @@ -72,5 +72,34 @@ def test_unregister_hook_that_does_not_exist(self): self.assertFalse(self.called) +class TestHierarchicalEventEmitter(unittest.TestCase): + def setUp(self): + self.hooks = EventHooks() + self.emitter = HierarchicalEmitter(self.hooks) + self.hook_calls = [] + + def hook(self, **kwargs): + self.hook_calls.append(kwargs) + + def test_non_dot_behavior(self): + self.emitter.register('no-dot', self.hook) + self.emitter.emit('no-dot') + self.assertEqual(len(self.hook_calls), 1) + + def test_with_dots(self): + self.emitter.register('foo.bar.baz', self.hook) + self.emitter.emit('foo.bar.baz') + self.assertEqual(len(self.hook_calls), 1) + + def test_catch_all_hook(self): + self.emitter.register('foo', self.hook) + self.emitter.register('foo.bar', self.hook) + self.emitter.register('foo.bar.baz', self.hook) + self.emitter.emit('foo.bar.baz') + self.assertEqual(len(self.hook_calls), 3) + self.assertEqual([e['event_name'] for e in self.hook_calls], + ['foo.bar.baz', 'foo.bar', 'foo']) + + if __name__ == '__main__': unittest.main() diff --git a/tests/unit/test_plugin.py b/tests/unit/test_plugin.py index 01cf2c9fca3e..5a05e94017e9 100644 --- a/tests/unit/test_plugin.py +++ b/tests/unit/test_plugin.py @@ -11,9 +11,11 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import sys -import unittest +from tests import unittest from awscli import plugin +from awscli.plugin import first_non_none_response +from awscli import hooks class FakeModule(object): @@ -25,7 +27,9 @@ def __init__(self): def awscli_initialize(self, context): self.called = True self.context = context - self.context.before_call(lambda **kwargs: self.events_seen.append(kwargs)) + self.context.register( + 'before_operation', + (lambda **kwargs: self.events_seen.append(kwargs))) class TestPlugins(unittest.TestCase): @@ -38,17 +42,39 @@ def tearDown(self): del sys.modules['__fake_plugin__'] def test_plugin_register(self): - emitter = plugin.load_plugins(['__fake_plugin__']) + emitter = plugin.load_plugins({'fake_plugin': '__fake_plugin__'}) self.assertTrue(self.fake_module.called) - self.assertTrue(isinstance(emitter, plugin.HierarchicalEmitter)) + self.assertTrue(isinstance(emitter, hooks.HierarchicalEmitter)) self.assertTrue(isinstance(self.fake_module.context, plugin.CLI)) def test_event_hooks_can_be_passed_in(self): hooks = plugin.EventHooks() - emitter = plugin.load_plugins(['__fake_plugin__'], event_hooks=hooks) - emitter.emit('before_call') + emitter = plugin.load_plugins({'fake_plugin': '__fake_plugin__'}, + event_hooks=hooks) + emitter.emit('before_operation') self.assertEqual(len(self.fake_module.events_seen), 1) +class TestFirstNonNoneResponse(unittest.TestCase): + def test_all_none(self): + self.assertIsNone(first_non_none_response([])) + + def test_first_non_none(self): + correct_value = 'correct_value' + wrong_value = 'wrong_value' + # The responses are tuples of (handler, response), + # and we don't care about the handler so we just use a value of + # None. + responses = [(None, None), (None, correct_value), (None, wrong_value)] + self.assertEqual(first_non_none_response(responses), correct_value) + + def test_default_value_if_non_none_found(self): + responses = [(None, None), (None, None)] + # If no response is found and a default value is passed in, it will + # be returned. + self.assertEqual( + first_non_none_response(responses, default='notfound'), 'notfound') + + if __name__ == '__main__': unittest.main()