Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate zip file param validation for AWS Lambda #1296

Merged
merged 8 commits into from
Apr 20, 2015
34 changes: 34 additions & 0 deletions awscli/customizations/awslambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,29 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import zipfile
from contextlib import closing

from botocore.vendored import six

from awscli.arguments import CustomArgument
from awscli.customizations import utils

ERROR_MSG = (
"--zip-file must be a file with the fileb:// prefix.\n"
"Example usage: --zip-file fileb://path/to/file.zip")


def register_lambda_create_function(cli):
cli.register('building-argument-table.lambda.create-function',
_flatten_code_argument)
cli.register('process-cli-arg.lambda.update-function-code',
validate_is_zip_file)


def validate_is_zip_file(cli_argument, value, **kwargs):
if cli_argument.name == 'zip-file':
_should_contain_zip_content(value)


def _flatten_code_argument(argument_table, **kwargs):
Expand All @@ -27,7 +43,25 @@ def _flatten_code_argument(argument_table, **kwargs):
del argument_table['code']


def _should_contain_zip_content(value):
if not isinstance(value, bytes):
# If it's not bytes it's basically impossible for
# this to be valid zip content, but we'll at least
# still try to load the contents as a zip file
# to be absolutely sure.
value = value.encode('utf-8')
fileobj = six.BytesIO(value)
try:
with closing(zipfile.ZipFile(fileobj)) as f:
f.infolist()
except zipfile.BadZipfile:
raise ValueError(ERROR_MSG)


class ZipFileArgument(CustomArgument):
def add_to_params(self, parameters, value):
if value is None:
return
_should_contain_zip_content(value)
zip_file_param = {'ZipFile': value}
parameters['Code'] = zip_file_param
45 changes: 26 additions & 19 deletions awscli/paramfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import logging
import os

Expand Down Expand Up @@ -74,32 +73,43 @@ class ResourceLoadingError(Exception):


def get_paramfile(path):
"""
"""Load parameter based on a resource URI.

It is possible to pass parameters to operations by referring
to files or URI's. If such a reference is detected, this
function attempts to retrieve the data from the file or URI
and returns it. If there are any errors or if the ``path``
does not appear to refer to a file or URI, a ``None`` is
returned.

:type path: str
:param path: The resource URI, e.g. file://foo.txt. This value
may also be a non resource URI, in which case ``None`` is returned.

:return: The loaded value associated with the resource URI.
If the provided ``path`` is not a resource URI, then a
value of ``None`` is returned.

