From 9d2bf2ab66650fa926d6b065e71aac746c591c73 Mon Sep 17 00:00:00 2001 From: Viranch Mehta Date: Wed, 2 Oct 2024 18:02:18 -0700 Subject: [PATCH] Add support for mTLS-capable HTTP proxy with self-signed certs --- ns1/__init__.py | 12 ++++++++---- ns1/config.py | 15 ++++++++++++++- ns1/rest/transport/requests.py | 6 ++++++ 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/ns1/__init__.py b/ns1/__init__.py index daab292..5a6c49c 100644 --- a/ns1/__init__.py +++ b/ns1/__init__.py @@ -24,17 +24,21 @@ def __init__(self, apiKey=None, config=None, configFile=None, keyID=None): """ self.config = config - if self.config is None: - self._loadConfig(apiKey, configFile) + if not isinstance(self.config, Config): + self._loadConfig(apiKey, config, configFile) if keyID: self.config.useKeyID(keyID) - def _loadConfig(self, apiKey, configFile): + def _loadConfig(self, apiKey, config, configFile): self.config = Config() if apiKey: - self.config.createFromAPIKey(apiKey) + if config is None: + config = {} + config["apiKey"] = apiKey + + self.config.loadFromDict(config) else: configFile = ( Config.DEFAULT_CONFIG_FILE if not configFile else configFile diff --git a/ns1/config.py b/ns1/config.py index 2728b67..935f99c 100644 --- a/ns1/config.py +++ b/ns1/config.py @@ -84,6 +84,15 @@ def _doDefaults(self): if "follow_pagination" not in self._data: self._data["follow_pagination"] = False + if "http_proxy" not in self._data: + self._data["proxy"] = None + + if "client_cert" not in self._data: + self._data["client_cert"] = None + + if "cert_verify" not in self._data: + self._data["cert_verify"] = True + def createFromAPIKey(self, apikey, maybeWriteDefault=False): """ Create a basic config from a single API key @@ -109,7 +118,11 @@ def loadFromDict(self, d): :param dict d: Python dictionary containing configuration items """ - self._data = d + apikey = d.pop("apiKey", None) + if apikey: + self.createFromAPIKey(apikey) + + self._data.update(d) self._doDefaults() def loadFromString(self, body): diff --git a/ns1/rest/transport/requests.py b/ns1/rest/transport/requests.py index 65cfb9a..3f780b1 100644 --- a/ns1/rest/transport/requests.py +++ b/ns1/rest/transport/requests.py @@ -26,7 +26,13 @@ def __init__(self, config): if not have_requests: raise ImportError("requests module required for RequestsTransport") TransportBase.__init__(self, config, self.__module__) + self.session = requests.Session() + if self._config.get("http_proxy", None): + self.session.proxies = {"https": self._config["http_proxy"]} + self.session.cert = self._config.get("client_cert") + self.session.verify = self._config.get("cert_verify", True) + self.REQ_MAP = { "GET": self.session.get, "POST": self.session.post,