Skip to content

Commit

Permalink
IAMProvider accepts ECS IAM Task roles (#960)
Browse files Browse the repository at this point in the history
  • Loading branch information
NickLavrov authored Aug 22, 2020
1 parent 20ddf2a commit ec60a37
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 19 deletions.
60 changes: 41 additions & 19 deletions minio/credentials/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def retrieve(self):

class IAMProvider(Provider):
"""
IAM EC2 credential provider.
IAM EC2/ECS credential provider.
expiry_delta param is used to create a window to the token
expiration time. If expiry_delta is greater than 0 the
Expand All @@ -201,9 +201,20 @@ class IAMProvider(Provider):
def __init__(self,
endpoint=None,
http_client=None,
expiry_delta=None):
expiry_delta=None,
is_ecs_task=False):

self._endpoint = endpoint or "http://169.254.169.254"
default_iam_role_endpoint = "http://169.254.169.254"
default_ecs_role_endpoint = "http://169.254.170.2"

self._is_ecs_task = is_ecs_task

if self._is_ecs_task:
default_endpoint = default_ecs_role_endpoint
else:
default_endpoint = default_iam_role_endpoint

self._endpoint = endpoint or default_endpoint
self._http_client = http_client or urllib3.PoolManager(
retries=urllib3.Retry(
total=5,
Expand All @@ -217,28 +228,39 @@ def __init__(self,
self._expiry_delta = expiry_delta

def retrieve(self):
"""Retrieve credential value and its expiry from IAM EC2."""
# Get role names.
creds_path = "/latest/meta-data/iam/security-credentials"
url = self._endpoint + creds_path
res = self._http_client.urlopen("GET", url)
if res.status != 200:
raise HTTPError(
"request failed with status {0}".format(res.status),
)
role_names = res.data.decode("utf-8").split("\n")
if not role_names:
raise ResponseError("no role names found in response")
"""
Retrieve credential value and its expiry from IAM EC2 instance role
or ECS task role.
"""
if not self._is_ecs_task:
# Get role names and get the first role for EC2.
creds_path = "/latest/meta-data/iam/security-credentials"
url = self._endpoint + creds_path
res = self._http_client.urlopen("GET", url)
if res.status != 200:
raise HTTPError(
"request failed with status {0}".format(res.status),
)
role_names = res.data.decode("utf-8").split("\n")
if not role_names:
raise ResponseError("no role names found in response")
credentials_url = self._endpoint + creds_path + "/" + role_names[0]
else:
# This URL directly gives the credentials for an ECS task
relative_url_var = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
creds_path = os.environ.get(relative_url_var) or ""
credentials_url = self._endpoint + creds_path

# Get credentials of first role.
url = self._endpoint + creds_path + "/" + role_names[0]
res = self._http_client.urlopen("GET", url)
# Get credentials of role.
res = self._http_client.urlopen("GET", credentials_url)
if res.status != 200:
raise HTTPError(
"request failed with status {0}".format(res.status),
)
data = json.loads(res.data)
if data["Code"] != "Success":

# Note the response in ECS does not include the "Code" key.
if not self._is_ecs_task and data["Code"] != "Success":
raise ResponseError(
"credential retrieval failed with code {0}".format(
data["Code"]),
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/credentials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ class CredsResponse(object):
})


class ECSCredsResponse(object):
status = 200
data = json.dumps({
"RoleArn": "arn:aws:iam::123456789101:role/my-ecs-role",
"AccessKeyId": "accessKey",
"SecretAccessKey": "secret",
"Token": "token",
"Expiration": "2014-12-16T01:51:37Z",
})


class IAMProviderTest(TestCase):
@mock.patch("urllib3.PoolManager.urlopen")
def test_iam(self, mock_connection):
Expand All @@ -76,6 +87,19 @@ def test_iam(self, mock_connection):
eq_(expiry, datetime(2014, 12, 16, 1, 46, 37))


class IAMProviderECSTest(TestCase):
@mock.patch("urllib3.PoolManager.urlopen")
def test_iam(self, mock_connection):
mock_connection.side_effect = [ECSCredsResponse()]
provider = IAMProvider(
expiry_delta=timedelta(minutes=5), is_ecs_task=True)
creds, expiry = provider.retrieve()
eq_(creds.access_key, "accessKey")
eq_(creds.secret_key, "secret")
eq_(creds.session_token, "token")
eq_(expiry, datetime(2014, 12, 16, 1, 46, 37))


class ChainProviderTest(TestCase):
def test_chain_retrieve(self):
# clear environment
Expand Down

0 comments on commit ec60a37

Please sign in to comment.