Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remove extra block assumptions in mbedtls integration #1323

Merged
merged 1 commit into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions packager/media/base/aes_cryptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ bool AesCryptor::Crypt(const std::vector<uint8_t>& text,
// Save text size to make it work for in-place conversion, since the
// next statement will update the text size.
const size_t text_size = text.size();
// mbedtls requires an extra block's worth of output buffer available.
crypt_text->resize(text_size + NumPaddingBytes(text_size) + AES_BLOCK_SIZE);
crypt_text->resize(text_size + NumPaddingBytes(text_size));
size_t crypt_text_size = crypt_text->size();
if (!Crypt(text.data(), text_size, crypt_text->data(), &crypt_text_size)) {
return false;
Expand All @@ -58,8 +57,7 @@ bool AesCryptor::Crypt(const std::string& text, std::string* crypt_text) {
// Save text size to make it work for in-place conversion, since the
// next statement will update the text size.
const size_t text_size = text.size();
// mbedtls requires an extra block's worth of output buffer available.
crypt_text->resize(text_size + NumPaddingBytes(text_size) + AES_BLOCK_SIZE);
crypt_text->resize(text_size + NumPaddingBytes(text_size));
size_t crypt_text_size = crypt_text->size();
if (!Crypt(reinterpret_cast<const uint8_t*>(text.data()), text_size,
reinterpret_cast<uint8_t*>(&(*crypt_text)[0]), &crypt_text_size))
Expand Down
114 changes: 50 additions & 64 deletions packager/media/base/aes_decryptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ bool AesCbcDecryptor::InitializeWithIv(const std::vector<uint8_t>& key,
}

size_t AesCbcDecryptor::RequiredOutputSize(size_t plaintext_size) {
// mbedtls requires a buffer large enough for one extra block.
return plaintext_size + AES_BLOCK_SIZE;
return plaintext_size;
}

bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
Expand All @@ -60,14 +59,12 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
// Plaintext size is the same as ciphertext size except for pkcs5 padding.
// Will update later if using pkcs5 padding. For pkcs5 padding, we still
// need at least |ciphertext_size| bytes for intermediate operation.
// mbedtls requires a buffer large enough for one extra block.
const size_t required_plaintext_size = ciphertext_size + AES_BLOCK_SIZE;
if (*plaintext_size < required_plaintext_size) {
LOG(ERROR) << "Expecting output size of at least "
<< required_plaintext_size << " bytes.";
if (*plaintext_size < ciphertext_size) {
LOG(ERROR) << "Expecting output size of at least " << ciphertext_size
<< " bytes.";
return false;
}
*plaintext_size = required_plaintext_size - AES_BLOCK_SIZE;
*plaintext_size = ciphertext_size;

// If the ciphertext size is 0, this can be a no-op decrypt, so long as the
// padding mode isn't PKCS5.
Expand All @@ -83,15 +80,9 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,

const size_t residual_block_size = ciphertext_size % AES_BLOCK_SIZE;
const size_t cbc_size = ciphertext_size - residual_block_size;

// Copy the residual block early, since mbedtls may overwrite one extra block
// of the output, and input and output may be the same buffer.
std::vector<uint8_t> residual_block(ciphertext + cbc_size,
ciphertext + ciphertext_size);
DCHECK_EQ(residual_block.size(), residual_block_size);

if (residual_block_size == 0) {
CbcDecryptBlocks(ciphertext, ciphertext_size, plaintext);
CbcDecryptBlocks(ciphertext, ciphertext_size, plaintext,
internal_iv_.data());
if (padding_scheme_ != kPkcs5Padding)
return true;

Expand All @@ -105,10 +96,11 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
*plaintext_size -= num_padding_bytes;
return true;
} else if (padding_scheme_ == kNoPadding) {
CbcDecryptBlocks(ciphertext, cbc_size, plaintext);

if (cbc_size > 0) {
CbcDecryptBlocks(ciphertext, cbc_size, plaintext, internal_iv_.data());
}
// The residual block is not encrypted.
memcpy(plaintext + cbc_size, residual_block.data(), residual_block_size);
memcpy(plaintext + cbc_size, ciphertext + cbc_size, residual_block_size);
return true;
} else if (padding_scheme_ != kCtsPadding) {
LOG(ERROR) << "Expecting cipher text size to be multiple of "
Expand All @@ -123,49 +115,44 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
return true;
}

// Copy the next-to-last block early, since mbedtls may overwrite one extra
// block of the output, and input and output may be the same buffer.
// NOTE: Before this point, there may not be such a block. Here, we know
// this is safe.
std::vector<uint8_t> next_to_last_block(
ciphertext + cbc_size - AES_BLOCK_SIZE, ciphertext + cbc_size);

// AES-CBC decrypt everything up to the next-to-last full block.
if (cbc_size > AES_BLOCK_SIZE) {
CbcDecryptBlocks(ciphertext, cbc_size - AES_BLOCK_SIZE, plaintext);
CbcDecryptBlocks(ciphertext, cbc_size - AES_BLOCK_SIZE, plaintext,
internal_iv_.data());
}

const uint8_t* next_to_last_ciphertext_block =
ciphertext + ciphertext_size - residual_block_size - AES_BLOCK_SIZE;
uint8_t* next_to_last_plaintext_block =
plaintext + ciphertext_size - residual_block_size - AES_BLOCK_SIZE;

// Determine what the last IV should be so that we can "skip ahead" in the
// CBC decryption.
std::vector<uint8_t> last_iv(
ciphertext + ciphertext_size - residual_block_size,
ciphertext + ciphertext_size);
last_iv.resize(AES_BLOCK_SIZE, 0);

// Decrypt the next-to-last block using the IV determined above. This decrypts
// the residual block bits.
CbcDecryptBlocks(next_to_last_ciphertext_block, AES_BLOCK_SIZE,
next_to_last_plaintext_block, last_iv.data());

// Swap back the residual block bits and the next-to-last block.
if (plaintext == ciphertext) {
std::swap_ranges(next_to_last_plaintext_block,
next_to_last_plaintext_block + residual_block_size,
next_to_last_plaintext_block + AES_BLOCK_SIZE);
} else {
memcpy(next_to_last_plaintext_block + AES_BLOCK_SIZE,
next_to_last_plaintext_block, residual_block_size);
memcpy(next_to_last_plaintext_block,
next_to_last_ciphertext_block + AES_BLOCK_SIZE, residual_block_size);
}

uint8_t* next_to_last_plaintext_block = plaintext + cbc_size - AES_BLOCK_SIZE;

// The next-to-last block should be decrypted first in ECB mode, which is
// effectively what you get with an IV of all zeroes.
std::vector<uint8_t> backup_iv(internal_iv_);
internal_iv_.assign(AES_BLOCK_SIZE, 0);
// mbedtls requires a buffer large enough for one extra block.
std::vector<uint8_t> stolen_bits(AES_BLOCK_SIZE * 2);
CbcDecryptBlocks(next_to_last_block.data(), AES_BLOCK_SIZE,
stolen_bits.data());

// Reconstruct the final two blocks of ciphertext.
std::vector<uint8_t> reconstructed_blocks(AES_BLOCK_SIZE * 2);
memcpy(reconstructed_blocks.data(), residual_block.data(),
residual_block_size);
memcpy(reconstructed_blocks.data() + residual_block_size,
stolen_bits.data() + residual_block_size,
AES_BLOCK_SIZE - residual_block_size);
memcpy(reconstructed_blocks.data() + AES_BLOCK_SIZE,
next_to_last_block.data(), AES_BLOCK_SIZE);

// Decrypt the last two blocks.
internal_iv_ = backup_iv;
// mbedtls requires a buffer large enough for one extra block.
std::vector<uint8_t> final_output_blocks(AES_BLOCK_SIZE * 3);
CbcDecryptBlocks(reconstructed_blocks.data(), AES_BLOCK_SIZE * 2,
final_output_blocks.data());

// Copy the final output.
memcpy(next_to_last_plaintext_block, final_output_blocks.data(),
AES_BLOCK_SIZE + residual_block_size);
// Decrypt the next-to-last full block.
CbcDecryptBlocks(next_to_last_plaintext_block, AES_BLOCK_SIZE,
next_to_last_plaintext_block, internal_iv_.data());
return true;
}

Expand All @@ -176,7 +163,8 @@ void AesCbcDecryptor::SetIvInternal() {

void AesCbcDecryptor::CbcDecryptBlocks(const uint8_t* ciphertext,
size_t ciphertext_size,
uint8_t* plaintext) {
uint8_t* plaintext,
uint8_t* iv) {
CHECK_EQ(ciphertext_size % AES_BLOCK_SIZE, 0u);
CHECK_GT(ciphertext_size, 0u);

Expand All @@ -186,14 +174,12 @@ void AesCbcDecryptor::CbcDecryptBlocks(const uint8_t* ciphertext,
std::vector<uint8_t> next_iv(last_block, last_block + AES_BLOCK_SIZE);

size_t output_size = 0;
CHECK_EQ(mbedtls_cipher_crypt(&cipher_ctx_, internal_iv_.data(),
AES_BLOCK_SIZE, ciphertext, ciphertext_size,
plaintext, &output_size),
CHECK_EQ(mbedtls_cipher_crypt(&cipher_ctx_, iv, AES_BLOCK_SIZE, ciphertext,
ciphertext_size, plaintext, &output_size),
0);
DCHECK_EQ(output_size % AES_BLOCK_SIZE, 0u);

// Update the internal IV.
internal_iv_ = next_iv;
memcpy(iv, next_iv.data(), next_iv.size());
}

} // namespace media
Expand Down
3 changes: 2 additions & 1 deletion packager/media/base/aes_decryptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class AesCbcDecryptor : public AesCryptor {
void SetIvInternal() override;
void CbcDecryptBlocks(const uint8_t* plaintext,
size_t plaintext_size,
uint8_t* ciphertext);
uint8_t* ciphertext,
uint8_t* iv);

const CbcPaddingScheme padding_scheme_;
// 16-byte internal iv for crypto operations.
Expand Down
46 changes: 18 additions & 28 deletions packager/media/base/aes_encryptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ namespace media {
AesCtrEncryptor::AesCtrEncryptor()
: AesCryptor(kDontUseConstantIv),
block_offset_(0),
// mbedtls requires an extra output block.
encrypted_counter_(AES_BLOCK_SIZE * 2, 0) {}
encrypted_counter_(AES_BLOCK_SIZE, 0) {}

AesCtrEncryptor::~AesCtrEncryptor() {}

Expand Down Expand Up @@ -129,8 +128,7 @@ bool AesCbcEncryptor::InitializeWithIv(const std::vector<uint8_t>& key,
}

size_t AesCbcEncryptor::RequiredOutputSize(size_t plaintext_size) {
// mbedtls requires a buffer large enough for one extra block.
return plaintext_size + NumPaddingBytes(plaintext_size) + AES_BLOCK_SIZE;
return plaintext_size + NumPaddingBytes(plaintext_size);
}

bool AesCbcEncryptor::CryptInternal(const uint8_t* plaintext,
Expand All @@ -146,19 +144,12 @@ bool AesCbcEncryptor::CryptInternal(const uint8_t* plaintext,
<< required_ciphertext_size << " bytes.";
return false;
}
*ciphertext_size = required_ciphertext_size - AES_BLOCK_SIZE;
*ciphertext_size = required_ciphertext_size;

// Encrypt everything but the residual block using CBC.
const size_t cbc_size = plaintext_size - residual_block_size;

// Copy the residual block early, since mbedtls may overwrite one extra block
// of the output, and input and output may be the same buffer.
std::vector<uint8_t> residual_block(plaintext + cbc_size,
plaintext + plaintext_size);
DCHECK_EQ(residual_block.size(), residual_block_size);

if (cbc_size != 0) {
CbcEncryptBlocks(plaintext, cbc_size, ciphertext);
CbcEncryptBlocks(plaintext, cbc_size, ciphertext, internal_iv_.data());
} else if (padding_scheme_ == kCtsPadding) {
// Don't have a full block, leave unencrypted.
memcpy(ciphertext, plaintext, plaintext_size);
Expand All @@ -175,27 +166,26 @@ bool AesCbcEncryptor::CryptInternal(const uint8_t* plaintext,
return true;
}

std::vector<uint8_t> residual_block(plaintext + cbc_size,
plaintext + plaintext_size);
DCHECK_EQ(residual_block.size(), residual_block_size);
uint8_t* residual_ciphertext_block = ciphertext + cbc_size;

if (padding_scheme_ == kPkcs5Padding) {
DCHECK_EQ(num_padding_bytes, AES_BLOCK_SIZE - residual_block_size);

// Pad residue block with PKCS5 padding.
residual_block.resize(AES_BLOCK_SIZE, static_cast<char>(num_padding_bytes));

CbcEncryptBlocks(residual_block.data(), AES_BLOCK_SIZE,
residual_ciphertext_block);
residual_ciphertext_block, internal_iv_.data());
} else {
DCHECK_EQ(num_padding_bytes, 0u);
DCHECK_EQ(padding_scheme_, kCtsPadding);

// Zero-pad the residual block and encrypt using CBC.
residual_block.resize(AES_BLOCK_SIZE, 0);
// mbedtls requires an extra block in the output buffer, and it cannot be
// the same as the input buffer.
std::vector<uint8_t> encrypted_residual_block(AES_BLOCK_SIZE * 2);

CbcEncryptBlocks(residual_block.data(), AES_BLOCK_SIZE,
encrypted_residual_block.data());
residual_block.data(), internal_iv_.data());

// Replace the last full block with the zero-padded, encrypted residual
// block, and replace the residual block with the equivalent portion of the
Expand All @@ -206,8 +196,8 @@ bool AesCbcEncryptor::CryptInternal(const uint8_t* plaintext,
// https://en.wikipedia.org/wiki/Ciphertext_stealing#CS2
memcpy(residual_ciphertext_block,
residual_ciphertext_block - AES_BLOCK_SIZE, residual_block_size);
memcpy(residual_ciphertext_block - AES_BLOCK_SIZE,
encrypted_residual_block.data(), AES_BLOCK_SIZE);
memcpy(residual_ciphertext_block - AES_BLOCK_SIZE, residual_block.data(),
AES_BLOCK_SIZE);
}
return true;
}
Expand All @@ -225,20 +215,20 @@ size_t AesCbcEncryptor::NumPaddingBytes(size_t size) const {

void AesCbcEncryptor::CbcEncryptBlocks(const uint8_t* plaintext,
size_t plaintext_size,
uint8_t* ciphertext) {
uint8_t* ciphertext,
uint8_t* iv) {
CHECK_EQ(plaintext_size % AES_BLOCK_SIZE, 0u);

size_t output_size = 0;
CHECK_EQ(
mbedtls_cipher_crypt(&cipher_ctx_, internal_iv_.data(), AES_BLOCK_SIZE,
plaintext, plaintext_size, ciphertext, &output_size),
0);
CHECK_EQ(mbedtls_cipher_crypt(&cipher_ctx_, iv, AES_BLOCK_SIZE, plaintext,
plaintext_size, ciphertext, &output_size),
0);

CHECK_EQ(output_size % AES_BLOCK_SIZE, 0u);
CHECK_GT(output_size, 0u);

uint8_t* last_block = ciphertext + output_size - AES_BLOCK_SIZE;
internal_iv_.assign(last_block, last_block + AES_BLOCK_SIZE);
memcpy(iv, last_block, AES_BLOCK_SIZE);
}

} // namespace media
Expand Down
3 changes: 2 additions & 1 deletion packager/media/base/aes_encryptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class AesCbcEncryptor : public AesCryptor {

void CbcEncryptBlocks(const uint8_t* plaintext,
size_t plaintext_size,
uint8_t* ciphertext);
uint8_t* ciphertext,
uint8_t* iv);

const CbcPaddingScheme padding_scheme_;
// 16-byte internal iv for crypto operations.
Expand Down
5 changes: 1 addition & 4 deletions packager/media/base/playready_pssh_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ void AesEcbEncrypt(const std::vector<uint8_t>& key,
const std::vector<uint8_t>& plaintext,
std::vector<uint8_t>* ciphertext) {
CHECK_EQ(plaintext.size() % AES_BLOCK_SIZE, 0u);
// mbedtls requires an extra block worth of output buffer.
ciphertext->resize(plaintext.size() + AES_BLOCK_SIZE);
ciphertext->resize(plaintext.size());

mbedtls_cipher_context_t ctx;
mbedtls_cipher_init(&ctx);
Expand All @@ -98,8 +97,6 @@ void AesEcbEncrypt(const std::vector<uint8_t>& key,
plaintext.data(), plaintext.size(),
ciphertext->data(), &output_size),
0);
// Truncate the output to the correct size.
ciphertext->resize(plaintext.size());

mbedtls_cipher_free(&ctx);
}
Expand Down
Loading