diff --git a/iopipe/agent.py b/iopipe/agent.py index 78ee5339..5b01010c 100644 --- a/iopipe/agent.py +++ b/iopipe/agent.py @@ -89,10 +89,6 @@ def __call__(self, func): def wrapped(event, context): logger.debug("%s wrapped with IOpipe decorator" % repr(func)) - context = ContextWrapper(context, self) - - self.run_hooks("pre:invoke", event=event, context=context) - # if env var IOPIPE_ENABLED is set to False skip reporting if self.config["enabled"] is False: logger.debug("IOpipe agent disabled, skipping reporting") @@ -107,6 +103,15 @@ def wrapped(event, context): ) return func(event, context) + # If context doesn't pass validation, skip reporting + if not self.validate_context(context): + logger.debug("Invalid context, skipping reporting") + return func(event, context) + + context = ContextWrapper(context, self) + + self.run_hooks("pre:invoke", event=event, context=context) + self.report = Report(self, context) signal.signal(signal.SIGALRM, self.handle_timeout) @@ -244,7 +249,31 @@ def submit_future(self, func, *args, **kwargs): def wait_for_futures(self): """ Wait for all futures to complete. This should be done at the end of an - invocation. + an invocation. """ [future for future in futures.as_completed(self.futures)] self.futures = [] + + def validate_context(self, context): + """ + Checks to see if we're working with a valid lambda context object. + + :returns: True if valid, False if not + :rtype: bool + """ + return all( + [ + hasattr(context, attr) + for attr in [ + "aws_request_id", + "function_name", + "function_version", + "get_remaining_time_in_millis", + "invoked_function_arn", + "log_group_name", + "log_stream_name", + "memory_limit_in_mb", + "remaining_time_in_millis", + ] + ] + ) and callable(context.get_remaining_time_in_millis) diff --git a/tests/test_agent.py b/tests/test_agent.py index c646b6a7..a0d8e818 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -123,3 +123,15 @@ def test_sync_http(mock_send_report, handler_with_sync_http, mock_context): handler({}, mock_context) assert iopipe.report.sent + + +def test_validate_context(iopipe, mock_context): + """Asserts that contexts are validated correctly""" + assert iopipe.validate_context(mock_context) is True + + class InvalidContext(object): + pass + + invalid_context = InvalidContext() + + assert iopipe.validate_context(invalid_context) is False