"""
data = None
if isinstance(path, six.string_types):
for prefix in PrefixMap:
for prefix, function_spec in PREFIX_MAP.items():
if path.startswith(prefix):
kwargs = KwargsMap.get(prefix, {})
data = PrefixMap[prefix](prefix, path, **kwargs)
function, kwargs = function_spec
data = function(prefix, path, **kwargs)
return data


def get_file(prefix, path, mode):
file_path = path[len(prefix):]
file_path = os.path.expanduser(file_path)
file_path = os.path.expandvars(file_path)
if not os.path.isfile(file_path):
raise ResourceLoadingError("file does not exist: %s" % file_path)
file_path = os.path.expandvars(os.path.expanduser(path[len(prefix):]))
try:
with compat_open(file_path, mode) as f:
return f.read()
except UnicodeDecodeError:
raise ResourceLoadingError(
'Unable to load paramfile (%s), text contents could '
'not be decoded. If this is a binary file, please use the '
'fileb:// prefix instead of the file:// prefix.' % file_path)
except (OSError, IOError) as e:
raise ResourceLoadingError('Unable to load paramfile %s: %s' % (
path, e))
Expand All @@ -118,12 +128,9 @@ def get_uri(prefix, uri):
raise ResourceLoadingError('Unable to retrieve %s: %s' % (uri, e))


PrefixMap = {'file://': get_file,
'fileb://': get_file,
'http://': get_uri,
'https://': get_uri}

KwargsMap = {'file://': {'mode': 'r'},
'fileb://': {'mode': 'rb'},
'http://': {},
'https://': {}}
PREFIX_MAP = {
'file://': (get_file, {'mode': 'r'}),
'fileb://': (get_file, {'mode': 'rb'}),
'http://': (get_uri, {}),
'https://': (get_uri, {}),
}
87 changes: 64 additions & 23 deletions tests/unit/customizations/test_awslambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,52 +10,48 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import os
import zipfile
from contextlib import closing

from awscli.testutils import unittest
from awscli.testutils import BaseAWSCommandParamsTest
from awscli.testutils import FileCreator


class TestCreateFunction(BaseAWSCommandParamsTest):

prefix = 'lambda create-function'
class BaseLambdaTests(BaseAWSCommandParamsTest):

def setUp(self):
super(TestCreateFunction, self).setUp()

# Make a temporary file
super(BaseLambdaTests, self).setUp()
self.files = FileCreator()
self.contents_of_file = 'myzipcontents'
self.temp_file = self.files.create_file(
'foo', self.contents_of_file)
'foo', 'mycontents')
self.zip_file = os.path.join(self.files.rootdir, 'foo.zip')
with closing(zipfile.ZipFile(self.zip_file, 'w')) as f:
f.write(self.temp_file)
with open(self.zip_file, 'rb') as f:
self.zip_file_contents = f.read()

def tearDown(self):
super(TestCreateFunction, self).tearDown()
super(BaseLambdaTests, self).tearDown()
self.files.remove_all()

def test_create_function(self):
cmdline = self.prefix
cmdline += ' --function-name myfunction --runtime myruntime'
cmdline += ' --role myrole --handler myhandler --zip-file myzip'
result = {
'FunctionName': 'myfunction',
'Runtime': 'myruntime',
'Role': 'myrole',
'Handler': 'myhandler',
'Code': {'ZipFile': 'myzip'}
}
self.assert_params_for_cmd(cmdline, result)

class TestCreateFunction(BaseLambdaTests):

prefix = 'lambda create-function'

def test_create_function_with_file(self):
cmdline = self.prefix
cmdline += ' --function-name myfunction --runtime myruntime'
cmdline += ' --role myrole --handler myhandler'
cmdline += ' --zip-file file://%s' % self.temp_file
cmdline += ' --zip-file fileb://%s' % self.zip_file
result = {
'FunctionName': 'myfunction',
'Runtime': 'myruntime',
'Role': 'myrole',
'Handler': 'myhandler',
'Code': {'ZipFile': self.contents_of_file}
'Code': {'ZipFile': self.zip_file_contents}
}
self.assert_params_for_cmd(cmdline, result)

Expand All @@ -66,3 +62,48 @@ def test_create_function_code_argument_cause_error(self):
cmdline += ' --code mycode'
stdout, stderr, rc = self.run_cmd(cmdline, expected_rc=255)
self.assertIn('Unknown options: --code', stderr)

def test_create_function_with_invalid_file_contents(self):
cmdline = self.prefix
cmdline += ' --function-name myfunction --runtime myruntime'
cmdline += ' --role myrole --handler myhandler'
cmdline += ' --zip-file filename_instead_of_contents.zip'
stdout, stderr, rc = self.run_cmd(cmdline, expected_rc=255)
self.assertIn('must be a file with the fileb:// prefix', stderr)
# Should also give a pointer to fileb:// for them.
self.assertIn('fileb://', stderr)

def test_not_using_fileb_prefix(self):
cmdline = self.prefix
cmdline += ' --function-name myfunction --runtime myruntime'
cmdline += ' --role myrole --handler myhandler'
# Note file:// instead of fileb://
cmdline += ' --zip-file file://%s' % self.zip_file
stdout, stderr, rc = self.run_cmd(cmdline, expected_rc=255)
# Ensure we mention fileb:// to give the user an idea of
# where to go next.
self.assertIn('fileb://', stderr)


class TestUpdateFunctionCode(BaseLambdaTests):

prefix = 'lambda update-function-code'

def test_not_using_fileb_prefix(self):
cmdline = self.prefix + ' --function-name foo'
cmdline += ' --zip-file filename_instead_of_contents.zip'
stdout, stderr, rc = self.run_cmd(cmdline, expected_rc=255)
self.assertIn('must be a file with the fileb:// prefix', stderr)
# Should also give a pointer to fileb:// for them.
self.assertIn('fileb://', stderr)

def test_using_fileb_prefix_succeeds(self):
cmdline = self.prefix
cmdline += ' --function-name myfunction'
cmdline += ' --zip-file fileb://%s' % self.zip_file
result = {
'FunctionName': 'myfunction',
'ZipFile': self.zip_file_contents,
}
self.assert_params_for_cmd(cmdline, result)

6 changes: 4 additions & 2 deletions tests/unit/test_clidriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,9 +543,11 @@ def test_file_param_does_not_exist(self):
rc = driver.main('ec2 describe-instances '
'--filters file://does/not/exist.json'.split())
self.assertEqual(rc, 255)
error_msg = self.stderr.getvalue()
self.assertIn("Error parsing parameter '--filters': "
"file does not exist: does/not/exist.json",
self.stderr.getvalue())
"Unable to load paramfile file://does/not/exist.json",
error_msg)
self.assertIn("No such file or directory", error_msg)

def test_aws_configure_in_error_message_no_credentials(self):
driver = create_clidriver()
Expand Down
55 changes: 53 additions & 2 deletions tests/unit/test_paramfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import mock
from awscli.compat import six

from awscli.paramfile import get_paramfile
from awscli.testutils import unittest, FileCreator

from awscli.paramfile import get_paramfile, ResourceLoadingError


class TestParamFile(unittest.TestCase):
def setUp(self):
Expand All @@ -38,3 +39,53 @@ def test_binary_file(self):
data = get_paramfile(prefixed_filename)
self.assertEqual(data, b'This is a test')
self.assertIsInstance(data, six.binary_type)

def test_cannot_load_text_file(self):
contents = b'\xbfX\xac\xbe'
filename = self.files.create_file('foo', contents, mode='wb')
prefixed_filename = 'file://' + filename
with self.assertRaises(ResourceLoadingError):
get_paramfile(prefixed_filename)

def test_file_does_not_exist_raises_error(self):
with self.assertRaises(ResourceLoadingError):
get_paramfile('file://file/does/not/existsasdf.txt')

def test_no_match_uris_returns_none(self):
self.assertIsNone(get_paramfile('foobar://somewhere.bar'))

def test_non_string_type_returns_none(self):
self.assertIsNone(get_paramfile(100))


class TestHTTPBasedResourceLoading(unittest.TestCase):
def setUp(self):
self.requests_patch = mock.patch('awscli.paramfile.requests')
self.requests_mock = self.requests_patch.start()
self.response = mock.Mock(status_code=200)
self.requests_mock.get.return_value = self.response

def tearDown(self):
self.requests_patch.stop()

def test_resource_from_http(self):
self.response.text = 'http contents'
loaded = get_paramfile('http://foo.bar.baz')
self.assertEqual(loaded, 'http contents')
self.requests_mock.get.assert_called_with('http://foo.bar.baz')

def test_resource_from_https(self):
self.response.text = 'http contents'
loaded = get_paramfile('https://foo.bar.baz')
self.assertEqual(loaded, 'http contents')
self.requests_mock.get.assert_called_with('https://foo.bar.baz')

def test_non_200_raises_error(self):
self.response.status_code = 500
with self.assertRaisesRegexp(ResourceLoadingError, 'foo\.bar\.baz'):
get_paramfile('https://foo.bar.baz')

def test_connection_error_raises_error(self):
self.requests_mock.get.side_effect = Exception("Connection error.")
with self.assertRaisesRegexp(ResourceLoadingError, 'foo\.bar\.baz'):
get_paramfile('https://foo.bar.baz')