Skip to content

Commit

Permalink
Merge pull request #903 from kyleknap/streams
Browse files Browse the repository at this point in the history
Added the ability to stream data using ``cp``.
  • Loading branch information
kyleknap committed Sep 29, 2014
2 parents ab363c3 + 4716948 commit 23a1aac
Show file tree
Hide file tree
Showing 20 changed files with 977 additions and 140 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Next Release (TBD)
(`issue 919 <https://github.com/aws/aws-cli/pull/919>`__)
* feature:``aws s3``: Add ``--only-show-errors`` option that displays
errors and warnings but suppresses all other output.
* feature:``aws s3 cp``: Added ability to upload local
file streams from standard input to s3 and download s3
objects as local file streams to standard output.
(`issue 903 <https://github.com/aws/aws-cli/pull/903>`__)


1.4.4
Expand Down
24 changes: 14 additions & 10 deletions awscli/customizations/s3/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import sys
import threading

from awscli.customizations.s3.utils import uni_print, \
IORequest, IOCloseRequest, StablePriorityQueue
from awscli.customizations.s3.utils import uni_print, bytes_print, \
IORequest, IOCloseRequest, StablePriorityQueue
from awscli.customizations.s3.tasks import OrderableTask


Expand Down Expand Up @@ -154,15 +154,19 @@ def run(self):
self._cleanup()
return
elif isinstance(task, IORequest):
filename, offset, data = task
fileobj = self.fd_descriptor_cache.get(filename)
if fileobj is None:
fileobj = open(filename, 'rb+')
self.fd_descriptor_cache[filename] = fileobj
fileobj.seek(offset)
filename, offset, data, is_stream = task
if is_stream:
fileobj = sys.stdout
bytes_print(data)
else:
fileobj = self.fd_descriptor_cache.get(filename)
if fileobj is None:
fileobj = open(filename, 'rb+')
self.fd_descriptor_cache[filename] = fileobj
fileobj.seek(offset)
fileobj.write(data)
LOGGER.debug("Writing data to: %s, offset: %s",
filename, offset)
fileobj.write(data)
fileobj.flush()
elif isinstance(task, IOCloseRequest):
LOGGER.debug("IOCloseRequest received for %s, closing file.",
Expand Down Expand Up @@ -239,7 +243,7 @@ def __init__(self, result_queue, quiet, only_show_errors):
self._lock = threading.Lock()
self._needs_newline = False

self._total_parts = 0
self._total_parts = '...'
self._total_files = '...'

# This is a public attribute that clients can inspect to determine
Expand Down
24 changes: 6 additions & 18 deletions awscli/customizations/s3/filegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from dateutil.tz import tzlocal

from awscli.customizations.s3.utils import find_bucket_key, get_file_stat
from awscli.customizations.s3.utils import BucketLister, create_warning
from awscli.customizations.s3.utils import BucketLister, create_warning, \
find_dest_path_comp_key
from awscli.errorhandler import ClientError


Expand Down Expand Up @@ -131,26 +132,13 @@ def call(self, files):
``dir_op`` and ``use_src_name`` flags affect which files are used and
ensure the proper destination paths and compare keys are formed.
"""
src = files['src']
dest = files['dest']
src_type = src['type']
dest_type = dest['type']
function_table = {'s3': self.list_objects, 'local': self.list_files}
sep_table = {'s3': '/', 'local': os.sep}
source = src['path']
source = files['src']['path']
src_type = files['src']['type']
dest_type = files['dest']['type']
file_list = function_table[src_type](source, files['dir_op'])
for src_path, size, last_update in file_list:
if files['dir_op']:
rel_path = src_path[len(src['path']):]
else:
rel_path = src_path.split(sep_table[src_type])[-1]
compare_key = rel_path.replace(sep_table[src_type], '/')
if files['use_src_name']:
dest_path = dest['path']
dest_path += rel_path.replace(sep_table[src_type],
sep_table[dest_type])
else:
dest_path = dest['path']
dest_path, compare_key = find_dest_path_comp_key(files, src_path)
yield FileStat(src=src_path, dest=dest_path,
compare_key=compare_key, size=size,
last_update=last_update, src_type=src_type,
Expand Down
121 changes: 83 additions & 38 deletions awscli/customizations/s3/fileinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from botocore.compat import quote
from awscli.customizations.s3.utils import find_bucket_key, \
check_etag, check_error, operate, uni_print, \
guess_content_type, MD5Error
guess_content_type, MD5Error, bytes_print


class CreateDirectoryError(Exception):
Expand All @@ -26,7 +26,7 @@ def read_file(filename):
return in_file.read()


def save_file(filename, response_data, last_update):
def save_file(filename, response_data, last_update, is_stream=False):
"""
This writes to the file upon downloading. It reads the data in the
response. Makes a new directory if needed and then writes the
Expand All @@ -35,31 +35,57 @@ def save_file(filename, response_data, last_update):
"""
body = response_data['Body']
etag = response_data['ETag'][1:-1]
d = os.path.dirname(filename)
try:
if not os.path.exists(d):
os.makedirs(d)
except OSError as e:
if not e.errno == errno.EEXIST:
raise CreateDirectoryError(
"Could not create directory %s: %s" % (d, e))
if not is_stream:
d = os.path.dirname(filename)
try:
if not os.path.exists(d):
os.makedirs(d)
except OSError as e:
if not e.errno == errno.EEXIST:
raise CreateDirectoryError(
"Could not create directory %s: %s" % (d, e))
md5 = hashlib.md5()
file_chunks = iter(partial(body.read, 1024 * 1024), b'')
with open(filename, 'wb') as out_file:
if not _is_multipart_etag(etag):
for chunk in file_chunks:
md5.update(chunk)
out_file.write(chunk)
else:
for chunk in file_chunks:
out_file.write(chunk)
if is_stream:
# Need to save the data to be able to check the etag for a stream
# becuase once the data is written to the stream there is no
# undoing it.
payload = write_to_file(None, etag, md5, file_chunks, True)
else:
with open(filename, 'wb') as out_file:
write_to_file(out_file, etag, md5, file_chunks)

if not _is_multipart_etag(etag):
if etag != md5.hexdigest():
os.remove(filename)
if not is_stream:
os.remove(filename)
raise MD5Error(filename)
last_update_tuple = last_update.timetuple()
mod_timestamp = time.mktime(last_update_tuple)
os.utime(filename, (int(mod_timestamp), int(mod_timestamp)))

if not is_stream:
last_update_tuple = last_update.timetuple()
mod_timestamp = time.mktime(last_update_tuple)
os.utime(filename, (int(mod_timestamp), int(mod_timestamp)))
else:
# Now write the output to stdout since the md5 is correct.
bytes_print(payload)
sys.stdout.flush()


def write_to_file(out_file, etag, md5, file_chunks, is_stream=False):
"""
Updates the etag for each file chunk. It will write to the file if it a
file but if it is a stream it will return a byte string to be later
written to a stream.
"""
body = b''
for chunk in file_chunks:
if not _is_multipart_etag(etag):
md5.update(chunk)
if is_stream:
body += chunk
else:
out_file.write(chunk)
return body


def _is_multipart_etag(etag):
Expand Down Expand Up @@ -140,7 +166,7 @@ class FileInfo(TaskInfo):
def __init__(self, src, dest=None, compare_key=None, size=None,
last_update=None, src_type=None, dest_type=None,
operation_name=None, service=None, endpoint=None,
parameters=None, source_endpoint=None):
parameters=None, source_endpoint=None, is_stream=False):
super(FileInfo, self).__init__(src, src_type=src_type,
operation_name=operation_name,
service=service,
Expand All @@ -157,6 +183,18 @@ def __init__(self, src, dest=None, compare_key=None, size=None,
self.parameters = {'acl': None,
'sse': None}
self.source_endpoint = source_endpoint
self.is_stream = is_stream

def set_size_from_s3(self):
"""
This runs a ``HeadObject`` on the s3 object and sets the size.
"""
bucket, key = find_bucket_key(self.src)
params = {'endpoint': self.endpoint,
'bucket': bucket,
'key': key}
response_data, http = operate(self.service, 'HeadObject', params)
self.size = int(response_data['ContentLength'])

def _permission_to_param(self, permission):
if permission == 'read':
Expand Down Expand Up @@ -204,24 +242,30 @@ def _handle_object_params(self, params):
if self.parameters['expires']:
params['expires'] = self.parameters['expires'][0]

def upload(self):
def upload(self, payload=None):
"""
Redirects the file to the multipart upload function if the file is
large. If it is small enough, it puts the file as an object in s3.
"""
with open(self.src, 'rb') as body:
bucket, key = find_bucket_key(self.dest)
params = {
'endpoint': self.endpoint,
'bucket': bucket,
'key': key,
'body': body,
}
self._handle_object_params(params)
response_data, http = operate(self.service, 'PutObject', params)
etag = response_data['ETag'][1:-1]
body.seek(0)
check_etag(etag, body)
if payload:
self._handle_upload(payload)
else:
with open(self.src, 'rb') as body:
self._handle_upload(body)

def _handle_upload(self, body):
bucket, key = find_bucket_key(self.dest)
params = {
'endpoint': self.endpoint,
'bucket': bucket,
'key': key,
'body': body,
}
self._handle_object_params(params)
response_data, http = operate(self.service, 'PutObject', params)
etag = response_data['ETag'][1:-1]
body.seek(0)
check_etag(etag, body)

def _inject_content_type(self, params, filename):
# Add a content type param if we can guess the type.
Expand All @@ -237,7 +281,8 @@ def download(self):
bucket, key = find_bucket_key(self.src)
params = {'endpoint': self.endpoint, 'bucket': bucket, 'key': key}
response_data, http = operate(self.service, 'GetObject', params)
save_file(self.dest, response_data, self.last_update)
save_file(self.dest, response_data, self.last_update,
self.is_stream)

def copy(self):
"""
Expand Down
6 changes: 4 additions & 2 deletions awscli/customizations/s3/fileinfobuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ class FileInfoBuilder(object):
a ``FileInfo`` object so that the operation can be performed.
"""
def __init__(self, service, endpoint, source_endpoint=None,
parameters = None):
parameters = None, is_stream=False):
self._service = service
self._endpoint = endpoint
self._source_endpoint = endpoint
if source_endpoint:
self._source_endpoint = source_endpoint
self._parameters = parameters
self._parameters = parameters
self._is_stream = is_stream

def call(self, files):
for file_base in files:
Expand All @@ -46,4 +47,5 @@ def _inject_info(self, file_base):
file_info_attr['endpoint'] = self._endpoint
file_info_attr['source_endpoint'] = self._source_endpoint
file_info_attr['parameters'] = self._parameters
file_info_attr['is_stream'] = self._is_stream
return FileInfo(**file_info_attr)
Loading

0 comments on commit 23a1aac

Please sign in to comment.