diff --git a/fizz/crypto/aead/AEGISCipher.cpp b/fizz/crypto/aead/AEGISCipher.cpp index 71e8a245ebd..424ab60cd8a 100644 --- a/fizz/crypto/aead/AEGISCipher.cpp +++ b/fizz/crypto/aead/AEGISCipher.cpp @@ -186,81 +186,91 @@ folly::Optional> AEGISCipher::doDecrypt( std::unique_ptr&& ciphertext, const folly::IOBuf* associatedData, folly::ByteRange iv, - folly::ByteRange key, folly::MutableByteRange tagOut, bool inPlace) const { - const uint8_t* adData = nullptr; - size_t adLen = 0; - std::unique_ptr ad; - if (associatedData) { - if (associatedData->isChained()) { - ad = associatedData->cloneCoalesced(); - adData = ad->data(); - adLen = ad->length(); - } else { - adData = associatedData->data(); - adLen = associatedData->length(); - } - } + struct AeadImpl { + const AEGISCipher& self; - auto inputLength = ciphertext->computeChainDataLength(); - // the stateInit function will skip adding aad to the state when adData is - // null and adLen is 0 - impl_->stateInit(adData, adLen, iv.data(), key.data(), inputLength); - - folly::IOBuf* input; - std::unique_ptr output; - if (!inPlace) { - output = folly::IOBuf::create(inputLength); - output->append(inputLength); - input = ciphertext.get(); - } else { - output = std::move(ciphertext); - input = output.get(); - } + explicit AeadImpl(const AEGISCipher& s) : self(s) {} + void init( + folly::ByteRange iv, + const folly::IOBuf* associatedData, + size_t ciphertextLength) { + const uint8_t* adData = nullptr; + size_t adLen = 0; + std::unique_ptr ad; + if (associatedData) { + if (associatedData->isChained()) { + ad = associatedData->cloneCoalesced(); + adData = ad->data(); + adLen = ad->length(); + } else { + adData = associatedData->data(); + adLen = associatedData->length(); + } + } - struct Impl { - const AEGISCipher& self; - const unsigned char* expectedTag{nullptr}; - - explicit Impl(const AEGISCipher& s) : self(s) {} - bool decryptUpdate( - uint8_t* plain, - const uint8_t* cipher, - size_t len, - int* outLen) { - size_t tempOutLen; - auto ret = self.impl_->decryptUpdate(plain, &tempOutLen, cipher, len); - *outLen = static_cast(tempOutLen); - return ret == 0; - } - bool setExpectedTag(int /*tagSize*/, const unsigned char* tag) { - this->expectedTag = tag; - return true; + // the stateInit function will skip adding aad to the state when adData is + // null and adLen is 0 + // @lint-ignore CLANGTIDY facebook-hte-NullableDereference + self.impl_->stateInit( + adData, + adLen, + iv.data(), + self.trafficKeyKey_.data(), + ciphertextLength); } - bool decryptFinal(unsigned char* outm, int* outLen) { - size_t tempOutLen; - auto ret = - self.impl_->decryptFinal(outm, &tempOutLen, expectedTag, kTagLength); - *outLen = static_cast(tempOutLen); - return ret == 0; + + bool decryptAndFinal( + folly::IOBuf& ciphertext, + folly::IOBuf& plaintext, + folly::MutableByteRange tagOut) { + struct EVPDecImpl { + const AEGISCipher& self; + const unsigned char* expectedTag{nullptr}; + + explicit EVPDecImpl(const AEGISCipher& s) : self(s) {} + bool decryptUpdate( + uint8_t* plain, + const uint8_t* cipher, + size_t len, + int* outLen) { + size_t tempOutLen; + auto ret = self.impl_->decryptUpdate(plain, &tempOutLen, cipher, len); + *outLen = static_cast(tempOutLen); + return ret == 0; + } + bool setExpectedTag(int /*tagSize*/, const unsigned char* tag) { + this->expectedTag = tag; + return true; + } + bool decryptFinal(unsigned char* outm, int* outLen) { + size_t tempOutLen; + auto ret = self.impl_->decryptFinal( + outm, &tempOutLen, expectedTag, kTagLength); + *outLen = static_cast(tempOutLen); + return ret == 0; + } + }; + + if (self.mms_ == AEGISCipher::kAEGIS128LMMS) { + return decFuncBlocks( + EVPDecImpl(self), ciphertext, plaintext, tagOut); + } else if (self.mms_ == AEGISCipher::kAEGIS256MMS) { + return decFuncBlocks( + EVPDecImpl(self), ciphertext, plaintext, tagOut); + } else { + throw std::runtime_error("Unsupported AEGIS state size"); + } } }; - - bool decrypted; - if (mms_ == AEGISCipher::kAEGIS128LMMS) { - decrypted = decFuncBlocks( - Impl(*this), *input, *output, tagOut); - } else if (mms_ == AEGISCipher::kAEGIS256MMS) { - decrypted = decFuncBlocks( - Impl(*this), *input, *output, tagOut); - } else { - throw std::runtime_error("Unsupported AEGIS state size"); - } - if (!decrypted) { - return folly::none; - } - return output; + return decryptHelper( + AeadImpl{*this}, + std::move(ciphertext), + associatedData, + iv, + tagOut, + inPlace); } AEGISCipher::AEGISCipher( @@ -414,12 +424,7 @@ folly::Optional> AEGISCipher::tryDecrypt( trimBytes(*ciphertext, tagOut); } return doDecrypt( - std::move(ciphertext), - associatedData, - nonce, - trafficKeyKey_, - tagOut, - inPlace); + std::move(ciphertext), associatedData, nonce, tagOut, inPlace); } size_t AEGISCipher::getCipherOverhead() const { diff --git a/fizz/crypto/aead/AEGISCipher.h b/fizz/crypto/aead/AEGISCipher.h index 15e00c62b2c..01a746b0627 100644 --- a/fizz/crypto/aead/AEGISCipher.h +++ b/fizz/crypto/aead/AEGISCipher.h @@ -60,7 +60,6 @@ class AEGISCipher : public Aead { std::unique_ptr&& ciphertext, const folly::IOBuf* associatedData, folly::ByteRange iv, - folly::ByteRange key, folly::MutableByteRange tagOut, bool inPlace) const;