diff --git a/minio/credentials/providers.py b/minio/credentials/providers.py index 65981a440..80a5e6e4f 100644 --- a/minio/credentials/providers.py +++ b/minio/credentials/providers.py @@ -185,9 +185,24 @@ def retrieve(self): class IAMProvider(Provider): - """IAM EC2 credential provider.""" + """ + IAM EC2 credential provider. + + expiry_delta param is used to create a window to the token + expiration time. If expiry_delta is greater than 0 the + expiration time will be reduced by the delta value. + + Using a delta value is helpful to trigger credentials to + expire sooner than the expiration time given to ensure no + requests are made with expired token. + + """ + + def __init__(self, + endpoint=None, + http_client=None, + expiry_delta=None): - def __init__(self, endpoint=None, http_client=None): self._endpoint = endpoint or "http://169.254.169.254" self._http_client = http_client or urllib3.PoolManager( retries=urllib3.Retry( @@ -196,6 +211,10 @@ def __init__(self, endpoint=None, http_client=None): status_forcelist=[500, 502, 503, 504], ), ) + if expiry_delta is None: + self._expiry_delta = timedelta(seconds=10) + else: + self._expiry_delta = expiry_delta def retrieve(self): """Retrieve credential value and its expiry from IAM EC2.""" @@ -233,7 +252,7 @@ def retrieve(self): data["AccessKeyId"], data["SecretAccessKey"], session_token=data["Token"], - ), expiration + timedelta(minutes=5) + ), expiration - self._expiry_delta class Static(Provider): diff --git a/tests/unit/credentials_test.py b/tests/unit/credentials_test.py index 2087188b3..1b0f2cec9 100644 --- a/tests/unit/credentials_test.py +++ b/tests/unit/credentials_test.py @@ -16,7 +16,7 @@ import json import os -from datetime import datetime +from datetime import datetime, timedelta from unittest import TestCase from nose.tools import eq_, raises @@ -68,12 +68,12 @@ class IAMProviderTest(TestCase): @mock.patch("urllib3.PoolManager.urlopen") def test_iam(self, mock_connection): mock_connection.side_effect = [CredListResponse(), CredsResponse()] - provider = IAMProvider() + provider = IAMProvider(expiry_delta=timedelta(minutes=5)) creds, expiry = provider.retrieve() eq_(creds.access_key, "accessKey") eq_(creds.secret_key, "secret") eq_(creds.session_token, "token") - eq_(expiry, datetime(2014, 12, 16, 1, 56, 37)) + eq_(expiry, datetime(2014, 12, 16, 1, 46, 37)) class ChainProviderTest(TestCase):