Skip to content

Commit

Permalink
Fix incorrect downloaded file size when size of the file is not a mul…
Browse files Browse the repository at this point in the history
…tiple of the block size (#22036)

* convert apply method of SourceContentValidation to private

* fixed buffer length while batch download

* add another test case and fixed typos

* Fixing buffer size

* PR comment

---------

Co-authored-by: Sourav Gupta <souravgupta@microsoft.com>
  • Loading branch information
tanyasethi-msft and souravgupta-msft authored Nov 22, 2023
1 parent 6a155b3 commit c933ba3
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
1 change: 1 addition & 0 deletions sdk/storage/azblob/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* Fixed SharedKeyMissingError when using client.BlobClient().GetSASURL() method
* Fixed an issue that would cause metadata keys with empty values to be omitted when enumerating blobs.
* Fixed an issue where passing empty map to set blob tags API was causing panic. Fixes [#21869](https://github.com/Azure/azure-sdk-for-go/issues/21869).
* Fixed an issue where downloaded file has incorrect size when not a multiple of block size. Fixes [#21995](https://github.com/Azure/azure-sdk-for-go/issues/21995).

### Other Changes

Expand Down
20 changes: 11 additions & 9 deletions sdk/storage/azblob/blob/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ func (b *Client) downloadFile(ctx context.Context, writer io.Writer, o downloadO

buffers := shared.NewMMBPool(int(o.Concurrency), o.BlockSize)
defer buffers.Free()
aquireBuffer := func() ([]byte, error) {
acquireBuffer := func() ([]byte, error) {
select {
case b := <-buffers.Acquire():
// got a buffer
Expand All @@ -489,21 +489,23 @@ func (b *Client) downloadFile(ctx context.Context, writer io.Writer, o downloadO
/*
* We have created as many channels as the number of chunks we have.
* Each downloaded block will be sent to the channel matching its
* sequece number, i.e. 0th block is sent to 0th channel, 1st block
* sequence number, i.e. 0th block is sent to 0th channel, 1st block
* to 1st channel and likewise. The blocks are then read and written
* to the file serially by below goroutine. Do note that the blocks
* blocks are still downloaded parallelly from n/w, only serailized
* are still downloaded parallelly from n/w, only serialized
* and written to file here.
*/
writerError := make(chan error)
writeSize := int64(0)
go func(ch chan error) {
for _, block := range blocks {
select {
case <-ctx.Done():
return
case block := <-block:
_, err := writer.Write(block)
buffers.Release(block)
n, err := writer.Write(block)
writeSize += int64(n)
buffers.Release(block[:cap(block)])
if err != nil {
ch <- err
return
Expand All @@ -521,7 +523,7 @@ func (b *Client) downloadFile(ctx context.Context, writer io.Writer, o downloadO
NumChunks: numChunks,
Concurrency: o.Concurrency,
Operation: func(ctx context.Context, chunkStart int64, count int64) error {
buff, err := aquireBuffer()
buff, err := acquireBuffer()
if err != nil {
return err
}
Expand All @@ -538,8 +540,8 @@ func (b *Client) downloadFile(ctx context.Context, writer io.Writer, o downloadO
return err
}

blockIndex := (chunkStart / o.BlockSize)
blocks[blockIndex] <- buff
blockIndex := chunkStart / o.BlockSize
blocks[blockIndex] <- buff[:count]
return nil
},
})
Expand All @@ -551,7 +553,7 @@ func (b *Client) downloadFile(ctx context.Context, writer io.Writer, o downloadO
if err = <-writerError; err != nil {
return 0, err
}
return count, nil
return writeSize, nil
}

// DownloadStream reads a range of bytes from a blob. The response also includes the blob's properties and metadata.
Expand Down
10 changes: 9 additions & 1 deletion sdk/storage/azblob/blob/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ func (s *BlobUnrecordedTestsSuite) TestUploadDownloadBlockBlob() {
_, err = srcBlob.Upload(context.Background(), body, nil)
_require.NoError(err)

// downlod to a temp file and verify contents
// download to a temp file and verify contents
tmp, err := os.CreateTemp("", "")
_require.NoError(err)
defer tmp.Close()
Expand All @@ -299,6 +299,10 @@ func (s *BlobUnrecordedTestsSuite) TestUploadDownloadBlockBlob() {
_require.NoError(err)
_require.Equal(int64(contentSize), n)

stat, err := tmp.Stat()
_require.NoError(err)
_require.Equal(int64(contentSize), stat.Size())

// Compute md5 of file, and verify it against stored value.
_, _ = tmp.Seek(0, io.SeekStart)
buff := make([]byte, contentSize)
Expand Down Expand Up @@ -327,6 +331,10 @@ func (s *BlobUnrecordedTestsSuite) TestUploadDownloadBlockBlob() {

// 199 MB file, more blocks than threads
testUploadDownload(199 * MiB)

testUploadDownload(7 * MiB)

testUploadDownload(8241066)
}

func (s *BlobRecordedTestsSuite) TestBlobStartCopyDestEmpty() {
Expand Down

0 comments on commit c933ba3

Please sign in to comment.