diff --git a/aws_xray_sdk/core/plugins/ec2_plugin.py b/aws_xray_sdk/core/plugins/ec2_plugin.py index c5aea567..8f5467a0 100644 --- a/aws_xray_sdk/core/plugins/ec2_plugin.py +++ b/aws_xray_sdk/core/plugins/ec2_plugin.py @@ -1,3 +1,4 @@ +import json import logging from future.standard_library import install_aliases from urllib.request import urlopen, Request @@ -19,40 +20,51 @@ def initialize(): """ global runtime_context - # Try the IMDSv2 endpoint for metadata - try: - runtime_context = {} + # get session token with 60 seconds TTL to not have the token lying around for a long time + token = get_token() + + # get instance metadata + runtime_context = get_metadata(token) + - # get session token with 60 seconds TTL to not have the token lying around for a long time +def get_token(): + token = None + try: + headers = {"X-aws-ec2-metadata-token-ttl-seconds": "60"} token = do_request(url=IMDS_URL + "api/token", - headers={"X-aws-ec2-metadata-token-ttl-seconds": "60"}, + headers=headers, method="PUT") + except Exception: + log.warning("Failed to get token for IMDSv2") + return token - # get instance-id metadata - runtime_context['instance_id'] = do_request(url=IMDS_URL + "meta-data/instance-id", - headers={"X-aws-ec2-metadata-token": token}, - method="GET") - # get availability-zone metadata - runtime_context['availability_zone'] = do_request(url=IMDS_URL + "meta-data/placement/availability-zone", - headers={"X-aws-ec2-metadata-token": token}, - method="GET") +def get_metadata(token=None): + try: + header = None + if token: + header = {"X-aws-ec2-metadata-token": token} - except Exception as e: - # Falling back to IMDSv1 endpoint - log.debug("failed to get ec2 instance metadata from IMDSv2 due to {}. Falling back to IMDSv1".format(e)) + metadata_json = do_request(url=IMDS_URL + "dynamic/instance-identity/document", + headers=header, + method="GET") - try: - runtime_context = {} + return parse_metadata_json(metadata_json) + except Exception: + log.warning("Failed to get EC2 metadata") + return {} - runtime_context['instance_id'] = do_request(url=IMDS_URL + "meta-data/instance-id") - runtime_context['availability_zone'] = do_request(url=IMDS_URL + "meta-data/placement/availability-zone-1") +def parse_metadata_json(json_str): + data = json.loads(json_str) + dict = { + 'instance_id': data['instanceId'], + 'availability_zone': data['availabilityZone'], + 'instance_type': data['instanceType'], + 'ami_id': data['imageId'] + } - except Exception as e: - runtime_context = None - log.debug("failed to get ec2 instance metadata from IMDSv1 due to {}".format(e)) - log.warning("Failed to get ec2 instance metadata") + return dict def do_request(url, headers=None, method="GET"): @@ -61,7 +73,7 @@ def do_request(url, headers=None, method="GET"): if url is None: return None - + req = Request(url=url) req.headers = headers req.method = method diff --git a/tests/test_plugins.py b/tests/test_plugins.py index ca3e73c3..032bde72 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -18,26 +18,36 @@ def test_runtime_context_available(): @patch('aws_xray_sdk.core.plugins.ec2_plugin.do_request') def test_ec2_plugin_imdsv2_success(mock_do_request): - mock_do_request.side_effect = ['token', 'i-0a1d026d92d4709cd', 'us-west-2b'] + v2_json_str = "{\"availabilityZone\" : \"us-east-2a\", \"imageId\" : \"ami-03cca83dd001d4666\"," \ + " \"instanceId\" : \"i-07a181803de94c666\", \"instanceType\" : \"t3.xlarge\"}" + + mock_do_request.side_effect = ['token', v2_json_str] ec2_plugin = get_plugin_modules(('ec2_plugin',))[0] ec2_plugin.initialize() assert hasattr(ec2_plugin, 'runtime_context') r_c = getattr(ec2_plugin, 'runtime_context') - assert r_c['instance_id'] == 'i-0a1d026d92d4709cd' - assert r_c['availability_zone'] == 'us-west-2b' + assert r_c['instance_id'] == 'i-07a181803de94c666' + assert r_c['availability_zone'] == 'us-east-2a' + assert r_c['instance_type'] == 't3.xlarge' + assert r_c['ami_id'] == 'ami-03cca83dd001d4666' @patch('aws_xray_sdk.core.plugins.ec2_plugin.do_request') def test_ec2_plugin_v2_fail_v1_success(mock_do_request): - mock_do_request.side_effect = [Exception("Boom!"), 'i-0a1d026d92d4709ab', 'us-west-2a'] + v1_json_str = "{\"availabilityZone\" : \"cn-north-1a\", \"imageId\" : \"ami-03cca83dd001d4111\"," \ + " \"instanceId\" : \"i-07a181803de94c111\", \"instanceType\" : \"t2.xlarge\"}" + + mock_do_request.side_effect = [Exception("Boom!"), v1_json_str] ec2_plugin = get_plugin_modules(('ec2_plugin',))[0] ec2_plugin.initialize() assert hasattr(ec2_plugin, 'runtime_context') r_c = getattr(ec2_plugin, 'runtime_context') - assert r_c['instance_id'] == 'i-0a1d026d92d4709ab' - assert r_c['availability_zone'] == 'us-west-2a' + assert r_c['instance_id'] == 'i-07a181803de94c111' + assert r_c['availability_zone'] == 'cn-north-1a' + assert r_c['instance_type'] == 't2.xlarge' + assert r_c['ami_id'] == 'ami-03cca83dd001d4111' @patch('aws_xray_sdk.core.plugins.ec2_plugin.do_request') @@ -48,4 +58,4 @@ def test_ec2_plugin_v2_fail_v1_fail(mock_do_request): ec2_plugin.initialize() assert hasattr(ec2_plugin, 'runtime_context') r_c = getattr(ec2_plugin, 'runtime_context') - assert r_c is None + assert r_c == {}