From ca97302f1ef3df724129b36e6c25099f11afc819 Mon Sep 17 00:00:00 2001 From: Marko Bastovanovic Date: Fri, 1 Mar 2019 10:49:57 +0100 Subject: [PATCH] Switch to using aws cli credentials + Add support for profiles --- README.md | 11 +-- emr_cost_calculator.py | 207 +++++++++++++++++++++++------------------ 2 files changed, 117 insertions(+), 101 deletions(-) diff --git a/README.md b/README.md index 9dd035b..a970ab9 100644 --- a/README.md +++ b/README.md @@ -17,21 +17,14 @@ This module is using [docopt](http://docopt.org/) to parse command line argument It currently supports two operations: 1. Get the total cost of an EMR workflow for a given period of days - * `emr_cost_calculator.py total --region= --created_after= --created_before=` + * `emr_cost_calculator.py total --created_after= --created_before=` 2. Get the cost of an EMR cluster given the cluster id - * `emr_cost_calculator.py cluster --region= --cluster_id=` + * `emr_cost_calculator.py cluster --cluster_id=` Authentication to AWS API is done using credentials of AWS CLI which are configured by executing `aws configure` -Alternatively, you can provide credentials by using aws_access_key_id and the aws_secret_access_key script parameters. - -Or, you can set the environment variables: - -`AWS_ACCESS_KEY_ID - Your AWS Access Key ID -AWS_SECRET_ACCESS_KEY - Your AWS Secret Access Key` - ### Install To install all requirements it's best to use diff --git a/emr_cost_calculator.py b/emr_cost_calculator.py index 8bbe17b..736739e 100755 --- a/emr_cost_calculator.py +++ b/emr_cost_calculator.py @@ -2,30 +2,26 @@ """EMR cost calculator Usage: - emr_cost_calculator.py total --region= \ ---created_after= --created_before= \ -[--aws_access_key_id= --aws_secret_access_key=] - emr_cost_calculator.py cluster --region= --cluster_id= \ -[--aws_access_key_id= --aws_secret_access_key=] + emr_cost_calculator.py total --created_after= --created_before= + [--profile=] + emr_cost_calculator.py cluster --cluster_id= [--profile=] emr_cost_calculator.py -h | --help Options: -h --help Show this screen - total Calculate the total EMR cost \ -for a period of time - cluster Calculate the cost of single \ -cluster given the cluster id - --region= The aws region that the \ -cluster was launched on - --aws_access_key_id= Self-explanatory - --aws_secret_access_key= Self-explanatory - --created_after= The calculator will compute \ -the cost for all the cluster created after the created_after day - --created_before= The calculator will compute \ -the cost for all the cluster created before the created_before day - --cluster_id= The id of the cluster you want to \ -calculate the cost for + total Calculate the total EMR cost for a period of + time + cluster Calculate the cost of single cluster given + the cluster id + --profile= Use a specific AWS profile from your + credential file instead of a default one. + --created_after= The calculator will compute the cost for all + the cluster created after the created_after day + --created_before= The calculator will compute the cost for all + the cluster created before the created_before day + --cluster_id= The id of the cluster you want to calculate + the cost for """ from docopt import docopt @@ -75,39 +71,55 @@ def __init__(self, group_id, instance_type, market_type, group_type): class Ec2EmrPricing: - def __init__(self, region): - url_base = 'https://pricing.us-east-1.amazonaws.com' + def __init__(self): + my_session = boto3.session.Session() + region = my_session.region_name + url_base = 'https://pricing.' + region + '.amazonaws.com' index_response = requests.get(url_base + '/offers/v1.0/aws/index.json') index = index_response.json() - emr_regions_response = requests.get(url_base + index['offers']['ElasticMapReduce']['currentRegionIndexUrl']) - emr_region_url = url_base + emr_regions_response.json()['regions'][region]['currentVersionUrl'] + emr_regions_response = requests.get( + url_base + index['offers']['ElasticMapReduce'][ + 'currentRegionIndexUrl']) + emr_region_url = url_base + \ + emr_regions_response.json()['regions'][region][ + 'currentVersionUrl'] emr_pricing = requests.get(emr_region_url).json() sku_to_instance_type = {} for sku in emr_pricing['products']: - if emr_pricing['products'][sku]['attributes']['softwareType'] == 'EMR': - sku_to_instance_type[sku] = emr_pricing['products'][sku]['attributes']['instanceType'] + if emr_pricing['products'][sku]['attributes'][ + 'softwareType'] == 'EMR': + sku_to_instance_type[sku] = \ + emr_pricing['products'][sku]['attributes']['instanceType'] self.emr_prices = {} for sku in sku_to_instance_type.keys(): instance_type = sku_to_instance_type.get(sku) - price = float(emr_pricing['terms']['OnDemand'][sku].itervalues().next()['priceDimensions'] - .itervalues().next()['pricePerUnit']['USD']) + price = float( + emr_pricing['terms']['OnDemand'][sku].itervalues().next()[ + 'priceDimensions'] + .itervalues().next()['pricePerUnit']['USD']) self.emr_prices[instance_type] = price - ec2_regions_response = requests.get(url_base + index['offers']['AmazonEC2']['currentRegionIndexUrl']) - ec2_region_url = url_base + ec2_regions_response.json()['regions'][region]['currentVersionUrl'] + ec2_regions_response = requests.get( + url_base + index['offers']['AmazonEC2']['currentRegionIndexUrl']) + ec2_region_url = url_base + \ + ec2_regions_response.json()['regions'][region][ + 'currentVersionUrl'] ec2_pricing = requests.get(ec2_region_url).json() ec2_sku_to_instance_type = {} for sku in ec2_pricing['products']: try: - if (ec2_pricing['products'][sku]['attributes']['tenancy'] == 'Shared' and - ec2_pricing['products'][sku]['attributes']['operatingSystem'] == 'Linux'): - ec2_sku_to_instance_type[sku] = ec2_pricing['products'][sku]['attributes']['instanceType'] + if (ec2_pricing['products'][sku]['attributes'][ + 'tenancy'] == 'Shared' and + ec2_pricing['products'][sku]['attributes'][ + 'operatingSystem'] == 'Linux'): + ec2_sku_to_instance_type[sku] = \ + ec2_pricing['products'][sku]['attributes']['instanceType'] except KeyError: pass @@ -115,8 +127,10 @@ def __init__(self, region): self.ec2_prices = {} for sku in ec2_sku_to_instance_type.keys(): instance_type = ec2_sku_to_instance_type.get(sku) - price = float(ec2_pricing['terms']['OnDemand'][sku].itervalues().next()['priceDimensions'] - .itervalues().next()['pricePerUnit']['USD']) + price = float( + ec2_pricing['terms']['OnDemand'][sku].itervalues().next()[ + 'priceDimensions'] + .itervalues().next()['pricePerUnit']['USD']) self.ec2_prices[instance_type] = price def get_emr_price(self, instance_type): @@ -127,21 +141,11 @@ def get_ec2_price(self, instance_type): class EmrCostCalculator: - def __init__( - self, - region, - aws_access_key_id=None, - aws_secret_access_key=None): - + def __init__(self): + my_session = boto3.session.Session() + region = my_session.region_name try: - print >> sys.stderr, \ - '[INFO] Retrieving cost in region %s' % region - self.conn = \ - boto3.client('emr', - region_name=region, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key - ) + self.conn = boto3.client('emr', region_name=region) except Exception as e: print >> sys.stderr, \ '[ERROR] Could not establish connection with EMR API' @@ -149,13 +153,12 @@ def __init__( sys.exit() try: - self.spot_pricing = SpotPricing(region, aws_access_key_id, - aws_secret_access_key) + self.spot_pricing = SpotPricing() except: print >> sys.stderr, \ '[ERROR] Could not establish connection with EC2 API' - self.ec2_emr_pricing = Ec2EmrPricing(region) + self.ec2_emr_pricing = Ec2EmrPricing() def get_total_cost_by_dates(self, created_after, created_before): total_cost = 0 @@ -195,8 +198,10 @@ def get_cluster_cost(self, cluster_id): cost_dict.setdefault(group_type + ".EC2", 0) cost_dict[group_type + ".EC2"] += cost cost_dict.setdefault(group_type + ".EMR", 0) - hours_run = ((instance.termination_ts - instance.creation_ts).total_seconds() / 3600) - emr_cost = self.ec2_emr_pricing.get_emr_price(instance.instance_type) * hours_run + hours_run = (( + instance.termination_ts - instance.creation_ts).total_seconds() / 3600) + emr_cost = self.ec2_emr_pricing.get_emr_price( + instance.instance_type) * hours_run cost_dict[group_type + ".EMR"] += emr_cost cost_dict.setdefault('TOTAL', 0) cost_dict['TOTAL'] += cost + emr_cost @@ -206,17 +211,21 @@ def get_cluster_cost(self, cluster_id): def _get_instance_cost(self, instance, availability_zone): if instance.market_type == "SPOT": return self.spot_pricing.get_billed_price_for_period( - instance.instance_type, availability_zone, instance.creation_ts, instance.termination_ts) + instance.instance_type, availability_zone, instance.creation_ts, + instance.termination_ts) elif instance.market_type == "ON_DEMAND": - ec2_price = self.ec2_emr_pricing.get_ec2_price(instance.instance_type) - return ec2_price * ((instance.termination_ts - instance.creation_ts).total_seconds() / 3600) + ec2_price = self.ec2_emr_pricing.get_ec2_price( + instance.instance_type) + return ec2_price * (( + instance.termination_ts - instance.creation_ts).total_seconds() / 3600) def _get_cluster_list(self, created_after, created_before): """ :return: An iterator of cluster ids for the specified dates """ - kwargs = {'CreatedAfter': created_after, 'CreatedBefore': created_before} + kwargs = {'CreatedAfter': created_after, + 'CreatedBefore': created_before} while True: cluster_list = self.conn.list_clusters(**kwargs) for cluster in cluster_list['Clusters']: @@ -231,7 +240,8 @@ def _get_instance_groups(self, cluster_id): Invokes the EMR api and gets a list of the cluster's instance groups. :return: List of our custom InstanceGroup objects """ - groups = self.conn.list_instance_groups(ClusterId=cluster_id)['InstanceGroups'] + groups = self.conn.list_instance_groups(ClusterId=cluster_id)[ + 'InstanceGroups'] instance_groups = [] for group in groups: inst_group = InstanceGroup( @@ -253,7 +263,8 @@ def _get_instances(self, instance_group, cluster_id): :return: An iterator of our custom Ec2Instance objects. """ instance_list = [] - list_instances_args = {'ClusterId': cluster_id, 'InstanceGroupId': instance_group.group_id} + list_instances_args = {'ClusterId': cluster_id, + 'InstanceGroupId': instance_group.group_id} while True: batch = self.conn.list_instances(**list_instances_args) instance_list.extend(batch['Instances']) @@ -263,12 +274,15 @@ def _get_instances(self, instance_group, cluster_id): break for instance_info in instance_list: try: - creation_time = instance_info['Status']['Timeline']['CreationDateTime'] + creation_time = instance_info['Status']['Timeline'][ + 'CreationDateTime'] try: - end_date_time = instance_info['Status']['Timeline']['EndDateTime'] + end_date_time = instance_info['Status']['Timeline'][ + 'EndDateTime'] except KeyError: # use same TZ as one in creation time. By default datetime.now() is not TZ aware - end_date_time = datetime.datetime.now(tz=creation_time.tzinfo) + end_date_time = datetime.datetime.now( + tz=creation_time.tzinfo) inst = Ec2Instance( instance_info['Status']['Timeline']['CreationDateTime'], @@ -285,24 +299,26 @@ def _get_instances(self, instance_group, cluster_id): def _get_availability_zone(self, cluster_id): cluster_description = self.conn.describe_cluster(ClusterId=cluster_id) - return cluster_description['Cluster']['Ec2InstanceAttributes']['Ec2AvailabilityZone'] + return cluster_description['Cluster']['Ec2InstanceAttributes'][ + 'Ec2AvailabilityZone'] class SpotPricing: - def __init__(self, region, aws_access_key_id, aws_secret_access_key): + + def __init__(self): + my_session = boto3.session.Session() + region = my_session.region_name self.all_prices = {} - self.client_ec2 = boto3.client( - 'ec2', - region_name=region, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key) + self.client_ec2 = boto3.client('ec2', region_name=region) - def _populate_all_prices_if_needed(self, instance_id, availability_zone, start_time, end_time): + def _populate_all_prices_if_needed(self, instance_id, availability_zone, + start_time, end_time): previous_ts = None if (instance_id, availability_zone) in self.all_prices: prices = self.all_prices[(instance_id, availability_zone)] - if (end_time - sorted(prices.keys())[-1] < datetime.timedelta(days=1, hours=1) and + if (end_time - sorted(prices.keys())[-1] < datetime.timedelta( + days=1, hours=1) and sorted(prices.keys())[0] < start_time): # this means we already have requested dates. Nothing to do return @@ -322,11 +338,13 @@ def _populate_all_prices_if_needed(self, instance_id, availability_zone, start_t for price in prices_response['SpotPriceHistory']: if previous_ts is None: previous_ts = price['Timestamp'] - if previous_ts - price['Timestamp'] > datetime.timedelta(days=1, hours=1): + if previous_ts - price['Timestamp'] > datetime.timedelta(days=1, + hours=1): print >> sys.stderr, \ "[ERROR] Expecting maximum of 1 day 1 hour difference between spot price entries. Two dates " \ "causing problems: %s AND %s Diff is: %s" % ( - previous_ts, price['Timestamp'], previous_ts - price['Timestamp']) + previous_ts, price['Timestamp'], + previous_ts - price['Timestamp']) quit(-1) prices[price['Timestamp']] = float(price['SpotPrice']) previous_ts = price['Timestamp'] @@ -337,8 +355,10 @@ def _populate_all_prices_if_needed(self, instance_id, availability_zone, start_t self.all_prices[(instance_id, availability_zone)] = prices - def get_billed_price_for_period(self, instance_id, availability_zone, start_time, end_time): - self._populate_all_prices_if_needed(instance_id, availability_zone, start_time, end_time) + def get_billed_price_for_period(self, instance_id, availability_zone, + start_time, end_time): + self._populate_all_prices_if_needed(instance_id, availability_zone, + start_time, end_time) prices = self.all_prices[(instance_id, availability_zone)] @@ -348,35 +368,38 @@ def get_billed_price_for_period(self, instance_id, availability_zone, start_time summed_until_timestamp = start_time for key_id in range(0, len(sorted_price_timestamps)): price_timestamp = sorted_price_timestamps[key_id] - if key_id == len(sorted_price_timestamps) - 1 or end_time < sorted_price_timestamps[key_id + 1]: + if key_id == len(sorted_price_timestamps) - 1 or end_time < \ + sorted_price_timestamps[key_id + 1]: # this is the last price measurement we want: add final part of price segment and exit - seconds_passed = (end_time - summed_until_timestamp).total_seconds() - summed_price = summed_price + (float(seconds_passed) * prices[price_timestamp] / 3600.0) + seconds_passed = ( + end_time - summed_until_timestamp).total_seconds() + summed_price = summed_price + (float(seconds_passed) * prices[ + price_timestamp] / 3600.0) return summed_price - if sorted_price_timestamps[key_id] < summed_until_timestamp < sorted_price_timestamps[key_id + 1]: - seconds_passed = (sorted_price_timestamps[key_id + 1] - summed_until_timestamp).total_seconds() - summed_price = summed_price + (float(seconds_passed) * prices[price_timestamp] / 3600.0) + if sorted_price_timestamps[key_id] < summed_until_timestamp < \ + sorted_price_timestamps[key_id + 1]: + seconds_passed = (sorted_price_timestamps[ + key_id + 1] - summed_until_timestamp).total_seconds() + summed_price = summed_price + (float(seconds_passed) * prices[ + price_timestamp] / 3600.0) summed_until_timestamp = sorted_price_timestamps[key_id + 1] if __name__ == '__main__': args = docopt(__doc__) + profile = args.get('--profile') + if profile is not None: + boto3.setup_default_session(profile_name=profile) + if args.get('total'): created_after_arg = validate_date(args.get('--created_after')) created_before_arg = validate_date(args.get('--created_before')) - calc = EmrCostCalculator( - args.get('--region'), - args.get('--aws_access_key_id'), - args.get('--aws_secret_access_key') - ) - print "TOTAL COST: %.2f" % (calc.get_total_cost_by_dates(created_after_arg, created_before_arg)) + calc = EmrCostCalculator() + print "TOTAL COST: %.2f" % (calc.get_total_cost_by_dates( + created_after_arg, created_before_arg)) elif args.get('cluster'): - calc = EmrCostCalculator( - args.get('--region'), - args.get('--aws_access_key_id'), - args.get('--aws_secret_access_key') - ) + calc = EmrCostCalculator() calculated_prices = calc.get_cluster_cost(args.get('--cluster_id')) for key in sorted(calculated_prices.keys()): print "%12s: %6.2f" % (key, calculated_prices[key])