Skip to content

Commit

Permalink
Implement decryption for cloud blobstore download blob for cases wher…
Browse files Browse the repository at this point in the history
…e the blob was encrypted by cloud blobstore before uploading
  • Loading branch information
ankagrawal committed Sep 16, 2019
1 parent a690862 commit d2545ca
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public class CloudBlobMetadata {
private String vcrKmsContext;
private String cryptoAgentFactory;
private String cloudBlobName;
private long encryptedSize;

/**
* Possible values of encryption origin for cloud stored blobs.
Expand All @@ -67,6 +68,19 @@ public enum EncryptionOrigin {
public CloudBlobMetadata() {
}

/**
* Constructor from {@link BlobId}.
* @param blobId The BlobId for metadata record.
* @param creationTime The blob creation time.
* @param expirationTime The blob expiration time.
* @param size The blob size.
* @param encryptionOrigin The blob's encryption origin.
*/
public CloudBlobMetadata(BlobId blobId, long creationTime, long expirationTime, long size,
EncryptionOrigin encryptionOrigin) {
this(blobId, creationTime, expirationTime, size, encryptionOrigin, null, null, -1);
}

/**
* Constructor from {@link BlobId}.
* @param blobId The BlobId for metadata record.
Expand All @@ -77,9 +91,11 @@ public CloudBlobMetadata() {
* @param vcrKmsContext The KMS context used to encrypt the blob. Only used when encryptionOrigin = VCR.
* @param cryptoAgentFactory The class name of the {@link CloudBlobCryptoAgentFactory} used to encrypt the blob.
* Only used when encryptionOrigin = VCR.
* @param encryptedSize The size of the uploaded blob if it was encrypted and then uploaded.
* Only used when encryptionOrigin = VCR.
*/
public CloudBlobMetadata(BlobId blobId, long creationTime, long expirationTime, long size,
EncryptionOrigin encryptionOrigin, String vcrKmsContext, String cryptoAgentFactory) {
EncryptionOrigin encryptionOrigin, String vcrKmsContext, String cryptoAgentFactory, long encryptedSize) {
this.id = blobId.getID();
this.partitionId = blobId.getPartition().toPathString();
this.accountId = blobId.getAccountId();
Expand All @@ -93,6 +109,7 @@ public CloudBlobMetadata(BlobId blobId, long creationTime, long expirationTime,
this.vcrKmsContext = vcrKmsContext;
this.cryptoAgentFactory = cryptoAgentFactory;
this.cloudBlobName = blobId.getID();
this.encryptedSize = encryptedSize;
}

/**
Expand Down Expand Up @@ -312,6 +329,22 @@ public CloudBlobMetadata setCryptoAgentFactory(String cryptoAgentFactory) {
return this;
}

/**
* @return the encrypted size of the blob if the blob was encrypted and uploaded to cloud, -1 otherwise
*/
public long getEncryptedSize() {
return encryptedSize;
}

/**
* Sets the encrypted size of the blob
* @param encryptedSize
*/
public CloudBlobMetadata setEncryptedSize(long encryptedSize) {
this.encryptedSize = encryptedSize;
return this;
}

@Override
public boolean equals(Object o) {
if (!(o instanceof CloudBlobMetadata)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import com.github.ambry.store.StoreStats;
import com.github.ambry.store.Write;
import com.github.ambry.utils.ByteBufferInputStream;
import com.github.ambry.utils.ByteBufferOutputStream;
import com.github.ambry.utils.Utils;
import java.io.IOException;
import java.io.OutputStream;
Expand Down Expand Up @@ -146,11 +147,27 @@ public StoreInfo get(List<? extends StoreKey> ids, EnumSet<StoreGetOptions> stor
return new StoreInfo(messageReadSet, messageInfos);
}

public void downloadBlob(BlobId blobId, OutputStream outputStream) throws StoreException {
/**
* Download the blob corresponding to the {@code blobId} from the {@code CloudDestination} to the given {@code outputStream}
* If the blob was encrypted by vcr during upload, then this method also decrypts it.
* @param cloudBlobMetadata blob metadata to determine if the blob was encrypted by vcr during upload.
* @param blobId Id of the blob to the downloaded.
* @param outputStream {@code OutputStream} of the donwloaded blob.
* @throws StoreException if there is an error in downloading the blob.
*/
void downloadBlob(CloudBlobMetadata cloudBlobMetadata, BlobId blobId, OutputStream outputStream)
throws StoreException {
try {
cloudDestination.downloadBlob(blobId, outputStream);
} catch (CloudStorageException e) {
throw new StoreException("Error occurred in downloading blob for blobid :" + blobId, StoreErrorCodes.IOError);
if (cloudBlobMetadata.getEncryptionOrigin().equals(EncryptionOrigin.VCR)) {
ByteBuffer encryptedBlob = ByteBuffer.allocate((int) cloudBlobMetadata.getEncryptedSize());
cloudDestination.downloadBlob(blobId, new ByteBufferOutputStream(encryptedBlob));
ByteBuffer decryptedBlob = cryptoAgent.decrypt(encryptedBlob);
outputStream.write(decryptedBlob.array());
} else {
cloudDestination.downloadBlob(blobId, outputStream);
}
} catch (CloudStorageException | GeneralSecurityException | IOException e) {
throw new StoreException("Error occured in downloading blob for blobid :" + blobId, StoreErrorCodes.IOError);
}
}

Expand Down Expand Up @@ -233,7 +250,7 @@ private void putBlob(MessageInfo messageInfo, ByteBuffer messageBuf, long size)
String kmsContext = null;
String cryptoAgentFactoryClass = null;
EncryptionOrigin encryptionOrigin = isRouterEncrypted ? EncryptionOrigin.ROUTER : EncryptionOrigin.NONE;
boolean bufferChanged = false;
long encryptedSize = -1;
if (requireEncryption) {
if (isRouterEncrypted) {
// Nothing further needed
Expand All @@ -242,7 +259,7 @@ private void putBlob(MessageInfo messageInfo, ByteBuffer messageBuf, long size)
Timer.Context encryptionTimer = vcrMetrics.blobEncryptionTime.time();
try {
messageBuf = cryptoAgent.encrypt(messageBuf);
bufferChanged = true;
encryptedSize = messageBuf.remaining();
} catch (GeneralSecurityException ex) {
vcrMetrics.blobEncryptionErrorCount.inc();
} finally {
Expand All @@ -258,9 +275,9 @@ private void putBlob(MessageInfo messageInfo, ByteBuffer messageBuf, long size)
}
CloudBlobMetadata blobMetadata =
new CloudBlobMetadata(blobId, messageInfo.getOperationTimeMs(), messageInfo.getExpirationTimeInMs(),
messageInfo.getSize(), encryptionOrigin, kmsContext, cryptoAgentFactoryClass);
messageInfo.getSize(), encryptionOrigin, kmsContext, cryptoAgentFactoryClass, encryptedSize);
// If buffer was encrypted, we no longer know its size
long bufferlen = bufferChanged ? -1 : size;
long bufferlen = (encryptedSize == -1) ? size : encryptedSize;
cloudDestination.uploadBlob(blobId, bufferlen, blobMetadata, new ByteBufferInputStream(messageBuf));
addToCache(blobId.getID(), BlobState.CREATED);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public void prefetchBlob(CloudBlobStore blobStore) throws StoreException {
// However, if in future, if very large size of blobs are allowed, then prefetching logic should be changed.
prefetchedBuffer = ByteBuffer.allocate((int) blobMetadata.getSize());
ByteBufferOutputStream outputStream = new ByteBufferOutputStream(prefetchedBuffer);
blobStore.downloadBlob(blobId, outputStream);
blobStore.downloadBlob(blobMetadata, blobId, outputStream);
isPrefetched = true;
}

Expand All @@ -138,7 +138,7 @@ public void prefetchBlob(CloudBlobStore blobStore) throws StoreException {
* @throws StoreException if blob download fails.
*/
public void downloadBlob(CloudBlobStore blobStore, OutputStream outputStream) throws StoreException {
blobStore.downloadBlob(blobId, outputStream);
blobStore.downloadBlob(blobMetadata, blobId, outputStream);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ public void testFindMissingKeys() throws Exception {
keys.add(existentBlobId);
metadataMap.put(existentBlobId.getID(),
new CloudBlobMetadata(existentBlobId, operationTime, Utils.Infinite_Time, 1024,
CloudBlobMetadata.EncryptionOrigin.ROUTER, null, null));
CloudBlobMetadata.EncryptionOrigin.ROUTER));
// Blob without metadata
BlobId nonexistentBlobId = getUniqueId();
keys.add(nonexistentBlobId);
Expand Down Expand Up @@ -560,7 +560,7 @@ private List<CloudBlobMetadata> generateMetadataList(long startTime, long blobSi
for (int j = 0; j < count; j++) {
BlobId blobId = getUniqueId();
CloudBlobMetadata metadata = new CloudBlobMetadata(blobId, startTime, Utils.Infinite_Time, blobSize,
CloudBlobMetadata.EncryptionOrigin.NONE, null, null);
CloudBlobMetadata.EncryptionOrigin.NONE);
metadata.setUploadTime(startTime + j);
metadataList.add(metadata);
}
Expand Down Expand Up @@ -629,7 +629,16 @@ private BlobId getUniqueId(short accountId, short containerId, boolean encrypted
*/
@Test
public void testStoreGets() throws Exception {
setupCloudStore(true, false, defaultCacheLimit, true);
testStoreGets(false);
testStoreGets(true);
}

/**
* Test cloud store get method with the given encryption requirement.
* @throws Exception
*/
private void testStoreGets(boolean requireEncryption) throws Exception {
setupCloudStore(true, requireEncryption, defaultCacheLimit, true);
// Put blobs with and without expiration and encryption
MockMessageWriteSet messageWriteSet = new MockMessageWriteSet();
int count = 5;
Expand Down Expand Up @@ -803,7 +812,7 @@ private BlobId forceUploadExpiredBlob() throws CloudStorageException {
long size = 1024;
long currentTime = System.currentTimeMillis();
CloudBlobMetadata expiredBlobMetadata =
new CloudBlobMetadata(expiredBlobId, currentTime, currentTime - 1, size, null, null, null);
new CloudBlobMetadata(expiredBlobId, currentTime, currentTime - 1, size, null);
ByteBuffer buffer = ByteBuffer.wrap(TestUtils.getRandomBytes((int) size));
InputStream inputStream = new ByteBufferInputStream(buffer);
dest.uploadBlob(expiredBlobId, size, expiredBlobMetadata, inputStream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ public void testPurge() throws Exception {
when(mockBlob.deleteIfExists(any(), any(), any(), any())).thenReturn(true);
CloudBlobMetadata cloudBlobMetadata =
new CloudBlobMetadata(blobId, System.currentTimeMillis(), Utils.Infinite_Time, blobSize,
CloudBlobMetadata.EncryptionOrigin.NONE, null, null);
CloudBlobMetadata.EncryptionOrigin.NONE);
assertTrue("Expected success", azureDest.purgeBlob(cloudBlobMetadata));
assertEquals(1, azureMetrics.blobDeletedCount.getCount());
assertEquals(0, azureMetrics.blobDeleteErrorCount.getCount());
Expand All @@ -204,7 +204,7 @@ public void testPurgeNotFound() throws Exception {
new DocumentClientException(404));
CloudBlobMetadata cloudBlobMetadata =
new CloudBlobMetadata(blobId, System.currentTimeMillis(), Utils.Infinite_Time, blobSize,
CloudBlobMetadata.EncryptionOrigin.NONE, null, null);
CloudBlobMetadata.EncryptionOrigin.NONE);
assertFalse("Expected false", azureDest.purgeBlob(cloudBlobMetadata));
assertEquals(0, azureMetrics.blobDeletedCount.getCount());
assertEquals(0, azureMetrics.blobDeleteErrorCount.getCount());
Expand Down Expand Up @@ -285,7 +285,7 @@ private void testQueryMetadata(int numBlobs, int expectedQueries) throws Excepti
BlobId blobId = generateBlobId();
blobIdList.add(blobId);
CloudBlobMetadata inputMetadata = new CloudBlobMetadata(blobId, creationTime, Utils.Infinite_Time, blobSize,
CloudBlobMetadata.EncryptionOrigin.NONE, null, null);
CloudBlobMetadata.EncryptionOrigin.NONE);
docList.add(new Document(objectMapper.writeValueAsString(inputMetadata)));
}
when(mockIterable.iterator()).thenReturn(docList.iterator());
Expand Down Expand Up @@ -332,7 +332,7 @@ public void testFindEntriesSince() throws Exception {
BlobId blobId = generateBlobId();
blobIdList.add(blobId.getID());
CloudBlobMetadata inputMetadata = new CloudBlobMetadata(blobId, creationTime, Utils.Infinite_Time, chunkSize,
CloudBlobMetadata.EncryptionOrigin.NONE, null, null);
CloudBlobMetadata.EncryptionOrigin.NONE);
inputMetadata.setUploadTime(startTime + j);
docList.add(new Document(objectMapper.writeValueAsString(inputMetadata)));
}
Expand Down Expand Up @@ -557,7 +557,7 @@ private void verifyUpdateErrorMetrics(int numUpdates, boolean isDocument) {
private boolean uploadDefaultBlob() throws CloudStorageException {
InputStream inputStream = getBlobInputStream(blobSize);
CloudBlobMetadata metadata = new CloudBlobMetadata(blobId, creationTime, Utils.Infinite_Time, blobSize,
CloudBlobMetadata.EncryptionOrigin.NONE, null, null);
CloudBlobMetadata.EncryptionOrigin.NONE);
return azureDest.uploadBlob(blobId, blobSize, metadata, inputStream);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public void testNormalFlow() throws Exception {
InputStream inputStream = new ByteArrayInputStream(uploadData);
CloudBlobMetadata cloudBlobMetadata =
new CloudBlobMetadata(blobId, System.currentTimeMillis(), Utils.Infinite_Time, blobSize,
CloudBlobMetadata.EncryptionOrigin.VCR, vcrKmsContext, cryptoAgentFactory);
CloudBlobMetadata.EncryptionOrigin.VCR, vcrKmsContext, cryptoAgentFactory, blobSize);
assertTrue("Expected upload to return true",
azureDest.uploadBlob(blobId, blobSize, cloudBlobMetadata, inputStream));

Expand Down Expand Up @@ -160,7 +160,7 @@ public void testBatchQuery() throws Exception {
byte[] randomBytes = TestUtils.getRandomBytes(blobSize);
blobIdtoDataMap.put(blobId, randomBytes);
CloudBlobMetadata cloudBlobMetadata = new CloudBlobMetadata(blobId, creationTime, Utils.Infinite_Time, blobSize,
CloudBlobMetadata.EncryptionOrigin.VCR, vcrKmsContext, cryptoAgentFactory);
CloudBlobMetadata.EncryptionOrigin.VCR, vcrKmsContext, cryptoAgentFactory, blobSize);
assertTrue("Expected upload to return true",
azureDest.uploadBlob(blobId, blobSize, cloudBlobMetadata, new ByteArrayInputStream(randomBytes)));
}
Expand Down Expand Up @@ -210,7 +210,7 @@ public void testPurgeDeadBlobs() throws Exception {
BlobDataType.DATACHUNK);
InputStream inputStream = getBlobInputStream(blobSize);
CloudBlobMetadata cloudBlobMetadata = new CloudBlobMetadata(blobId, creationTime, Utils.Infinite_Time, blobSize,
CloudBlobMetadata.EncryptionOrigin.VCR, vcrKmsContext, cryptoAgentFactory);
CloudBlobMetadata.EncryptionOrigin.VCR, vcrKmsContext, cryptoAgentFactory, blobSize);
assertTrue("Expected upload to return true",
azureDest.uploadBlob(blobId, blobSize, cloudBlobMetadata, inputStream));

Expand All @@ -220,7 +220,7 @@ public void testPurgeDeadBlobs() throws Exception {
BlobDataType.DATACHUNK);
inputStream = getBlobInputStream(blobSize);
cloudBlobMetadata = new CloudBlobMetadata(blobId, creationTime, Utils.Infinite_Time, blobSize,
CloudBlobMetadata.EncryptionOrigin.VCR, vcrKmsContext, cryptoAgentFactory);
CloudBlobMetadata.EncryptionOrigin.VCR, vcrKmsContext, cryptoAgentFactory, blobSize);
cloudBlobMetadata.setDeletionTime(timeOfDeath);
assertTrue("Expected upload to return true",
azureDest.uploadBlob(blobId, blobSize, cloudBlobMetadata, inputStream));
Expand All @@ -232,7 +232,7 @@ public void testPurgeDeadBlobs() throws Exception {
inputStream = getBlobInputStream(blobSize);
cloudBlobMetadata =
new CloudBlobMetadata(blobId, creationTime, timeOfDeath, blobSize, CloudBlobMetadata.EncryptionOrigin.VCR,
vcrKmsContext, cryptoAgentFactory);
vcrKmsContext, cryptoAgentFactory, blobSize);
assertTrue("Expected upload to return true",
azureDest.uploadBlob(blobId, blobSize, cloudBlobMetadata, inputStream));
expectedDeadBlobs++;
Expand All @@ -243,7 +243,7 @@ public void testPurgeDeadBlobs() throws Exception {
BlobDataType.DATACHUNK);
inputStream = getBlobInputStream(blobSize);
cloudBlobMetadata = new CloudBlobMetadata(blobId, creationTime, Utils.Infinite_Time, blobSize,
CloudBlobMetadata.EncryptionOrigin.VCR, vcrKmsContext, cryptoAgentFactory);
CloudBlobMetadata.EncryptionOrigin.VCR, vcrKmsContext, cryptoAgentFactory, blobSize);
cloudBlobMetadata.setDeletionTime(timeOfDeath);
assertTrue("Expected upload to return true",
azureDest.uploadBlob(blobId, blobSize, cloudBlobMetadata, inputStream));
Expand All @@ -254,7 +254,7 @@ public void testPurgeDeadBlobs() throws Exception {
inputStream = getBlobInputStream(blobSize);
cloudBlobMetadata =
new CloudBlobMetadata(blobId, creationTime, timeOfDeath, blobSize, CloudBlobMetadata.EncryptionOrigin.VCR,
vcrKmsContext, cryptoAgentFactory);
vcrKmsContext, cryptoAgentFactory, blobSize);
assertTrue("Expected upload to return true",
azureDest.uploadBlob(blobId, blobSize, cloudBlobMetadata, inputStream));
}
Expand Down Expand Up @@ -294,7 +294,7 @@ public void testFindEntriesSince() throws Exception {
BlobDataType.DATACHUNK);
InputStream inputStream = getBlobInputStream(chunkSize);
CloudBlobMetadata cloudBlobMetadata = new CloudBlobMetadata(blobId, startTime, Utils.Infinite_Time, chunkSize,
CloudBlobMetadata.EncryptionOrigin.VCR, vcrKmsContext, cryptoAgentFactory);
CloudBlobMetadata.EncryptionOrigin.VCR, vcrKmsContext, cryptoAgentFactory, chunkSize);
cloudBlobMetadata.setUploadTime(startTime + j * 1000);
assertTrue("Expected upload to return true",
azureDest.uploadBlob(blobId, blobSize, cloudBlobMetadata, inputStream));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public void setup() throws Exception {
blobId = new BlobId(BLOB_ID_V6, BlobIdType.NATIVE, dataCenterId, accountId, containerId, partitionId, false,
BlobDataType.DATACHUNK);
blobMetadata = new CloudBlobMetadata(blobId, System.currentTimeMillis(), Utils.Infinite_Time, blobSize,
CloudBlobMetadata.EncryptionOrigin.NONE, null, null);
CloudBlobMetadata.EncryptionOrigin.NONE);
azureMetrics = new AzureMetrics(new MetricRegistry());
cosmosAccessor = new CosmosDataAccessor(mockumentClient, "ambry/metadata", maxRetries, azureMetrics);
}
Expand Down

0 comments on commit d2545ca

Please sign in to comment.