From 3e01e7ef1a9d2f72b6ffa39043a3cbd250ab5395 Mon Sep 17 00:00:00 2001 From: James Hageman Date: Thu, 10 Jan 2019 15:14:36 -0800 Subject: [PATCH] Add Stripe client telemetry to request headers --- stripe/__init__.py | 1 + stripe/api_requestor.py | 23 +++++++++++++++- stripe/request_metrics.py | 13 +++++++++ tests/test_api_requestor.py | 55 +++++++++++++++++++++++++++++++++++-- 4 files changed, 89 insertions(+), 3 deletions(-) create mode 100644 stripe/request_metrics.py diff --git a/stripe/__init__.py b/stripe/__init__.py index a64425e56..463c418a4 100644 --- a/stripe/__init__.py +++ b/stripe/__init__.py @@ -21,6 +21,7 @@ proxy = None default_http_client = None app_info = None +enable_telemetry = False max_network_retries = 0 ca_bundle_path = os.path.join( os.path.dirname(__file__), 'data/ca-certificates.crt') diff --git a/stripe/api_requestor.py b/stripe/api_requestor.py index 5ad75b6dc..44d5be2f8 100644 --- a/stripe/api_requestor.py +++ b/stripe/api_requestor.py @@ -11,6 +11,7 @@ from stripe import error, oauth_error, http_client, version, util, six from stripe.multipart_data_generator import MultipartDataGenerator from stripe.six.moves.urllib.parse import urlencode, urlsplit, urlunsplit +from stripe.request_metrics import RequestMetrics from stripe.stripe_response import StripeResponse @@ -30,6 +31,10 @@ def _encode_nested_dict(key, data, fmt='%s[%s]'): return d +def _now_ms(): + return int(round(time.time() * 1000)) + + def _api_encode(data): for key, value in six.iteritems(data): key = util.utf8(key) @@ -80,6 +85,8 @@ def __init__(self, key=None, client=None, api_base=None, api_version=None, http_client.new_default_http_client( verify_ssl_certs=verify, proxy=proxy) + self._last_request_metrics = None + @classmethod def format_app_info(cls, info): str = info['name'] @@ -226,6 +233,11 @@ def request_headers(self, api_key, method): if self.api_version is not None: headers['Stripe-Version'] = self.api_version + if stripe.enable_telemetry and self._last_request_metrics: + headers['X-Stripe-Client-Telemetry'] = json.dumps({ + 'last_request_metrics': self._last_request_metrics.payload() + }) + return headers def request_raw(self, method, url, params=None, supplied_headers=None): @@ -287,15 +299,24 @@ def request_raw(self, method, url, params=None, supplied_headers=None): 'Post details', post_data=encoded_params, api_version=self.api_version) + request_start = _now_ms() + rbody, rcode, rheaders = self._client.request_with_retries( method, abs_url, headers, post_data) util.log_info( 'Stripe API response', path=abs_url, response_code=rcode) util.log_debug('API response body', body=rbody) + if 'Request-Id' in rheaders: + request_id = rheaders['Request-Id'] util.log_debug('Dashboard link for request', - link=util.dashboard_link(rheaders['Request-Id'])) + link=util.dashboard_link(request_id)) + if stripe.enable_telemetry: + request_duration_ms = _now_ms() - request_start + self._last_request_metrics = RequestMetrics( + request_id, request_duration_ms) + return rbody, rcode, rheaders, my_api_key def interpret_response(self, rbody, rcode, rheaders): diff --git a/stripe/request_metrics.py b/stripe/request_metrics.py new file mode 100644 index 000000000..34af5d333 --- /dev/null +++ b/stripe/request_metrics.py @@ -0,0 +1,13 @@ +from __future__ import absolute_import, division, print_function + + +class RequestMetrics(object): + def __init__(self, request_id, request_duration_ms): + self.request_id = request_id + self.request_duration_ms = request_duration_ms + + def payload(self): + return { + 'request_id': self.request_id, + 'request_duration_ms': self.request_duration_ms, + } diff --git a/tests/test_api_requestor.py b/tests/test_api_requestor.py index 4fe59e5f8..5839da85b 100644 --- a/tests/test_api_requestor.py +++ b/tests/test_api_requestor.py @@ -39,13 +39,15 @@ class APIHeaderMatcher(object): METHOD_EXTRA_KEYS = {"post": ["Content-Type", "Idempotency-Key"]} def __init__(self, api_key=None, extra={}, request_method=None, - user_agent=None, app_info=None, idempotency_key=None): + user_agent=None, app_info=None, idempotency_key=None, + client_telemetry=None): self.request_method = request_method self.api_key = api_key or stripe.api_key self.extra = extra self.user_agent = user_agent self.app_info = app_info self.idempotency_key = idempotency_key + self.client_telemetry = client_telemetry def __eq__(self, other): return (self._keys_match(other) and @@ -53,7 +55,8 @@ def __eq__(self, other): self._user_agent_match(other) and self._x_stripe_ua_contains_app_info(other) and self._idempotency_key_match(other) and - self._extra_match(other)) + self._extra_match(other) and + self._check_telemetry(other)) def __repr__(self): return ("APIHeaderMatcher(request_method=%s, api_key=%s, extra=%s, " @@ -68,6 +71,8 @@ def _keys_match(self, other): if self.request_method is not None and self.request_method in \ self.METHOD_EXTRA_KEYS: expected_keys.extend(self.METHOD_EXTRA_KEYS[self.request_method]) + if self.client_telemetry: + expected_keys.append('X-Stripe-Client-Telemetry') return sorted(other.keys()) == sorted(expected_keys) def _auth_match(self, other): @@ -100,6 +105,24 @@ def _extra_match(self, other): return True + def _check_telemetry(self, other): + if not self.client_telemetry: + return 'X-Stripe-Client-Telemetry' not in other + + if 'X-Stripe-Client-Telemetry' not in other: + return False + + telemetry = json.loads(other['X-Stripe-Client-Telemetry']) + req_id = telemetry['last_request_metrics']['request_id'] + + if req_id != self.client_telemetry['request_id']: + return False + + if 'request_duration_ms' not in telemetry['last_request_metrics']: + return False + + return True + class QueryMatcher(object): @@ -198,12 +221,15 @@ def setup_stripe(self): orig_attrs = { 'api_key': stripe.api_key, 'api_version': stripe.api_version, + 'enable_telemetry': stripe.enable_telemetry, } stripe.api_key = 'sk_test_123' stripe.api_version = '2017-12-14' + stripe.enable_telemetry = False yield stripe.api_key = orig_attrs['api_key'] stripe.api_version = orig_attrs['api_version'] + stripe.enable_telemetry = orig_attrs['enable_telemetry'] @pytest.fixture def http_client(self, mocker): @@ -368,6 +394,31 @@ def test_uses_headers(self, requestor, mock_response, check_call): requestor.request('get', self.valid_path, {}, {'foo': 'bar'}) check_call('get', headers=APIHeaderMatcher(extra={'foo': 'bar'})) + def test_telemetry_headers_disabled(self, requestor, mock_response, + check_call): + mock_response('{}', 200, headers={'Request-Id': 1}) + requestor.request('get', self.valid_path, {}) + check_call('get', headers=APIHeaderMatcher(client_telemetry=None)) + + mock_response('{}', 200, headers={'Request-Id': 2}) + requestor.request('get', self.valid_path, {}) + check_call('get', headers=APIHeaderMatcher(client_telemetry=None)) + + def test_telemetry_headers_enabled(self, requestor, mock_response, + check_call): + stripe.enable_telemetry = True + + mock_response('{}', 200, headers={'Request-Id': 1}) + requestor.request('get', self.valid_path, {}) + check_call('get', headers=APIHeaderMatcher(client_telemetry=None)) + + mock_response('{}', 200, headers={'Request-Id': 2}) + requestor.request('get', self.valid_path, {}) + check_call( + 'get', + headers=APIHeaderMatcher(client_telemetry={'request_id': 1}) + ) + def test_uses_instance_key(self, http_client, mock_response, check_call): key = 'fookey' requestor = stripe.api_requestor.APIRequestor(key,