diff --git a/minio/credentials/providers.py b/minio/credentials/providers.py index 80a5e6e4f..b03f82a00 100644 --- a/minio/credentials/providers.py +++ b/minio/credentials/providers.py @@ -186,7 +186,7 @@ def retrieve(self): class IAMProvider(Provider): """ - IAM EC2 credential provider. + IAM EC2/ECS credential provider. expiry_delta param is used to create a window to the token expiration time. If expiry_delta is greater than 0 the @@ -201,9 +201,20 @@ class IAMProvider(Provider): def __init__(self, endpoint=None, http_client=None, - expiry_delta=None): + expiry_delta=None, + is_ecs_task=False): - self._endpoint = endpoint or "http://169.254.169.254" + default_iam_role_endpoint = "http://169.254.169.254" + default_ecs_role_endpoint = "http://169.254.170.2" + + self._is_ecs_task = is_ecs_task + + if self._is_ecs_task: + default_endpoint = default_ecs_role_endpoint + else: + default_endpoint = default_iam_role_endpoint + + self._endpoint = endpoint or default_endpoint self._http_client = http_client or urllib3.PoolManager( retries=urllib3.Retry( total=5, @@ -217,28 +228,39 @@ def __init__(self, self._expiry_delta = expiry_delta def retrieve(self): - """Retrieve credential value and its expiry from IAM EC2.""" - # Get role names. - creds_path = "/latest/meta-data/iam/security-credentials" - url = self._endpoint + creds_path - res = self._http_client.urlopen("GET", url) - if res.status != 200: - raise HTTPError( - "request failed with status {0}".format(res.status), - ) - role_names = res.data.decode("utf-8").split("\n") - if not role_names: - raise ResponseError("no role names found in response") + """ + Retrieve credential value and its expiry from IAM EC2 instance role + or ECS task role. + """ + if not self._is_ecs_task: + # Get role names and get the first role for EC2. + creds_path = "/latest/meta-data/iam/security-credentials" + url = self._endpoint + creds_path + res = self._http_client.urlopen("GET", url) + if res.status != 200: + raise HTTPError( + "request failed with status {0}".format(res.status), + ) + role_names = res.data.decode("utf-8").split("\n") + if not role_names: + raise ResponseError("no role names found in response") + credentials_url = self._endpoint + creds_path + "/" + role_names[0] + else: + # This URL directly gives the credentials for an ECS task + relative_url_var = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" + creds_path = os.environ.get(relative_url_var) or "" + credentials_url = self._endpoint + creds_path - # Get credentials of first role. - url = self._endpoint + creds_path + "/" + role_names[0] - res = self._http_client.urlopen("GET", url) + # Get credentials of role. + res = self._http_client.urlopen("GET", credentials_url) if res.status != 200: raise HTTPError( "request failed with status {0}".format(res.status), ) data = json.loads(res.data) - if data["Code"] != "Success": + + # Note the response in ECS does not include the "Code" key. + if not self._is_ecs_task and data["Code"] != "Success": raise ResponseError( "credential retrieval failed with code {0}".format( data["Code"]), diff --git a/tests/unit/credentials_test.py b/tests/unit/credentials_test.py index 1b0f2cec9..c5f75c490 100644 --- a/tests/unit/credentials_test.py +++ b/tests/unit/credentials_test.py @@ -64,6 +64,17 @@ class CredsResponse(object): }) +class ECSCredsResponse(object): + status = 200 + data = json.dumps({ + "RoleArn": "arn:aws:iam::123456789101:role/my-ecs-role", + "AccessKeyId": "accessKey", + "SecretAccessKey": "secret", + "Token": "token", + "Expiration": "2014-12-16T01:51:37Z", + }) + + class IAMProviderTest(TestCase): @mock.patch("urllib3.PoolManager.urlopen") def test_iam(self, mock_connection): @@ -76,6 +87,19 @@ def test_iam(self, mock_connection): eq_(expiry, datetime(2014, 12, 16, 1, 46, 37)) +class IAMProviderECSTest(TestCase): + @mock.patch("urllib3.PoolManager.urlopen") + def test_iam(self, mock_connection): + mock_connection.side_effect = [ECSCredsResponse()] + provider = IAMProvider( + expiry_delta=timedelta(minutes=5), is_ecs_task=True) + creds, expiry = provider.retrieve() + eq_(creds.access_key, "accessKey") + eq_(creds.secret_key, "secret") + eq_(creds.session_token, "token") + eq_(expiry, datetime(2014, 12, 16, 1, 46, 37)) + + class ChainProviderTest(TestCase): def test_chain_retrieve(self): # clear environment