diff --git a/config/incluster_config.py b/config/incluster_config.py index 80853c28..16a21dae 100644 --- a/config/incluster_config.py +++ b/config/incluster_config.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import datetime +import os from kubernetes.client import Configuration @@ -35,30 +35,39 @@ def _join_host_port(host, port): class InClusterConfigLoader(object): - - def __init__(self, token_filename, - cert_filename, environ=os.environ): + def __init__(self, + token_filename, + cert_filename, + try_refresh_token=True, + environ=os.environ): self._token_filename = token_filename self._cert_filename = cert_filename self._environ = environ + self._try_refresh_token = try_refresh_token self._token_refresh_period = datetime.timedelta(minutes=1) - def load_and_set(self, refresh_token=True): + def load_and_set(self, client_configuration=None): + try_set_default = False + if client_configuration is None: + config = type.__call__(Configuration) + try_set_default = True self._load_config() - self._set_config(refresh_token=refresh_token) + self._set_config(client_configuration) + if try_set_default: + Configuration.set_default(config) def _load_config(self): - if (SERVICE_HOST_ENV_NAME not in self._environ or - SERVICE_PORT_ENV_NAME not in self._environ): + if (SERVICE_HOST_ENV_NAME not in self._environ + or SERVICE_PORT_ENV_NAME not in self._environ): raise ConfigException("Service host/port is not set.") - if (not self._environ[SERVICE_HOST_ENV_NAME] or - not self._environ[SERVICE_PORT_ENV_NAME]): + if (not self._environ[SERVICE_HOST_ENV_NAME] + or not self._environ[SERVICE_PORT_ENV_NAME]): raise ConfigException("Service host/port is set but empty.") - self.host = ( - "https://" + _join_host_port(self._environ[SERVICE_HOST_ENV_NAME], - self._environ[SERVICE_PORT_ENV_NAME])) + self.host = ("https://" + + _join_host_port(self._environ[SERVICE_HOST_ENV_NAME], + self._environ[SERVICE_PORT_ENV_NAME])) if not os.path.isfile(self._token_filename): raise ConfigException("Service token file does not exists.") @@ -75,37 +84,38 @@ def _load_config(self): self.ssl_ca_cert = self._cert_filename - def _set_config(self, refresh_token): - configuration = Configuration() - configuration.host = self.host - configuration.ssl_ca_cert = self.ssl_ca_cert - configuration.api_key['authorization'] = "bearer " + self.token - Configuration.set_default(configuration) - if not refresh_token: + def _set_config(self, client_configuration): + client_configuration.host = self.host + client_configuration.ssl_ca_cert = self.ssl_ca_cert + if self.token is not None: + client_configuration.api_key['authorization'] = self.token + if not self._try_refresh_token: return - def wrap(f): - in_cluster_config = self - def wrapped(self, identifier): - if identifier == 'authorization' and identifier in self.api_key and in_cluster_config.token_expires_at <= datetime.datetime.now(): - in_cluster_config._read_token_file() - self.api_key[identifier] = "bearer " + in_cluster_config.token - return f(self, identifier) - return wrapped - Configuration.get_api_key_with_prefix = wrap(Configuration.get_api_key_with_prefix) + + def load_token_from_file(*args): + if self.token_expires_at <= datetime.datetime.now(): + self._read_token_file() + return self.token + + client_configuration.get_api_key_with_prefix = load_token_from_file def _read_token_file(self): with open(self._token_filename) as f: - self.token = f.read() - self.token_expires_at = datetime.datetime.now() + self._token_refresh_period - if not self.token: + content = f.read() + if not content: raise ConfigException("Token file exists but empty.") + self.token = "bearer " + content + self.token_expires_at = datetime.datetime.now( + ) + self._token_refresh_period -def load_incluster_config(refresh_token=True): +def load_incluster_config(client_configuration=None, try_refresh_token=True): """ Use the service account kubernetes gives to pods to connect to kubernetes cluster. It's intended for clients that expect to be running inside a pod running on kubernetes. It will raise an exception if called from a process not running in a kubernetes environment.""" - InClusterConfigLoader(token_filename=SERVICE_TOKEN_FILENAME, - cert_filename=SERVICE_CERT_FILENAME).load_and_set(refresh_token=refresh_token) + InClusterConfigLoader( + token_filename=SERVICE_TOKEN_FILENAME, + cert_filename=SERVICE_CERT_FILENAME, + try_refresh_token=try_refresh_token).load_and_set(client_configuration) diff --git a/config/incluster_config_test.py b/config/incluster_config_test.py index e5698021..ef7468d7 100644 --- a/config/incluster_config_test.py +++ b/config/incluster_config_test.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import os import tempfile -import unittest -import datetime import time +import unittest from kubernetes.client import Configuration @@ -33,14 +33,17 @@ _TEST_IPV6_HOST = "::1" _TEST_IPV6_HOST_PORT = "[::1]:80" -_TEST_ENVIRON = {SERVICE_HOST_ENV_NAME: _TEST_HOST, - SERVICE_PORT_ENV_NAME: _TEST_PORT} -_TEST_IPV6_ENVIRON = {SERVICE_HOST_ENV_NAME: _TEST_IPV6_HOST, - SERVICE_PORT_ENV_NAME: _TEST_PORT} +_TEST_ENVIRON = { + SERVICE_HOST_ENV_NAME: _TEST_HOST, + SERVICE_PORT_ENV_NAME: _TEST_PORT +} +_TEST_IPV6_ENVIRON = { + SERVICE_HOST_ENV_NAME: _TEST_IPV6_HOST, + SERVICE_PORT_ENV_NAME: _TEST_PORT +} class InClusterConfigTest(unittest.TestCase): - def setUp(self): self._temp_files = [] @@ -55,25 +58,18 @@ def _create_file_with_temp_content(self, content=""): os.close(handler) return name - def _overwrite_file_with_content(self, name, content=""): - handler = os.open(name, os.O_RDWR) - os.truncate(name, 0) - os.write(handler, str.encode(content)) - os.close(handler) - - def get_test_loader( - self, - token_filename=None, - cert_filename=None, - environ=_TEST_ENVIRON): + def get_test_loader(self, + token_filename=None, + cert_filename=None, + environ=_TEST_ENVIRON): if not token_filename: token_filename = self._create_file_with_temp_content(_TEST_TOKEN) if not cert_filename: cert_filename = self._create_file_with_temp_content(_TEST_CERT) - return InClusterConfigLoader( - token_filename=token_filename, - cert_filename=cert_filename, - environ=environ) + return InClusterConfigLoader(token_filename=token_filename, + cert_filename=cert_filename, + try_refresh_token=True, + environ=environ) def test_join_host_port(self): self.assertEqual(_TEST_HOST_PORT, @@ -87,25 +83,29 @@ def test_load_config(self): loader._load_config() self.assertEqual("https://" + _TEST_HOST_PORT, loader.host) self.assertEqual(cert_filename, loader.ssl_ca_cert) - self.assertEqual(_TEST_TOKEN, loader.token) + self.assertEqual('bearer ' + _TEST_TOKEN, loader.token) def test_refresh_token(self): loader = self.get_test_loader() - loader._token_refresh_period = datetime.timedelta(seconds=5) - loader.load_and_set() config = Configuration() + loader.load_and_set(config) - self.assertEqual('bearer '+_TEST_TOKEN, config.get_api_key_with_prefix('authorization')) - self.assertEqual(_TEST_TOKEN, loader.token) + self.assertEqual('bearer ' + _TEST_TOKEN, + config.get_api_key_with_prefix('authorization')) + self.assertEqual('bearer ' + _TEST_TOKEN, loader.token) self.assertIsNotNone(loader.token_expires_at) old_token = loader.token old_token_expires_at = loader.token_expires_at - self._overwrite_file_with_content(loader._token_filename, _TEST_NEW_TOKEN) - time.sleep(5) - - self.assertEqual('bearer '+_TEST_NEW_TOKEN, config.get_api_key_with_prefix('authorization')) - self.assertEqual(_TEST_NEW_TOKEN, loader.token) + loader._token_filename = self._create_file_with_temp_content( + _TEST_NEW_TOKEN) + self.assertEqual('bearer ' + _TEST_TOKEN, + config.get_api_key_with_prefix('authorization')) + + loader.token_expires_at = datetime.datetime.now() + self.assertEqual('bearer ' + _TEST_NEW_TOKEN, + config.get_api_key_with_prefix('authorization')) + self.assertEqual('bearer ' + _TEST_NEW_TOKEN, loader.token) self.assertGreater(loader.token_expires_at, old_token_expires_at) def _should_fail_load(self, config_loader, reason): @@ -122,9 +122,10 @@ def test_no_port(self): self._should_fail_load(loader, "no port specified") def test_empty_port(self): - loader = self.get_test_loader( - environ={SERVICE_HOST_ENV_NAME: _TEST_HOST, - SERVICE_PORT_ENV_NAME: ""}) + loader = self.get_test_loader(environ={ + SERVICE_HOST_ENV_NAME: _TEST_HOST, + SERVICE_PORT_ENV_NAME: "" + }) self._should_fail_load(loader, "empty port specified") def test_no_host(self): @@ -133,9 +134,10 @@ def test_no_host(self): self._should_fail_load(loader, "no host specified") def test_empty_host(self): - loader = self.get_test_loader( - environ={SERVICE_HOST_ENV_NAME: "", - SERVICE_PORT_ENV_NAME: _TEST_PORT}) + loader = self.get_test_loader(environ={ + SERVICE_HOST_ENV_NAME: "", + SERVICE_PORT_ENV_NAME: _TEST_PORT + }) self._should_fail_load(loader, "empty host specified") def test_no_cert_file(self):