From 1a1e796f4fa6ce7385c73b031aef5d4b1877b170 Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Sat, 5 Oct 2024 09:24:10 +0530 Subject: [PATCH] Refactor --- velox/common/encode/Base64.cpp | 324 +++++++++--------- velox/common/encode/Base64.h | 166 ++++----- velox/common/encode/EncoderUtils.h | 89 ++++- velox/common/encode/tests/Base64Test.cpp | 177 +++------- .../common/encode/tests/EncoderUtilsTests.cpp | 10 +- 5 files changed, 377 insertions(+), 389 deletions(-) diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index ef6da4cb1b70..f5a0a8c54d37 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -18,7 +18,7 @@ #include #include #include -#include +#include namespace facebook::velox::encoding { @@ -85,6 +85,15 @@ constexpr const Base64::ReverseIndex kBase64UrlReverseIndexTable = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}; +// Validate the character in charset with ReverseIndex table +constexpr bool checkForwardIndex( + uint8_t index, + const Base64::Charset& charset, + const Base64::ReverseIndex& reverseIndex) { + return (reverseIndex[static_cast(charset[index])] == index) && + (index > 0 ? checkForwardIndex(index - 1, charset, reverseIndex) : true); +} + // Verify that for every entry in kBase64Charset, the corresponding entry // in kBase64ReverseIndexTable is correct. static_assert( @@ -103,6 +112,28 @@ static_assert( kBase64UrlReverseIndexTable), "kBase64UrlCharset has incorrect entries"); +// Searches for a character within a charset up to a certain index. +constexpr bool findCharacterInCharset( + const Base64::Charset& charset, + uint8_t index, + const char targetChar) { + return index < charset.size() && + ((charset[index] == targetChar) || + findCharacterInCharset(charset, index + 1, targetChar)); +} + +// Checks the consistency of a reverse index mapping for a given character +// set. +constexpr bool checkReverseIndex( + uint8_t index, + const Base64::Charset& charset, + const Base64::ReverseIndex& reverseIndex) { + return (reverseIndex[index] == 255 + ? !findCharacterInCharset(charset, 0, static_cast(index)) + : (charset[reverseIndex[index]] == index)) && + (index > 0 ? checkReverseIndex(index - 1, charset, reverseIndex) : true); +} + // Verify that for every entry in kBase64ReverseIndexTable, the corresponding // entry in kBase64Charset is correct. static_assert( @@ -128,13 +159,11 @@ static_assert( template std::string Base64::encodeImpl( const T& input, - const Base64::Charset& charset, + const Charset& charset, bool includePadding) { - size_t outputSize = calculateEncodedSize(input.size(), includePadding); - std::string output; - output.resize(outputSize); - encodeImpl(input, charset, includePadding, output.data()); - return output; + std::string encodedResult; + (void)encodeImpl(input, charset, includePadding, encodedResult); + return encodedResult; } // static @@ -146,18 +175,19 @@ size_t Base64::calculateEncodedSize(size_t inputSize, bool includePadding) { // Calculate the output size assuming that we are including padding. size_t encodedSize = ((inputSize + 2) / 3) * 4; if (!includePadding) { + // If the padding was not requested, subtract the padding bytes. encodedSize -= (3 - (inputSize % 3)) % 3; } return encodedSize; } // static -Status Base64::encode(std::string_view input, char* output) { +Status Base64::encode(std::string_view input, std::string& output) { return encodeImpl(input, kBase64Charset, true, output); } // static -Status Base64::encodeUrl(std::string_view input, char* output) { +Status Base64::encodeUrl(std::string_view input, std::string& output) { return encodeImpl(input, kBase64UrlCharset, true, output); } @@ -165,81 +195,90 @@ Status Base64::encodeUrl(std::string_view input, char* output) { template Status Base64::encodeImpl( const T& input, - const Base64::Charset& charset, + const Charset& charset, bool includePadding, - char* output) { + std::string& output) { auto inputSize = input.size(); if (inputSize == 0) { + output.clear(); return Status::OK(); } - auto outputPtr = output; - auto dataIterator = input.begin(); + // Calculate the output size and resize the string beforehand + size_t outputSize = calculateEncodedSize(inputSize, includePadding); + output.resize(outputSize); // Resize the output string to the required size + + // Use a pointer to write into the pre-allocated buffer + auto outputPointer = output.data(); + auto inputIterator = input.begin(); + // Encode input in chunks of 3 bytes for (; inputSize > 2; inputSize -= 3) { - uint32_t currentBlock = uint8_t(*dataIterator++) << 16; - currentBlock |= uint8_t(*dataIterator++) << 8; - currentBlock |= uint8_t(*dataIterator++); - - *outputPtr++ = charset[(currentBlock >> 18) & 0x3f]; - *outputPtr++ = charset[(currentBlock >> 12) & 0x3f]; - *outputPtr++ = charset[(currentBlock >> 6) & 0x3f]; - *outputPtr++ = charset[currentBlock & 0x3f]; + uint32_t inputBlock = uint8_t(*inputIterator++) << 16; + inputBlock |= uint8_t(*inputIterator++) << 8; + inputBlock |= uint8_t(*inputIterator++); + + *outputPointer++ = charset[(inputBlock >> 18) & 0x3f]; + *outputPointer++ = charset[(inputBlock >> 12) & 0x3f]; + *outputPointer++ = charset[(inputBlock >> 6) & 0x3f]; + *outputPointer++ = charset[inputBlock & 0x3f]; } + // Handle remaining bytes (1 or 2 bytes) if (inputSize > 0) { - uint32_t currentBlock = uint8_t(*dataIterator++) << 16; - *outputPtr++ = charset[(currentBlock >> 18) & 0x3f]; + uint32_t inputBlock = uint8_t(*inputIterator++) << 16; + *outputPointer++ = charset[(inputBlock >> 18) & 0x3f]; if (inputSize > 1) { - currentBlock |= uint8_t(*dataIterator) << 8; - *outputPtr++ = charset[(currentBlock >> 12) & 0x3f]; - *outputPtr++ = charset[(currentBlock >> 6) & 0x3f]; + inputBlock |= uint8_t(*inputIterator) << 8; + *outputPointer++ = charset[(inputBlock >> 12) & 0x3f]; + *outputPointer++ = charset[(inputBlock >> 6) & 0x3f]; if (includePadding) { - *outputPtr = kPadding; + *outputPointer++ = kPadding; } } else { - *outputPtr++ = charset[(currentBlock >> 12) & 0x3f]; + *outputPointer++ = charset[(inputBlock >> 12) & 0x3f]; if (includePadding) { - *outputPtr++ = kPadding; - *outputPtr = kPadding; + *outputPointer++ = kPadding; + *outputPointer++ = kPadding; } } } + return Status::OK(); } // static -std::string Base64::encode(std::string_view text) { - return encodeImpl(text, kBase64Charset, true); +std::string Base64::encode(folly::StringPiece input) { + return encodeImpl(input, kBase64Charset, true); } // static -std::string Base64::encode(std::string_view input, size_t /*len*/) { - return encodeImpl(input, kBase64Charset, true); +std::string Base64::encode(const char* input, size_t inputSize) { + return encode(folly::StringPiece(input, inputSize)); } namespace { /** - * this is a quick and dirty iterator implementation for an IOBuf so that the - * template that uses iterators can work on IOBuf chains. It only implements - * postfix increment because that is all the algorithm needs, and it is a noop - * since the read<>() function already incremented the cursor. + * This is a quick and simple iterator implementation for an IOBuf so that the + * template that uses iterators can work on IOBuf chains. It only implements + * postfix increment because that is all the algorithm needs, and it is a no-op + * since the read<>() function already increments the cursor. */ class IOBufWrapper { private: class Iterator { public: - explicit Iterator(const folly::IOBuf* data) : cursor_(data) {} + explicit Iterator(const folly::IOBuf* inputBuffer) : cursor_(inputBuffer) {} Iterator& operator++(int32_t) { - // This is a noop since reading from the Cursor has already moved the - // position + // This is a no-op since reading from the Cursor has already moved the + // position. return *this; } uint8_t operator*() { - // This will read _and_ increment + // This will read _and_ increment the cursor. return cursor_.read(); } @@ -248,67 +287,67 @@ class IOBufWrapper { }; public: - explicit IOBufWrapper(const folly::IOBuf* data) : data_(data) {} - + explicit IOBufWrapper(const folly::IOBuf* inputBuffer) + : input_(inputBuffer) {} size_t size() const { - return data_->computeChainDataLength(); + return input_->computeChainDataLength(); } Iterator begin() const { - return Iterator(data_); + return Iterator(input_); } private: - const folly::IOBuf* data_; + const folly::IOBuf* input_; }; } // namespace // static -std::string Base64::encode(const folly::IOBuf* input) { - return encodeImpl(IOBufWrapper(input), kBase64Charset, true); +std::string Base64::encode(const folly::IOBuf* inputBuffer) { + return encodeImpl(IOBufWrapper(inputBuffer), kBase64Charset, true); } // static -std::string Base64::decode(std::string_view encoded) { - std::string output; - Base64::decode(encoded, output); - return output; +std::string Base64::decode(folly::StringPiece encodedText) { + std::string decodedOutput; + std::string_view input(encodedText.data(), encodedText.size()); + (void)decodeImpl(input, decodedOutput, kBase64ReverseIndexTable); + return decodedOutput; } // static -void Base64::decode(std::string_view input, std::string& output) { - size_t inputSize{input.size()}; - size_t decodedSize; - - (void)calculateDecodedSize(input, inputSize, decodedSize); - output.resize(decodedSize); - (void)decode(input.data(), inputSize, output.data(), output.size()); +void Base64::decode( + const std::pair& payload, + std::string& decodedOutput) { + std::string_view input(payload.first, payload.second); + (void)decodeImpl(input, decodedOutput, kBase64ReverseIndexTable); } // static -void Base64::decode(std::string_view input, size_t inputSize, char* output) { - size_t outputSize; - (void)calculateDecodedSize(input, inputSize, outputSize); - (void)decode(input, inputSize, output, outputSize); +void Base64::decode(const char* input, size_t inputSize, char* outputBuffer) { + std::string_view inputView(input, inputSize); + std::string output; + (void)decodeImpl(inputView, output, kBase64ReverseIndexTable); + memcpy(outputBuffer, output.data(), output.size()); } // static uint8_t Base64::base64ReverseLookup( - char p, - const Base64::ReverseIndex& reverseIndex, + char encodedChar, + const ReverseIndex& reverseIndex, Status& status) { - return reverseLookup(p, reverseIndex, status, Base64::kCharsetSize); + auto reverseLookupValue = reverseIndex[static_cast(encodedChar)]; + if (reverseLookupValue >= 0x40) { + status = Status::UserError( + "decode() - invalid input string: invalid characters"); + } + return reverseLookupValue; } // static -Status Base64::decode( - std::string_view input, - size_t inputSize, - char* output, - size_t outputSize) { - return decodeImpl( - input, inputSize, output, outputSize, kBase64ReverseIndexTable); +Status Base64::decode(std::string_view input, std::string& output) { + return decodeImpl(input, output, kBase64ReverseIndexTable); } // static @@ -321,27 +360,27 @@ Status Base64::calculateDecodedSize( return Status::OK(); } - // Check if the input data is padded - if (isPadded(input, inputSize)) { + // Check if the input string is padded + if (isPadded(input)) { // If padded, ensure that the string length is a multiple of the encoded // block size if (inputSize % kEncodedBlockByteSize != 0) { return Status::UserError( - "Base64::decode() - invalid input string: string length is not a multiple of 4."); + "Base64::decode() - invalid input string: " + "string length is not a multiple of 4."); } decodedSize = (inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize; - auto padding = numPadding(input, inputSize); - inputSize -= padding; + auto paddingCount = numPadding(input); + inputSize -= paddingCount; // Adjust the needed size by deducting the bytes corresponding to the // padding from the calculated size. decodedSize -= - ((padding * kBinaryBlockByteSize) + (kEncodedBlockByteSize - 1)) / + ((paddingCount * kBinaryBlockByteSize) + (kEncodedBlockByteSize - 1)) / kEncodedBlockByteSize; return Status::OK(); } - // If not padded, calculate extra bytes, if any auto extraBytes = inputSize % kEncodedBlockByteSize; decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize; @@ -350,7 +389,8 @@ Status Base64::calculateDecodedSize( if (extraBytes) { if (extraBytes == 1) { return Status::UserError( - "Base64::decode() - invalid input string: string length cannot be 1 more than a multiple of 4."); + "Base64::decode() - invalid input string: " + "string length cannot be 1 more than a multiple of 4."); } decodedSize += (extraBytes * kBinaryBlockByteSize) / kEncodedBlockByteSize; } @@ -361,28 +401,27 @@ Status Base64::calculateDecodedSize( // static Status Base64::decodeImpl( std::string_view input, - size_t inputSize, - char* output, - size_t outputSize, - const Base64::ReverseIndex& reverseIndex) { + std::string& output, + const ReverseIndex& reverseIndex) { + size_t inputSize = input.size(); if (inputSize == 0) { + output.clear(); return Status::OK(); } + // Calculate the decoded size based on the input size size_t decodedSize; - // Calculate decoded size and check for status auto status = calculateDecodedSize(input, inputSize, decodedSize); if (!status.ok()) { return status; } - if (outputSize < decodedSize) { - return Status::UserError( - "Base64::decode() - invalid output string: output string is too small."); - } + // Resize the output string to fit the decoded data + output.resize(decodedSize); + // Set up input and output pointers const char* inputPtr = input.data(); - char* outputPtr = output; + char* outputPointer = output.data(); Status lookupStatus; // Process full blocks of 4 characters @@ -393,100 +432,77 @@ Status Base64::decodeImpl( uint8_t val2 = base64ReverseLookup(inputPtr[2], reverseIndex, lookupStatus); uint8_t val3 = base64ReverseLookup(inputPtr[3], reverseIndex, lookupStatus); - uint32_t currentBlock = (val0 << 18) | (val1 << 12) | (val2 << 6) | val3; - outputPtr[0] = static_cast((currentBlock >> 16) & 0xFF); - outputPtr[1] = static_cast((currentBlock >> 8) & 0xFF); - outputPtr[2] = static_cast(currentBlock & 0xFF); + if (!lookupStatus.ok()) { + return lookupStatus; + } + + uint32_t inputBlock = (val0 << 18) | (val1 << 12) | (val2 << 6) | val3; + outputPointer[0] = static_cast((inputBlock >> 16) & 0xFF); + outputPointer[1] = static_cast((inputBlock >> 8) & 0xFF); + outputPointer[2] = static_cast(inputBlock & 0xFF); inputPtr += 4; - outputPtr += 3; + outputPointer += 3; } - // Handle the last block (2-3 characters) + // Handle remaining characters (2 or 3 characters at the end) size_t remaining = inputSize % 4; if (remaining > 1) { uint8_t val0 = base64ReverseLookup(inputPtr[0], reverseIndex, lookupStatus); uint8_t val1 = base64ReverseLookup(inputPtr[1], reverseIndex, lookupStatus); - uint32_t currentBlock = (val0 << 18) | (val1 << 12); - outputPtr[0] = static_cast((currentBlock >> 16) & 0xFF); + uint32_t inputBlock = (val0 << 18) | (val1 << 12); + outputPointer[0] = static_cast((inputBlock >> 16) & 0xFF); if (remaining == 3) { uint8_t val2 = base64ReverseLookup(inputPtr[2], reverseIndex, lookupStatus); - currentBlock |= (val2 << 6); - outputPtr[1] = static_cast((currentBlock >> 8) & 0xFF); + inputBlock |= (val2 << 6); + outputPointer[1] = static_cast((inputBlock >> 8) & 0xFF); } } - if (!lookupStatus.ok()) + + // Check for any lookup errors + if (!lookupStatus.ok()) { return lookupStatus; + } + return Status::OK(); } // static -std::string Base64::encodeUrl(std::string_view input) { - return encodeImpl(input, kBase64UrlCharset, false); +std::string Base64::encodeUrl(folly::StringPiece text) { + return encodeImpl(text, kBase64UrlCharset, false); } // static -std::string Base64::encodeUrl(const folly::IOBuf* input) { - return encodeImpl(IOBufWrapper(input), kBase64UrlCharset, false); +std::string Base64::encodeUrl(const char* input, size_t inputSize) { + return encodeUrl(folly::StringPiece(input, inputSize)); } // static -Status Base64::decodeUrl( - std::string_view input, - size_t inputSize, - char* output, - size_t outputSize) { - return decodeImpl( - input, inputSize, output, outputSize, kBase64UrlReverseIndexTable); +std::string Base64::encodeUrl(const folly::IOBuf* inputBuffer) { + return encodeImpl(IOBufWrapper(inputBuffer), kBase64UrlCharset, false); } // static -std::string Base64::decodeUrl(std::string_view input) { - std::string output; - Base64::decodeUrl(input, output); - return output; +Status Base64::decodeUrl(std::string_view input, std::string& output) { + return decodeImpl(input, output, kBase64UrlReverseIndexTable); } // static -void Base64::decodeUrl(std::string_view input, std::string& output) { - // Early exit if input is empty - if (input.empty()) { - output.clear(); - return; - } - - size_t inputSize = input.size(); - size_t outputSize; - - // Calculate the size for the decoded output - auto status = calculateDecodedSize(input, inputSize, outputSize); - if (!status.ok()) { - // status is discarded here, but could be used to handle the error - output.clear(); - return; - } - - // Resize the output string to the calculated size - output.resize(outputSize); - - // Perform the actual decoding - status = Base64::decodeImpl( - input.data(), - inputSize, - output.data(), - outputSize, - kBase64UrlReverseIndexTable); - - if (!status.ok()) { - // status is discarded here, but could be used to handle the error - output.clear(); - return; - } +std::string Base64::decodeUrl(folly::StringPiece encodedText) { + std::string decodedOutput; + std::string_view input(encodedText.data(), encodedText.size()); + (void)decodeImpl(input, decodedOutput, kBase64UrlReverseIndexTable); + return decodedOutput; +} - // Resize the output to match the actual size of the decoded data - output.resize(outputSize); +// static +void Base64::decodeUrl( + const std::pair& payload, + std::string& decodedOutput) { + std::string_view inputView(payload.first, payload.second); + (void)decodeImpl(inputView, decodedOutput, kBase64UrlReverseIndexTable); } } // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index 7041a181662e..c45e745c8e8b 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -13,17 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#pragma once -#include -#include +#pragma once #include #include - +#include +#include #include "velox/common/base/GTestMacros.h" #include "velox/common/base/Status.h" -#include "velox/common/encode/EncoderUtils.h" namespace facebook::velox::encoding { @@ -32,112 +30,98 @@ class Base64 { static const size_t kCharsetSize = 64; static const size_t kReverseIndexSize = 256; - /// Character set used for encoding purposes. - /// Contains specific characters that form the encoding scheme. + /// Character set used for Base64 encoding. using Charset = std::array; - /// Reverse lookup table for decoding purposes. - /// Maps each possible encoded character to its corresponding numeric value - /// within the encoding base. + /// Reverse lookup table for decoding. using ReverseIndex = std::array; - /// Encodes the specified number of characters from the 'input'. - static std::string encode(std::string_view input, size_t len); - - /// Encodes the specified text. - static std::string encode(std::string_view text); - - /// Encodes the specified IOBuf input. - static std::string encode(const folly::IOBuf* input); - - /// Returns encoded size for the input of the specified size. - static size_t calculateEncodedSize( - size_t inputSize, - bool includePadding = true); - - /// Encodes the specified number of characters from the 'input' and writes the - /// result to the 'output'. The output must have enough space, e.g., as - /// returned by calculateEncodedSize(). - static Status encode(std::string_view input, char* output); - - /// Decodes the specified encoded text. - static std::string decode(std::string_view encoded); - - /// Returns the actual size of the decoded data. Will also remove the padding - /// length from the 'inputSize'. - static Status calculateDecodedSize( - std::string_view input, - size_t& inputSize, - size_t& decodedSize); - - /// Decodes the specified number of characters from the 'input' and writes the - /// result to the 'output'. The output must have enough space, e.g., as - /// returned by calculateDecodedSize(). - static void decode(std::string_view input, size_t inputSize, char* output); - - static void decode(std::string_view input, std::string& output); - - /// Encodes the specified number of characters from the 'input' and writes the - /// result to the 'output' using URL encoding. The output must have enough - /// space as returned by calculateEncodedSize(). - static Status encodeUrl(std::string_view input, char* output); - - /// Encodes the specified IOBuf input using URL encoding. - static std::string encodeUrl(const folly::IOBuf* input); - - /// Encodes the specified text using URL encoding. - static std::string encodeUrl(std::string_view input); - - /// Decodes the specified URL encoded input and writes the result to the - /// 'output'. - static void decodeUrl(std::string_view input, std::string& output); - - /// Decodes the specified URL encoded text. - static std::string decodeUrl(std::string_view input); - - /// Decodes the specified number of characters from the 'input' and writes the - /// result to the 'output'. - static Status decode( - std::string_view input, - size_t inputSize, - char* output, - size_t outputSize); - - /// Decodes the specified number of characters from the 'input' using URL - /// encoding and writes the result to the 'output'. - static Status decodeUrl( - std::string_view input, - size_t inputSize, - char* output, - size_t outputSize); + /// Padding character used in encoding. + static const char kPadding = '='; + + // Encoding Functions + /// Encodes the input data using Base64 encoding. + static std::string encode(const char* input, size_t inputSize); + static std::string encode(folly::StringPiece input); + static std::string encode(const folly::IOBuf* inputBuffer); + static Status encode(std::string_view input, std::string& outputBuffer); + + /// Encodes the input data using Base64 URL encoding. + static std::string encodeUrl(const char* input, size_t inputSize); + static std::string encodeUrl(folly::StringPiece text); + static std::string encodeUrl(const folly::IOBuf* inputBuffer); + static Status encodeUrl(std::string_view input, std::string& output); + + // Decoding Functions + /// Decodes the input Base64 encoded string. + static std::string decode(folly::StringPiece encodedText); + static void decode( + const std::pair& payload, + std::string& output); + static void decode(const char* input, size_t inputSize, char* outputBuffer); + static Status decode(std::string_view input, std::string& output); + + /// Decodes the input Base64 URL encoded string. + static std::string decodeUrl(folly::StringPiece encodedText); + static void decodeUrl( + const std::pair& payload, + std::string& output); + static Status decodeUrl(std::string_view input, std::string& output); private: - // Performs a reverse lookup in the reverse index to retrieve the original - // index of a character in the base. - static uint8_t - base64ReverseLookup(char p, const ReverseIndex& reverseIndex, Status& status); + // Checks if the input Base64 string is padded. + static inline bool isPadded(std::string_view input) { + size_t inputSize{input.size()}; + return (inputSize > 0 && input[inputSize - 1] == kPadding); + } + + // Counts the number of padding characters in encoded input. + static inline size_t numPadding(std::string_view input) { + size_t numPadding{0}; + size_t inputSize{input.size()}; + while (inputSize > 0 && input[inputSize - 1] == kPadding) { + numPadding++; + inputSize--; + } + return numPadding; + } + + // Reverse lookup helper function to get the original index of a Base64 + // character. + static uint8_t base64ReverseLookup( + char encodedChar, + const ReverseIndex& reverseIndex, + Status& status); - // Encodes the specified input using the provided charset. template static std::string encodeImpl(const T& input, const Charset& charset, bool includePadding); - // Encodes the specified input using the provided charset. template static Status encodeImpl( const T& input, const Charset& charset, bool includePadding, - char* output); + std::string& output); - // Decodes the specified input using the provided reverse lookup table. static Status decodeImpl( std::string_view input, - size_t inputSize, - char* output, - size_t outputSize, + std::string& output, const ReverseIndex& reverseIndex); - VELOX_FRIEND_TEST(Base64Test, testDecodeImpl); + + // Returns the actual size of the decoded data. Will also remove the padding + // length from the 'inputSize'. + static Status calculateDecodedSize( + std::string_view input, + size_t& inputSize, + size_t& decodedSize); + + // Calculates the encoded size based on input size. + static size_t calculateEncodedSize(size_t inputSize, bool withPadding = true); + + VELOX_FRIEND_TEST(Base64Test, isPadded); + VELOX_FRIEND_TEST(Base64Test, numPadding); + VELOX_FRIEND_TEST(Base64Test, calculateDecodedSize); }; } // namespace facebook::velox::encoding diff --git a/velox/common/encode/EncoderUtils.h b/velox/common/encode/EncoderUtils.h index f14b87cf5e79..ff6244cea79b 100644 --- a/velox/common/encode/EncoderUtils.h +++ b/velox/common/encode/EncoderUtils.h @@ -23,14 +23,16 @@ namespace facebook::velox::encoding { /// Padding character used in encoding. const static char kPadding = '='; -// Checks if there is padding in encoded input. -static inline bool isPadded(std::string_view input, size_t inputSize) { +// Checks if the input Base64 string is padded. +static inline bool isPadded(std::string_view input) { + size_t inputSize{input.size()}; return (inputSize > 0 && input[inputSize - 1] == kPadding); } // Counts the number of padding characters in encoded input. -static inline size_t numPadding(std::string_view input, size_t inputSize) { +static inline size_t numPadding(std::string_view input) { size_t numPadding{0}; + size_t inputSize{input.size()}; while (inputSize > 0 && input[inputSize - 1] == kPadding) { numPadding++; inputSize--; @@ -41,11 +43,11 @@ static inline size_t numPadding(std::string_view input, size_t inputSize) { // Validate the character in charset with ReverseIndex table template constexpr bool checkForwardIndex( - uint8_t idx, + uint8_t index, const Charset& charset, const ReverseIndex& reverseIndex) { - return (reverseIndex[static_cast(charset[idx])] == idx) && - (idx > 0 ? checkForwardIndex(idx - 1, charset, reverseIndex) : true); + return (reverseIndex[static_cast(charset[index])] == index) && + (index > 0 ? checkForwardIndex(index - 1, charset, reverseIndex) : true); } // Searches for a character within a charset up to a certain index. @@ -53,10 +55,10 @@ template constexpr bool findCharacterInCharset( const Charset& charset, uint8_t index, - const char character) { + const char targetChar) { return index < charset.size() && - ((charset[index] == character) || - findCharacterInCharset(charset, index + 1, character)); + ((charset[index] == targetChar) || + findCharacterInCharset(charset, index + 1, targetChar)); } // Checks the consistency of a reverse index mapping for a given character set. @@ -73,11 +75,11 @@ constexpr bool checkReverseIndex( template uint8_t reverseLookup( - char p, + char encodedChar, const ReverseIndexType& reverseIndex, Status& status, uint8_t kBase) { - auto curr = reverseIndex[(uint8_t)p]; + auto curr = reverseIndex[static_cast(encodedChar)]; if (curr >= kBase) { status = Status::UserError("invalid input string: contains invalid characters."); @@ -86,4 +88,69 @@ uint8_t reverseLookup( return curr; } +// Returns the actual size of the decoded data. Will also remove the padding +// length from the 'inputSize'. +static Status calculateDecodedSize( + std::string_view input, + size_t& inputSize, + size_t& decodedSize, + const int binaryBlockByteSize, + const int encodedBlockByteSize) { + if (inputSize == 0) { + decodedSize = 0; + return Status::OK(); + } + + // Check if the input string is padded + if (isPadded(input)) { + // If padded, ensure that the string length is a multiple of the encoded + // block size + if (inputSize % encodedBlockByteSize != 0) { + return Status::UserError( + "calculateDecodedSize() - invalid input string: " + "string length is not a multiple of the encoded block size."); + } + + decodedSize = (inputSize * binaryBlockByteSize) / encodedBlockByteSize; + auto paddingCount = numPadding(input); + inputSize -= paddingCount; + + // Adjust the needed size by deducting the bytes corresponding to the + // padding from the calculated size. + decodedSize -= + ((paddingCount * binaryBlockByteSize) + (encodedBlockByteSize - 1)) / + encodedBlockByteSize; + } else { + decodedSize = (inputSize * binaryBlockByteSize) / encodedBlockByteSize; + } + + return Status::OK(); +} + +// Calculates the encoded size based on input size. +static size_t calculateEncodedSize( + size_t inputSize, + bool includePadding, + const int binaryBlockByteSize, + const int encodedBlockByteSize) { + if (inputSize == 0) { + return 0; + } + + // Calculate the output size assuming that we are including padding. + size_t encodedSize = + ((inputSize + binaryBlockByteSize - 1) / binaryBlockByteSize) * + encodedBlockByteSize; + + if (!includePadding) { + // If the padding was not requested, subtract the padding bytes. + size_t remainder = inputSize % binaryBlockByteSize; + if (remainder != 0) { + encodedSize -= (binaryBlockByteSize - remainder); + } + } + + return encodedSize; +} + } // namespace facebook::velox::encoding diff --git a/velox/common/encode/tests/Base64Test.cpp b/velox/common/encode/tests/Base64Test.cpp index d226ce3bab39..6a429e2da1c7 100644 --- a/velox/common/encode/tests/Base64Test.cpp +++ b/velox/common/encode/tests/Base64Test.cpp @@ -17,148 +17,69 @@ #include "velox/common/encode/Base64.h" #include -#include "velox/common/base/Status.h" +#include "velox/common/base/Exceptions.h" #include "velox/common/base/tests/GTestUtils.h" namespace facebook::velox::encoding { -class Base64Test : public ::testing::Test { - protected: - void checkDecodedSize( - const std::string& encodedString, - size_t expectedEncodedSize, - size_t expectedDecodedSize) { - size_t encodedSize = expectedEncodedSize; - size_t decodedSize = 0; - EXPECT_EQ( - Status::OK(), - Base64::calculateDecodedSize(encodedString, encodedSize, decodedSize)); - EXPECT_EQ(expectedEncodedSize, encodedSize); - EXPECT_EQ(expectedDecodedSize, decodedSize); - } -}; +class Base64Test : public ::testing::Test {}; TEST_F(Base64Test, fromBase64) { - EXPECT_EQ("Hello, World!", Base64::decode("SGVsbG8sIFdvcmxkIQ==")); - EXPECT_EQ( - "Base64 encoding is fun.", - Base64::decode("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=")); - EXPECT_EQ("Simple text", Base64::decode("U2ltcGxlIHRleHQ=")); - EXPECT_EQ("1234567890", Base64::decode("MTIzNDU2Nzg5MA==")); - - // Check encoded strings without padding - EXPECT_EQ("Hello, World!", Base64::decode("SGVsbG8sIFdvcmxkIQ")); - EXPECT_EQ( - "Base64 encoding is fun.", - Base64::decode("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4")); - EXPECT_EQ("Simple text", Base64::decode("U2ltcGxlIHRleHQ")); - EXPECT_EQ("1234567890", Base64::decode("MTIzNDU2Nzg5MA")); -} - -TEST_F(Base64Test, calculateDecodedSizeProperSize) { - checkDecodedSize("SGVsbG8sIFdvcmxkIQ==", 18, 13); - checkDecodedSize("SGVsbG8sIFdvcmxkIQ", 18, 13); - checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", 31, 23); - checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", 31, 23); - checkDecodedSize("MTIzNDU2Nzg5MA==", 14, 10); - checkDecodedSize("MTIzNDU2Nzg5MA", 14, 10); -} + // Lambda function to reduce repetition in test cases + auto checkBase64Decode = [](const std::string& expected, + const std::string& encoded) { + EXPECT_EQ(expected, Base64::decode(folly::StringPiece(encoded))); + }; -TEST_F(Base64Test, calculateDecodedSizeImproperSize) { - size_t encodedSize{21}; - size_t decodedSize; + // Check encoded strings with padding + checkBase64Decode("Hello, World!", "SGVsbG8sIFdvcmxkIQ=="); + checkBase64Decode( + "Base64 encoding is fun.", "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4="); + checkBase64Decode("Simple text", "U2ltcGxlIHRleHQ="); + checkBase64Decode("1234567890", "MTIzNDU2Nzg5MA=="); - EXPECT_EQ( - Status::UserError( - "Base64::decode() - invalid input string: string length is not a multiple of 4."), - Base64::calculateDecodedSize( - "SGVsbG8sIFdvcmxkIQ===", encodedSize, decodedSize)); + // Check encoded strings without padding + checkBase64Decode("Hello, World!", "SGVsbG8sIFdvcmxkIQ"); + checkBase64Decode( + "Base64 encoding is fun.", "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4"); + checkBase64Decode("Simple text", "U2ltcGxlIHRleHQ"); + checkBase64Decode("1234567890", "MTIzNDU2Nzg5MA"); } -TEST_F(Base64Test, testDecodeImpl) { - constexpr const Base64::ReverseIndex reverseTable = { - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, - 255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, - 255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, - 25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33, - 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, - 49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255}; - - auto testDecode = [&](const std::string_view input, - char* output1, - size_t outputSize, - Status expectedStatus) { - EXPECT_EQ( - Base64::decodeImpl( - input, input.size(), output1, outputSize, reverseTable), - expectedStatus); +TEST_F(Base64Test, calculateDecodedSize) { + auto checkDecodedSize = [](std::string_view encodedString, + size_t initialEncodedSize, + size_t expectedEncodedSize, + size_t expectedDecodedSize, + Status expectedStatus = Status::OK()) { + size_t encoded_size = initialEncodedSize; + size_t decoded_size = 0; + Status status = + calculateDecodedSize(encodedString, encoded_size, decoded_size, 3, 4); + + if (expectedStatus.ok()) { + EXPECT_EQ(Status::OK(), status); + EXPECT_EQ(expectedEncodedSize, encoded_size); + EXPECT_EQ(expectedDecodedSize, decoded_size); + } else { + EXPECT_EQ(expectedStatus, status); + } }; - // Predefine buffer sizes and reuse. - char output1[20] = {}; - char output2[1] = {}; - char output3[1] = {}; - - // Invalid characters in the input string - testDecode( - "SGVsbG8gd29ybGQ$", - output1, - sizeof(output1), - Status::UserError("invalid input string: contains invalid characters.")); - - // All characters are padding characters - testDecode("====", output1, sizeof(output1), Status::OK()); - - // Invalid input size - testDecode( - "S", - output1, - sizeof(output1), - Status::UserError( - "Base64::decode() - invalid input string: string length cannot be 1 more than a multiple of 4.")); - - // Valid input without padding characters - testDecode("SGVsbG8gd29ybGQ", output1, sizeof(output1), Status::OK()); - EXPECT_STREQ(output1, "Hello world"); - - // Empty input string - testDecode("", output2, sizeof(output2), Status::OK()); - EXPECT_STREQ(output2, ""); - - // Invalid input size - testDecode( - "SGVsbG8gd29ybGQ===", - output1, - sizeof(output1), + // Using the lambda to reduce repetitive code + checkDecodedSize("SGVsbG8sIFdvcmxkIQ==", 20, 18, 13); + checkDecodedSize("SGVsbG8sIFdvcmxkIQ", 18, 18, 13); + checkDecodedSize( + "SGVsbG8sIFdvcmxkIQ===", + 21, + 0, + 0, Status::UserError( "Base64::decode() - invalid input string: string length is not a multiple of 4.")); - - // whiltespaces in the input string - testDecode( - " SGVsb G8gd2 9ybGQ= ", - output1, - sizeof(output1), - Status::UserError("invalid input string: contains invalid characters.")); - - // insufficient buffer size - testDecode( - " SGVsb G8gd2 9ybGQ= ", - output3, - sizeof(output3), - Status::UserError( - "Base64::decode() - invalid output string: output string is too small.")); + checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", 32, 31, 23); + checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", 31, 31, 23); + checkDecodedSize("MTIzNDU2Nzg5MA==", 16, 14, 10); + checkDecodedSize("MTIzNDU2Nzg5MA", 14, 14, 10); } TEST_F(Base64Test, testEncodeDecodeUrl) { diff --git a/velox/common/encode/tests/EncoderUtilsTests.cpp b/velox/common/encode/tests/EncoderUtilsTests.cpp index 22a1ce1af9fb..e112f8125349 100644 --- a/velox/common/encode/tests/EncoderUtilsTests.cpp +++ b/velox/common/encode/tests/EncoderUtilsTests.cpp @@ -22,14 +22,14 @@ namespace facebook::velox::encoding { class EncoderUtilsTest : public ::testing::Test {}; TEST_F(EncoderUtilsTest, isPadded) { - EXPECT_TRUE(isPadded("ABC=", 4)); - EXPECT_FALSE(isPadded("ABC", 3)); + EXPECT_TRUE(isPadded("ABC=")); + EXPECT_FALSE(isPadded("ABC")); } TEST_F(EncoderUtilsTest, numPadding) { - EXPECT_EQ(0, numPadding("ABC", 3)); - EXPECT_EQ(1, numPadding("ABC=", 4)); - EXPECT_EQ(2, numPadding("AB==", 4)); + EXPECT_EQ(0, numPadding("ABC")); + EXPECT_EQ(1, numPadding("ABC=")); + EXPECT_EQ(2, numPadding("AB==")); } } // namespace facebook::velox::encoding