diff --git a/luigi/contrib/hadoop.py b/luigi/contrib/hadoop.py index fbf2b2e26d..0abe156638 100644 --- a/luigi/contrib/hadoop.py +++ b/luigi/contrib/hadoop.py @@ -51,8 +51,8 @@ import luigi.task import luigi.contrib.gcs import luigi.contrib.hdfs -import luigi.s3 -from luigi import mrrunner +import luigi.contrib.s3 +from luigi.contrib import mrrunner if six.PY2: from itertools import imap as map @@ -460,7 +460,7 @@ def run_job(self, job, tracking_url_callback=None): # atomic output: replace output with a temporary work directory if self.end_job_with_atomic_move_dir: illegal_targets = ( - luigi.s3.S3FlagTarget, luigi.contrib.gcs.GCSFlagTarget) + luigi.contrib.s3.S3FlagTarget, luigi.contrib.gcs.GCSFlagTarget) if isinstance(job.output(), illegal_targets): raise TypeError("end_job_with_atomic_move_dir is not supported" " for {}".format(illegal_targets)) @@ -533,7 +533,7 @@ def run_job(self, job, tracking_url_callback=None): allowed_input_targets = ( luigi.contrib.hdfs.HdfsTarget, - luigi.s3.S3Target, + luigi.contrib.s3.S3Target, luigi.contrib.gcs.GCSTarget) for target in luigi.task.flatten(job.input_hadoop()): if not isinstance(target, allowed_input_targets): @@ -543,7 +543,7 @@ def run_job(self, job, tracking_url_callback=None): allowed_output_targets = ( luigi.contrib.hdfs.HdfsTarget, - luigi.s3.S3FlagTarget, + luigi.contrib.s3.S3FlagTarget, luigi.contrib.gcs.GCSFlagTarget) if not isinstance(job.output(), allowed_output_targets): raise TypeError('output must be one of: {}'.format( diff --git a/luigi/mrrunner.py b/luigi/contrib/mrrunner.py similarity index 92% rename from luigi/mrrunner.py rename to luigi/contrib/mrrunner.py index 6d9412c626..b446c6e860 100644 --- a/luigi/mrrunner.py +++ b/luigi/contrib/mrrunner.py @@ -17,6 +17,10 @@ # """ +Since after Luigi 2.5.0, this is a private module to Luigi. Luigi users should +not rely on that importing this module works. Furthermore, "luigi mr streaming" +have been greatly superseeded by technoligies like Spark, Hive, etc. + The hadoop runner. This module contains the main() method which will be used to run the diff --git a/luigi/contrib/postgres.py b/luigi/contrib/postgres.py new file mode 100644 index 0000000000..a00c6d574a --- /dev/null +++ b/luigi/contrib/postgres.py @@ -0,0 +1,390 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Implements a subclass of :py:class:`~luigi.target.Target` that writes data to Postgres. +Also provides a helper task to copy data into a Postgres table. +""" + +import datetime +import logging +import re +import tempfile + +from luigi import six + +import luigi +from luigi.contrib import rdbms + +logger = logging.getLogger('luigi-interface') + +try: + import psycopg2 + import psycopg2.errorcodes + import psycopg2.extensions +except ImportError: + logger.warning("Loading postgres module without psycopg2 installed. Will crash at runtime if postgres functionality is used.") + + +class MultiReplacer(object): + """ + Object for one-pass replace of multiple words + + Substituted parts will not be matched against other replace patterns, as opposed to when using multipass replace. + The order of the items in the replace_pairs input will dictate replacement precedence. + + Constructor arguments: + replace_pairs -- list of 2-tuples which hold strings to be replaced and replace string + + Usage: + + .. code-block:: python + + >>> replace_pairs = [("a", "b"), ("b", "c")] + >>> MultiReplacer(replace_pairs)("abcd") + 'bccd' + >>> replace_pairs = [("ab", "x"), ("a", "x")] + >>> MultiReplacer(replace_pairs)("ab") + 'x' + >>> replace_pairs.reverse() + >>> MultiReplacer(replace_pairs)("ab") + 'xb' + """ +# TODO: move to misc/util module + + def __init__(self, replace_pairs): + """ + Initializes a MultiReplacer instance. + + :param replace_pairs: list of 2-tuples which hold strings to be replaced and replace string. + :type replace_pairs: tuple + """ + replace_list = list(replace_pairs) # make a copy in case input is iterable + self._replace_dict = dict(replace_list) + pattern = '|'.join(re.escape(x) for x, y in replace_list) + self._search_re = re.compile(pattern) + + def _replacer(self, match_object): + # this method is used as the replace function in the re.sub below + return self._replace_dict[match_object.group()] + + def __call__(self, search_string): + # using function replacing for a per-result replace + return self._search_re.sub(self._replacer, search_string) + + +# these are the escape sequences recognized by postgres COPY +# according to http://www.postgresql.org/docs/8.1/static/sql-copy.html +default_escape = MultiReplacer([('\\', '\\\\'), + ('\t', '\\t'), + ('\n', '\\n'), + ('\r', '\\r'), + ('\v', '\\v'), + ('\b', '\\b'), + ('\f', '\\f') + ]) + + +class PostgresTarget(luigi.Target): + """ + Target for a resource in Postgres. + + This will rarely have to be directly instantiated by the user. + """ + marker_table = luigi.configuration.get_config().get('postgres', 'marker-table', 'table_updates') + + # Use DB side timestamps or client side timestamps in the marker_table + use_db_timestamps = True + + def __init__( + self, host, database, user, password, table, update_id, port=None + ): + """ + Args: + host (str): Postgres server address. Possibly a host:port string. + database (str): Database name + user (str): Database user + password (str): Password for specified user + update_id (str): An identifier for this data set + port (int): Postgres server port. + + """ + if ':' in host: + self.host, self.port = host.split(':') + else: + self.host = host + self.port = port + self.database = database + self.user = user + self.password = password + self.table = table + self.update_id = update_id + + def touch(self, connection=None): + """ + Mark this update as complete. + + Important: If the marker table doesn't exist, the connection transaction will be aborted + and the connection reset. + Then the marker table will be created. + """ + self.create_marker_table() + + if connection is None: + # TODO: test this + connection = self.connect() + connection.autocommit = True # if connection created here, we commit it here + + if self.use_db_timestamps: + connection.cursor().execute( + """INSERT INTO {marker_table} (update_id, target_table) + VALUES (%s, %s) + """.format(marker_table=self.marker_table), + (self.update_id, self.table)) + else: + connection.cursor().execute( + """INSERT INTO {marker_table} (update_id, target_table, inserted) + VALUES (%s, %s, %s); + """.format(marker_table=self.marker_table), + (self.update_id, self.table, + datetime.datetime.now())) + + # make sure update is properly marked + assert self.exists(connection) + + def exists(self, connection=None): + if connection is None: + connection = self.connect() + connection.autocommit = True + cursor = connection.cursor() + try: + cursor.execute("""SELECT 1 FROM {marker_table} + WHERE update_id = %s + LIMIT 1""".format(marker_table=self.marker_table), + (self.update_id,) + ) + row = cursor.fetchone() + except psycopg2.ProgrammingError as e: + if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE: + row = None + else: + raise + return row is not None + + def connect(self): + """ + Get a psycopg2 connection object to the database where the table is. + """ + connection = psycopg2.connect( + host=self.host, + port=self.port, + database=self.database, + user=self.user, + password=self.password) + connection.set_client_encoding('utf-8') + return connection + + def create_marker_table(self): + """ + Create marker table if it doesn't exist. + + Using a separate connection since the transaction might have to be reset. + """ + connection = self.connect() + connection.autocommit = True + cursor = connection.cursor() + if self.use_db_timestamps: + sql = """ CREATE TABLE {marker_table} ( + update_id TEXT PRIMARY KEY, + target_table TEXT, + inserted TIMESTAMP DEFAULT NOW()) + """.format(marker_table=self.marker_table) + else: + sql = """ CREATE TABLE {marker_table} ( + update_id TEXT PRIMARY KEY, + target_table TEXT, + inserted TIMESTAMP); + """.format(marker_table=self.marker_table) + try: + cursor.execute(sql) + except psycopg2.ProgrammingError as e: + if e.pgcode == psycopg2.errorcodes.DUPLICATE_TABLE: + pass + else: + raise + connection.close() + + def open(self, mode): + raise NotImplementedError("Cannot open() PostgresTarget") + + +class CopyToTable(rdbms.CopyToTable): + """ + Template task for inserting a data set into Postgres + + Usage: + Subclass and override the required `host`, `database`, `user`, + `password`, `table` and `columns` attributes. + + To customize how to access data from an input task, override the `rows` method + with a generator that yields each row as a tuple with fields ordered according to `columns`. + """ + + def rows(self): + """ + Return/yield tuples or lists corresponding to each row to be inserted. + """ + with self.input().open('r') as fobj: + for line in fobj: + yield line.strip('\n').split('\t') + + def map_column(self, value): + """ + Applied to each column of every row returned by `rows`. + + Default behaviour is to escape special characters and identify any self.null_values. + """ + if value in self.null_values: + return r'\\N' + else: + return default_escape(six.text_type(value)) + +# everything below will rarely have to be overridden + + def output(self): + """ + Returns a PostgresTarget representing the inserted dataset. + + Normally you don't override this. + """ + return PostgresTarget( + host=self.host, + database=self.database, + user=self.user, + password=self.password, + table=self.table, + update_id=self.update_id + ) + + def copy(self, cursor, file): + if isinstance(self.columns[0], six.string_types): + column_names = self.columns + elif len(self.columns[0]) == 2: + column_names = [c[0] for c in self.columns] + else: + raise Exception('columns must consist of column strings or (column string, type string) tuples (was %r ...)' % (self.columns[0],)) + cursor.copy_from(file, self.table, null=r'\\N', sep=self.column_separator, columns=column_names) + + def run(self): + """ + Inserts data generated by rows() into target table. + + If the target table doesn't exist, self.create_table will be called to attempt to create the table. + + Normally you don't want to override this. + """ + if not (self.table and self.columns): + raise Exception("table and columns need to be specified") + + connection = self.output().connect() + # transform all data generated by rows() using map_column and write data + # to a temporary file for import using postgres COPY + tmp_dir = luigi.configuration.get_config().get('postgres', 'local-tmp-dir', None) + tmp_file = tempfile.TemporaryFile(dir=tmp_dir) + n = 0 + for row in self.rows(): + n += 1 + if n % 100000 == 0: + logger.info("Wrote %d lines", n) + rowstr = self.column_separator.join(self.map_column(val) for val in row) + rowstr += "\n" + tmp_file.write(rowstr.encode('utf-8')) + + logger.info("Done writing, importing at %s", datetime.datetime.now()) + tmp_file.seek(0) + + # attempt to copy the data into postgres + # if it fails because the target table doesn't exist + # try to create it by running self.create_table + for attempt in range(2): + try: + cursor = connection.cursor() + self.init_copy(connection) + self.copy(cursor, tmp_file) + self.post_copy(connection) + except psycopg2.ProgrammingError as e: + if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE and attempt == 0: + # if first attempt fails with "relation not found", try creating table + logger.info("Creating table %s", self.table) + connection.reset() + self.create_table(connection) + else: + raise + else: + break + + # mark as complete in same transaction + self.output().touch(connection) + + # commit and clean up + connection.commit() + connection.close() + tmp_file.close() + + +class PostgresQuery(rdbms.Query): + """ + Template task for querying a Postgres compatible database + + Usage: + Subclass and override the required `host`, `database`, `user`, `password`, `table`, and `query` attributes. + + Override the `run` method if your use case requires some action with the query result. + + Task instances require a dynamic `update_id`, e.g. via parameter(s), otherwise the query will only execute once + + To customize the query signature as recorded in the database marker table, override the `update_id` property. + """ + + def run(self): + connection = self.output().connect() + cursor = connection.cursor() + sql = self.query + + logger.info('Executing query from task: {name}'.format(name=self.__class__)) + cursor.execute(sql) + + # Update marker table + self.output().touch(connection) + + # commit and close connection + connection.commit() + connection.close() + + def output(self): + """ + Returns a PostgresTarget representing the executed query. + + Normally you don't override this. + """ + return PostgresTarget( + host=self.host, + database=self.database, + user=self.user, + password=self.password, + table=self.table, + update_id=self.update_id + ) diff --git a/luigi/contrib/s3.py b/luigi/contrib/s3.py new file mode 100644 index 0000000000..d01995fb79 --- /dev/null +++ b/luigi/contrib/s3.py @@ -0,0 +1,794 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Implementation of Simple Storage Service support. +:py:class:`S3Target` is a subclass of the Target class to support S3 file +system operations. The `boto` library is required to use S3 targets. +""" + +from __future__ import division + +import datetime +import itertools +import logging +import os +import os.path + +import time +from multiprocessing.pool import ThreadPool + +try: + from urlparse import urlsplit +except ImportError: + from urllib.parse import urlsplit +import warnings + +try: + from ConfigParser import NoSectionError +except ImportError: + from configparser import NoSectionError + +from luigi import six +from luigi.six.moves import range + +from luigi import configuration +from luigi.format import get_default_format +from luigi.parameter import Parameter +from luigi.target import FileAlreadyExists, FileSystem, FileSystemException, FileSystemTarget, AtomicLocalFile, MissingParentDirectory +from luigi.task import ExternalTask + +logger = logging.getLogger('luigi-interface') + + +# two different ways of marking a directory +# with a suffix in S3 +S3_DIRECTORY_MARKER_SUFFIX_0 = '_$folder$' +S3_DIRECTORY_MARKER_SUFFIX_1 = '/' + + +class InvalidDeleteException(FileSystemException): + pass + + +class FileNotFoundException(FileSystemException): + pass + + +class S3Client(FileSystem): + """ + boto-powered S3 client. + """ + + def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, + **kwargs): + # only import boto when needed to allow top-lvl s3 module import + import boto + import boto.s3.connection + from boto.s3.key import Key + + options = self._get_s3_config() + options.update(kwargs) + # Removing key args would break backwards compability + role_arn = options.get('aws_role_arn') + role_session_name = options.get('aws_role_session_name') + + aws_session_token = None + + if role_arn and role_session_name: + from boto import sts + + sts_client = sts.STSConnection() + assumed_role = sts_client.assume_role(role_arn, role_session_name) + aws_secret_access_key = assumed_role.credentials.secret_key + aws_access_key_id = assumed_role.credentials.access_key + aws_session_token = assumed_role.credentials.session_token + + else: + if not aws_access_key_id: + aws_access_key_id = options.get('aws_access_key_id') + + if not aws_secret_access_key: + aws_secret_access_key = options.get('aws_secret_access_key') + + for key in ['aws_access_key_id', 'aws_secret_access_key', 'aws_role_session_name', 'aws_role_arn']: + if key in options: + options.pop(key) + + self.s3 = boto.s3.connection.S3Connection(aws_access_key_id, + aws_secret_access_key, + security_token=aws_session_token, + **options) + self.Key = Key + + def exists(self, path): + """ + Does provided path exist on S3? + """ + (bucket, key) = self._path_to_bucket_and_key(path) + + # grab and validate the bucket + s3_bucket = self.s3.get_bucket(bucket, validate=True) + + # root always exists + if self._is_root(key): + return True + + # file + s3_key = s3_bucket.get_key(key) + if s3_key: + return True + + if self.isdir(path): + return True + + logger.debug('Path %s does not exist', path) + return False + + def remove(self, path, recursive=True): + """ + Remove a file or directory from S3. + """ + if not self.exists(path): + logger.debug('Could not delete %s; path does not exist', path) + return False + + (bucket, key) = self._path_to_bucket_and_key(path) + + # root + if self._is_root(key): + raise InvalidDeleteException('Cannot delete root of bucket at path %s' % path) + + # grab and validate the bucket + s3_bucket = self.s3.get_bucket(bucket, validate=True) + + # file + s3_key = s3_bucket.get_key(key) + if s3_key: + s3_bucket.delete_key(s3_key) + logger.debug('Deleting %s from bucket %s', key, bucket) + return True + + if self.isdir(path) and not recursive: + raise InvalidDeleteException('Path %s is a directory. Must use recursive delete' % path) + + delete_key_list = [ + k for k in s3_bucket.list(self._add_path_delimiter(key))] + + # delete the directory marker file if it exists + s3_dir_with_suffix_key = s3_bucket.get_key(key + S3_DIRECTORY_MARKER_SUFFIX_0) + if s3_dir_with_suffix_key: + delete_key_list.append(s3_dir_with_suffix_key) + + if len(delete_key_list) > 0: + for k in delete_key_list: + logger.debug('Deleting %s from bucket %s', k, bucket) + s3_bucket.delete_keys(delete_key_list) + return True + + return False + + def get_key(self, path): + """ + Returns just the key from the path. + + An s3 path is composed of a bucket and a key. + + Suppose we have a path `s3://my_bucket/some/files/my_file`. The key is `some/files/my_file`. + """ + (bucket, key) = self._path_to_bucket_and_key(path) + + s3_bucket = self.s3.get_bucket(bucket, validate=True) + + return s3_bucket.get_key(key) + + def put(self, local_path, destination_s3_path, **kwargs): + """ + Put an object stored locally to an S3 path. + + :param kwargs: Keyword arguments are passed to the boto function `set_contents_from_filename` + """ + (bucket, key) = self._path_to_bucket_and_key(destination_s3_path) + + # grab and validate the bucket + s3_bucket = self.s3.get_bucket(bucket, validate=True) + + # put the file + s3_key = self.Key(s3_bucket) + s3_key.key = key + s3_key.set_contents_from_filename(local_path, **kwargs) + + def put_string(self, content, destination_s3_path, **kwargs): + """ + Put a string to an S3 path. + + :param kwargs: Keyword arguments are passed to the boto function `set_contents_from_string` + """ + (bucket, key) = self._path_to_bucket_and_key(destination_s3_path) + # grab and validate the bucket + s3_bucket = self.s3.get_bucket(bucket, validate=True) + + # put the content + s3_key = self.Key(s3_bucket) + s3_key.key = key + s3_key.set_contents_from_string(content, **kwargs) + + def put_multipart(self, local_path, destination_s3_path, part_size=67108864, **kwargs): + """ + Put an object stored locally to an S3 path + using S3 multi-part upload (for files > 5GB). + + :param local_path: Path to source local file + :param destination_s3_path: URL for target S3 location + :param part_size: Part size in bytes. Default: 67108864 (64MB), must be >= 5MB and <= 5 GB. + :param kwargs: Keyword arguments are passed to the boto function `initiate_multipart_upload` + """ + # calculate number of parts to upload + # based on the size of the file + source_size = os.stat(local_path).st_size + + if source_size <= part_size: + # fallback to standard, non-multipart strategy + return self.put(local_path, destination_s3_path, **kwargs) + + (bucket, key) = self._path_to_bucket_and_key(destination_s3_path) + + # grab and validate the bucket + s3_bucket = self.s3.get_bucket(bucket, validate=True) + + # calculate the number of parts (int division). + # use modulo to avoid float precision issues + # for exactly-sized fits + num_parts = (source_size + part_size - 1) // part_size + + mp = None + try: + mp = s3_bucket.initiate_multipart_upload(key, **kwargs) + + for i in range(num_parts): + # upload a part at a time to S3 + offset = part_size * i + bytes = min(part_size, source_size - offset) + with open(local_path, 'rb') as fp: + part_num = i + 1 + logger.info('Uploading part %s/%s to %s', part_num, num_parts, destination_s3_path) + fp.seek(offset) + mp.upload_part_from_file(fp, part_num=part_num, size=bytes) + + # finish the upload, making the file available in S3 + mp.complete_upload() + except BaseException: + if mp: + logger.info('Canceling multipart s3 upload for %s', destination_s3_path) + # cancel the upload so we don't get charged for + # storage consumed by uploaded parts + mp.cancel_upload() + raise + + def get(self, s3_path, destination_local_path): + """ + Get an object stored in S3 and write it to a local path. + """ + (bucket, key) = self._path_to_bucket_and_key(s3_path) + + # grab and validate the bucket + s3_bucket = self.s3.get_bucket(bucket, validate=True) + + # download the file + s3_key = self.Key(s3_bucket) + s3_key.key = key + s3_key.get_contents_to_filename(destination_local_path) + + def get_as_string(self, s3_path): + """ + Get the contents of an object stored in S3 as a string. + """ + (bucket, key) = self._path_to_bucket_and_key(s3_path) + + # grab and validate the bucket + s3_bucket = self.s3.get_bucket(bucket, validate=True) + + # get the content + s3_key = self.Key(s3_bucket) + s3_key.key = key + contents = s3_key.get_contents_as_string() + + return contents + + def copy(self, source_path, destination_path, threads=100, start_time=None, end_time=None, part_size=67108864, **kwargs): + """ + Copy object(s) from one S3 location to another. Works for individual keys or entire directories. + + When files are larger than `part_size`, multipart uploading will be used. + + :param source_path: The `s3://` path of the directory or key to copy from + :param destination_path: The `s3://` path of the directory or key to copy to + :param threads: Optional argument to define the number of threads to use when copying (min: 3 threads) + :param start_time: Optional argument to copy files with modified dates after start_time + :param end_time: Optional argument to copy files with modified dates before end_time + :param part_size: Part size in bytes. Default: 67108864 (64MB), must be >= 5MB and <= 5 GB. + :param kwargs: Keyword arguments are passed to the boto function `copy_key` + + :returns tuple (number_of_files_copied, total_size_copied_in_bytes) + """ + start = datetime.datetime.now() + + (src_bucket, src_key) = self._path_to_bucket_and_key(source_path) + (dst_bucket, dst_key) = self._path_to_bucket_and_key(destination_path) + + # As the S3 copy command is completely server side, there is no issue with issuing a lot of threads + # to issue a single API call per copy, however, this may in theory cause issues on systems with low ulimits for + # number of threads when copying really large files (e.g. with a ~100GB file this will open ~1500 + # threads), or large directories. Around 100 threads seems to work well. + + threads = 3 if threads < 3 else threads # don't allow threads to be less than 3 + total_keys = 0 + + copy_pool = ThreadPool(processes=threads) + + if self.isdir(source_path): + # The management pool is to ensure that there's no deadlock between the s3 copying threads, and the + # multipart_copy threads that monitors each group of s3 copy threads and returns a success once the entire file + # is copied. Without this, we could potentially fill up the pool with threads waiting to check if the s3 copies + # have completed, leaving no available threads to actually perform any copying. + copy_jobs = [] + management_pool = ThreadPool(processes=threads) + + (bucket, key) = self._path_to_bucket_and_key(source_path) + key_path = self._add_path_delimiter(key) + key_path_len = len(key_path) + + total_size_bytes = 0 + src_prefix = self._add_path_delimiter(src_key) + dst_prefix = self._add_path_delimiter(dst_key) + for item in self.list(source_path, start_time=start_time, end_time=end_time, return_key=True): + path = item.key[key_path_len:] + # prevents copy attempt of empty key in folder + if path != '' and path != '/': + total_keys += 1 + total_size_bytes += item.size + job = management_pool.apply_async(self.__copy_multipart, + args=(copy_pool, + src_bucket, src_prefix + path, + dst_bucket, dst_prefix + path, + part_size), + kwds=kwargs) + copy_jobs.append(job) + + # Wait for the pools to finish scheduling all the copies + management_pool.close() + management_pool.join() + copy_pool.close() + copy_pool.join() + + # Raise any errors encountered in any of the copy processes + for result in copy_jobs: + result.get() + + end = datetime.datetime.now() + duration = end - start + logger.info('%s : Complete : %s total keys copied in %s' % + (datetime.datetime.now(), total_keys, duration)) + + return total_keys, total_size_bytes + + # If the file isn't a directory just perform a simple copy + else: + self.__copy_multipart(copy_pool, src_bucket, src_key, dst_bucket, dst_key, part_size, **kwargs) + # Close the pool + copy_pool.close() + copy_pool.join() + + def __copy_multipart(self, pool, src_bucket, src_key, dst_bucket, dst_key, part_size, **kwargs): + """ + Copy a single S3 object to another S3 object, falling back to multipart copy where necessary + + NOTE: This is a private method and should only be called from within the `luigi.s3.copy` method + + :param pool: The threadpool to put the s3 copy processes onto + :param src_bucket: source bucket name + :param src_key: source key name + :param dst_bucket: destination bucket name + :param dst_key: destination key name + :param key_size: size of the key to copy in bytes + :param part_size: Part size in bytes. Must be >= 5MB and <= 5 GB. + :param kwargs: Keyword arguments are passed to the boto function `initiate_multipart_upload` + """ + + source_bucket = self.s3.get_bucket(src_bucket, validate=True) + dest_bucket = self.s3.get_bucket(dst_bucket, validate=True) + + key_size = source_bucket.lookup(src_key).size + + # We can't do a multipart copy on an empty Key, so handle this specially. + # Also, don't bother using the multipart machinery if we're only dealing with a small non-multipart file + if key_size == 0 or key_size <= part_size: + result = pool.apply_async(dest_bucket.copy_key, args=(dst_key, src_bucket, src_key), kwds=kwargs) + # Bubble up any errors we may encounter + return result.get() + + mp = None + + try: + mp = dest_bucket.initiate_multipart_upload(dst_key, **kwargs) + cur_pos = 0 + + # Store the results from the apply_async in a list so we can check for failures + results = [] + + # Calculate the number of chunks the file will be + num_parts = (key_size + part_size - 1) // part_size + + for i in range(num_parts): + # Issue an S3 copy request, one part at a time, from one S3 object to another + part_start = cur_pos + cur_pos += part_size + part_end = min(cur_pos - 1, key_size - 1) + part_num = i + 1 + results.append(pool.apply_async(mp.copy_part_from_key, args=(src_bucket, src_key, part_num, part_start, part_end))) + logger.info('Requesting copy of %s/%s to %s/%s', part_num, num_parts, dst_bucket, dst_key) + + logger.info('Waiting for multipart copy of %s/%s to finish', dst_bucket, dst_key) + + # This will raise any exceptions in any of the copy threads + for result in results: + result.get() + + # finish the copy, making the file available in S3 + mp.complete_upload() + return mp.key_name + + except: + logger.info('Error during multipart s3 copy for %s/%s to %s/%s...', src_bucket, src_key, dst_bucket, dst_key) + # cancel the copy so we don't get charged for storage consumed by copied parts + if mp: + mp.cancel_upload() + raise + + def move(self, source_path, destination_path, **kwargs): + """ + Rename/move an object from one S3 location to another. + + :param kwargs: Keyword arguments are passed to the boto function `copy_key` + """ + self.copy(source_path, destination_path, **kwargs) + self.remove(source_path) + + def listdir(self, path, start_time=None, end_time=None, return_key=False): + """ + Get an iterable with S3 folder contents. + Iterable contains paths relative to queried path. + + :param start_time: Optional argument to list files with modified dates after start_time + :param end_time: Optional argument to list files with modified dates before end_time + :param return_key: Optional argument, when set to True will return a boto.s3.key.Key (instead of the filename) + """ + (bucket, key) = self._path_to_bucket_and_key(path) + + # grab and validate the bucket + s3_bucket = self.s3.get_bucket(bucket, validate=True) + + key_path = self._add_path_delimiter(key) + key_path_len = len(key_path) + for item in s3_bucket.list(prefix=key_path): + last_modified_date = time.strptime(item.last_modified, "%Y-%m-%dT%H:%M:%S.%fZ") + if ( + (not start_time and not end_time) or # neither are defined, list all + (start_time and not end_time and start_time < last_modified_date) or # start defined, after start + (not start_time and end_time and last_modified_date < end_time) or # end defined, prior to end + (start_time and end_time and start_time < last_modified_date < end_time) # both defined, between + ): + if return_key: + yield item + else: + yield self._add_path_delimiter(path) + item.key[key_path_len:] + + def list(self, path, start_time=None, end_time=None, return_key=False): # backwards compat + key_path_len = len(self._add_path_delimiter(path)) + for item in self.listdir(path, start_time=start_time, end_time=end_time, return_key=return_key): + if return_key: + yield item + else: + yield item[key_path_len:] + + def isdir(self, path): + """ + Is the parameter S3 path a directory? + """ + (bucket, key) = self._path_to_bucket_and_key(path) + + # grab and validate the bucket + s3_bucket = self.s3.get_bucket(bucket, validate=True) + + # root is a directory + if self._is_root(key): + return True + + for suffix in (S3_DIRECTORY_MARKER_SUFFIX_0, + S3_DIRECTORY_MARKER_SUFFIX_1): + s3_dir_with_suffix_key = s3_bucket.get_key(key + suffix) + if s3_dir_with_suffix_key: + return True + + # files with this prefix + key_path = self._add_path_delimiter(key) + s3_bucket_list_result = list(itertools.islice(s3_bucket.list(prefix=key_path), 1)) + if s3_bucket_list_result: + return True + + return False + + is_dir = isdir # compatibility with old version. + + def mkdir(self, path, parents=True, raise_if_exists=False): + if raise_if_exists and self.isdir(path): + raise FileAlreadyExists() + + _, key = self._path_to_bucket_and_key(path) + if self._is_root(key): + return # isdir raises if the bucket doesn't exist; nothing to do here. + + key = self._add_path_delimiter(key) + + if not parents and not self.isdir(os.path.dirname(key)): + raise MissingParentDirectory() + + return self.put_string("", self._add_path_delimiter(path)) + + def _get_s3_config(self, key=None): + try: + config = dict(configuration.get_config().items('s3')) + except NoSectionError: + return {} + # So what ports etc can be read without us having to specify all dtypes + for k, v in six.iteritems(config): + try: + config[k] = int(v) + except ValueError: + pass + if key: + return config.get(key) + return config + + def _path_to_bucket_and_key(self, path): + (scheme, netloc, path, query, fragment) = urlsplit(path) + path_without_initial_slash = path[1:] + return netloc, path_without_initial_slash + + def _is_root(self, key): + return (len(key) == 0) or (key == '/') + + def _add_path_delimiter(self, key): + return key if key[-1:] == '/' or key == '' else key + '/' + + +class AtomicS3File(AtomicLocalFile): + """ + An S3 file that writes to a temp file and puts to S3 on close. + + :param kwargs: Keyword arguments are passed to the boto function `initiate_multipart_upload` + """ + + def __init__(self, path, s3_client, **kwargs): + self.s3_client = s3_client + super(AtomicS3File, self).__init__(path) + self.s3_options = kwargs + + def move_to_final_destination(self): + self.s3_client.put_multipart(self.tmp_path, self.path, **self.s3_options) + + +class ReadableS3File(object): + def __init__(self, s3_key): + self.s3_key = s3_key + self.buffer = [] + self.closed = False + self.finished = False + + def read(self, size=0): + f = self.s3_key.read(size=size) + + # boto will loop on the key forever and it's not what is expected by + # the python io interface + # boto/boto#2805 + if f == b'': + self.finished = True + if self.finished: + return b'' + + return f + + def close(self): + self.s3_key.close() + self.closed = True + + def __del__(self): + self.close() + + def __exit__(self, exc_type, exc, traceback): + self.close() + + def __enter__(self): + return self + + def _add_to_buffer(self, line): + self.buffer.append(line) + + def _flush_buffer(self): + output = b''.join(self.buffer) + self.buffer = [] + return output + + def readable(self): + return True + + def writable(self): + return False + + def seekable(self): + return False + + def __iter__(self): + key_iter = self.s3_key.__iter__() + + has_next = True + while has_next: + try: + # grab the next chunk + chunk = next(key_iter) + + # split on newlines, preserving the newline + for line in chunk.splitlines(True): + + if not line.endswith(os.linesep): + # no newline, so store in buffer + self._add_to_buffer(line) + else: + # newline found, send it out + if self.buffer: + self._add_to_buffer(line) + yield self._flush_buffer() + else: + yield line + except StopIteration: + # send out anything we have left in the buffer + output = self._flush_buffer() + if output: + yield output + has_next = False + self.close() + + +class S3Target(FileSystemTarget): + """ + Target S3 file object + + :param kwargs: Keyword arguments are passed to the boto function `initiate_multipart_upload` + """ + + fs = None + + def __init__(self, path, format=None, client=None, **kwargs): + super(S3Target, self).__init__(path) + if format is None: + format = get_default_format() + + self.path = path + self.format = format + self.fs = client or S3Client() + self.s3_options = kwargs + + def open(self, mode='r'): + if mode not in ('r', 'w'): + raise ValueError("Unsupported open mode '%s'" % mode) + + if mode == 'r': + s3_key = self.fs.get_key(self.path) + if not s3_key: + raise FileNotFoundException("Could not find file at %s" % self.path) + + fileobj = ReadableS3File(s3_key) + return self.format.pipe_reader(fileobj) + else: + return self.format.pipe_writer(AtomicS3File(self.path, self.fs, **self.s3_options)) + + +class S3FlagTarget(S3Target): + """ + Defines a target directory with a flag-file (defaults to `_SUCCESS`) used + to signify job success. + + This checks for two things: + + * the path exists (just like the S3Target) + * the _SUCCESS file exists within the directory. + + Because Hadoop outputs into a directory and not a single file, + the path is assumed to be a directory. + + This is meant to be a handy alternative to AtomicS3File. + + The AtomicFile approach can be burdensome for S3 since there are no directories, per se. + + If we have 1,000,000 output files, then we have to rename 1,000,000 objects. + """ + + fs = None + + def __init__(self, path, format=None, client=None, flag='_SUCCESS'): + """ + Initializes a S3FlagTarget. + + :param path: the directory where the files are stored. + :type path: str + :param client: + :type client: + :param flag: + :type flag: str + """ + if format is None: + format = get_default_format() + + if path[-1] != "/": + raise ValueError("S3FlagTarget requires the path to be to a " + "directory. It must end with a slash ( / ).") + super(S3FlagTarget, self).__init__(path, format, client) + self.flag = flag + + def exists(self): + hadoopSemaphore = self.path + self.flag + return self.fs.exists(hadoopSemaphore) + + +class S3EmrTarget(S3FlagTarget): + """ + Deprecated. Use :py:class:`S3FlagTarget` + """ + + def __init__(self, *args, **kwargs): + warnings.warn("S3EmrTarget is deprecated. Please use S3FlagTarget") + super(S3EmrTarget, self).__init__(*args, **kwargs) + + +class S3PathTask(ExternalTask): + """ + A external task that to require existence of a path in S3. + """ + path = Parameter() + + def output(self): + return S3Target(self.path) + + +class S3EmrTask(ExternalTask): + """ + An external task that requires the existence of EMR output in S3. + """ + path = Parameter() + + def output(self): + return S3EmrTarget(self.path) + + +class S3FlagTask(ExternalTask): + """ + An external task that requires the existence of EMR output in S3. + """ + path = Parameter() + flag = Parameter(default=None) + + def output(self): + return S3FlagTarget(self.path, flag=self.flag) diff --git a/luigi/postgres.py b/luigi/postgres.py index a00c6d574a..561d27aac7 100644 --- a/luigi/postgres.py +++ b/luigi/postgres.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# Copyright 2012-2015 Spotify AB +# Copyright 2017 VNG Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,376 +15,12 @@ # limitations under the License. # """ -Implements a subclass of :py:class:`~luigi.target.Target` that writes data to Postgres. -Also provides a helper task to copy data into a Postgres table. +luigi.postgres has moved to :py:mod:`luigi.contrib.postgres` """ +# Delete this file any time after 24 march 2017 -import datetime -import logging -import re -import tempfile +import warnings -from luigi import six - -import luigi -from luigi.contrib import rdbms - -logger = logging.getLogger('luigi-interface') - -try: - import psycopg2 - import psycopg2.errorcodes - import psycopg2.extensions -except ImportError: - logger.warning("Loading postgres module without psycopg2 installed. Will crash at runtime if postgres functionality is used.") - - -class MultiReplacer(object): - """ - Object for one-pass replace of multiple words - - Substituted parts will not be matched against other replace patterns, as opposed to when using multipass replace. - The order of the items in the replace_pairs input will dictate replacement precedence. - - Constructor arguments: - replace_pairs -- list of 2-tuples which hold strings to be replaced and replace string - - Usage: - - .. code-block:: python - - >>> replace_pairs = [("a", "b"), ("b", "c")] - >>> MultiReplacer(replace_pairs)("abcd") - 'bccd' - >>> replace_pairs = [("ab", "x"), ("a", "x")] - >>> MultiReplacer(replace_pairs)("ab") - 'x' - >>> replace_pairs.reverse() - >>> MultiReplacer(replace_pairs)("ab") - 'xb' - """ -# TODO: move to misc/util module - - def __init__(self, replace_pairs): - """ - Initializes a MultiReplacer instance. - - :param replace_pairs: list of 2-tuples which hold strings to be replaced and replace string. - :type replace_pairs: tuple - """ - replace_list = list(replace_pairs) # make a copy in case input is iterable - self._replace_dict = dict(replace_list) - pattern = '|'.join(re.escape(x) for x, y in replace_list) - self._search_re = re.compile(pattern) - - def _replacer(self, match_object): - # this method is used as the replace function in the re.sub below - return self._replace_dict[match_object.group()] - - def __call__(self, search_string): - # using function replacing for a per-result replace - return self._search_re.sub(self._replacer, search_string) - - -# these are the escape sequences recognized by postgres COPY -# according to http://www.postgresql.org/docs/8.1/static/sql-copy.html -default_escape = MultiReplacer([('\\', '\\\\'), - ('\t', '\\t'), - ('\n', '\\n'), - ('\r', '\\r'), - ('\v', '\\v'), - ('\b', '\\b'), - ('\f', '\\f') - ]) - - -class PostgresTarget(luigi.Target): - """ - Target for a resource in Postgres. - - This will rarely have to be directly instantiated by the user. - """ - marker_table = luigi.configuration.get_config().get('postgres', 'marker-table', 'table_updates') - - # Use DB side timestamps or client side timestamps in the marker_table - use_db_timestamps = True - - def __init__( - self, host, database, user, password, table, update_id, port=None - ): - """ - Args: - host (str): Postgres server address. Possibly a host:port string. - database (str): Database name - user (str): Database user - password (str): Password for specified user - update_id (str): An identifier for this data set - port (int): Postgres server port. - - """ - if ':' in host: - self.host, self.port = host.split(':') - else: - self.host = host - self.port = port - self.database = database - self.user = user - self.password = password - self.table = table - self.update_id = update_id - - def touch(self, connection=None): - """ - Mark this update as complete. - - Important: If the marker table doesn't exist, the connection transaction will be aborted - and the connection reset. - Then the marker table will be created. - """ - self.create_marker_table() - - if connection is None: - # TODO: test this - connection = self.connect() - connection.autocommit = True # if connection created here, we commit it here - - if self.use_db_timestamps: - connection.cursor().execute( - """INSERT INTO {marker_table} (update_id, target_table) - VALUES (%s, %s) - """.format(marker_table=self.marker_table), - (self.update_id, self.table)) - else: - connection.cursor().execute( - """INSERT INTO {marker_table} (update_id, target_table, inserted) - VALUES (%s, %s, %s); - """.format(marker_table=self.marker_table), - (self.update_id, self.table, - datetime.datetime.now())) - - # make sure update is properly marked - assert self.exists(connection) - - def exists(self, connection=None): - if connection is None: - connection = self.connect() - connection.autocommit = True - cursor = connection.cursor() - try: - cursor.execute("""SELECT 1 FROM {marker_table} - WHERE update_id = %s - LIMIT 1""".format(marker_table=self.marker_table), - (self.update_id,) - ) - row = cursor.fetchone() - except psycopg2.ProgrammingError as e: - if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE: - row = None - else: - raise - return row is not None - - def connect(self): - """ - Get a psycopg2 connection object to the database where the table is. - """ - connection = psycopg2.connect( - host=self.host, - port=self.port, - database=self.database, - user=self.user, - password=self.password) - connection.set_client_encoding('utf-8') - return connection - - def create_marker_table(self): - """ - Create marker table if it doesn't exist. - - Using a separate connection since the transaction might have to be reset. - """ - connection = self.connect() - connection.autocommit = True - cursor = connection.cursor() - if self.use_db_timestamps: - sql = """ CREATE TABLE {marker_table} ( - update_id TEXT PRIMARY KEY, - target_table TEXT, - inserted TIMESTAMP DEFAULT NOW()) - """.format(marker_table=self.marker_table) - else: - sql = """ CREATE TABLE {marker_table} ( - update_id TEXT PRIMARY KEY, - target_table TEXT, - inserted TIMESTAMP); - """.format(marker_table=self.marker_table) - try: - cursor.execute(sql) - except psycopg2.ProgrammingError as e: - if e.pgcode == psycopg2.errorcodes.DUPLICATE_TABLE: - pass - else: - raise - connection.close() - - def open(self, mode): - raise NotImplementedError("Cannot open() PostgresTarget") - - -class CopyToTable(rdbms.CopyToTable): - """ - Template task for inserting a data set into Postgres - - Usage: - Subclass and override the required `host`, `database`, `user`, - `password`, `table` and `columns` attributes. - - To customize how to access data from an input task, override the `rows` method - with a generator that yields each row as a tuple with fields ordered according to `columns`. - """ - - def rows(self): - """ - Return/yield tuples or lists corresponding to each row to be inserted. - """ - with self.input().open('r') as fobj: - for line in fobj: - yield line.strip('\n').split('\t') - - def map_column(self, value): - """ - Applied to each column of every row returned by `rows`. - - Default behaviour is to escape special characters and identify any self.null_values. - """ - if value in self.null_values: - return r'\\N' - else: - return default_escape(six.text_type(value)) - -# everything below will rarely have to be overridden - - def output(self): - """ - Returns a PostgresTarget representing the inserted dataset. - - Normally you don't override this. - """ - return PostgresTarget( - host=self.host, - database=self.database, - user=self.user, - password=self.password, - table=self.table, - update_id=self.update_id - ) - - def copy(self, cursor, file): - if isinstance(self.columns[0], six.string_types): - column_names = self.columns - elif len(self.columns[0]) == 2: - column_names = [c[0] for c in self.columns] - else: - raise Exception('columns must consist of column strings or (column string, type string) tuples (was %r ...)' % (self.columns[0],)) - cursor.copy_from(file, self.table, null=r'\\N', sep=self.column_separator, columns=column_names) - - def run(self): - """ - Inserts data generated by rows() into target table. - - If the target table doesn't exist, self.create_table will be called to attempt to create the table. - - Normally you don't want to override this. - """ - if not (self.table and self.columns): - raise Exception("table and columns need to be specified") - - connection = self.output().connect() - # transform all data generated by rows() using map_column and write data - # to a temporary file for import using postgres COPY - tmp_dir = luigi.configuration.get_config().get('postgres', 'local-tmp-dir', None) - tmp_file = tempfile.TemporaryFile(dir=tmp_dir) - n = 0 - for row in self.rows(): - n += 1 - if n % 100000 == 0: - logger.info("Wrote %d lines", n) - rowstr = self.column_separator.join(self.map_column(val) for val in row) - rowstr += "\n" - tmp_file.write(rowstr.encode('utf-8')) - - logger.info("Done writing, importing at %s", datetime.datetime.now()) - tmp_file.seek(0) - - # attempt to copy the data into postgres - # if it fails because the target table doesn't exist - # try to create it by running self.create_table - for attempt in range(2): - try: - cursor = connection.cursor() - self.init_copy(connection) - self.copy(cursor, tmp_file) - self.post_copy(connection) - except psycopg2.ProgrammingError as e: - if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE and attempt == 0: - # if first attempt fails with "relation not found", try creating table - logger.info("Creating table %s", self.table) - connection.reset() - self.create_table(connection) - else: - raise - else: - break - - # mark as complete in same transaction - self.output().touch(connection) - - # commit and clean up - connection.commit() - connection.close() - tmp_file.close() - - -class PostgresQuery(rdbms.Query): - """ - Template task for querying a Postgres compatible database - - Usage: - Subclass and override the required `host`, `database`, `user`, `password`, `table`, and `query` attributes. - - Override the `run` method if your use case requires some action with the query result. - - Task instances require a dynamic `update_id`, e.g. via parameter(s), otherwise the query will only execute once - - To customize the query signature as recorded in the database marker table, override the `update_id` property. - """ - - def run(self): - connection = self.output().connect() - cursor = connection.cursor() - sql = self.query - - logger.info('Executing query from task: {name}'.format(name=self.__class__)) - cursor.execute(sql) - - # Update marker table - self.output().touch(connection) - - # commit and close connection - connection.commit() - connection.close() - - def output(self): - """ - Returns a PostgresTarget representing the executed query. - - Normally you don't override this. - """ - return PostgresTarget( - host=self.host, - database=self.database, - user=self.user, - password=self.password, - table=self.table, - update_id=self.update_id - ) +from luigi.contrib.postgres import * # NOQA +warnings.warn("luigi.postgres module has been moved to luigi.contrib.postgres", + DeprecationWarning) diff --git a/luigi/s3.py b/luigi/s3.py index d01995fb79..4f710600db 100644 --- a/luigi/s3.py +++ b/luigi/s3.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# Copyright 2012-2015 Spotify AB +# Copyright 2017 VNG Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,780 +15,12 @@ # limitations under the License. # """ -Implementation of Simple Storage Service support. -:py:class:`S3Target` is a subclass of the Target class to support S3 file -system operations. The `boto` library is required to use S3 targets. +luigi.s3 has moved to :py:mod:`luigi.contrib.s3` """ +# Delete this file any time after 24 march 2017 -from __future__ import division - -import datetime -import itertools -import logging -import os -import os.path - -import time -from multiprocessing.pool import ThreadPool - -try: - from urlparse import urlsplit -except ImportError: - from urllib.parse import urlsplit import warnings -try: - from ConfigParser import NoSectionError -except ImportError: - from configparser import NoSectionError - -from luigi import six -from luigi.six.moves import range - -from luigi import configuration -from luigi.format import get_default_format -from luigi.parameter import Parameter -from luigi.target import FileAlreadyExists, FileSystem, FileSystemException, FileSystemTarget, AtomicLocalFile, MissingParentDirectory -from luigi.task import ExternalTask - -logger = logging.getLogger('luigi-interface') - - -# two different ways of marking a directory -# with a suffix in S3 -S3_DIRECTORY_MARKER_SUFFIX_0 = '_$folder$' -S3_DIRECTORY_MARKER_SUFFIX_1 = '/' - - -class InvalidDeleteException(FileSystemException): - pass - - -class FileNotFoundException(FileSystemException): - pass - - -class S3Client(FileSystem): - """ - boto-powered S3 client. - """ - - def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, - **kwargs): - # only import boto when needed to allow top-lvl s3 module import - import boto - import boto.s3.connection - from boto.s3.key import Key - - options = self._get_s3_config() - options.update(kwargs) - # Removing key args would break backwards compability - role_arn = options.get('aws_role_arn') - role_session_name = options.get('aws_role_session_name') - - aws_session_token = None - - if role_arn and role_session_name: - from boto import sts - - sts_client = sts.STSConnection() - assumed_role = sts_client.assume_role(role_arn, role_session_name) - aws_secret_access_key = assumed_role.credentials.secret_key - aws_access_key_id = assumed_role.credentials.access_key - aws_session_token = assumed_role.credentials.session_token - - else: - if not aws_access_key_id: - aws_access_key_id = options.get('aws_access_key_id') - - if not aws_secret_access_key: - aws_secret_access_key = options.get('aws_secret_access_key') - - for key in ['aws_access_key_id', 'aws_secret_access_key', 'aws_role_session_name', 'aws_role_arn']: - if key in options: - options.pop(key) - - self.s3 = boto.s3.connection.S3Connection(aws_access_key_id, - aws_secret_access_key, - security_token=aws_session_token, - **options) - self.Key = Key - - def exists(self, path): - """ - Does provided path exist on S3? - """ - (bucket, key) = self._path_to_bucket_and_key(path) - - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) - - # root always exists - if self._is_root(key): - return True - - # file - s3_key = s3_bucket.get_key(key) - if s3_key: - return True - - if self.isdir(path): - return True - - logger.debug('Path %s does not exist', path) - return False - - def remove(self, path, recursive=True): - """ - Remove a file or directory from S3. - """ - if not self.exists(path): - logger.debug('Could not delete %s; path does not exist', path) - return False - - (bucket, key) = self._path_to_bucket_and_key(path) - - # root - if self._is_root(key): - raise InvalidDeleteException('Cannot delete root of bucket at path %s' % path) - - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) - - # file - s3_key = s3_bucket.get_key(key) - if s3_key: - s3_bucket.delete_key(s3_key) - logger.debug('Deleting %s from bucket %s', key, bucket) - return True - - if self.isdir(path) and not recursive: - raise InvalidDeleteException('Path %s is a directory. Must use recursive delete' % path) - - delete_key_list = [ - k for k in s3_bucket.list(self._add_path_delimiter(key))] - - # delete the directory marker file if it exists - s3_dir_with_suffix_key = s3_bucket.get_key(key + S3_DIRECTORY_MARKER_SUFFIX_0) - if s3_dir_with_suffix_key: - delete_key_list.append(s3_dir_with_suffix_key) - - if len(delete_key_list) > 0: - for k in delete_key_list: - logger.debug('Deleting %s from bucket %s', k, bucket) - s3_bucket.delete_keys(delete_key_list) - return True - - return False - - def get_key(self, path): - """ - Returns just the key from the path. - - An s3 path is composed of a bucket and a key. - - Suppose we have a path `s3://my_bucket/some/files/my_file`. The key is `some/files/my_file`. - """ - (bucket, key) = self._path_to_bucket_and_key(path) - - s3_bucket = self.s3.get_bucket(bucket, validate=True) - - return s3_bucket.get_key(key) - - def put(self, local_path, destination_s3_path, **kwargs): - """ - Put an object stored locally to an S3 path. - - :param kwargs: Keyword arguments are passed to the boto function `set_contents_from_filename` - """ - (bucket, key) = self._path_to_bucket_and_key(destination_s3_path) - - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) - - # put the file - s3_key = self.Key(s3_bucket) - s3_key.key = key - s3_key.set_contents_from_filename(local_path, **kwargs) - - def put_string(self, content, destination_s3_path, **kwargs): - """ - Put a string to an S3 path. - - :param kwargs: Keyword arguments are passed to the boto function `set_contents_from_string` - """ - (bucket, key) = self._path_to_bucket_and_key(destination_s3_path) - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) - - # put the content - s3_key = self.Key(s3_bucket) - s3_key.key = key - s3_key.set_contents_from_string(content, **kwargs) - - def put_multipart(self, local_path, destination_s3_path, part_size=67108864, **kwargs): - """ - Put an object stored locally to an S3 path - using S3 multi-part upload (for files > 5GB). - - :param local_path: Path to source local file - :param destination_s3_path: URL for target S3 location - :param part_size: Part size in bytes. Default: 67108864 (64MB), must be >= 5MB and <= 5 GB. - :param kwargs: Keyword arguments are passed to the boto function `initiate_multipart_upload` - """ - # calculate number of parts to upload - # based on the size of the file - source_size = os.stat(local_path).st_size - - if source_size <= part_size: - # fallback to standard, non-multipart strategy - return self.put(local_path, destination_s3_path, **kwargs) - - (bucket, key) = self._path_to_bucket_and_key(destination_s3_path) - - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) - - # calculate the number of parts (int division). - # use modulo to avoid float precision issues - # for exactly-sized fits - num_parts = (source_size + part_size - 1) // part_size - - mp = None - try: - mp = s3_bucket.initiate_multipart_upload(key, **kwargs) - - for i in range(num_parts): - # upload a part at a time to S3 - offset = part_size * i - bytes = min(part_size, source_size - offset) - with open(local_path, 'rb') as fp: - part_num = i + 1 - logger.info('Uploading part %s/%s to %s', part_num, num_parts, destination_s3_path) - fp.seek(offset) - mp.upload_part_from_file(fp, part_num=part_num, size=bytes) - - # finish the upload, making the file available in S3 - mp.complete_upload() - except BaseException: - if mp: - logger.info('Canceling multipart s3 upload for %s', destination_s3_path) - # cancel the upload so we don't get charged for - # storage consumed by uploaded parts - mp.cancel_upload() - raise - - def get(self, s3_path, destination_local_path): - """ - Get an object stored in S3 and write it to a local path. - """ - (bucket, key) = self._path_to_bucket_and_key(s3_path) - - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) - - # download the file - s3_key = self.Key(s3_bucket) - s3_key.key = key - s3_key.get_contents_to_filename(destination_local_path) - - def get_as_string(self, s3_path): - """ - Get the contents of an object stored in S3 as a string. - """ - (bucket, key) = self._path_to_bucket_and_key(s3_path) - - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) - - # get the content - s3_key = self.Key(s3_bucket) - s3_key.key = key - contents = s3_key.get_contents_as_string() - - return contents - - def copy(self, source_path, destination_path, threads=100, start_time=None, end_time=None, part_size=67108864, **kwargs): - """ - Copy object(s) from one S3 location to another. Works for individual keys or entire directories. - - When files are larger than `part_size`, multipart uploading will be used. - - :param source_path: The `s3://` path of the directory or key to copy from - :param destination_path: The `s3://` path of the directory or key to copy to - :param threads: Optional argument to define the number of threads to use when copying (min: 3 threads) - :param start_time: Optional argument to copy files with modified dates after start_time - :param end_time: Optional argument to copy files with modified dates before end_time - :param part_size: Part size in bytes. Default: 67108864 (64MB), must be >= 5MB and <= 5 GB. - :param kwargs: Keyword arguments are passed to the boto function `copy_key` - - :returns tuple (number_of_files_copied, total_size_copied_in_bytes) - """ - start = datetime.datetime.now() - - (src_bucket, src_key) = self._path_to_bucket_and_key(source_path) - (dst_bucket, dst_key) = self._path_to_bucket_and_key(destination_path) - - # As the S3 copy command is completely server side, there is no issue with issuing a lot of threads - # to issue a single API call per copy, however, this may in theory cause issues on systems with low ulimits for - # number of threads when copying really large files (e.g. with a ~100GB file this will open ~1500 - # threads), or large directories. Around 100 threads seems to work well. - - threads = 3 if threads < 3 else threads # don't allow threads to be less than 3 - total_keys = 0 - - copy_pool = ThreadPool(processes=threads) - - if self.isdir(source_path): - # The management pool is to ensure that there's no deadlock between the s3 copying threads, and the - # multipart_copy threads that monitors each group of s3 copy threads and returns a success once the entire file - # is copied. Without this, we could potentially fill up the pool with threads waiting to check if the s3 copies - # have completed, leaving no available threads to actually perform any copying. - copy_jobs = [] - management_pool = ThreadPool(processes=threads) - - (bucket, key) = self._path_to_bucket_and_key(source_path) - key_path = self._add_path_delimiter(key) - key_path_len = len(key_path) - - total_size_bytes = 0 - src_prefix = self._add_path_delimiter(src_key) - dst_prefix = self._add_path_delimiter(dst_key) - for item in self.list(source_path, start_time=start_time, end_time=end_time, return_key=True): - path = item.key[key_path_len:] - # prevents copy attempt of empty key in folder - if path != '' and path != '/': - total_keys += 1 - total_size_bytes += item.size - job = management_pool.apply_async(self.__copy_multipart, - args=(copy_pool, - src_bucket, src_prefix + path, - dst_bucket, dst_prefix + path, - part_size), - kwds=kwargs) - copy_jobs.append(job) - - # Wait for the pools to finish scheduling all the copies - management_pool.close() - management_pool.join() - copy_pool.close() - copy_pool.join() - - # Raise any errors encountered in any of the copy processes - for result in copy_jobs: - result.get() - - end = datetime.datetime.now() - duration = end - start - logger.info('%s : Complete : %s total keys copied in %s' % - (datetime.datetime.now(), total_keys, duration)) - - return total_keys, total_size_bytes - - # If the file isn't a directory just perform a simple copy - else: - self.__copy_multipart(copy_pool, src_bucket, src_key, dst_bucket, dst_key, part_size, **kwargs) - # Close the pool - copy_pool.close() - copy_pool.join() - - def __copy_multipart(self, pool, src_bucket, src_key, dst_bucket, dst_key, part_size, **kwargs): - """ - Copy a single S3 object to another S3 object, falling back to multipart copy where necessary - - NOTE: This is a private method and should only be called from within the `luigi.s3.copy` method - - :param pool: The threadpool to put the s3 copy processes onto - :param src_bucket: source bucket name - :param src_key: source key name - :param dst_bucket: destination bucket name - :param dst_key: destination key name - :param key_size: size of the key to copy in bytes - :param part_size: Part size in bytes. Must be >= 5MB and <= 5 GB. - :param kwargs: Keyword arguments are passed to the boto function `initiate_multipart_upload` - """ - - source_bucket = self.s3.get_bucket(src_bucket, validate=True) - dest_bucket = self.s3.get_bucket(dst_bucket, validate=True) - - key_size = source_bucket.lookup(src_key).size - - # We can't do a multipart copy on an empty Key, so handle this specially. - # Also, don't bother using the multipart machinery if we're only dealing with a small non-multipart file - if key_size == 0 or key_size <= part_size: - result = pool.apply_async(dest_bucket.copy_key, args=(dst_key, src_bucket, src_key), kwds=kwargs) - # Bubble up any errors we may encounter - return result.get() - - mp = None - - try: - mp = dest_bucket.initiate_multipart_upload(dst_key, **kwargs) - cur_pos = 0 - - # Store the results from the apply_async in a list so we can check for failures - results = [] - - # Calculate the number of chunks the file will be - num_parts = (key_size + part_size - 1) // part_size - - for i in range(num_parts): - # Issue an S3 copy request, one part at a time, from one S3 object to another - part_start = cur_pos - cur_pos += part_size - part_end = min(cur_pos - 1, key_size - 1) - part_num = i + 1 - results.append(pool.apply_async(mp.copy_part_from_key, args=(src_bucket, src_key, part_num, part_start, part_end))) - logger.info('Requesting copy of %s/%s to %s/%s', part_num, num_parts, dst_bucket, dst_key) - - logger.info('Waiting for multipart copy of %s/%s to finish', dst_bucket, dst_key) - - # This will raise any exceptions in any of the copy threads - for result in results: - result.get() - - # finish the copy, making the file available in S3 - mp.complete_upload() - return mp.key_name - - except: - logger.info('Error during multipart s3 copy for %s/%s to %s/%s...', src_bucket, src_key, dst_bucket, dst_key) - # cancel the copy so we don't get charged for storage consumed by copied parts - if mp: - mp.cancel_upload() - raise - - def move(self, source_path, destination_path, **kwargs): - """ - Rename/move an object from one S3 location to another. - - :param kwargs: Keyword arguments are passed to the boto function `copy_key` - """ - self.copy(source_path, destination_path, **kwargs) - self.remove(source_path) - - def listdir(self, path, start_time=None, end_time=None, return_key=False): - """ - Get an iterable with S3 folder contents. - Iterable contains paths relative to queried path. - - :param start_time: Optional argument to list files with modified dates after start_time - :param end_time: Optional argument to list files with modified dates before end_time - :param return_key: Optional argument, when set to True will return a boto.s3.key.Key (instead of the filename) - """ - (bucket, key) = self._path_to_bucket_and_key(path) - - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) - - key_path = self._add_path_delimiter(key) - key_path_len = len(key_path) - for item in s3_bucket.list(prefix=key_path): - last_modified_date = time.strptime(item.last_modified, "%Y-%m-%dT%H:%M:%S.%fZ") - if ( - (not start_time and not end_time) or # neither are defined, list all - (start_time and not end_time and start_time < last_modified_date) or # start defined, after start - (not start_time and end_time and last_modified_date < end_time) or # end defined, prior to end - (start_time and end_time and start_time < last_modified_date < end_time) # both defined, between - ): - if return_key: - yield item - else: - yield self._add_path_delimiter(path) + item.key[key_path_len:] - - def list(self, path, start_time=None, end_time=None, return_key=False): # backwards compat - key_path_len = len(self._add_path_delimiter(path)) - for item in self.listdir(path, start_time=start_time, end_time=end_time, return_key=return_key): - if return_key: - yield item - else: - yield item[key_path_len:] - - def isdir(self, path): - """ - Is the parameter S3 path a directory? - """ - (bucket, key) = self._path_to_bucket_and_key(path) - - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) - - # root is a directory - if self._is_root(key): - return True - - for suffix in (S3_DIRECTORY_MARKER_SUFFIX_0, - S3_DIRECTORY_MARKER_SUFFIX_1): - s3_dir_with_suffix_key = s3_bucket.get_key(key + suffix) - if s3_dir_with_suffix_key: - return True - - # files with this prefix - key_path = self._add_path_delimiter(key) - s3_bucket_list_result = list(itertools.islice(s3_bucket.list(prefix=key_path), 1)) - if s3_bucket_list_result: - return True - - return False - - is_dir = isdir # compatibility with old version. - - def mkdir(self, path, parents=True, raise_if_exists=False): - if raise_if_exists and self.isdir(path): - raise FileAlreadyExists() - - _, key = self._path_to_bucket_and_key(path) - if self._is_root(key): - return # isdir raises if the bucket doesn't exist; nothing to do here. - - key = self._add_path_delimiter(key) - - if not parents and not self.isdir(os.path.dirname(key)): - raise MissingParentDirectory() - - return self.put_string("", self._add_path_delimiter(path)) - - def _get_s3_config(self, key=None): - try: - config = dict(configuration.get_config().items('s3')) - except NoSectionError: - return {} - # So what ports etc can be read without us having to specify all dtypes - for k, v in six.iteritems(config): - try: - config[k] = int(v) - except ValueError: - pass - if key: - return config.get(key) - return config - - def _path_to_bucket_and_key(self, path): - (scheme, netloc, path, query, fragment) = urlsplit(path) - path_without_initial_slash = path[1:] - return netloc, path_without_initial_slash - - def _is_root(self, key): - return (len(key) == 0) or (key == '/') - - def _add_path_delimiter(self, key): - return key if key[-1:] == '/' or key == '' else key + '/' - - -class AtomicS3File(AtomicLocalFile): - """ - An S3 file that writes to a temp file and puts to S3 on close. - - :param kwargs: Keyword arguments are passed to the boto function `initiate_multipart_upload` - """ - - def __init__(self, path, s3_client, **kwargs): - self.s3_client = s3_client - super(AtomicS3File, self).__init__(path) - self.s3_options = kwargs - - def move_to_final_destination(self): - self.s3_client.put_multipart(self.tmp_path, self.path, **self.s3_options) - - -class ReadableS3File(object): - def __init__(self, s3_key): - self.s3_key = s3_key - self.buffer = [] - self.closed = False - self.finished = False - - def read(self, size=0): - f = self.s3_key.read(size=size) - - # boto will loop on the key forever and it's not what is expected by - # the python io interface - # boto/boto#2805 - if f == b'': - self.finished = True - if self.finished: - return b'' - - return f - - def close(self): - self.s3_key.close() - self.closed = True - - def __del__(self): - self.close() - - def __exit__(self, exc_type, exc, traceback): - self.close() - - def __enter__(self): - return self - - def _add_to_buffer(self, line): - self.buffer.append(line) - - def _flush_buffer(self): - output = b''.join(self.buffer) - self.buffer = [] - return output - - def readable(self): - return True - - def writable(self): - return False - - def seekable(self): - return False - - def __iter__(self): - key_iter = self.s3_key.__iter__() - - has_next = True - while has_next: - try: - # grab the next chunk - chunk = next(key_iter) - - # split on newlines, preserving the newline - for line in chunk.splitlines(True): - - if not line.endswith(os.linesep): - # no newline, so store in buffer - self._add_to_buffer(line) - else: - # newline found, send it out - if self.buffer: - self._add_to_buffer(line) - yield self._flush_buffer() - else: - yield line - except StopIteration: - # send out anything we have left in the buffer - output = self._flush_buffer() - if output: - yield output - has_next = False - self.close() - - -class S3Target(FileSystemTarget): - """ - Target S3 file object - - :param kwargs: Keyword arguments are passed to the boto function `initiate_multipart_upload` - """ - - fs = None - - def __init__(self, path, format=None, client=None, **kwargs): - super(S3Target, self).__init__(path) - if format is None: - format = get_default_format() - - self.path = path - self.format = format - self.fs = client or S3Client() - self.s3_options = kwargs - - def open(self, mode='r'): - if mode not in ('r', 'w'): - raise ValueError("Unsupported open mode '%s'" % mode) - - if mode == 'r': - s3_key = self.fs.get_key(self.path) - if not s3_key: - raise FileNotFoundException("Could not find file at %s" % self.path) - - fileobj = ReadableS3File(s3_key) - return self.format.pipe_reader(fileobj) - else: - return self.format.pipe_writer(AtomicS3File(self.path, self.fs, **self.s3_options)) - - -class S3FlagTarget(S3Target): - """ - Defines a target directory with a flag-file (defaults to `_SUCCESS`) used - to signify job success. - - This checks for two things: - - * the path exists (just like the S3Target) - * the _SUCCESS file exists within the directory. - - Because Hadoop outputs into a directory and not a single file, - the path is assumed to be a directory. - - This is meant to be a handy alternative to AtomicS3File. - - The AtomicFile approach can be burdensome for S3 since there are no directories, per se. - - If we have 1,000,000 output files, then we have to rename 1,000,000 objects. - """ - - fs = None - - def __init__(self, path, format=None, client=None, flag='_SUCCESS'): - """ - Initializes a S3FlagTarget. - - :param path: the directory where the files are stored. - :type path: str - :param client: - :type client: - :param flag: - :type flag: str - """ - if format is None: - format = get_default_format() - - if path[-1] != "/": - raise ValueError("S3FlagTarget requires the path to be to a " - "directory. It must end with a slash ( / ).") - super(S3FlagTarget, self).__init__(path, format, client) - self.flag = flag - - def exists(self): - hadoopSemaphore = self.path + self.flag - return self.fs.exists(hadoopSemaphore) - - -class S3EmrTarget(S3FlagTarget): - """ - Deprecated. Use :py:class:`S3FlagTarget` - """ - - def __init__(self, *args, **kwargs): - warnings.warn("S3EmrTarget is deprecated. Please use S3FlagTarget") - super(S3EmrTarget, self).__init__(*args, **kwargs) - - -class S3PathTask(ExternalTask): - """ - A external task that to require existence of a path in S3. - """ - path = Parameter() - - def output(self): - return S3Target(self.path) - - -class S3EmrTask(ExternalTask): - """ - An external task that requires the existence of EMR output in S3. - """ - path = Parameter() - - def output(self): - return S3EmrTarget(self.path) - - -class S3FlagTask(ExternalTask): - """ - An external task that requires the existence of EMR output in S3. - """ - path = Parameter() - flag = Parameter(default=None) - - def output(self): - return S3FlagTarget(self.path, flag=self.flag) +from luigi.contrib.s3 import * # NOQA +warnings.warn("luigi.s3 module has been moved to luigi.contrib.s3", + DeprecationWarning) diff --git a/test/contrib/hadoop_test.py b/test/contrib/hadoop_test.py index fed99e211a..ad69be3a6e 100644 --- a/test/contrib/hadoop_test.py +++ b/test/contrib/hadoop_test.py @@ -24,7 +24,7 @@ import luigi.format import luigi.contrib.hadoop import luigi.contrib.hdfs -import luigi.mrrunner +import luigi.contrib.mrrunner import luigi.notifications import minicluster import mock diff --git a/test/postgres_test.py b/test/contrib/postgres_test.py similarity index 100% rename from test/postgres_test.py rename to test/contrib/postgres_test.py diff --git a/test/postgres_with_server_test.py b/test/contrib/postgres_with_server_test.py similarity index 100% rename from test/postgres_with_server_test.py rename to test/contrib/postgres_with_server_test.py diff --git a/test/s3_test.py b/test/contrib/s3_test.py similarity index 100% rename from test/s3_test.py rename to test/contrib/s3_test.py diff --git a/test/contrib/streaming_test.py b/test/contrib/streaming_test.py index 4a8d2a0311..c1b4ac7569 100644 --- a/test/contrib/streaming_test.py +++ b/test/contrib/streaming_test.py @@ -3,7 +3,8 @@ import unittest -from luigi import mrrunner, Parameter +from luigi import Parameter +from luigi.contrib import mrrunner from luigi.contrib.hadoop import HadoopJobRunner, JobTask from luigi.contrib.hdfs import HdfsTarget