diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py index 8ce74d2b4eedc..265d4e56af517 100644 --- a/airflow/contrib/hooks/aws_hook.py +++ b/airflow/contrib/hooks/aws_hook.py @@ -17,7 +17,6 @@ # specific language governing permissions and limitations # under the License. - import boto3 import configparser import logging @@ -164,17 +163,17 @@ def _get_credentials(self, region_name): aws_session_token=aws_session_token, region_name=region_name), endpoint_url - def get_client_type(self, client_type, region_name=None): + def get_client_type(self, client_type, region_name=None, config=None): session, endpoint_url = self._get_credentials(region_name) return session.client(client_type, endpoint_url=endpoint_url, - verify=self.verify) + config=config, verify=self.verify) - def get_resource_type(self, resource_type, region_name=None): + def get_resource_type(self, resource_type, region_name=None, config=None): session, endpoint_url = self._get_credentials(region_name) return session.resource(resource_type, endpoint_url=endpoint_url, - verify=self.verify) + config=config, verify=self.verify) def get_session(self, region_name=None): """Get the underlying boto3.session.""" @@ -191,3 +190,16 @@ def get_credentials(self, region_name=None): # secret key separately can lead to a race condition. # See https://stackoverflow.com/a/36291428/8283373 return session.get_credentials().get_frozen_credentials() + + def expand_role(self, role): + """ + Expand an IAM role name to an IAM role ARN. If role is already an IAM ARN, + no change is made. + + :param role: IAM role name or ARN + :return: IAM role ARN + """ + if '/' in role: + return role + else: + return self.get_client_type('iam').get_role(RoleName=role)['Role']['Arn'] diff --git a/airflow/contrib/hooks/sagemaker_hook.py b/airflow/contrib/hooks/sagemaker_hook.py index bc096ff55fccd..823f430db994f 100644 --- a/airflow/contrib/hooks/sagemaker_hook.py +++ b/airflow/contrib/hooks/sagemaker_hook.py @@ -16,299 +16,746 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import copy +import tarfile +import tempfile import time +import os +import collections + +import botocore.config from botocore.exceptions import ClientError from airflow.exceptions import AirflowException from airflow.contrib.hooks.aws_hook import AwsHook from airflow.hooks.S3_hook import S3Hook +from airflow.utils import timezone + + +class LogState(object): + STARTING = 1 + WAIT_IN_PROGRESS = 2 + TAILING = 3 + JOB_COMPLETE = 4 + COMPLETE = 5 + + +# Position is a tuple that includes the last read timestamp and the number of items that were read +# at that time. This is used to figure out which event to start with on the next read. +Position = collections.namedtuple('Position', ['timestamp', 'skip']) + + +def argmin(arr, f): + """Return the index, i, in arr that minimizes f(arr[i])""" + m = None + i = None + for idx, item in enumerate(arr): + if item is not None: + if m is None or f(item) < m: + m = f(item) + i = idx + return i + + +def secondary_training_status_changed(current_job_description, prev_job_description): + """ + Returns true if training job's secondary status message has changed. + + :param current_job_description: Current job description, returned from DescribeTrainingJob call. + :type current_job_description: dict + :param prev_job_description: Previous job description, returned from DescribeTrainingJob call. + :type prev_job_description: dict + + :return: Whether the secondary status message of a training job changed or not. + """ + current_secondary_status_transitions = current_job_description.get('SecondaryStatusTransitions') + if current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0: + return False + + prev_job_secondary_status_transitions = prev_job_description.get('SecondaryStatusTransitions') \ + if prev_job_description is not None else None + + last_message = prev_job_secondary_status_transitions[-1]['StatusMessage'] \ + if prev_job_secondary_status_transitions is not None \ + and len(prev_job_secondary_status_transitions) > 0 else '' + + message = current_job_description['SecondaryStatusTransitions'][-1]['StatusMessage'] + + return message != last_message + + +def secondary_training_status_message(job_description, prev_description): + """ + Returns a string contains start time and the secondary training job status message. + + :param job_description: Returned response from DescribeTrainingJob call + :type job_description: dict + :param prev_description: Previous job description from DescribeTrainingJob call + :type prev_description: dict + + :return: Job status string to be printed. + """ + + if job_description is None or job_description.get('SecondaryStatusTransitions') is None\ + or len(job_description.get('SecondaryStatusTransitions')) == 0: + return '' + + prev_description_secondary_transitions = prev_description.get('SecondaryStatusTransitions')\ + if prev_description is not None else None + prev_transitions_num = len(prev_description['SecondaryStatusTransitions'])\ + if prev_description_secondary_transitions is not None else 0 + current_transitions = job_description['SecondaryStatusTransitions'] + + transitions_to_print = current_transitions[-1:] if len(current_transitions) == prev_transitions_num else \ + current_transitions[prev_transitions_num - len(current_transitions):] + + status_strs = [] + for transition in transitions_to_print: + message = transition['StatusMessage'] + time_str = timezone.convert_to_utc(job_description['LastModifiedTime']).strftime('%Y-%m-%d %H:%M:%S') + status_strs.append('{} {} - {}'.format(time_str, transition['Status'], message)) + + return '\n'.join(status_strs) class SageMakerHook(AwsHook): """ Interact with Amazon SageMaker. - sagemaker_conn_id is required for using - the config stored in db for training/tuning """ - non_terminal_states = {'InProgress', 'Stopping', 'Stopped'} + non_terminal_states = {'InProgress', 'Stopping'} + endpoint_non_terminal_states = {'Creating', 'Updating', 'SystemUpdating', + 'RollingBack', 'Deleting'} failed_states = {'Failed'} def __init__(self, - sagemaker_conn_id=None, - use_db_config=False, - region_name=None, - check_interval=5, - max_ingestion_time=None, *args, **kwargs): super(SageMakerHook, self).__init__(*args, **kwargs) - self.sagemaker_conn_id = sagemaker_conn_id - self.use_db_config = use_db_config - self.region_name = region_name - self.check_interval = check_interval - self.max_ingestion_time = max_ingestion_time - self.conn = self.get_conn() + self.s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) + + def tar_and_s3_upload(self, path, key, bucket): + """ + Tar the local file or directory and upload to s3 + + :param path: local file or directory + :type path: str + :param key: s3 key + :type key: str + :param bucket: s3 bucket + :type bucket: str + :return: None + """ + with tempfile.TemporaryFile() as temp_file: + if os.path.isdir(path): + files = [os.path.join(path, name) for name in os.listdir(path)] + else: + files = [path] + with tarfile.open(mode='w:gz', fileobj=temp_file) as tar_file: + for f in files: + tar_file.add(f, arcname=os.path.basename(f)) + temp_file.seek(0) + self.s3_hook.load_file_obj(temp_file, key, bucket, replace=True) + + def configure_s3_resources(self, config): + """ + Extract the S3 operations from the configuration and execute them. - def check_for_url(self, s3url): + :param config: config of SageMaker operation + :type config: dict + :return: dict + """ + s3_operations = config.pop('S3Operations', None) + + if s3_operations is not None: + create_bucket_ops = s3_operations.get('S3CreateBucket', []) + upload_ops = s3_operations.get('S3Upload', []) + for op in create_bucket_ops: + self.s3_hook.create_bucket(bucket_name=op['Bucket']) + for op in upload_ops: + if op['Tar']: + self.tar_and_s3_upload(op['Path'], op['Key'], + op['Bucket']) + else: + self.s3_hook.load_file(op['Path'], op['Key'], + op['Bucket']) + + def check_s3_url(self, s3url): """ - check if the s3url exists + Check if an S3 URL exists + :param s3url: S3 url :type s3url:str :return: bool """ bucket, key = S3Hook.parse_s3_url(s3url) - s3hook = S3Hook(aws_conn_id=self.aws_conn_id) - if not s3hook.check_for_bucket(bucket_name=bucket): + if not self.s3_hook.check_for_bucket(bucket_name=bucket): raise AirflowException( "The input S3 Bucket {} does not exist ".format(bucket)) - if key and not s3hook.check_for_key(key=key, bucket_name=bucket)\ - and not s3hook.check_for_prefix( + if key and not self.s3_hook.check_for_key(key=key, bucket_name=bucket)\ + and not self.s3_hook.check_for_prefix( prefix=key, bucket_name=bucket, delimiter='/'): # check if s3 key exists in the case user provides a single file - # or if s3 prefix exists in the case user provides a prefix for files + # or if s3 prefix exists in the case user provides multiple files in + # a prefix raise AirflowException("The input S3 Key " "or Prefix {} does not exist in the Bucket {}" .format(s3url, bucket)) return True - def check_valid_training_input(self, training_config): + def check_training_config(self, training_config): """ - Run checks before a training starts + Check if a training configuration is valid + :param training_config: training_config :type training_config: dict :return: None """ for channel in training_config['InputDataConfig']: - self.check_for_url(channel['DataSource'] - ['S3DataSource']['S3Uri']) + self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri']) - def check_valid_tuning_input(self, tuning_config): + def check_tuning_config(self, tuning_config): """ - Run checks before a tuning job starts + Check if a tuning configuration is valid + :param tuning_config: tuning_config :type tuning_config: dict :return: None """ for channel in tuning_config['TrainingJobDefinition']['InputDataConfig']: - self.check_for_url(channel['DataSource'] - ['S3DataSource']['S3Uri']) + self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri']) - def check_status(self, non_terminal_states, - failed_state, key, - describe_function, *args): - """ - :param non_terminal_states: the set of non_terminal states - :type non_terminal_states: set - :param failed_state: the set of failed states - :type failed_state: set - :param key: the key of the response dict - that points to the state - :type key: str - :param describe_function: the function used to retrieve the status - :type describe_function: python callable - :param args: the arguments for the function - :return: None + def get_conn(self): """ - sec = 0 - running = True - - while running: - - sec = sec + self.check_interval - - if self.max_ingestion_time and sec > self.max_ingestion_time: - # ensure that the job gets killed if the max ingestion time is exceeded - raise AirflowException("SageMaker job took more than " - "%s seconds", self.max_ingestion_time) + Establish an AWS connection for SageMaker - time.sleep(self.check_interval) - try: - response = describe_function(*args) - status = response[key] - self.log.info("Job still running for %s seconds... " - "current status is %s" % (sec, status)) - except KeyError: - raise AirflowException("Could not get status of the SageMaker job") - except ClientError: - raise AirflowException("AWS request failed, check log for more info") - - if status in non_terminal_states: - running = True - elif status in failed_state: - raise AirflowException("SageMaker job failed because %s" - % response['FailureReason']) - else: - running = False - - self.log.info('SageMaker Job Compeleted') + :return: a boto3 SageMaker client + """ + return self.get_client_type('sagemaker') - def get_conn(self): + def get_log_conn(self): """ - Establish an AWS connection - :return: a boto3 SageMaker client + Establish an AWS connection for retrieving logs during training + + :return: a boto3 CloudWatchLog client """ - return self.get_client_type('sagemaker', region_name=self.region_name) + config = botocore.config.Config(retries={'max_attempts': 15}) + return self.get_client_type('logs', config=config) - def list_training_job(self, name_contains=None, status_equals=None): + def log_stream(self, log_group, stream_name, start_time=0, skip=0): """ - List the training jobs associated with the given input - :param name_contains: A string in the training job name - :type name_contains: str - :param status_equals: 'InProgress'|'Completed' - |'Failed'|'Stopping'|'Stopped' - :return:dict + A generator for log items in a single stream. This will yield all the + items that are available at the current moment. + + :param log_group: The name of the log group. + :type log_group: str + :param stream_name: The name of the specific stream. + :type stream_name: str + :param start_time: The time stamp value to start reading the logs from (default: 0). + :type start_time: int + :param skip: The number of log entries to skip at the start (default: 0). + This is for when there are multiple entries at the same timestamp. + :type skip: int + :return:A CloudWatch log event with the following key-value pairs: + 'timestamp' (int): The time in milliseconds of the event. + 'message' (str): The log event data. + 'ingestionTime' (int): The time in milliseconds the event was ingested. """ - return self.conn.list_training_jobs( - NameContains=name_contains, StatusEquals=status_equals) - def list_tuning_job(self, name_contains=None, status_equals=None): + next_token = None + + event_count = 1 + while event_count > 0: + if next_token is not None: + token_arg = {'nextToken': next_token} + else: + token_arg = {} + + response = self.get_log_conn().get_log_events(logGroupName=log_group, + logStreamName=stream_name, + startTime=start_time, + startFromHead=True, + **token_arg) + next_token = response['nextForwardToken'] + events = response['events'] + event_count = len(events) + if event_count > skip: + events = events[skip:] + skip = 0 + else: + skip = skip - event_count + events = [] + for ev in events: + yield ev + + def multi_stream_iter(self, log_group, streams, positions=None): """ - List the tuning jobs associated with the given input - :param name_contains: A string in the training job name - :type name_contains: str - :param status_equals: 'InProgress'|'Completed' - |'Failed'|'Stopping'|'Stopped' - :return:dict + Iterate over the available events coming from a set of log streams in a single log group + interleaving the events from each stream so they're yielded in timestamp order. + + :param log_group: The name of the log group. + :type log_group: str + :param streams: A list of the log stream names. The position of the stream in this list is + the stream number. + :type streams: list + :param positions: A list of pairs of (timestamp, skip) which represents the last record + read from each stream. + :type positions: list + :return: A tuple of (stream number, cloudwatch log event). """ - return self.conn.list_hyper_parameter_tuning_job( - NameContains=name_contains, StatusEquals=status_equals) + positions = positions or {s: Position(timestamp=0, skip=0) for s in streams} + event_iters = [self.log_stream(log_group, s, positions[s].timestamp, positions[s].skip) + for s in streams] + events = [next(s) if s else None for s in event_iters] + + while any(events): + i = argmin(events, lambda x: x['timestamp'] if x else 9999999999) + yield (i, events[i]) + try: + events[i] = next(event_iters[i]) + except StopIteration: + events[i] = None - def create_training_job(self, training_job_config, wait_for_completion=True): + def create_training_job(self, config, wait_for_completion=True, print_log=True, + check_interval=30, max_ingestion_time=None): """ Create a training job - :param training_job_config: the config for training - :type training_job_config: dict + + :param config: the config for training + :type config: dict :param wait_for_completion: if the program should keep running until job finishes :type wait_for_completion: bool - :return: A dict that contains ARN of the training job. + :param check_interval: the time interval in seconds which the operator + will check the status of any SageMaker job + :type check_interval: int + :param max_ingestion_time: the maximum ingestion time in seconds. Any + SageMaker jobs that run longer than this will fail. Setting this to + None implies no timeout for any SageMaker job. + :type max_ingestion_time: int + :return: A response to training job creation """ - if self.use_db_config: - if not self.sagemaker_conn_id: - raise AirflowException("SageMaker connection id must be present to read \ - SageMaker training jobs configuration.") - sagemaker_conn = self.get_connection(self.sagemaker_conn_id) - config = copy.deepcopy(sagemaker_conn.extra_dejson) - training_job_config.update(config) + self.check_training_config(config) + + response = self.get_conn().create_training_job(**config) + if print_log: + self.check_training_status_with_log(config['TrainingJobName'], + self.non_terminal_states, + self.failed_states, + wait_for_completion, + check_interval, max_ingestion_time + ) + elif wait_for_completion: + describe_response = self.check_status(config['TrainingJobName'], + 'TrainingJobStatus', + self.describe_training_job, + check_interval, max_ingestion_time + ) + + billable_time = \ + (describe_response['TrainingEndTime'] - describe_response['TrainingStartTime']) * \ + describe_response['ResourceConfig']['InstanceCount'] + self.log.info('Billable seconds:{}'.format(int(billable_time.total_seconds()) + 1)) - self.check_valid_training_input(training_job_config) - - response = self.conn.create_training_job( - **training_job_config) - if wait_for_completion: - self.check_status(SageMakerHook.non_terminal_states, - SageMakerHook.failed_states, - 'TrainingJobStatus', - self.describe_training_job, - training_job_config['TrainingJobName']) return response - def create_tuning_job(self, tuning_job_config, wait_for_completion=True): + def create_tuning_job(self, config, wait_for_completion=True, + check_interval=30, max_ingestion_time=None): """ Create a tuning job - :param tuning_job_config: the config for tuning - :type tuning_job_config: dict + + :param config: the config for tuning + :type config: dict :param wait_for_completion: if the program should keep running until job finishes :param wait_for_completion: bool - :return: A dict that contains ARN of the tuning job. + :param check_interval: the time interval in seconds which the operator + will check the status of any SageMaker job + :type check_interval: int + :param max_ingestion_time: the maximum ingestion time in seconds. Any + SageMaker jobs that run longer than this will fail. Setting this to + None implies no timeout for any SageMaker job. + :type max_ingestion_time: int + :return: A response to tuning job creation """ - if self.use_db_config: - if not self.sagemaker_conn_id: - raise AirflowException( - "SageMaker connection id must be present to \ - read SageMaker tunning job configuration.") - sagemaker_conn = self.get_connection(self.sagemaker_conn_id) + self.check_tuning_config(config) - config = sagemaker_conn.extra_dejson.copy() - tuning_job_config.update(config) - - self.check_valid_tuning_input(tuning_job_config) - - response = self.conn.create_hyper_parameter_tuning_job( - **tuning_job_config) + response = self.get_conn().create_hyper_parameter_tuning_job(**config) if wait_for_completion: - self.check_status(SageMakerHook.non_terminal_states, - SageMakerHook.failed_states, + self.check_status(config['HyperParameterTuningJobName'], 'HyperParameterTuningJobStatus', self.describe_tuning_job, - tuning_job_config['HyperParameterTuningJobName']) + check_interval, max_ingestion_time + ) return response - def create_transform_job(self, transform_job_config, wait_for_completion=True): + def create_transform_job(self, config, wait_for_completion=True, + check_interval=30, max_ingestion_time=None): """ Create a transform job - :param transform_job_config: the config for transform job - :type transform_job_config: dict - :param wait_for_completion: - if the program should keep running until job finishes + + :param config: the config for transform job + :type config: dict + :param wait_for_completion: if the program should keep running until job finishes :type wait_for_completion: bool - :return: A dict that contains ARN of the transform job. + :param check_interval: the time interval in seconds which the operator + will check the status of any SageMaker job + :type check_interval: int + :param max_ingestion_time: the maximum ingestion time in seconds. Any + SageMaker jobs that run longer than this will fail. Setting this to + None implies no timeout for any SageMaker job. + :type max_ingestion_time: int + :return: A response to transform job creation """ - if self.use_db_config: - if not self.sagemaker_conn_id: - raise AirflowException( - "SageMaker connection id must be present to \ - read SageMaker transform job configuration.") - - sagemaker_conn = self.get_connection(self.sagemaker_conn_id) - - config = sagemaker_conn.extra_dejson.copy() - transform_job_config.update(config) - self.check_for_url(transform_job_config - ['TransformInput']['DataSource'] - ['S3DataSource']['S3Uri']) + self.check_s3_url(config['TransformInput']['DataSource']['S3DataSource']['S3Uri']) - response = self.conn.create_transform_job( - **transform_job_config) + response = self.get_conn().create_transform_job(**config) if wait_for_completion: - self.check_status(SageMakerHook.non_terminal_states, - SageMakerHook.failed_states, + self.check_status(config['TransformJobName'], 'TransformJobStatus', self.describe_transform_job, - transform_job_config['TransformJobName']) + check_interval, max_ingestion_time + ) return response - def create_model(self, model_config): + def create_model(self, config): """ Create a model job - :param model_config: the config for model - :type model_config: dict - :return: A dict that contains ARN of the model. + + :param config: the config for model + :type config: dict + :return: A response to model creation """ - return self.conn.create_model( - **model_config) + return self.get_conn().create_model(**config) - def describe_training_job(self, training_job_name): + def create_endpoint_config(self, config): """ - :param training_job_name: the name of the training job - :type training_job_name: str - Return the training job info associated with the current job_name + Create an endpoint config + + :param config: the config for endpoint-config + :type config: dict + :return: A response to endpoint config creation + """ + + return self.get_conn().create_endpoint_config(**config) + + def create_endpoint(self, config, wait_for_completion=True, + check_interval=30, max_ingestion_time=None): + """ + Create an endpoint + + :param config: the config for endpoint + :type config: dict + :param wait_for_completion: if the program should keep running until job finishes + :type wait_for_completion: bool + :param check_interval: the time interval in seconds which the operator + will check the status of any SageMaker job + :type check_interval: int + :param max_ingestion_time: the maximum ingestion time in seconds. Any + SageMaker jobs that run longer than this will fail. Setting this to + None implies no timeout for any SageMaker job. + :type max_ingestion_time: int + :return: A response to endpoint creation + """ + + response = self.get_conn().create_endpoint(**config) + if wait_for_completion: + self.check_status(config['EndpointName'], + 'EndpointStatus', + self.describe_endpoint, + check_interval, max_ingestion_time, + non_terminal_states=self.endpoint_non_terminal_states + ) + return response + + def update_endpoint(self, config, wait_for_completion=True, + check_interval=30, max_ingestion_time=None): + """ + Update an endpoint + + :param config: the config for endpoint + :type config: dict + :param wait_for_completion: if the program should keep running until job finishes + :type wait_for_completion: bool + :param check_interval: the time interval in seconds which the operator + will check the status of any SageMaker job + :type check_interval: int + :param max_ingestion_time: the maximum ingestion time in seconds. Any + SageMaker jobs that run longer than this will fail. Setting this to + None implies no timeout for any SageMaker job. + :type max_ingestion_time: int + :return: A response to endpoint update + """ + + response = self.get_conn().update_endpoint(**config) + if wait_for_completion: + self.check_status(config['EndpointName'], + 'EndpointStatus', + self.describe_endpoint, + check_interval, max_ingestion_time, + non_terminal_states=self.endpoint_non_terminal_states + ) + return response + + def describe_training_job(self, name): + """ + Return the training job info associated with the name + + :param name: the name of the training job + :type name: str :return: A dict contains all the training job info """ - return self.conn\ - .describe_training_job(TrainingJobName=training_job_name) - def describe_tuning_job(self, tuning_job_name): + return self.get_conn().describe_training_job(TrainingJobName=name) + + def describe_training_job_with_log(self, job_name, positions, stream_names, + instance_count, state, last_description, + last_describe_job_call): + """ + Return the training job info associated with job_name and print CloudWatch logs + """ + log_group = '/aws/sagemaker/TrainingJobs' + + if len(stream_names) < instance_count: + # Log streams are created whenever a container starts writing to stdout/err, so this list + # may be dynamic until we have a stream for every instance. + logs_conn = self.get_log_conn() + try: + streams = logs_conn.describe_log_streams( + logGroupName=log_group, + logStreamNamePrefix=job_name + '/', + orderBy='LogStreamName', + limit=instance_count + ) + stream_names = [s['logStreamName'] for s in streams['logStreams']] + positions.update([(s, Position(timestamp=0, skip=0)) + for s in stream_names if s not in positions]) + except logs_conn.exceptions.ResourceNotFoundException: + # On the very first training job run on an account, there's no log group until + # the container starts logging, so ignore any errors thrown about that + pass + + if len(stream_names) > 0: + for idx, event in self.multi_stream_iter(log_group, stream_names, positions): + self.log.info(event['message']) + ts, count = positions[stream_names[idx]] + if event['timestamp'] == ts: + positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1) + else: + positions[stream_names[idx]] = Position(timestamp=event['timestamp'], skip=1) + + if state == LogState.COMPLETE: + return state, last_description, last_describe_job_call + + if state == LogState.JOB_COMPLETE: + state = LogState.COMPLETE + elif time.time() - last_describe_job_call >= 30: + description = self.describe_training_job(job_name) + last_describe_job_call = time.time() + + if secondary_training_status_changed(description, last_description): + self.log.info(secondary_training_status_message(description, last_description)) + last_description = description + + status = description['TrainingJobStatus'] + + if status not in self.non_terminal_states: + state = LogState.JOB_COMPLETE + return state, last_description, last_describe_job_call + + def describe_tuning_job(self, name): """ - :param tuning_job_name: the name of the tuning job - :type tuning_job_name: string - Return the tuning job info associated with the current job_name + Return the tuning job info associated with the name + + :param name: the name of the tuning job + :type name: string :return: A dict contains all the tuning job info """ - return self.conn\ - .describe_hyper_parameter_tuning_job( - HyperParameterTuningJobName=tuning_job_name) - def describe_transform_job(self, transform_job_name): + return self.get_conn().describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=name) + + def describe_model(self, name): + """ + Return the SageMaker model info associated with the name + + :param name: the name of the SageMaker model + :type name: string + :return: A dict contains all the model info + """ + + return self.get_conn().describe_model(ModelName=name) + + def describe_transform_job(self, name): """ - :param transform_job_name: the name of the transform job - :type transform_job_name: string - Return the transform job info associated with the current job_name + Return the transform job info associated with the name + + :param name: the name of the transform job + :type name: string :return: A dict contains all the transform job info """ - return self.conn\ - .describe_transform_job( - TransformJobName=transform_job_name) + + return self.get_conn().describe_transform_job(TransformJobName=name) + + def describe_endpoint_config(self, name): + """ + Return the endpoint config info associated with the name + + :param name: the name of the endpoint config + :type name: string + :return: A dict contains all the endpoint config info + """ + + return self.get_conn().describe_endpoint_config(EndpointConfigName=name) + + def describe_endpoint(self, name): + """ + :param name: the name of the endpoint + :type name: string + :return: A dict contains all the endpoint info + """ + + return self.get_conn().describe_endpoint(EndpointName=name) + + def check_status(self, job_name, key, + describe_function, check_interval, + max_ingestion_time, + non_terminal_states=None): + """ + Check status of a SageMaker job + + :param job_name: name of the job to check status + :type job_name: str + :param key: the key of the response dict + that points to the state + :type key: str + :param describe_function: the function used to retrieve the status + :type describe_function: python callable + :param args: the arguments for the function + :param check_interval: the time interval in seconds which the operator + will check the status of any SageMaker job + :type check_interval: int + :param max_ingestion_time: the maximum ingestion time in seconds. Any + SageMaker jobs that run longer than this will fail. Setting this to + None implies no timeout for any SageMaker job. + :type max_ingestion_time: int + :param non_terminal_states: the set of nonterminal states + :type non_terminal_states: set + :return: response of describe call after job is done + """ + if not non_terminal_states: + non_terminal_states = self.non_terminal_states + + sec = 0 + running = True + + while running: + time.sleep(check_interval) + sec = sec + check_interval + + try: + response = describe_function(job_name) + status = response[key] + self.log.info('Job still running for %s seconds... ' + 'current status is %s' % (sec, status)) + except KeyError: + raise AirflowException('Could not get status of the SageMaker job') + except ClientError: + raise AirflowException('AWS request failed, check logs for more info') + + if status in non_terminal_states: + running = True + elif status in self.failed_states: + raise AirflowException('SageMaker job failed because %s' % response['FailureReason']) + else: + running = False + + if max_ingestion_time and sec > max_ingestion_time: + # ensure that the job gets killed if the max ingestion time is exceeded + raise AirflowException('SageMaker job took more than %s seconds', max_ingestion_time) + + self.log.info('SageMaker Job Compeleted') + response = describe_function(job_name) + return response + + def check_training_status_with_log(self, job_name, non_terminal_states, failed_states, + wait_for_completion, check_interval, max_ingestion_time): + """ + Display the logs for a given training job, optionally tailing them until the + job is complete. + + :param job_name: name of the training job to check status and display logs for + :type job_name: str + :param non_terminal_states: the set of non_terminal states + :type non_terminal_states: set + :param failed_states: the set of failed states + :type failed_states: set + :param wait_for_completion: Whether to keep looking for new log entries + until the job completes + :type wait_for_completion: bool + :param check_interval: The interval in seconds between polling for new log entries and job completion + :type check_interval: int + :param max_ingestion_time: the maximum ingestion time in seconds. Any + SageMaker jobs that run longer than this will fail. Setting this to + None implies no timeout for any SageMaker job. + :type max_ingestion_time: int + :return: None + """ + + sec = 0 + description = self.describe_training_job(job_name) + self.log.info(secondary_training_status_message(description, None)) + instance_count = description['ResourceConfig']['InstanceCount'] + status = description['TrainingJobStatus'] + + stream_names = [] # The list of log streams + positions = {} # The current position in each stream, map of stream name -> position + + job_already_completed = status not in non_terminal_states + + state = LogState.TAILING if wait_for_completion and not job_already_completed else LogState.COMPLETE + + # The loop below implements a state machine that alternates between checking the job status and + # reading whatever is available in the logs at this point. Note, that if we were called with + # wait_for_completion == False, we never check the job status. + # + # If wait_for_completion == TRUE and job is not completed, the initial state is TAILING + # If wait_for_completion == FALSE, the initial state is COMPLETE + # (doesn't matter if the job really is complete). + # + # The state table: + # + # STATE ACTIONS CONDITION NEW STATE + # ---------------- ---------------- ----------------- ---------------- + # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE + # Else TAILING + # JOB_COMPLETE Read logs, Pause Any COMPLETE + # COMPLETE Read logs, Exit N/A + # + # Notes: + # - The JOB_COMPLETE state forces us to do an extra pause and read any items that + # got to Cloudwatch after the job was marked complete. + last_describe_job_call = time.time() + last_description = description + + while True: + time.sleep(check_interval) + sec = sec + check_interval + + state, last_description, last_describe_job_call = \ + self.describe_training_job_with_log(job_name, positions, stream_names, + instance_count, state, last_description, + last_describe_job_call) + if state == LogState.COMPLETE: + break + + if max_ingestion_time and sec > max_ingestion_time: + # ensure that the job gets killed if the max ingestion time is exceeded + raise AirflowException('SageMaker job took more than %s seconds', max_ingestion_time) + + if wait_for_completion: + status = last_description['TrainingJobStatus'] + if status in failed_states: + reason = last_description.get('FailureReason', '(No reason provided)') + raise AirflowException('Error training {}: {} Reason: {}'.format(job_name, status, reason)) + billable_time = (last_description['TrainingEndTime'] - last_description['TrainingStartTime']) \ + * instance_count + self.log.info('Billable seconds:{}'.format(int(billable_time.total_seconds()) + 1)) diff --git a/airflow/contrib/operators/sagemaker_base_operator.py b/airflow/contrib/operators/sagemaker_base_operator.py new file mode 100644 index 0000000000000..cf1e59387a784 --- /dev/null +++ b/airflow/contrib/operators/sagemaker_base_operator.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json + +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +class SageMakerBaseOperator(BaseOperator): + """ + This is the base operator for all SageMaker operators. + + :param config: The configuration necessary to start a training job (templated) + :type config: dict + :param aws_conn_id: The AWS connection ID to use. + :type aws_conn_id: str + """ + + template_fields = ['config'] + template_ext = () + ui_color = '#ededed' + + integer_fields = [] + + @apply_defaults + def __init__(self, + config, + aws_conn_id='aws_default', + *args, **kwargs): + super(SageMakerBaseOperator, self).__init__(*args, **kwargs) + + self.aws_conn_id = aws_conn_id + self.config = config + self.hook = None + + def parse_integer(self, config, field): + if len(field) == 1: + if isinstance(config, list): + for sub_config in config: + self.parse_integer(sub_config, field) + return + head = field[0] + if head in config: + config[head] = int(config[head]) + return + + if isinstance(config, list): + for sub_config in config: + self.parse_integer(sub_config, field) + return + + head, tail = field[0], field[1:] + if head in config: + self.parse_integer(config[head], tail) + return + + def parse_config_integers(self): + # Parse the integer fields of training config to integers + # in case the config is rendered by Jinja and all fields are str + for field in self.integer_fields: + self.parse_integer(self.config, field) + + def expand_role(self): + raise NotImplementedError('Please implement expand_role() in sub class!') + + def preprocess_config(self): + self.log.info( + 'Preprocessing the config and doing required s3_operations' + ) + self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) + + self.hook.configure_s3_resources(self.config) + self.parse_config_integers() + self.expand_role() + + self.log.info( + 'After preprocessing the config is:\n {}'.format( + json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': '))) + ) + + def execute(self, context): + raise NotImplementedError('Please implement execute() in sub class!') diff --git a/airflow/contrib/operators/sagemaker_create_training_job_operator.py b/airflow/contrib/operators/sagemaker_create_training_job_operator.py deleted file mode 100644 index 279220867956d..0000000000000 --- a/airflow/contrib/operators/sagemaker_create_training_job_operator.py +++ /dev/null @@ -1,119 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from airflow.contrib.hooks.sagemaker_hook import SageMakerHook -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults -from airflow.exceptions import AirflowException - - -class SageMakerCreateTrainingJobOperator(BaseOperator): - - """ - Initiate a SageMaker training - - This operator returns The ARN of the model created in Amazon SageMaker - - :param training_job_config: - The configuration necessary to start a training job (templated) - :type training_job_config: dict - :param region_name: The AWS region_name - :type region_name: str - :param sagemaker_conn_id: The SageMaker connection ID to use. - :type sagemaker_conn_id: str - :param use_db_config: Whether or not to use db config - associated with sagemaker_conn_id. - If set to true, will automatically update the training config - with what's in db, so the db config doesn't need to - included everything, but what's there does replace the ones - in the training_job_config, so be careful - :type use_db_config: bool - :param aws_conn_id: The AWS connection ID to use. - :type aws_conn_id: str - :param wait_for_completion: if the operator should block - until training job finishes - :type wait_for_completion: bool - :param check_interval: if wait is set to be true, this is the time interval - in seconds which the operator will check the status of the training job - :type check_interval: int - :param max_ingestion_time: if wait is set to be true, the operator will fail - if the training job hasn't finish within the max_ingestion_time - (Caution: be careful to set this parameters because training can take very long) - :type max_ingestion_time: int - - **Example**: - The following operator would start a training job when executed - - sagemaker_training = - SageMakerCreateTrainingJobOperator( - task_id='sagemaker_training', - training_job_config=config, - region_name='us-west-2' - sagemaker_conn_id='sagemaker_customers_conn', - use_db_config=True, - aws_conn_id='aws_customers_conn' - ) - """ - - template_fields = ['training_job_config'] - template_ext = () - ui_color = '#ededed' - - @apply_defaults - def __init__(self, - training_job_config=None, - region_name=None, - sagemaker_conn_id=None, - use_db_config=False, - wait_for_completion=True, - check_interval=5, - max_ingestion_time=None, - *args, **kwargs): - super(SageMakerCreateTrainingJobOperator, self).__init__(*args, **kwargs) - - self.sagemaker_conn_id = sagemaker_conn_id - self.training_job_config = training_job_config - self.use_db_config = use_db_config - self.region_name = region_name - self.wait_for_completion = wait_for_completion - self.check_interval = check_interval - self.max_ingestion_time = max_ingestion_time - - def execute(self, context): - sagemaker = SageMakerHook( - sagemaker_conn_id=self.sagemaker_conn_id, - use_db_config=self.use_db_config, - region_name=self.region_name, - check_interval=self.check_interval, - max_ingestion_time=self.max_ingestion_time - ) - - self.log.info( - "Creating SageMaker Training Job %s." - % self.training_job_config['TrainingJobName'] - ) - response = sagemaker.create_training_job( - self.training_job_config, - wait_for_completion=self.wait_for_completion) - if not response['ResponseMetadata']['HTTPStatusCode'] \ - == 200: - raise AirflowException( - 'Sagemaker Training Job creation failed: %s' % response) - else: - return response diff --git a/airflow/contrib/operators/sagemaker_create_transform_job_operator.py b/airflow/contrib/operators/sagemaker_create_transform_job_operator.py deleted file mode 100644 index 22c8c2b4ba297..0000000000000 --- a/airflow/contrib/operators/sagemaker_create_transform_job_operator.py +++ /dev/null @@ -1,132 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from airflow.contrib.hooks.sagemaker_hook import SageMakerHook -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults -from airflow.exceptions import AirflowException - - -class SageMakerCreateTransformJobOperator(BaseOperator): - """ - Initiate a SageMaker transform - - This operator returns The ARN of the model created in Amazon SageMaker - - :param sagemaker_conn_id: The SageMaker connection ID to use. - :type sagemaker_conn_id: string - :param transform_job_config: - The configuration necessary to start a transform job (templated) - :type transform_job_config: dict - :param model_config: - The configuration necessary to create a model, the default is none - which means that user should provide a created model in transform_job_config - If given, will be used to create a model before creating transform job - :type model_config: dict - :param use_db_config: Whether or not to use db config - associated with sagemaker_conn_id. - If set to true, will automatically update the transform config - with what's in db, so the db config doesn't need to - included everything, but what's there does replace the ones - in the transform_job_config, so be careful - :type use_db_config: bool - :param region_name: The AWS region_name - :type region_name: string - :param wait_for_completion: if the program should keep running until job finishes - :type wait_for_completion: bool - :param check_interval: if wait is set to be true, this is the time interval - in seconds which the operator will check the status of the transform job - :type check_interval: int - :param max_ingestion_time: if wait is set to be true, the operator will fail - if the transform job hasn't finish within the max_ingestion_time - (Caution: be careful to set this parameters because transform can take very long) - :type max_ingestion_time: int - :param aws_conn_id: The AWS connection ID to use. - :type aws_conn_id: string - - **Example**: - The following operator would start a transform job when executed - - sagemaker_transform = - SageMakerCreateTransformJobOperator( - task_id='sagemaker_transform', - transform_job_config=config_transform, - model_config=config_model, - region_name='us-west-2' - sagemaker_conn_id='sagemaker_customers_conn', - use_db_config=True, - aws_conn_id='aws_customers_conn' - ) - """ - - template_fields = ['transform_job_config'] - template_ext = () - ui_color = '#ededed' - - @apply_defaults - def __init__(self, - sagemaker_conn_id=None, - transform_job_config=None, - model_config=None, - use_db_config=False, - region_name=None, - wait_for_completion=True, - check_interval=2, - max_ingestion_time=None, - *args, **kwargs): - super(SageMakerCreateTransformJobOperator, self).__init__(*args, **kwargs) - - self.sagemaker_conn_id = sagemaker_conn_id - self.transform_job_config = transform_job_config - self.model_config = model_config - self.use_db_config = use_db_config - self.region_name = region_name - self.wait_for_completion = wait_for_completion - self.check_interval = check_interval - self.max_ingestion_time = max_ingestion_time - - def execute(self, context): - sagemaker = SageMakerHook( - sagemaker_conn_id=self.sagemaker_conn_id, - use_db_config=self.use_db_config, - region_name=self.region_name, - check_interval=self.check_interval, - max_ingestion_time=self.max_ingestion_time - ) - - if self.model_config: - self.log.info( - "Creating SageMaker Model %s for transform job" - % self.model_config['ModelName'] - ) - sagemaker.create_model(self.model_config) - - self.log.info( - "Creating SageMaker transform Job %s." - % self.transform_job_config['TransformJobName'] - ) - response = sagemaker.create_transform_job( - self.transform_job_config, - wait_for_completion=self.wait_for_completion) - if not response['ResponseMetadata']['HTTPStatusCode'] \ - == 200: - raise AirflowException( - 'Sagemaker transform Job creation failed: %s' % response) - else: - return response diff --git a/airflow/contrib/operators/sagemaker_create_tuning_job_operator.py b/airflow/contrib/operators/sagemaker_create_tuning_job_operator.py deleted file mode 100644 index d5f4396375993..0000000000000 --- a/airflow/contrib/operators/sagemaker_create_tuning_job_operator.py +++ /dev/null @@ -1,121 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from airflow.contrib.hooks.sagemaker_hook import SageMakerHook -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults -from airflow.exceptions import AirflowException - - -class SageMakerCreateTuningJobOperator(BaseOperator): - - """ - Initiate a SageMaker HyperParameter Tuning Job - - This operator returns The ARN of the model created in Amazon SageMaker - - :param sagemaker_conn_id: The SageMaker connection ID to use. - :type sagemaker_conn_id: str - :param region_name: The AWS region_name - :type region_name: str - :param tuning_job_config: - The configuration necessary to start a tuning job (templated) - :type tuning_job_config: dict - :param use_db_config: Whether or not to use db config - associated with sagemaker_conn_id. - If set to true, will automatically update the tuning config - with what's in db, so the db config doesn't need to - included everything, but what's there does replace the ones - in the tuning_job_config, so be careful - :type use_db_config: bool - :param wait_for_completion: if the operator should block - until tuning job finishes - :type wait_for_completion: bool - :param check_interval: if wait is set to be true, this is the time interval - in seconds which the operator will check the status of the tuning job - :type check_interval: int - :param max_ingestion_time: if wait is set to be true, the operator will fail - if the tuning job hasn't finish within the max_ingestion_time - (Caution: be careful to set this parameters because tuning can take very long) - :type max_ingestion_time: int - :param aws_conn_id: The AWS connection ID to use. - :type aws_conn_id: str - - **Example**: - The following operator would start a tuning job when executed - - sagemaker_tuning = - SageMakerCreateTuningJobOperator( - task_id='sagemaker_tuning', - sagemaker_conn_id='sagemaker_customers_conn', - tuning_job_config=config, - check_interval=2, - max_ingestion_time=3600, - aws_conn_id='aws_customers_conn', - ) - """ - - template_fields = ['tuning_job_config'] - template_ext = () - ui_color = '#ededed' - - @apply_defaults - def __init__(self, - sagemaker_conn_id=None, - region_name=None, - tuning_job_config=None, - use_db_config=False, - wait_for_completion=True, - check_interval=5, - max_ingestion_time=None, - *args, **kwargs): - super(SageMakerCreateTuningJobOperator, self)\ - .__init__(*args, **kwargs) - - self.sagemaker_conn_id = sagemaker_conn_id - self.region_name = region_name - self.tuning_job_config = tuning_job_config - self.use_db_config = use_db_config - self.wait_for_completion = wait_for_completion - self.check_interval = check_interval - self.max_ingestion_time = max_ingestion_time - - def execute(self, context): - sagemaker = SageMakerHook(sagemaker_conn_id=self.sagemaker_conn_id, - region_name=self.region_name, - use_db_config=self.use_db_config, - check_interval=self.check_interval, - max_ingestion_time=self.max_ingestion_time - ) - - self.log.info( - "Creating SageMaker Hyper Parameter Tunning Job %s" - % self.tuning_job_config['HyperParameterTuningJobName'] - ) - - response = sagemaker.create_tuning_job( - self.tuning_job_config, - wait_for_completion=self.wait_for_completion - ) - if not response['ResponseMetadata']['HTTPStatusCode'] \ - == 200: - raise AirflowException( - "Sagemaker Tuning Job creation failed: %s" % response) - else: - return response diff --git a/airflow/contrib/operators/sagemaker_training_operator.py b/airflow/contrib/operators/sagemaker_training_operator.py new file mode 100644 index 0000000000000..69036925f34f7 --- /dev/null +++ b/airflow/contrib/operators/sagemaker_training_operator.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.contrib.hooks.aws_hook import AwsHook +from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.exceptions import AirflowException + + +class SageMakerTrainingOperator(SageMakerBaseOperator): + """ + Initiate a SageMaker training job. + + This operator returns The ARN of the training job created in Amazon SageMaker. + + :param config: The configuration necessary to start a training job (templated) + :type config: dict + :param aws_conn_id: The AWS connection ID to use. + :type aws_conn_id: str + :param wait_for_completion: if the operator should block until training job finishes + :type wait_for_completion: bool + :param print_log: if the operator should print the cloudwatch log during training + :type print_log: bool + :param check_interval: if wait is set to be true, this is the time interval + in seconds which the operator will check the status of the training job + :type check_interval: int + :param max_ingestion_time: if wait is set to be true, the operator will fail + if the training job hasn't finish within the max_ingestion_time in seconds + (Caution: be careful to set this parameters because training can take very long) + Setting it to None implies no timeout. + :type max_ingestion_time: int + """ + + integer_fields = [ + ['ResourceConfig', 'InstanceCount'], + ['ResourceConfig', 'VolumeSizeInGB'], + ['StoppingCondition', 'MaxRuntimeInSeconds'] + ] + + @apply_defaults + def __init__(self, + config, + wait_for_completion=True, + print_log=True, + check_interval=30, + max_ingestion_time=None, + *args, **kwargs): + super(SageMakerTrainingOperator, self).__init__(config=config, + *args, **kwargs) + + self.wait_for_completion = wait_for_completion + self.print_log = print_log + self.check_interval = check_interval + self.max_ingestion_time = max_ingestion_time + + def expand_role(self): + if 'RoleArn' in self.config: + hook = AwsHook(self.aws_conn_id) + self.config['RoleArn'] = hook.expand_role(self.config['RoleArn']) + + def execute(self, context): + self.preprocess_config() + + self.log.info('Creating SageMaker Training Job %s.', self.config['TrainingJobName']) + + response = self.hook.create_training_job( + self.config, + wait_for_completion=self.wait_for_completion, + print_log=self.print_log, + check_interval=self.check_interval, + max_ingestion_time=self.max_ingestion_time + ) + if response['ResponseMetadata']['HTTPStatusCode'] != 200: + raise AirflowException( + 'Sagemaker Training Job creation failed: %s' % response) + else: + return { + 'Training': self.hook.describe_training_job( + self.config['TrainingJobName'] + ) + } diff --git a/airflow/contrib/operators/sagemaker_transform_operator.py b/airflow/contrib/operators/sagemaker_transform_operator.py new file mode 100644 index 0000000000000..7be570cdacd6f --- /dev/null +++ b/airflow/contrib/operators/sagemaker_transform_operator.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.contrib.hooks.aws_hook import AwsHook +from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.exceptions import AirflowException + + +class SageMakerTransformOperator(SageMakerBaseOperator): + """ + Initiate a SageMaker transform job. + + This operator returns The ARN of the model created in Amazon SageMaker. + + :param config: The configuration necessary to start a transform job (templated) + :type config: dict + :param model_config: + The configuration necessary to create a SageMaker model, the default is none + which means the SageMaker model used for the SageMaker transform job already exists. + If given, it will be used to create a SageMaker model before creating + the SageMaker transform job + :type model_config: dict + :param aws_conn_id: The AWS connection ID to use. + :type aws_conn_id: string + :param wait_for_completion: if the program should keep running until job finishes + :type wait_for_completion: bool + :param check_interval: if wait is set to be true, this is the time interval + in seconds which the operator will check the status of the transform job + :type check_interval: int + :param max_ingestion_time: if wait is set to be true, the operator will fail + if the transform job hasn't finish within the max_ingestion_time in seconds + (Caution: be careful to set this parameters because transform can take very long) + :type max_ingestion_time: int + """ + + @apply_defaults + def __init__(self, + config, + wait_for_completion=True, + check_interval=30, + max_ingestion_time=None, + *args, **kwargs): + super(SageMakerTransformOperator, self).__init__(config=config, + *args, **kwargs) + self.config = config + self.wait_for_completion = wait_for_completion + self.check_interval = check_interval + self.max_ingestion_time = max_ingestion_time + self.create_integer_fields() + + def create_integer_fields(self): + self.integer_fields = [ + ['Transform', 'TransformResources', 'InstanceCount'], + ['Transform', 'MaxConcurrentTransforms'], + ['Transform', 'MaxPayloadInMB'] + ] + if 'Transform' not in self.config: + for field in self.integer_fields: + field.pop(0) + + def expand_role(self): + if 'Model' not in self.config: + return + config = self.config['Model'] + if 'ExecutionRoleArn' in config: + hook = AwsHook(self.aws_conn_id) + config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn']) + + def execute(self, context): + self.preprocess_config() + + model_config = self.config.get('Model') + transform_config = self.config.get('Transform', self.config) + + if model_config: + self.log.info('Creating SageMaker Model %s for transform job', model_config['ModelName']) + self.hook.create_model(model_config) + + self.log.info('Creating SageMaker transform Job %s.', transform_config['TransformJobName']) + response = self.hook.create_transform_job( + transform_config, + wait_for_completion=self.wait_for_completion, + check_interval=self.check_interval, + max_ingestion_time=self.max_ingestion_time) + if response['ResponseMetadata']['HTTPStatusCode'] != 200: + raise AirflowException('Sagemaker transform Job creation failed: %s' % response) + else: + return { + 'Model': self.hook.describe_model( + transform_config['ModelName'] + ), + 'Transform': self.hook.describe_transform_job( + transform_config['TransformJobName'] + ) + } diff --git a/airflow/contrib/operators/sagemaker_tuning_operator.py b/airflow/contrib/operators/sagemaker_tuning_operator.py new file mode 100644 index 0000000000000..94c995072a8da --- /dev/null +++ b/airflow/contrib/operators/sagemaker_tuning_operator.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.contrib.hooks.aws_hook import AwsHook +from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.exceptions import AirflowException + + +class SageMakerTuningOperator(SageMakerBaseOperator): + """ + Initiate a SageMaker hyper-parameter tuning job. + + This operator returns The ARN of the tuning job created in Amazon SageMaker. + + :param config: The configuration necessary to start a tuning job (templated) + :type config: dict + :param aws_conn_id: The AWS connection ID to use. + :type aws_conn_id: str + :param wait_for_completion: if the operator should block until tuning job finishes + :type wait_for_completion: bool + :param check_interval: if wait is set to be true, this is the time interval + in seconds which the operator will check the status of the tuning job + :type check_interval: int + :param max_ingestion_time: if wait is set to be true, the operator will fail + if the tuning job hasn't finish within the max_ingestion_time in seconds + (Caution: be careful to set this parameters because tuning can take very long) + :type max_ingestion_time: int + """ + + integer_fields = [ + ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxNumberOfTrainingJobs'], + ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'], + ['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'], + ['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'], + ['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds'] + ] + + @apply_defaults + def __init__(self, + config, + wait_for_completion=True, + check_interval=30, + max_ingestion_time=None, + *args, **kwargs): + super(SageMakerTuningOperator, self).__init__(config=config, + *args, **kwargs) + self.config = config + self.wait_for_completion = wait_for_completion + self.check_interval = check_interval + self.max_ingestion_time = max_ingestion_time + + def expand_role(self): + if 'TrainingJobDefinition' in self.config: + config = self.config['TrainingJobDefinition'] + if 'RoleArn' in config: + hook = AwsHook(self.aws_conn_id) + config['RoleArn'] = hook.expand_role(config['RoleArn']) + + def execute(self, context): + self.preprocess_config() + + self.log.info( + 'Creating SageMaker Hyper-Parameter Tuning Job %s', self.config['HyperParameterTuningJobName'] + ) + + response = self.hook.create_tuning_job( + self.config, + wait_for_completion=self.wait_for_completion, + check_interval=self.check_interval, + max_ingestion_time=self.max_ingestion_time + ) + if response['ResponseMetadata']['HTTPStatusCode'] != 200: + raise AirflowException( + 'Sagemaker Tuning Job creation failed: %s' % response) + else: + return { + 'Tuning': self.hook.describe_tuning_job( + self.config['HyperParameterTuningJobName'] + ) + } diff --git a/airflow/contrib/sensors/sagemaker_base_sensor.py b/airflow/contrib/sensors/sagemaker_base_sensor.py index 149c2a1aab124..10dd6b2357a66 100644 --- a/airflow/contrib/sensors/sagemaker_base_sensor.py +++ b/airflow/contrib/sensors/sagemaker_base_sensor.py @@ -28,7 +28,7 @@ class SageMakerBaseSensor(BaseSensorOperator): and state_from_response() methods. Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods. """ - ui_color = '#66c3ff' + ui_color = '#ededed' @apply_defaults def __init__( @@ -54,23 +54,21 @@ def poke(self, context): if state in self.failed_states(): failed_reason = self.get_failed_reason_from_response(response) - raise AirflowException("Sagemaker job failed for the following reason: %s" + raise AirflowException('Sagemaker job failed for the following reason: %s' % failed_reason) return True def non_terminal_states(self): - raise AirflowException("Non Terminal States need to be specified in subclass") + raise NotImplementedError('Please implement non_terminal_states() in subclass') def failed_states(self): - raise AirflowException("Failed States need to be specified in subclass") + raise NotImplementedError('Please implement failed_states() in subclass') def get_sagemaker_response(self): - raise AirflowException( - "Method get_sagemaker_response()not implemented.") + raise NotImplementedError('Please implement get_sagemaker_response() in subclass') def get_failed_reason_from_response(self, response): return 'Unknown' def state_from_response(self, response): - raise AirflowException( - "Method state_from_response()not implemented.") + raise NotImplementedError('Please implement state_from_response() in subclass') diff --git a/airflow/contrib/sensors/sagemaker_training_sensor.py b/airflow/contrib/sensors/sagemaker_training_sensor.py index 5e83e846efec5..2d820111a08c8 100644 --- a/airflow/contrib/sensors/sagemaker_training_sensor.py +++ b/airflow/contrib/sensors/sagemaker_training_sensor.py @@ -17,7 +17,9 @@ # specific language governing permissions and limitations # under the License. -from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +import time + +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook, LogState from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor from airflow.utils.decorators import apply_defaults @@ -27,8 +29,10 @@ class SageMakerTrainingSensor(SageMakerBaseSensor): Asks for the state of the training state until it reaches a terminal state. If it fails the sensor errors, failing the task. - :param job_name: job_name of the training instance to check the state of + :param job_name: name of the SageMaker training job to check the state of :type job_name: str + :param print_log: if the operator should print the cloudwatch log + :type print_log: bool """ template_fields = ['job_name'] @@ -37,12 +41,30 @@ class SageMakerTrainingSensor(SageMakerBaseSensor): @apply_defaults def __init__(self, job_name, - region_name=None, + print_log=True, *args, **kwargs): super(SageMakerTrainingSensor, self).__init__(*args, **kwargs) self.job_name = job_name - self.region_name = region_name + self.print_log = print_log + self.positions = {} + self.stream_names = [] + self.instance_count = None + self.state = None + self.last_description = None + self.last_describe_job_call = None + self.log_resource_inited = False + + def init_log_resource(self, hook): + description = hook.describe_training_job(self.job_name) + self.instance_count = description['ResourceConfig']['InstanceCount'] + + status = description['TrainingJobStatus'] + job_already_completed = status not in self.non_terminal_states() + self.state = LogState.TAILING if not job_already_completed else LogState.COMPLETE + self.last_description = description + self.last_describe_job_call = time.time() + self.log_resource_inited = True def non_terminal_states(self): return SageMakerHook.non_terminal_states @@ -51,13 +73,27 @@ def failed_states(self): return SageMakerHook.failed_states def get_sagemaker_response(self): - sagemaker = SageMakerHook( - aws_conn_id=self.aws_conn_id, - region_name=self.region_name - ) + sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id) + if self.print_log: + if not self.log_resource_inited: + self.init_log_resource(sagemaker_hook) + self.state, self.last_description, self.last_describe_job_call = \ + sagemaker_hook.describe_training_job_with_log(self.job_name, + self.positions, self.stream_names, + self.instance_count, self.state, + self.last_description, + self.last_describe_job_call) + else: + self.last_description = sagemaker_hook.describe_training_job(self.job_name) + + status = self.state_from_response(self.last_description) + if status not in self.non_terminal_states() and status not in self.failed_states(): + billable_time = \ + (self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']) * \ + self.last_description['ResourceConfig']['InstanceCount'] + self.log.info('Billable seconds:{}'.format(int(billable_time.total_seconds()) + 1)) - self.log.info('Poking Sagemaker Training Job %s', self.job_name) - return sagemaker.describe_training_job(self.job_name) + return self.last_description def get_failed_reason_from_response(self, response): return response['FailureReason'] diff --git a/airflow/contrib/sensors/sagemaker_transform_sensor.py b/airflow/contrib/sensors/sagemaker_transform_sensor.py index 68ef1d8dd7b05..f64724bde9b24 100644 --- a/airflow/contrib/sensors/sagemaker_transform_sensor.py +++ b/airflow/contrib/sensors/sagemaker_transform_sensor.py @@ -30,8 +30,6 @@ class SageMakerTransformSensor(SageMakerBaseSensor): :param job_name: job_name of the transform job instance to check the state of :type job_name: string - :param region_name: The AWS region_name - :type region_name: string """ template_fields = ['job_name'] @@ -40,12 +38,10 @@ class SageMakerTransformSensor(SageMakerBaseSensor): @apply_defaults def __init__(self, job_name, - region_name=None, *args, **kwargs): super(SageMakerTransformSensor, self).__init__(*args, **kwargs) self.job_name = job_name - self.region_name = region_name def non_terminal_states(self): return SageMakerHook.non_terminal_states @@ -54,10 +50,7 @@ def failed_states(self): return SageMakerHook.failed_states def get_sagemaker_response(self): - sagemaker = SageMakerHook( - aws_conn_id=self.aws_conn_id, - region_name=self.region_name - ) + sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id) self.log.info('Poking Sagemaker Transform Job %s', self.job_name) return sagemaker.describe_transform_job(self.job_name) diff --git a/airflow/contrib/sensors/sagemaker_tuning_sensor.py b/airflow/contrib/sensors/sagemaker_tuning_sensor.py index ec90964266a11..8c835216d6bb3 100644 --- a/airflow/contrib/sensors/sagemaker_tuning_sensor.py +++ b/airflow/contrib/sensors/sagemaker_tuning_sensor.py @@ -30,8 +30,6 @@ class SageMakerTuningSensor(SageMakerBaseSensor): :param job_name: job_name of the tuning instance to check the state of :type job_name: str - :param region_name: The AWS region_name - :type region_name: str """ template_fields = ['job_name'] @@ -40,12 +38,10 @@ class SageMakerTuningSensor(SageMakerBaseSensor): @apply_defaults def __init__(self, job_name, - region_name=None, *args, **kwargs): super(SageMakerTuningSensor, self).__init__(*args, **kwargs) self.job_name = job_name - self.region_name = region_name def non_terminal_states(self): return SageMakerHook.non_terminal_states @@ -54,10 +50,7 @@ def failed_states(self): return SageMakerHook.failed_states def get_sagemaker_response(self): - sagemaker = SageMakerHook( - aws_conn_id=self.aws_conn_id, - region_name=self.region_name - ) + sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id) self.log.info('Poking Sagemaker Tuning Job %s', self.job_name) return sagemaker.describe_tuning_job(self.job_name) diff --git a/airflow/hooks/S3_hook.py b/airflow/hooks/S3_hook.py index 35d26fd2b50fc..c9e81e9d7330e 100644 --- a/airflow/hooks/S3_hook.py +++ b/airflow/hooks/S3_hook.py @@ -69,6 +69,26 @@ def get_bucket(self, bucket_name): s3 = self.get_resource_type('s3') return s3.Bucket(bucket_name) + def create_bucket(self, bucket_name, region_name=None): + """ + Creates a boto3.S3.Bucket object + + :param bucket_name: the name of the bucket + :type bucket_name: str + :param region__name: the name of the aws region + :type region_name: str + """ + s3_conn = self.get_conn() + if not region_name: + region_name = s3_conn.meta.region_name + if region_name == 'us-east-1': + self.get_conn().create_bucket(Bucket=bucket_name) + else: + self.get_conn().create_bucket(Bucket=bucket_name, + CreateBucketConfiguration={ + 'LocationConstraint': region_name + }) + def check_for_prefix(self, bucket_name, prefix, delimiter): """ Checks that a prefix exists in a bucket @@ -401,6 +421,41 @@ def load_bytes(self, client = self.get_conn() client.upload_fileobj(filelike_buffer, bucket_name, key, ExtraArgs=extra_args) + def load_file_obj(self, + file_obj, + key, + bucket_name=None, + replace=False, + encrypt=False): + """ + Loads file object to S3 + + :param file_obj: file-like object to set as content for the key. + :type file_obj: file-like object + :param key: S3 key that will point to the file + :type key: str + :param bucket_name: Name of the bucket in which to store the file + :type bucket_name: str + :param replace: A flag to decide whether or not to overwrite the key + if it already exists + :type replace: bool + :param encrypt: If True, the file will be encrypted on the server-side + by S3 and will be stored in an encrypted form while at rest in S3. + :type encrypt: bool + """ + if not bucket_name: + (bucket_name, key) = self.parse_s3_url(key) + + if not replace and self.check_for_key(key, bucket_name): + raise ValueError("The key {key} already exists.".format(key=key)) + + extra_args = {} + if encrypt: + extra_args['ServerSideEncryption'] = "AES256" + + client = self.get_conn() + client.upload_fileobj(file_obj, bucket_name, key, ExtraArgs=extra_args) + def copy_object(self, source_bucket_key, dest_bucket_key, diff --git a/docs/code.rst b/docs/code.rst index 5b4a494911726..817c5046ea29d 100644 --- a/docs/code.rst +++ b/docs/code.rst @@ -189,6 +189,10 @@ Operators .. autoclass:: airflow.contrib.operators.s3_delete_objects_operator.S3DeleteObjectsOperator .. autoclass:: airflow.contrib.operators.s3_list_operator.S3ListOperator .. autoclass:: airflow.contrib.operators.s3_to_gcs_operator.S3ToGoogleCloudStorageOperator +.. autoclass:: airflow.contrib.operators.sagemaker_base_operator.SageMakerBaseOperator +.. autoclass:: airflow.contrib.operators.sagemaker_training_operator.SageMakerTrainingOperator +.. autoclass:: airflow.contrib.operators.sagemaker_transform_operator.SageMakerTransformOperator +.. autoclass:: airflow.contrib.operators.sagemaker_tuning_operator.SagemakerTuningOperator .. autoclass:: airflow.contrib.operators.segment_track_event_operator.SegmentTrackEventOperator .. autoclass:: airflow.contrib.operators.sftp_operator.SFTPOperator .. autoclass:: airflow.contrib.operators.slack_webhook_operator.SlackWebhookOperator @@ -226,6 +230,10 @@ Sensors .. autoclass:: airflow.contrib.sensors.pubsub_sensor.PubSubPullSensor .. autoclass:: airflow.contrib.sensors.qubole_sensor.QuboleSensor .. autoclass:: airflow.contrib.sensors.redis_key_sensor.RedisKeySensor +.. autoclass:: airflow.contrib.sensors.sagemaker_base_sensor.SageMakerBaseSensor +.. autoclass:: airflow.contrib.sensors.sagemaker_training_sensor.SageMakerTrainingSensor +.. autoclass:: airflow.contrib.sensors.sagemaker_transform_sensor.SageMakerTransformSensor +.. autoclass:: airflow.contrib.sensors.sagemaker_tuning_sensor.SageMakerTuningSensor .. autoclass:: airflow.contrib.sensors.sftp_sensor.SFTPSensor .. autoclass:: airflow.contrib.sensors.wasb_sensor.WasbBlobSensor @@ -407,6 +415,7 @@ Community contributed hooks .. autoclass:: airflow.contrib.hooks.qubole_hook.QuboleHook .. autoclass:: airflow.contrib.hooks.redis_hook.RedisHook .. autoclass:: airflow.contrib.hooks.redshift_hook.RedshiftHook +.. autoclass:: airflow.contrib.hooks.sagemaker_hook.SageMakerHook .. autoclass:: airflow.contrib.hooks.salesforce_hook.SalesforceHook .. autoclass:: airflow.contrib.hooks.segment_hook.SegmentHook .. autoclass:: airflow.contrib.hooks.sftp_hook.SFTPHook diff --git a/tests/contrib/hooks/test_aws_hook.py b/tests/contrib/hooks/test_aws_hook.py index eaadc5fbff413..addee85109194 100644 --- a/tests/contrib/hooks/test_aws_hook.py +++ b/tests/contrib/hooks/test_aws_hook.py @@ -35,11 +35,12 @@ mock = None try: - from moto import mock_emr, mock_dynamodb2, mock_sts + from moto import mock_emr, mock_dynamodb2, mock_sts, mock_iam except ImportError: mock_emr = None mock_dynamodb2 = None mock_sts = None + mock_iam = None class TestAwsHook(unittest.TestCase): @@ -204,6 +205,16 @@ def test_get_credentials_from_role_arn_with_external_id(self, mock_get_connectio 'gRmpRV3zrkuWJOgQs8IZZaIv2BXIa2R4OlgkBN9bkUDNCJiBeb/AXlzBBko7b15' 'fjrBs2+cTQtpZ3CYWFXG8C5zqx37wnOE49mRl/+OtkIKGO7fAE') + @unittest.skipIf(mock_iam is None, 'mock_iam package not present') + @mock_iam + def test_expand_role(self): + conn = boto3.client('iam', region_name='us-east-1') + conn.create_role(RoleName='test-role', AssumeRolePolicyDocument='some policy') + hook = AwsHook() + arn = hook.expand_role('test-role') + expect_arn = conn.get_role(RoleName='test-role').get('Role').get('Arn') + self.assertEqual(arn, expect_arn) + if __name__ == '__main__': unittest.main() diff --git a/tests/contrib/hooks/test_sagemaker_hook.py b/tests/contrib/hooks/test_sagemaker_hook.py index 3a863b3cb0dc7..bec00bf601a8f 100644 --- a/tests/contrib/hooks/test_sagemaker_hook.py +++ b/tests/contrib/hooks/test_sagemaker_hook.py @@ -18,10 +18,11 @@ # under the License. # - -import json import unittest -import copy +import time +from datetime import datetime +from tzlocal import get_localzone + try: from unittest import mock except ImportError: @@ -31,200 +32,221 @@ mock = None from airflow import configuration -from airflow import models -from airflow.utils import db -from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.contrib.hooks.sagemaker_hook import (SageMakerHook, secondary_training_status_changed, + secondary_training_status_message, LogState) from airflow.hooks.S3_hook import S3Hook from airflow.exceptions import AirflowException -role = 'test-role' +role = 'arn:aws:iam:role/test-role' +path = 'local/data' bucket = 'test-bucket' - key = 'test/data' data_url = 's3://{}/{}'.format(bucket, key) -job_name = 'test-job-name' - -model_name = 'test-model-name' +job_name = 'test-job' +model_name = 'test-model' +config_name = 'test-endpoint-config' +endpoint_name = 'test-endpoint' image = 'test-image' +test_arn_return = {'Arn': 'testarn'} +output_url = 's3://{}/test/output'.format(bucket) -test_arn_return = {'TrainingJobArn': 'testarn'} - -test_list_training_job_return = { - 'TrainingJobSummaries': [ - { - 'TrainingJobName': job_name, - 'TrainingJobStatus': 'InProgress' - }, - ], - 'NextToken': 'test-token' -} - -test_list_tuning_job_return = { - 'TrainingJobSummaries': [ +create_training_params = { + 'AlgorithmSpecification': { + 'TrainingImage': image, + 'TrainingInputMode': 'File' + }, + 'RoleArn': role, + 'OutputDataConfig': { + 'S3OutputPath': output_url + }, + 'ResourceConfig': { + 'InstanceCount': 2, + 'InstanceType': 'ml.c4.8xlarge', + 'VolumeSizeInGB': 50 + }, + 'TrainingJobName': job_name, + 'HyperParameters': { + 'k': '10', + 'feature_dim': '784', + 'mini_batch_size': '500', + 'force_dense': 'True' + }, + 'StoppingCondition': { + 'MaxRuntimeInSeconds': 60 * 60 + }, + 'InputDataConfig': [ { - 'TrainingJobName': job_name, - 'TrainingJobArn': 'testarn', - 'TunedHyperParameters': { - 'k': '3' + 'ChannelName': 'train', + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': data_url, + 'S3DataDistributionType': 'FullyReplicated' + } }, - 'TrainingJobStatus': 'InProgress' - }, - ], - 'NextToken': 'test-token' + 'CompressionType': 'None', + 'RecordWrapperType': 'None' + } + ] } -output_url = 's3://{}/test/output'.format(bucket) -create_training_params = \ - { - 'AlgorithmSpecification': { - 'TrainingImage': image, - 'TrainingInputMode': 'File' +create_tuning_params = { + 'HyperParameterTuningJobName': job_name, + 'HyperParameterTuningJobConfig': { + 'Strategy': 'Bayesian', + 'HyperParameterTuningJobObjective': { + 'Type': 'Maximize', + 'MetricName': 'test_metric' }, - 'RoleArn': role, - 'OutputDataConfig': { - 'S3OutputPath': output_url + 'ResourceLimits': { + 'MaxNumberOfTrainingJobs': 123, + 'MaxParallelTrainingJobs': 123 }, - 'ResourceConfig': { - 'InstanceCount': 2, - 'InstanceType': 'ml.c4.8xlarge', - 'VolumeSizeInGB': 50 - }, - 'TrainingJobName': job_name, - 'HyperParameters': { - 'k': '10', - 'feature_dim': '784', - 'mini_batch_size': '500', - 'force_dense': 'True' - }, - 'StoppingCondition': { - 'MaxRuntimeInSeconds': 60 * 60 - }, - 'InputDataConfig': [ - { - 'ChannelName': 'train', - 'DataSource': { - 'S3DataSource': { - 'S3DataType': 'S3Prefix', - 'S3Uri': data_url, - 'S3DataDistributionType': 'FullyReplicated' - } + 'ParameterRanges': { + 'IntegerParameterRanges': [ + { + 'Name': 'k', + 'MinValue': '2', + 'MaxValue': '10' }, - 'CompressionType': 'None', - 'RecordWrapperType': 'None' - } - ] - } -create_tuning_params = \ - { - 'HyperParameterTuningJobName': job_name, - 'HyperParameterTuningJobConfig': { - 'Strategy': 'Bayesian', - 'HyperParameterTuningJobObjective': { - 'Type': 'Maximize', - 'MetricName': 'test_metric' - }, - 'ResourceLimits': { - 'MaxNumberOfTrainingJobs': 123, - 'MaxParallelTrainingJobs': 123 - }, - 'ParameterRanges': { - 'IntegerParameterRanges': [ - { - 'Name': 'k', - 'MinValue': '2', - 'MaxValue': '10' - }, - - ] - } - }, - 'TrainingJobDefinition': { - 'StaticHyperParameters': create_training_params['HyperParameters'], - 'AlgorithmSpecification': create_training_params['AlgorithmSpecification'], - 'RoleArn': 'string', - 'InputDataConfig': create_training_params['InputDataConfig'], - 'OutputDataConfig': create_training_params['OutputDataConfig'], - 'ResourceConfig': create_training_params['ResourceConfig'], - 'StoppingCondition': dict(MaxRuntimeInSeconds=60 * 60) + ] } + }, + 'TrainingJobDefinition': { + 'StaticHyperParameters': create_training_params['HyperParameters'], + 'AlgorithmSpecification': create_training_params['AlgorithmSpecification'], + 'RoleArn': 'string', + 'InputDataConfig': create_training_params['InputDataConfig'], + 'OutputDataConfig': create_training_params['OutputDataConfig'], + 'ResourceConfig': create_training_params['ResourceConfig'], + 'StoppingCondition': dict(MaxRuntimeInSeconds=60 * 60) } +} -create_transform_params = \ - { - 'TransformJobName': job_name, - 'ModelName': model_name, - 'BatchStrategy': 'MultiRecord', - 'TransformInput': { - 'DataSource': { - 'S3DataSource': { - 'S3DataType': 'S3Prefix', - 'S3Uri': data_url - } +create_transform_params = { + 'TransformJobName': job_name, + 'ModelName': model_name, + 'BatchStrategy': 'MultiRecord', + 'TransformInput': { + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': data_url } - }, - 'TransformOutput': { - 'S3OutputPath': output_url, - }, - 'TransformResources': { - 'InstanceType': 'ml.m4.xlarge', - 'InstanceCount': 123 } + }, + 'TransformOutput': { + 'S3OutputPath': output_url, + }, + 'TransformResources': { + 'InstanceType': 'ml.m4.xlarge', + 'InstanceCount': 123 } +} -create_model_params = \ - { - 'ModelName': model_name, - 'PrimaryContainer': { - 'Image': image, - 'ModelDataUrl': output_url, - }, - 'ExecutionRoleArn': role - } +create_model_params = { + 'ModelName': model_name, + 'PrimaryContainer': { + 'Image': image, + 'ModelDataUrl': output_url, + }, + 'ExecutionRoleArn': role +} -db_config = { - 'Tags': [ +create_endpoint_config_params = { + 'EndpointConfigName': config_name, + 'ProductionVariants': [ { - 'Key': 'test-db-key', - 'Value': 'test-db-value', - - }, + 'VariantName': 'AllTraffic', + 'ModelName': model_name, + 'InitialInstanceCount': 1, + 'InstanceType': 'ml.c4.xlarge' + } ] } -DESCRIBE_TRAINING_INPROGRESS_RETURN = { - 'TrainingJobStatus': 'InProgress', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } +create_endpoint_params = { + 'EndpointName': endpoint_name, + 'EndpointConfigName': config_name } + +update_endpoint_params = create_endpoint_params + DESCRIBE_TRAINING_COMPELETED_RETURN = { - 'TrainingJobStatus': 'Compeleted', + 'TrainingJobStatus': 'Completed', + 'ResourceConfig': { + 'InstanceCount': 1, + 'InstanceType': 'ml.c4.xlarge', + 'VolumeSizeInGB': 10 + }, + 'TrainingStartTime': datetime(2018, 2, 17, 7, 15, 0, 103000), + 'TrainingEndTime': datetime(2018, 2, 17, 7, 19, 34, 953000), 'ResponseMetadata': { 'HTTPStatusCode': 200, } } -DESCRIBE_TRAINING_FAILED_RETURN = { - 'TrainingJobStatus': 'Failed', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - }, - 'FailureReason': 'Unknown' + +DESCRIBE_TRAINING_INPROGRESS_RETURN = dict(DESCRIBE_TRAINING_COMPELETED_RETURN) +DESCRIBE_TRAINING_INPROGRESS_RETURN.update({'TrainingJobStatus': 'InProgress'}) + +DESCRIBE_TRAINING_FAILED_RETURN = dict(DESCRIBE_TRAINING_COMPELETED_RETURN) +DESCRIBE_TRAINING_FAILED_RETURN.update({'TrainingJobStatus': 'Failed', + 'FailureReason': 'Unknown'}) + +DESCRIBE_TRAINING_STOPPING_RETURN = dict(DESCRIBE_TRAINING_COMPELETED_RETURN) +DESCRIBE_TRAINING_STOPPING_RETURN.update({'TrainingJobStatus': 'Stopping'}) + +message = 'message' +status = 'status' +SECONDARY_STATUS_DESCRIPTION_1 = { + 'SecondaryStatusTransitions': [{'StatusMessage': message, 'Status': status}] } -DESCRIBE_TRAINING_STOPPING_RETURN = { - 'TrainingJobStatus': 'Stopping', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } +SECONDARY_STATUS_DESCRIPTION_2 = { + 'SecondaryStatusTransitions': [{'StatusMessage': 'different message', 'Status': status}] } -DESCRIBE_TRAINING_STOPPED_RETURN = { - 'TrainingJobStatus': 'Stopped', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, + +DEFAULT_LOG_STREAMS = {'logStreams': [{'logStreamName': job_name + '/xxxxxxxxx'}]} +LIFECYCLE_LOG_STREAMS = [DEFAULT_LOG_STREAMS, + DEFAULT_LOG_STREAMS, + DEFAULT_LOG_STREAMS, + DEFAULT_LOG_STREAMS, + DEFAULT_LOG_STREAMS, + DEFAULT_LOG_STREAMS] + +DEFAULT_LOG_EVENTS = [{'nextForwardToken': None, 'events': [{'timestamp': 1, 'message': 'hi there #1'}]}, + {'nextForwardToken': None, 'events': []}] +STREAM_LOG_EVENTS = [{'nextForwardToken': None, 'events': [{'timestamp': 1, 'message': 'hi there #1'}]}, + {'nextForwardToken': None, 'events': []}, + {'nextForwardToken': None, 'events': [{'timestamp': 1, 'message': 'hi there #1'}, + {'timestamp': 2, 'message': 'hi there #2'}]}, + {'nextForwardToken': None, 'events': []}, + {'nextForwardToken': None, 'events': [{'timestamp': 2, 'message': 'hi there #2'}, + {'timestamp': 2, 'message': 'hi there #2a'}, + {'timestamp': 3, 'message': 'hi there #3'}]}, + {'nextForwardToken': None, 'events': []}] + +test_evaluation_config = { + 'Image': image, + 'Role': role, + 'S3Operations': { + 'S3CreateBucket': [ + { + 'Bucket': bucket + } + ], + 'S3Upload': [ + { + 'Path': path, + 'Bucket': bucket, + 'Key': key, + 'Tar': False + } + ] } } @@ -233,94 +255,63 @@ class TestSageMakerHook(unittest.TestCase): def setUp(self): configuration.load_test_config() - db.merge_conn( - models.Connection( - conn_id='sagemaker_test_conn_id', - conn_type='sagemaker', - login='access_id', - password='access_key', - extra=json.dumps(db_config) - ) - ) + + @mock.patch.object(S3Hook, 'create_bucket') + @mock.patch.object(S3Hook, 'load_file') + def test_configure_s3_resources(self, mock_load_file, mock_create_bucket): + hook = SageMakerHook() + evaluation_result = { + 'Image': image, + 'Role': role + } + hook.configure_s3_resources(test_evaluation_config) + self.assertEqual(test_evaluation_config, evaluation_result) + mock_create_bucket.assert_called_once_with(bucket_name=bucket) + mock_load_file.assert_called_once_with(path, key, bucket) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(S3Hook, 'check_for_key') @mock.patch.object(S3Hook, 'check_for_bucket') @mock.patch.object(S3Hook, 'check_for_prefix') - def test_check_for_url(self, - mock_check_prefix, - mock_check_bucket, - mock_check_key, - mock_client): + def test_check_s3_url(self, + mock_check_prefix, + mock_check_bucket, + mock_check_key, + mock_client): mock_client.return_value = None hook = SageMakerHook() mock_check_bucket.side_effect = [False, True, True, True] mock_check_key.side_effect = [False, True, False] mock_check_prefix.side_effect = [False, True, True] self.assertRaises(AirflowException, - hook.check_for_url, data_url) + hook.check_s3_url, data_url) self.assertRaises(AirflowException, - hook.check_for_url, data_url) - self.assertEqual(hook.check_for_url(data_url), True) - self.assertEqual(hook.check_for_url(data_url), True) + hook.check_s3_url, data_url) + self.assertEqual(hook.check_s3_url(data_url), True) + self.assertEqual(hook.check_s3_url(data_url), True) @mock.patch.object(SageMakerHook, 'get_conn') - @mock.patch.object(SageMakerHook, 'check_for_url') + @mock.patch.object(SageMakerHook, 'check_s3_url') def test_check_valid_training(self, mock_check_url, mock_client): mock_client.return_value = None hook = SageMakerHook() - hook.check_valid_training_input(create_training_params) + hook.check_training_config(create_training_params) mock_check_url.assert_called_once_with(data_url) @mock.patch.object(SageMakerHook, 'get_conn') - @mock.patch.object(SageMakerHook, 'check_for_url') + @mock.patch.object(SageMakerHook, 'check_s3_url') def test_check_valid_tuning(self, mock_check_url, mock_client): mock_client.return_value = None hook = SageMakerHook() - hook.check_valid_tuning_input(create_tuning_params) + hook.check_tuning_config(create_tuning_params) mock_check_url.assert_called_once_with(data_url) @mock.patch.object(SageMakerHook, 'get_client_type') def test_conn(self, mock_get_client): - hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id', - region_name='us-east-1' - ) - self.assertEqual(hook.sagemaker_conn_id, 'sagemaker_test_conn_id') - mock_get_client.assert_called_once_with('sagemaker', - region_name='us-east-1' - ) + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') + self.assertEqual(hook.aws_conn_id, 'sagemaker_test_conn_id') - @mock.patch.object(SageMakerHook, 'get_conn') - def test_list_training_job(self, mock_client): - mock_session = mock.Mock() - attrs = {'list_training_jobs.return_value': - test_list_training_job_return} - mock_session.configure_mock(**attrs) - mock_client.return_value = mock_session - hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') - response = hook.list_training_job(name_contains=job_name, - status_equals='InProgress') - mock_session.list_training_jobs. \ - assert_called_once_with(NameContains=job_name, - StatusEquals='InProgress') - self.assertEqual(response, test_list_training_job_return) - - @mock.patch.object(SageMakerHook, 'get_conn') - def test_list_tuning_job(self, mock_client): - mock_session = mock.Mock() - attrs = {'list_hyper_parameter_tuning_job.return_value': - test_list_tuning_job_return} - mock_session.configure_mock(**attrs) - mock_client.return_value = mock_session - hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') - response = hook.list_tuning_job(name_contains=job_name, - status_equals='InProgress') - mock_session.list_hyper_parameter_tuning_job. \ - assert_called_once_with(NameContains=job_name, - StatusEquals='InProgress') - self.assertEqual(response, test_list_tuning_job_return) - - @mock.patch.object(SageMakerHook, 'check_valid_training_input') + @mock.patch.object(SageMakerHook, 'check_training_config') @mock.patch.object(SageMakerHook, 'get_conn') def test_create_training_job(self, mock_client, mock_check_training): mock_check_training.return_value = True @@ -329,71 +320,56 @@ def test_create_training_job(self, mock_client, mock_check_training): test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session - hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.create_training_job(create_training_params, - wait_for_completion=False) + wait_for_completion=False, + print_log=False) mock_session.create_training_job.assert_called_once_with(**create_training_params) self.assertEqual(response, test_arn_return) - @mock.patch.object(SageMakerHook, 'check_valid_training_input') - @mock.patch.object(SageMakerHook, 'get_conn') - def test_create_training_job_db_config(self, mock_client, mock_check_training): - mock_check_training.return_value = True - mock_session = mock.Mock() - attrs = {'create_training_job.return_value': - test_arn_return} - mock_session.configure_mock(**attrs) - mock_client.return_value = mock_session - hook_use_db_config = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id', - use_db_config=True) - response = hook_use_db_config.create_training_job(create_training_params, - wait_for_completion=False) - updated_config = copy.deepcopy(create_training_params) - updated_config.update(db_config) - mock_session.create_training_job.assert_called_once_with(**updated_config) - self.assertEqual(response, test_arn_return) - - @mock.patch.object(SageMakerHook, 'check_valid_training_input') + @mock.patch.object(SageMakerHook, 'check_training_config') @mock.patch.object(SageMakerHook, 'get_conn') - def test_training_ends_with_wait_on(self, mock_client, mock_check_training): + def test_training_ends_with_wait(self, mock_client, mock_check_training): mock_check_training.return_value = True mock_session = mock.Mock() attrs = {'create_training_job.return_value': test_arn_return, 'describe_training_job.side_effect': - [DESCRIBE_TRAINING_INPROGRESS_RETURN, - DESCRIBE_TRAINING_STOPPING_RETURN, - DESCRIBE_TRAINING_STOPPED_RETURN, - DESCRIBE_TRAINING_COMPELETED_RETURN] + [DESCRIBE_TRAINING_INPROGRESS_RETURN, + DESCRIBE_TRAINING_STOPPING_RETURN, + DESCRIBE_TRAINING_COMPELETED_RETURN, + DESCRIBE_TRAINING_COMPELETED_RETURN] } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session - hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id_1') - hook.create_training_job(create_training_params, wait_for_completion=True) + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1') + hook.create_training_job(create_training_params, wait_for_completion=True, + print_log=False, check_interval=1) self.assertEqual(mock_session.describe_training_job.call_count, 4) - @mock.patch.object(SageMakerHook, 'check_valid_training_input') + @mock.patch.object(SageMakerHook, 'check_training_config') @mock.patch.object(SageMakerHook, 'get_conn') - def test_training_throws_error_when_failed_with_wait_on( + def test_training_throws_error_when_failed_with_wait( self, mock_client, mock_check_training): mock_check_training.return_value = True mock_session = mock.Mock() attrs = {'create_training_job.return_value': test_arn_return, 'describe_training_job.side_effect': - [DESCRIBE_TRAINING_INPROGRESS_RETURN, - DESCRIBE_TRAINING_STOPPING_RETURN, - DESCRIBE_TRAINING_STOPPED_RETURN, - DESCRIBE_TRAINING_FAILED_RETURN] + [DESCRIBE_TRAINING_INPROGRESS_RETURN, + DESCRIBE_TRAINING_STOPPING_RETURN, + DESCRIBE_TRAINING_FAILED_RETURN, + DESCRIBE_TRAINING_COMPELETED_RETURN] } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session - hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id_1') + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1') self.assertRaises(AirflowException, hook.create_training_job, - create_training_params, wait_for_completion=True) - self.assertEqual(mock_session.describe_training_job.call_count, 4) + create_training_params, wait_for_completion=True, + print_log=False, check_interval=1) + self.assertEqual(mock_session.describe_training_job.call_count, 3) - @mock.patch.object(SageMakerHook, 'check_valid_tuning_input') + @mock.patch.object(SageMakerHook, 'check_tuning_config') @mock.patch.object(SageMakerHook, 'get_conn') def test_create_tuning_job(self, mock_client, mock_check_tuning): mock_session = mock.Mock() @@ -401,33 +377,14 @@ def test_create_tuning_job(self, mock_client, mock_check_tuning): test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session - hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.create_tuning_job(create_tuning_params, wait_for_completion=False) mock_session.create_hyper_parameter_tuning_job.\ assert_called_once_with(**create_tuning_params) self.assertEqual(response, test_arn_return) - @mock.patch.object(SageMakerHook, 'check_valid_tuning_input') - @mock.patch.object(SageMakerHook, 'get_conn') - def test_create_tuning_job_db_config(self, mock_client, mock_check_tuning): - mock_check_tuning.return_value = True - mock_session = mock.Mock() - attrs = {'create_hyper_parameter_tuning_job.return_value': - test_arn_return} - mock_session.configure_mock(**attrs) - mock_client.return_value = mock_session - hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id', - use_db_config=True) - response = hook.create_tuning_job(create_tuning_params, - wait_for_completion=False) - updated_config = copy.deepcopy(create_tuning_params) - updated_config.update(db_config) - mock_session.create_hyper_parameter_tuning_job. \ - assert_called_once_with(**updated_config) - self.assertEqual(response, test_arn_return) - - @mock.patch.object(SageMakerHook, 'check_for_url') + @mock.patch.object(SageMakerHook, 'check_s3_url') @mock.patch.object(SageMakerHook, 'get_conn') def test_create_transform_job(self, mock_client, mock_check_url): mock_check_url.return_value = True @@ -436,41 +393,64 @@ def test_create_transform_job(self, mock_client, mock_check_url): test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session - hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.create_transform_job(create_transform_params, wait_for_completion=False) mock_session.create_transform_job.assert_called_once_with( **create_transform_params) self.assertEqual(response, test_arn_return) - @mock.patch.object(SageMakerHook, 'check_for_url') @mock.patch.object(SageMakerHook, 'get_conn') - def test_create_transform_job_db_config(self, mock_client, mock_check_url): - mock_check_url.return_value = True + def test_create_model(self, mock_client): mock_session = mock.Mock() - attrs = {'create_transform_job.return_value': + attrs = {'create_model.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session - hook_use_db_config = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id', - use_db_config=True) - response = hook_use_db_config.create_transform_job( - create_transform_params, wait_for_completion=False) - updated_config = copy.deepcopy(create_transform_params) - updated_config.update(db_config) - mock_session.create_transform_job.assert_called_once_with(**updated_config) + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') + response = hook.create_model(create_model_params) + mock_session.create_model.assert_called_once_with(**create_model_params) self.assertEqual(response, test_arn_return) @mock.patch.object(SageMakerHook, 'get_conn') - def test_create_model(self, mock_client): + def test_create_endpoint_config(self, mock_client): mock_session = mock.Mock() - attrs = {'create_model.return_value': + attrs = {'create_endpoint_config.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session - hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') - response = hook.create_model(create_model_params) - mock_session.create_model.assert_called_once_with(**create_model_params) + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') + response = hook.create_endpoint_config(create_endpoint_config_params) + mock_session.create_endpoint_config\ + .assert_called_once_with(**create_endpoint_config_params) + self.assertEqual(response, test_arn_return) + + @mock.patch.object(SageMakerHook, 'get_conn') + def test_create_endpoint(self, mock_client): + mock_session = mock.Mock() + attrs = {'create_endpoint.return_value': + test_arn_return} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') + response = hook.create_endpoint(create_endpoint_params, + wait_for_completion=False) + mock_session.create_endpoint\ + .assert_called_once_with(**create_endpoint_params) + self.assertEqual(response, test_arn_return) + + @mock.patch.object(SageMakerHook, 'get_conn') + def test_update_endpoint(self, mock_client): + mock_session = mock.Mock() + attrs = {'update_endpoint.return_value': + test_arn_return} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') + response = hook.update_endpoint(update_endpoint_params, + wait_for_completion=False) + mock_session.update_endpoint\ + .assert_called_once_with(**update_endpoint_params) self.assertEqual(response, test_arn_return) @mock.patch.object(SageMakerHook, 'get_conn') @@ -479,7 +459,7 @@ def test_describe_training_job(self, mock_client): attrs = {'describe_training_job.return_value': 'InProgress'} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session - hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_training_job(job_name) mock_session.describe_training_job.\ assert_called_once_with(TrainingJobName=job_name) @@ -492,7 +472,7 @@ def test_describe_tuning_job(self, mock_client): 'InProgress'} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session - hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_tuning_job(job_name) mock_session.describe_hyper_parameter_tuning_job.\ assert_called_once_with(HyperParameterTuningJobName=job_name) @@ -505,12 +485,188 @@ def test_describe_transform_job(self, mock_client): 'InProgress'} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session - hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_transform_job(job_name) mock_session.describe_transform_job.\ assert_called_once_with(TransformJobName=job_name) self.assertEqual(response, 'InProgress') + @mock.patch.object(SageMakerHook, 'get_conn') + def test_describe_model(self, mock_client): + mock_session = mock.Mock() + attrs = {'describe_model.return_value': + model_name} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') + response = hook.describe_model(model_name) + mock_session.describe_model.\ + assert_called_once_with(ModelName=model_name) + self.assertEqual(response, model_name) + + @mock.patch.object(SageMakerHook, 'get_conn') + def test_describe_endpoint_config(self, mock_client): + mock_session = mock.Mock() + attrs = {'describe_endpoint_config.return_value': + config_name} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') + response = hook.describe_endpoint_config(config_name) + mock_session.describe_endpoint_config.\ + assert_called_once_with(EndpointConfigName=config_name) + self.assertEqual(response, config_name) + + @mock.patch.object(SageMakerHook, 'get_conn') + def test_describe_endpoint(self, mock_client): + mock_session = mock.Mock() + attrs = {'describe_endpoint.return_value': + 'InProgress'} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') + response = hook.describe_endpoint(endpoint_name) + mock_session.describe_endpoint.\ + assert_called_once_with(EndpointName=endpoint_name) + self.assertEqual(response, 'InProgress') + + def test_secondary_training_status_changed_true(self): + changed = secondary_training_status_changed(SECONDARY_STATUS_DESCRIPTION_1, + SECONDARY_STATUS_DESCRIPTION_2) + self.assertTrue(changed) + + def test_secondary_training_status_changed_false(self): + changed = secondary_training_status_changed(SECONDARY_STATUS_DESCRIPTION_1, + SECONDARY_STATUS_DESCRIPTION_1) + self.assertFalse(changed) + + def test_secondary_training_status_message_status_changed(self): + now = datetime.now(get_localzone()) + SECONDARY_STATUS_DESCRIPTION_1['LastModifiedTime'] = now + expected = '{} {} - {}'.format( + datetime.utcfromtimestamp(time.mktime(now.timetuple())).strftime('%Y-%m-%d %H:%M:%S'), + status, + message + ) + self.assertEqual( + secondary_training_status_message(SECONDARY_STATUS_DESCRIPTION_1, SECONDARY_STATUS_DESCRIPTION_2), + expected) + + @mock.patch.object(SageMakerHook, 'get_log_conn') + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(time, 'time') + def test_describe_training_job_with_logs_in_progress(self, mock_time, mock_client, mock_log_client): + mock_session = mock.Mock() + mock_log_session = mock.Mock() + attrs = {'describe_training_job.return_value': + DESCRIBE_TRAINING_COMPELETED_RETURN + } + log_attrs = {'describe_log_streams.side_effect': + LIFECYCLE_LOG_STREAMS, + 'get_log_events.side_effect': + STREAM_LOG_EVENTS + } + mock_time.return_value = 50 + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + mock_log_session.configure_mock(**log_attrs) + mock_log_client.return_value = mock_log_session + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') + response = hook.describe_training_job_with_log(job_name=job_name, + positions={}, + stream_names=[], + instance_count=1, + state=LogState.WAIT_IN_PROGRESS, + last_description={}, + last_describe_job_call=0) + self.assertEqual(response, (LogState.JOB_COMPLETE, {}, 50)) + + @mock.patch.object(SageMakerHook, 'get_log_conn') + @mock.patch.object(SageMakerHook, 'get_conn') + def test_describe_training_job_with_logs_job_complete(self, mock_client, mock_log_client): + mock_session = mock.Mock() + mock_log_session = mock.Mock() + attrs = {'describe_training_job.return_value': + DESCRIBE_TRAINING_COMPELETED_RETURN + } + log_attrs = {'describe_log_streams.side_effect': + LIFECYCLE_LOG_STREAMS, + 'get_log_events.side_effect': + STREAM_LOG_EVENTS + } + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + mock_log_session.configure_mock(**log_attrs) + mock_log_client.return_value = mock_log_session + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') + response = hook.describe_training_job_with_log(job_name=job_name, + positions={}, + stream_names=[], + instance_count=1, + state=LogState.JOB_COMPLETE, + last_description={}, + last_describe_job_call=0) + self.assertEqual(response, (LogState.COMPLETE, {}, 0)) + + @mock.patch.object(SageMakerHook, 'get_log_conn') + @mock.patch.object(SageMakerHook, 'get_conn') + def test_describe_training_job_with_logs_complete(self, mock_client, mock_log_client): + mock_session = mock.Mock() + mock_log_session = mock.Mock() + attrs = {'describe_training_job.return_value': + DESCRIBE_TRAINING_COMPELETED_RETURN + } + log_attrs = {'describe_log_streams.side_effect': + LIFECYCLE_LOG_STREAMS, + 'get_log_events.side_effect': + STREAM_LOG_EVENTS + } + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + mock_log_session.configure_mock(**log_attrs) + mock_log_client.return_value = mock_log_session + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') + response = hook.describe_training_job_with_log(job_name=job_name, + positions={}, + stream_names=[], + instance_count=1, + state=LogState.COMPLETE, + last_description={}, + last_describe_job_call=0) + self.assertEqual(response, (LogState.COMPLETE, {}, 0)) + + @mock.patch.object(SageMakerHook, 'check_training_config') + @mock.patch.object(SageMakerHook, 'get_log_conn') + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'describe_training_job_with_log') + def test_training_with_logs(self, mock_describe, mock_client, mock_log_client, mock_check_training): + mock_check_training.return_value = True + mock_describe.side_effect = \ + [(LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RETURN, 0), + (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RETURN, 0), + (LogState.COMPLETE, DESCRIBE_TRAINING_COMPELETED_RETURN, 0)] + mock_session = mock.Mock() + mock_log_session = mock.Mock() + attrs = {'create_training_job.return_value': + test_arn_return, + 'describe_training_job.return_value': + DESCRIBE_TRAINING_COMPELETED_RETURN + } + log_attrs = {'describe_log_streams.side_effect': + LIFECYCLE_LOG_STREAMS, + 'get_log_events.side_effect': + STREAM_LOG_EVENTS + } + mock_session.configure_mock(**attrs) + mock_log_session.configure_mock(**log_attrs) + mock_client.return_value = mock_session + mock_log_client.return_value = mock_log_session + hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1') + hook.create_training_job(create_training_params, wait_for_completion=True, + print_log=True, check_interval=1) + self.assertEqual(mock_describe.call_count, 3) + self.assertEqual(mock_session.describe_training_job.call_count, 1) + if __name__ == '__main__': unittest.main() diff --git a/tests/contrib/operators/test_sagemaker_base_operator.py b/tests/contrib/operators/test_sagemaker_base_operator.py new file mode 100644 index 0000000000000..996e61e20f2d1 --- /dev/null +++ b/tests/contrib/operators/test_sagemaker_base_operator.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# 'License'); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +from airflow import configuration +from airflow.contrib.operators.sagemaker_base_operator import SageMakerBaseOperator + +config = { + 'key1': '1', + 'key2': { + 'key3': '3', + 'key4': '4' + }, + 'key5': [ + { + 'key6': '6' + }, + { + 'key6': '7' + } + ] +} + +parsed_config = { + 'key1': 1, + 'key2': { + 'key3': 3, + 'key4': 4 + }, + 'key5': [ + { + 'key6': 6 + }, + { + 'key6': 7 + } + ] +} + + +class TestSageMakerBaseOperator(unittest.TestCase): + + def setUp(self): + configuration.load_test_config() + self.sagemaker = SageMakerBaseOperator( + task_id='test_sagemaker_operator', + aws_conn_id='sagemaker_test_id', + config=config + ) + + def test_parse_integer(self): + self.sagemaker.integer_fields = [ + ['key1'], ['key2', 'key3'], ['key2', 'key4'], ['key5', 'key6'] + ] + self.sagemaker.parse_config_integers() + self.assertEqual(self.sagemaker.config, parsed_config) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/contrib/operators/test_sagemaker_create_training_job_operator.py b/tests/contrib/operators/test_sagemaker_create_training_job_operator.py deleted file mode 100644 index 156c9d74c79ec..0000000000000 --- a/tests/contrib/operators/test_sagemaker_create_training_job_operator.py +++ /dev/null @@ -1,141 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import unittest -try: - from unittest import mock -except ImportError: - try: - import mock - except ImportError: - mock = None - -from airflow import configuration -from airflow.contrib.hooks.sagemaker_hook import SageMakerHook -from airflow.contrib.operators.sagemaker_create_training_job_operator \ - import SageMakerCreateTrainingJobOperator -from airflow.exceptions import AirflowException - -role = "test-role" - -bucket = "test-bucket" - -key = "test/data" -data_url = "s3://{}/{}".format(bucket, key) - -job_name = "test-job-name" - -image = "test-image" - -output_url = "s3://{}/test/output".format(bucket) -create_training_params = \ - { - "AlgorithmSpecification": { - "TrainingImage": image, - "TrainingInputMode": "File" - }, - "RoleArn": role, - "OutputDataConfig": { - "S3OutputPath": output_url - }, - "ResourceConfig": { - "InstanceCount": 2, - "InstanceType": "ml.c4.8xlarge", - "VolumeSizeInGB": 50 - }, - "TrainingJobName": job_name, - "HyperParameters": { - "k": "10", - "feature_dim": "784", - "mini_batch_size": "500", - "force_dense": "True" - }, - "StoppingCondition": { - "MaxRuntimeInSeconds": 60 * 60 - }, - "InputDataConfig": [ - { - "ChannelName": "train", - "DataSource": { - "S3DataSource": { - "S3DataType": "S3Prefix", - "S3Uri": data_url, - "S3DataDistributionType": "FullyReplicated" - } - }, - "CompressionType": "None", - "RecordWrapperType": "None" - } - ] - } - - -class TestSageMakerTrainingOperator(unittest.TestCase): - - def setUp(self): - configuration.load_test_config() - self.sagemaker = SageMakerCreateTrainingJobOperator( - task_id='test_sagemaker_operator', - sagemaker_conn_id='sagemaker_test_id', - training_job_config=create_training_params, - region_name='us-west-2', - use_db_config=True, - wait_for_completion=False, - check_interval=5 - ) - - @mock.patch.object(SageMakerHook, 'get_conn') - @mock.patch.object(SageMakerHook, 'create_training_job') - @mock.patch.object(SageMakerHook, '__init__') - def test_hook_init(self, hook_init, mock_training, mock_client): - mock_training.return_value = {"TrainingJobArn": "testarn", - "ResponseMetadata": - {"HTTPStatusCode": 200}} - hook_init.return_value = None - self.sagemaker.execute(None) - hook_init.assert_called_once_with( - sagemaker_conn_id='sagemaker_test_id', - region_name='us-west-2', - use_db_config=True, - check_interval=5, - max_ingestion_time=None - ) - - @mock.patch.object(SageMakerHook, 'get_conn') - @mock.patch.object(SageMakerHook, 'create_training_job') - def test_execute_without_failure(self, mock_training, mock_client): - mock_training.return_value = {"TrainingJobArn": "testarn", - "ResponseMetadata": - {"HTTPStatusCode": 200}} - self.sagemaker.execute(None) - mock_training.assert_called_once_with(create_training_params, - wait_for_completion=False - ) - - @mock.patch.object(SageMakerHook, 'get_conn') - @mock.patch.object(SageMakerHook, 'create_training_job') - def test_execute_with_failure(self, mock_training, mock_client): - mock_training.return_value = {"TrainingJobArn": "testarn", - "ResponseMetadata": - {"HTTPStatusCode": 404}} - self.assertRaises(AirflowException, self.sagemaker.execute, None) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/contrib/operators/test_sagemaker_training_operator.py b/tests/contrib/operators/test_sagemaker_training_operator.py new file mode 100644 index 0000000000000..147b7f1bb7fac --- /dev/null +++ b/tests/contrib/operators/test_sagemaker_training_operator.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# 'License'); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +from airflow import configuration +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.contrib.operators.sagemaker_training_operator \ + import SageMakerTrainingOperator +from airflow.exceptions import AirflowException + +role = 'arn:aws:iam:role/test-role' + +bucket = 'test-bucket' + +key = 'test/data' +data_url = 's3://{}/{}'.format(bucket, key) + +job_name = 'test-job-name' + +image = 'test-image' + +output_url = 's3://{}/test/output'.format(bucket) +create_training_params = \ + { + 'AlgorithmSpecification': { + 'TrainingImage': image, + 'TrainingInputMode': 'File' + }, + 'RoleArn': role, + 'OutputDataConfig': { + 'S3OutputPath': output_url + }, + 'ResourceConfig': { + 'InstanceCount': '2', + 'InstanceType': 'ml.c4.8xlarge', + 'VolumeSizeInGB': '50' + }, + 'TrainingJobName': job_name, + 'HyperParameters': { + 'k': '10', + 'feature_dim': '784', + 'mini_batch_size': '500', + 'force_dense': 'True' + }, + 'StoppingCondition': { + 'MaxRuntimeInSeconds': '3600' + }, + 'InputDataConfig': [ + { + 'ChannelName': 'train', + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': data_url, + 'S3DataDistributionType': 'FullyReplicated' + } + }, + 'CompressionType': 'None', + 'RecordWrapperType': 'None' + } + ] + } + + +class TestSageMakerTrainingOperator(unittest.TestCase): + + def setUp(self): + configuration.load_test_config() + self.sagemaker = SageMakerTrainingOperator( + task_id='test_sagemaker_operator', + aws_conn_id='sagemaker_test_id', + config=create_training_params, + wait_for_completion=False, + check_interval=5 + ) + + def test_parse_config_integers(self): + self.sagemaker.parse_config_integers() + self.assertEqual(self.sagemaker.config['ResourceConfig']['InstanceCount'], + int(self.sagemaker.config['ResourceConfig']['InstanceCount'])) + self.assertEqual(self.sagemaker.config['ResourceConfig']['VolumeSizeInGB'], + int(self.sagemaker.config['ResourceConfig']['VolumeSizeInGB'])) + self.assertEqual(self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds'], + int(self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds'])) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'create_training_job') + def test_execute(self, mock_training, mock_client): + mock_training.return_value = {'TrainingJobArn': 'testarn', + 'ResponseMetadata': + {'HTTPStatusCode': 200}} + self.sagemaker.execute(None) + mock_training.assert_called_once_with(create_training_params, + wait_for_completion=False, + print_log=True, + check_interval=5, + max_ingestion_time=None + ) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'create_training_job') + def test_execute_with_failure(self, mock_training, mock_client): + mock_training.return_value = {'TrainingJobArn': 'testarn', + 'ResponseMetadata': + {'HTTPStatusCode': 404}} + self.assertRaises(AirflowException, self.sagemaker.execute, None) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/contrib/operators/test_sagemaker_transform_operator.py b/tests/contrib/operators/test_sagemaker_transform_operator.py new file mode 100644 index 0000000000000..e6bc4272b1cb0 --- /dev/null +++ b/tests/contrib/operators/test_sagemaker_transform_operator.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +from airflow import configuration +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.contrib.operators.sagemaker_transform_operator \ + import SageMakerTransformOperator +from airflow.exceptions import AirflowException + +role = 'arn:aws:iam:role/test-role' + +bucket = 'test-bucket' + +key = 'test/data' +data_url = 's3://{}/{}'.format(bucket, key) + +job_name = 'test-job-name' + +model_name = 'test-model-name' + +image = 'test-image' + +output_url = 's3://{}/test/output'.format(bucket) + +create_transform_params = { + 'TransformJobName': job_name, + 'ModelName': model_name, + 'MaxConcurrentTransforms': '12', + 'MaxPayloadInMB': '6', + 'BatchStrategy': 'MultiRecord', + 'TransformInput': { + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': data_url + } + } + }, + 'TransformOutput': { + 'S3OutputPath': output_url, + }, + 'TransformResources': { + 'InstanceType': 'ml.m4.xlarge', + 'InstanceCount': '3' + } +} + +create_model_params = { + 'ModelName': model_name, + 'PrimaryContainer': { + 'Image': image, + 'ModelDataUrl': output_url, + }, + 'ExecutionRoleArn': role +} + +config = { + 'Model': create_model_params, + 'Transform': create_transform_params +} + + +class TestSageMakerTransformOperator(unittest.TestCase): + + def setUp(self): + configuration.load_test_config() + self.sagemaker = SageMakerTransformOperator( + task_id='test_sagemaker_operator', + aws_conn_id='sagemaker_test_id', + config=config, + wait_for_completion=False, + check_interval=5 + ) + + def test_parse_config_integers(self): + self.sagemaker.parse_config_integers() + test_config = self.sagemaker.config['Transform'] + self.assertEqual(test_config['TransformResources']['InstanceCount'], + int(test_config['TransformResources']['InstanceCount'])) + self.assertEqual(test_config['MaxConcurrentTransforms'], + int(test_config['MaxConcurrentTransforms'])) + self.assertEqual(test_config['MaxPayloadInMB'], + int(test_config['MaxPayloadInMB'])) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'create_model') + @mock.patch.object(SageMakerHook, 'create_transform_job') + def test_execute(self, mock_transform, mock_model, mock_client): + mock_transform.return_value = {'TransformJobArn': 'testarn', + 'ResponseMetadata': + {'HTTPStatusCode': 200}} + self.sagemaker.execute(None) + mock_model.assert_called_once_with(create_model_params) + mock_transform.assert_called_once_with(create_transform_params, + wait_for_completion=False, + check_interval=5, + max_ingestion_time=None + ) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'create_model') + @mock.patch.object(SageMakerHook, 'create_transform_job') + def test_execute_with_failure(self, mock_transform, mock_model, mock_client): + mock_transform.return_value = {'TransformJobArn': 'testarn', + 'ResponseMetadata': + {'HTTPStatusCode': 404}} + self.assertRaises(AirflowException, self.sagemaker.execute, None) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/contrib/operators/test_sagemaker_create_tuning_job_operator.py b/tests/contrib/operators/test_sagemaker_tuning_operator.py similarity index 71% rename from tests/contrib/operators/test_sagemaker_create_tuning_job_operator.py rename to tests/contrib/operators/test_sagemaker_tuning_operator.py index d317cff6f2289..9ec533fa869b4 100644 --- a/tests/contrib/operators/test_sagemaker_create_tuning_job_operator.py +++ b/tests/contrib/operators/test_sagemaker_tuning_operator.py @@ -28,11 +28,11 @@ from airflow import configuration from airflow.contrib.hooks.sagemaker_hook import SageMakerHook -from airflow.contrib.operators.sagemaker_create_tuning_job_operator \ - import SageMakerCreateTuningJobOperator +from airflow.contrib.operators.sagemaker_tuning_operator \ + import SageMakerTuningOperator from airflow.exceptions import AirflowException -role = 'test-role' +role = 'arn:aws:iam:role/test-role' bucket = 'test-bucket' @@ -53,8 +53,8 @@ 'MetricName': 'test_metric' }, 'ResourceLimits': { - 'MaxNumberOfTrainingJobs': 123, - 'MaxParallelTrainingJobs': 123 + 'MaxNumberOfTrainingJobs': '123', + 'MaxParallelTrainingJobs': '123' }, 'ParameterRanges': { 'IntegerParameterRanges': [ @@ -79,7 +79,7 @@ 'TrainingImage': image, 'TrainingInputMode': 'File' }, - 'RoleArn': 'string', + 'RoleArn': role, 'InputDataConfig': [ { @@ -102,55 +102,58 @@ }, 'ResourceConfig': { - 'InstanceCount': 2, + 'InstanceCount': '2', 'InstanceType': 'ml.c4.8xlarge', - 'VolumeSizeInGB': 50 + 'VolumeSizeInGB': '50' }, 'StoppingCondition': dict(MaxRuntimeInSeconds=60 * 60) } } -class TestSageMakerTrainingOperator(unittest.TestCase): +class TestSageMakerTuningOperator(unittest.TestCase): def setUp(self): configuration.load_test_config() - self.sagemaker = SageMakerCreateTuningJobOperator( + self.sagemaker = SageMakerTuningOperator( task_id='test_sagemaker_operator', - sagemaker_conn_id='sagemaker_test_conn', - tuning_job_config=create_tuning_params, - region_name='us-east-1', - use_db_config=False, + aws_conn_id='sagemaker_test_conn', + config=create_tuning_params, wait_for_completion=False, check_interval=5 ) - @mock.patch.object(SageMakerHook, 'get_conn') - @mock.patch.object(SageMakerHook, 'create_tuning_job') - @mock.patch.object(SageMakerHook, '__init__') - def test_hook_init(self, hook_init, mock_tuning, mock_client): - mock_tuning.return_value = {'TrainingJobArn': 'testarn', - 'ResponseMetadata': - {'HTTPStatusCode': 200}} - hook_init.return_value = None - self.sagemaker.execute(None) - hook_init.assert_called_once_with( - sagemaker_conn_id='sagemaker_test_conn', - region_name='us-east-1', - use_db_config=False, - check_interval=5, - max_ingestion_time=None - ) + def test_parse_config_integers(self): + self.sagemaker.parse_config_integers() + self.assertEqual(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig'] + ['InstanceCount'], + int(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig'] + ['InstanceCount'])) + self.assertEqual(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig'] + ['VolumeSizeInGB'], + int(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig'] + ['VolumeSizeInGB'])) + self.assertEqual(self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'] + ['MaxNumberOfTrainingJobs'], + int(self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'] + ['MaxNumberOfTrainingJobs'])) + self.assertEqual(self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'] + ['MaxParallelTrainingJobs'], + int(self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'] + ['MaxParallelTrainingJobs'])) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_tuning_job') - def test_execute_without_failure(self, mock_tuning, mock_client): + def test_execute(self, mock_tuning, mock_client): mock_tuning.return_value = {'TrainingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 200}} self.sagemaker.execute(None) mock_tuning.assert_called_once_with(create_tuning_params, - wait_for_completion=False) + wait_for_completion=False, + check_interval=5, + max_ingestion_time=None + ) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_tuning_job') diff --git a/tests/contrib/sensors/test_sagemaker_base_sensor.py b/tests/contrib/sensors/test_sagemaker_base_sensor.py index bc8cbe349858f..5870544838033 100644 --- a/tests/contrib/sensors/test_sagemaker_base_sensor.py +++ b/tests/contrib/sensors/test_sagemaker_base_sensor.py @@ -28,7 +28,7 @@ class TestSagemakerBaseSensor(unittest.TestCase): def setUp(self): configuration.load_test_config() - def test_subclasses_succeed_when_response_is_good(self): + def test_execute(self): class SageMakerBaseSensorSubclass(SageMakerBaseSensor): def non_terminal_states(self): return ['PENDING', 'RUNNING', 'CONTINUE'] @@ -53,7 +53,7 @@ def state_from_response(self, response): sensor.execute(None) - def test_poke_returns_false_when_state_is_a_non_terminal_state(self): + def test_poke_with_unfinished_job(self): class SageMakerBaseSensorSubclass(SageMakerBaseSensor): def non_terminal_states(self): return ['PENDING', 'RUNNING', 'CONTINUE'] @@ -78,7 +78,7 @@ def state_from_response(self, response): self.assertEqual(sensor.poke(None), False) - def test_poke_raise_exception_when_method_not_implemented(self): + def test_poke_with_not_implemented_method(self): class SageMakerBaseSensorSubclass(SageMakerBaseSensor): def non_terminal_states(self): return ['PENDING', 'RUNNING', 'CONTINUE'] @@ -92,9 +92,9 @@ def failed_states(self): aws_conn_id='aws_test' ) - self.assertRaises(AirflowException, sensor.poke, None) + self.assertRaises(NotImplementedError, sensor.poke, None) - def test_poke_returns_false_when_http_response_is_bad(self): + def test_poke_with_bad_response(self): class SageMakerBaseSensorSubclass(SageMakerBaseSensor): def non_terminal_states(self): return ['PENDING', 'RUNNING', 'CONTINUE'] @@ -119,7 +119,7 @@ def state_from_response(self, response): self.assertEqual(sensor.poke(None), False) - def test_poke_raises_error_when_job_has_failed(self): + def test_poke_with_job_failure(self): class SageMakerBaseSensorSubclass(SageMakerBaseSensor): def non_terminal_states(self): return ['PENDING', 'RUNNING', 'CONTINUE'] diff --git a/tests/contrib/sensors/test_sagemaker_training_sensor.py b/tests/contrib/sensors/test_sagemaker_training_sensor.py index fb966f60afbf0..5861d7a6fdb15 100644 --- a/tests/contrib/sensors/test_sagemaker_training_sensor.py +++ b/tests/contrib/sensors/test_sagemaker_training_sensor.py @@ -18,6 +18,7 @@ # under the License. import unittest +from datetime import datetime try: from unittest import mock @@ -30,55 +31,51 @@ from airflow import configuration from airflow.contrib.sensors.sagemaker_training_sensor \ import SageMakerTrainingSensor -from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook, LogState from airflow.exceptions import AirflowException -DESCRIBE_TRAINING_INPROGRESS_RETURN = { - 'TrainingJobStatus': 'InProgress', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } -} -DESCRIBE_TRAINING_COMPELETED_RETURN = { - 'TrainingJobStatus': 'Compeleted', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } -} -DESCRIBE_TRAINING_FAILED_RETURN = { - 'TrainingJobStatus': 'Failed', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, +DESCRIBE_TRAINING_COMPELETED_RESPONSE = { + 'TrainingJobStatus': 'Completed', + 'ResourceConfig': { + 'InstanceCount': 1, + 'InstanceType': 'ml.c4.xlarge', + 'VolumeSizeInGB': 10 }, - 'FailureReason': 'Unknown' -} -DESCRIBE_TRAINING_STOPPING_RETURN = { - 'TrainingJobStatus': 'Stopping', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } -} -DESCRIBE_TRAINING_STOPPED_RETURN = { - 'TrainingJobStatus': 'Stopped', + 'TrainingStartTime': datetime(2018, 2, 17, 7, 15, 0, 103000), + 'TrainingEndTime': datetime(2018, 2, 17, 7, 19, 34, 953000), 'ResponseMetadata': { 'HTTPStatusCode': 200, } } +DESCRIBE_TRAINING_INPROGRESS_RESPONSE = dict(DESCRIBE_TRAINING_COMPELETED_RESPONSE) +DESCRIBE_TRAINING_INPROGRESS_RESPONSE.update({'TrainingJobStatus': 'InProgress'}) + +DESCRIBE_TRAINING_FAILED_RESPONSE = dict(DESCRIBE_TRAINING_COMPELETED_RESPONSE) +DESCRIBE_TRAINING_FAILED_RESPONSE.update({'TrainingJobStatus': 'Failed', + 'FailureReason': 'Unknown'}) + +DESCRIBE_TRAINING_STOPPING_RESPONSE = dict(DESCRIBE_TRAINING_COMPELETED_RESPONSE) +DESCRIBE_TRAINING_STOPPING_RESPONSE.update({'TrainingJobStatus': 'Stopping'}) + class TestSageMakerTrainingSensor(unittest.TestCase): def setUp(self): configuration.load_test_config() @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, '__init__') @mock.patch.object(SageMakerHook, 'describe_training_job') - def test_raises_errors_failed_state(self, mock_describe_job, mock_client): - mock_describe_job.side_effect = [DESCRIBE_TRAINING_FAILED_RETURN] + def test_sensor_with_failure(self, mock_describe_job, hook_init, mock_client): + hook_init.return_value = None + + mock_describe_job.side_effect = [DESCRIBE_TRAINING_FAILED_RESPONSE] sensor = SageMakerTrainingSensor( task_id='test_task', poke_interval=2, aws_conn_id='aws_test', - job_name='test_job_name' + job_name='test_job_name', + print_log=False ) self.assertRaises(AirflowException, sensor.execute, None) mock_describe_job.assert_called_once_with('test_job_name') @@ -86,32 +83,59 @@ def test_raises_errors_failed_state(self, mock_describe_job, mock_client): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, '__init__') @mock.patch.object(SageMakerHook, 'describe_training_job') - def test_calls_until_a_terminal_state(self, - mock_describe_job, hook_init, mock_client): + def test_sensor(self, mock_describe_job, hook_init, mock_client): hook_init.return_value = None mock_describe_job.side_effect = [ - DESCRIBE_TRAINING_INPROGRESS_RETURN, - DESCRIBE_TRAINING_STOPPING_RETURN, - DESCRIBE_TRAINING_STOPPED_RETURN, - DESCRIBE_TRAINING_COMPELETED_RETURN + DESCRIBE_TRAINING_INPROGRESS_RESPONSE, + DESCRIBE_TRAINING_STOPPING_RESPONSE, + DESCRIBE_TRAINING_COMPELETED_RESPONSE ] sensor = SageMakerTrainingSensor( task_id='test_task', poke_interval=2, aws_conn_id='aws_test', job_name='test_job_name', - region_name='us-east-1' + print_log=False ) sensor.execute(None) # make sure we called 4 times(terminated when its compeleted) - self.assertEqual(mock_describe_job.call_count, 4) + self.assertEqual(mock_describe_job.call_count, 3) # make sure the hook was initialized with the specific params - hook_init.assert_called_with(aws_conn_id='aws_test', - region_name='us-east-1') + hook_init.assert_called_with(aws_conn_id='aws_test') + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'get_log_conn') + @mock.patch.object(SageMakerHook, '__init__') + @mock.patch.object(SageMakerHook, 'describe_training_job_with_log') + @mock.patch.object(SageMakerHook, 'describe_training_job') + def test_sensor_with_log(self, mock_describe_job, mock_describe_job_with_log, + hook_init, mock_log_client, mock_client): + hook_init.return_value = None + + mock_describe_job.return_value = DESCRIBE_TRAINING_COMPELETED_RESPONSE + mock_describe_job_with_log.side_effect = [ + (LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RESPONSE, 0), + (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RESPONSE, 0), + (LogState.COMPLETE, DESCRIBE_TRAINING_COMPELETED_RESPONSE, 0) + ] + sensor = SageMakerTrainingSensor( + task_id='test_task', + poke_interval=2, + aws_conn_id='aws_test', + job_name='test_job_name', + print_log=True + ) + + sensor.execute(None) + + self.assertEqual(mock_describe_job_with_log.call_count, 3) + self.assertEqual(mock_describe_job.call_count, 1) + + hook_init.assert_called_with(aws_conn_id='aws_test') if __name__ == '__main__': diff --git a/tests/contrib/sensors/test_sagemaker_transform_sensor.py b/tests/contrib/sensors/test_sagemaker_transform_sensor.py index bb4a184bb2797..1394920d5dc3d 100644 --- a/tests/contrib/sensors/test_sagemaker_transform_sensor.py +++ b/tests/contrib/sensors/test_sagemaker_transform_sensor.py @@ -33,37 +33,31 @@ from airflow.contrib.hooks.sagemaker_hook import SageMakerHook from airflow.exceptions import AirflowException -DESCRIBE_TRANSFORM_INPROGRESS_RETURN = { +DESCRIBE_TRANSFORM_INPROGRESS_RESPONSE = { 'TransformJobStatus': 'InProgress', 'ResponseMetadata': { 'HTTPStatusCode': 200, } } -DESCRIBE_TRANSFORM_COMPELETED_RETURN = { +DESCRIBE_TRANSFORM_COMPELETED_RESPONSE = { 'TransformJobStatus': 'Compeleted', 'ResponseMetadata': { 'HTTPStatusCode': 200, } } -DESCRIBE_TRANSFORM_FAILED_RETURN = { +DESCRIBE_TRANSFORM_FAILED_RESPONSE = { 'TransformJobStatus': 'Failed', 'ResponseMetadata': { 'HTTPStatusCode': 200, }, 'FailureReason': 'Unknown' } -DESCRIBE_TRANSFORM_STOPPING_RETURN = { +DESCRIBE_TRANSFORM_STOPPING_RESPONSE = { 'TransformJobStatus': 'Stopping', 'ResponseMetadata': { 'HTTPStatusCode': 200, } } -DESCRIBE_TRANSFORM_STOPPED_RETURN = { - 'TransformJobStatus': 'Stopped', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } -} class TestSageMakerTransformSensor(unittest.TestCase): @@ -72,8 +66,8 @@ def setUp(self): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'describe_transform_job') - def test_raises_errors_failed_state(self, mock_describe_job, mock_client): - mock_describe_job.side_effect = [DESCRIBE_TRANSFORM_FAILED_RETURN] + def test_sensor_with_failure(self, mock_describe_job, mock_client): + mock_describe_job.side_effect = [DESCRIBE_TRANSFORM_FAILED_RESPONSE] sensor = SageMakerTransformSensor( task_id='test_task', poke_interval=2, @@ -86,32 +80,28 @@ def test_raises_errors_failed_state(self, mock_describe_job, mock_client): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, '__init__') @mock.patch.object(SageMakerHook, 'describe_transform_job') - def test_calls_until_a_terminal_state(self, - mock_describe_job, hook_init, mock_client): + def test_sensor(self, mock_describe_job, hook_init, mock_client): hook_init.return_value = None mock_describe_job.side_effect = [ - DESCRIBE_TRANSFORM_INPROGRESS_RETURN, - DESCRIBE_TRANSFORM_STOPPING_RETURN, - DESCRIBE_TRANSFORM_STOPPED_RETURN, - DESCRIBE_TRANSFORM_COMPELETED_RETURN + DESCRIBE_TRANSFORM_INPROGRESS_RESPONSE, + DESCRIBE_TRANSFORM_STOPPING_RESPONSE, + DESCRIBE_TRANSFORM_COMPELETED_RESPONSE ] sensor = SageMakerTransformSensor( task_id='test_task', poke_interval=2, aws_conn_id='aws_test', - job_name='test_job_name', - region_name='us-east-1' + job_name='test_job_name' ) sensor.execute(None) # make sure we called 4 times(terminated when its compeleted) - self.assertEqual(mock_describe_job.call_count, 4) + self.assertEqual(mock_describe_job.call_count, 3) # make sure the hook was initialized with the specific params - hook_init.assert_called_with(aws_conn_id='aws_test', - region_name='us-east-1') + hook_init.assert_called_with(aws_conn_id='aws_test') if __name__ == '__main__': diff --git a/tests/contrib/sensors/test_sagemaker_tuning_sensor.py b/tests/contrib/sensors/test_sagemaker_tuning_sensor.py index 49f9b41b07c89..8c0ba11380c1a 100644 --- a/tests/contrib/sensors/test_sagemaker_tuning_sensor.py +++ b/tests/contrib/sensors/test_sagemaker_tuning_sensor.py @@ -33,37 +33,34 @@ from airflow.contrib.hooks.sagemaker_hook import SageMakerHook from airflow.exceptions import AirflowException -DESCRIBE_TUNING_INPROGRESS_RETURN = { +DESCRIBE_TUNING_INPROGRESS_RESPONSE = { 'HyperParameterTuningJobStatus': 'InProgress', 'ResponseMetadata': { 'HTTPStatusCode': 200, } } -DESCRIBE_TUNING_COMPELETED_RETURN = { + +DESCRIBE_TUNING_COMPELETED_RESPONSE = { 'HyperParameterTuningJobStatus': 'Compeleted', 'ResponseMetadata': { 'HTTPStatusCode': 200, } } -DESCRIBE_TUNING_FAILED_RETURN = { + +DESCRIBE_TUNING_FAILED_RESPONSE = { 'HyperParameterTuningJobStatus': 'Failed', 'ResponseMetadata': { 'HTTPStatusCode': 200, }, 'FailureReason': 'Unknown' } -DESCRIBE_TUNING_STOPPING_RETURN = { + +DESCRIBE_TUNING_STOPPING_RESPONSE = { 'HyperParameterTuningJobStatus': 'Stopping', 'ResponseMetadata': { 'HTTPStatusCode': 200, } } -DESCRIBE_TUNING_STOPPED_RETURN = { - 'HyperParameterTuningJobStatus': 'Stopped', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } -} class TestSageMakerTuningSensor(unittest.TestCase): @@ -72,8 +69,8 @@ def setUp(self): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'describe_tuning_job') - def test_raises_errors_failed_state(self, mock_describe_job, mock_client): - mock_describe_job.side_effect = [DESCRIBE_TUNING_FAILED_RETURN] + def test_sensor_with_failure(self, mock_describe_job, mock_client): + mock_describe_job.side_effect = [DESCRIBE_TUNING_FAILED_RESPONSE] sensor = SageMakerTuningSensor( task_id='test_task', poke_interval=2, @@ -86,32 +83,28 @@ def test_raises_errors_failed_state(self, mock_describe_job, mock_client): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, '__init__') @mock.patch.object(SageMakerHook, 'describe_tuning_job') - def test_calls_until_a_terminal_state(self, - mock_describe_job, hook_init, mock_client): + def test_sensor(self, mock_describe_job, hook_init, mock_client): hook_init.return_value = None mock_describe_job.side_effect = [ - DESCRIBE_TUNING_INPROGRESS_RETURN, - DESCRIBE_TUNING_STOPPING_RETURN, - DESCRIBE_TUNING_STOPPED_RETURN, - DESCRIBE_TUNING_COMPELETED_RETURN + DESCRIBE_TUNING_INPROGRESS_RESPONSE, + DESCRIBE_TUNING_STOPPING_RESPONSE, + DESCRIBE_TUNING_COMPELETED_RESPONSE ] sensor = SageMakerTuningSensor( task_id='test_task', poke_interval=2, aws_conn_id='aws_test', - job_name='test_job_name', - region_name='us-east-1' + job_name='test_job_name' ) sensor.execute(None) # make sure we called 4 times(terminated when its compeleted) - self.assertEqual(mock_describe_job.call_count, 4) + self.assertEqual(mock_describe_job.call_count, 3) # make sure the hook was initialized with the specific params - hook_init.assert_called_with(aws_conn_id='aws_test', - region_name='us-east-1') + hook_init.assert_called_with(aws_conn_id='aws_test') if __name__ == '__main__': diff --git a/tests/hooks/test_s3_hook.py b/tests/hooks/test_s3_hook.py index 27dfba49a5197..e0f9e8a3eca53 100644 --- a/tests/hooks/test_s3_hook.py +++ b/tests/hooks/test_s3_hook.py @@ -19,6 +19,7 @@ # import mock +import tempfile import unittest from botocore.exceptions import NoCredentialsError @@ -74,6 +75,31 @@ def test_get_bucket(self): b = hook.get_bucket('bucket') self.assertIsNotNone(b) + @mock_s3 + def test_create_bucket_default_region(self): + hook = S3Hook(aws_conn_id=None) + hook.create_bucket(bucket_name='new_bucket') + b = hook.get_bucket('new_bucket') + self.assertIsNotNone(b) + + @mock_s3 + def test_create_bucket_us_standard_region(self): + hook = S3Hook(aws_conn_id=None) + hook.create_bucket(bucket_name='new_bucket', region_name='us-east-1') + b = hook.get_bucket('new_bucket') + self.assertIsNotNone(b) + region = b.meta.client.get_bucket_location(Bucket=b.name).get('LocationConstraint', None) + self.assertEqual(region, 'us-east-1') + + @mock_s3 + def test_create_bucket_other_region(self): + hook = S3Hook(aws_conn_id=None) + hook.create_bucket(bucket_name='new_bucket', region_name='us-east-2') + b = hook.get_bucket('new_bucket') + self.assertIsNotNone(b) + region = b.meta.client.get_bucket_location(Bucket=b.name).get('LocationConstraint', None) + self.assertEqual(region, 'us-east-2') + @mock_s3 def test_check_for_prefix(self): hook = S3Hook(aws_conn_id=None) @@ -255,6 +281,21 @@ def test_load_bytes(self): self.assertEqual(body, b'Content') + @mock_s3 + def test_load_fileobj(self): + hook = S3Hook(aws_conn_id=None) + conn = hook.get_conn() + # We need to create the bucket since this is all in Moto's 'virtual' + # AWS account + conn.create_bucket(Bucket="mybucket") + with tempfile.TemporaryFile() as temp_file: + temp_file.write(b"Content") + temp_file.seek(0) + hook.load_file_obj(temp_file, "my_key", "mybucket") + body = boto3.resource('s3').Object('mybucket', 'my_key').get()['Body'].read() + + self.assertEqual(body, b'Content') + if __name__ == '__main__': unittest.main() diff --git a/tests/operators/test_sagemaker_create_transform_job_operator.py b/tests/operators/test_sagemaker_create_transform_job_operator.py deleted file mode 100644 index a8701530d9daa..0000000000000 --- a/tests/operators/test_sagemaker_create_transform_job_operator.py +++ /dev/null @@ -1,140 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import unittest -try: - from unittest import mock -except ImportError: - try: - import mock - except ImportError: - mock = None - -from airflow import configuration -from airflow.contrib.hooks.sagemaker_hook import SageMakerHook -from airflow.contrib.operators.sagemaker_create_transform_job_operator \ - import SageMakerCreateTransformJobOperator -from airflow.exceptions import AirflowException - -role = 'test-role' - -bucket = 'test-bucket' - -key = 'test/data' -data_url = 's3://{}/{}'.format(bucket, key) - -job_name = 'test-job-name' - -model_name = 'test-model-name' - -image = 'test-image' - -output_url = 's3://{}/test/output'.format(bucket) - -create_transform_params = \ - { - 'TransformJobName': job_name, - 'ModelName': model_name, - 'BatchStrategy': 'MultiRecord', - 'TransformInput': { - 'DataSource': { - 'S3DataSource': { - 'S3DataType': 'S3Prefix', - 'S3Uri': data_url - } - } - }, - 'TransformOutput': { - 'S3OutputPath': output_url, - }, - 'TransformResources': { - 'InstanceType': 'ml.m4.xlarge', - 'InstanceCount': 123 - } - } - -create_model_params = \ - { - 'ModelName': model_name, - 'PrimaryContainer': { - 'Image': image, - 'ModelDataUrl': output_url, - }, - 'ExecutionRoleArn': role - } - - -class TestSageMakertransformOperator(unittest.TestCase): - - def setUp(self): - configuration.load_test_config() - self.sagemaker = SageMakerCreateTransformJobOperator( - task_id='test_sagemaker_operator', - sagemaker_conn_id='sagemaker_test_id', - transform_job_config=create_transform_params, - model_config=create_model_params, - region_name='us-west-2', - use_db_config=True, - wait_for_completion=False, - check_interval=5 - ) - - @mock.patch.object(SageMakerHook, 'get_conn') - @mock.patch.object(SageMakerHook, 'create_model') - @mock.patch.object(SageMakerHook, 'create_transform_job') - @mock.patch.object(SageMakerHook, '__init__') - def test_hook_init(self, hook_init, mock_transform, mock_model, mock_client): - mock_transform.return_value = {"TransformJobArn": "testarn", - "ResponseMetadata": - {"HTTPStatusCode": 200}} - hook_init.return_value = None - self.sagemaker.execute(None) - hook_init.assert_called_once_with( - sagemaker_conn_id='sagemaker_test_id', - region_name='us-west-2', - use_db_config=True, - check_interval=5, - max_ingestion_time=None - ) - - @mock.patch.object(SageMakerHook, 'get_conn') - @mock.patch.object(SageMakerHook, 'create_model') - @mock.patch.object(SageMakerHook, 'create_transform_job') - def test_execute_without_failure(self, mock_transform, mock_model, mock_client): - mock_transform.return_value = {"TransformJobArn": "testarn", - "ResponseMetadata": - {"HTTPStatusCode": 200}} - self.sagemaker.execute(None) - mock_model.assert_called_once_with(create_model_params) - mock_transform.assert_called_once_with(create_transform_params, - wait_for_completion=False - ) - - @mock.patch.object(SageMakerHook, 'get_conn') - @mock.patch.object(SageMakerHook, 'create_model') - @mock.patch.object(SageMakerHook, 'create_transform_job') - def test_execute_with_failure(self, mock_transform, mock_model, mock_client): - mock_transform.return_value = {"TransformJobArn": "testarn", - "ResponseMetadata": - {"HTTPStatusCode": 404}} - self.assertRaises(AirflowException, self.sagemaker.execute, None) - - -if __name__ == '__main__': - unittest.main()