From d444d2554e7fbd7ac7eedf6b285a0ab715437743 Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Thu, 1 Feb 2024 11:28:20 +0530 Subject: [PATCH] Clean up Base64 --- velox/common/encode/Base64.cpp | 86 +++++++-------------- velox/common/encode/Base64.h | 35 +-------- velox/common/encode/tests/Base64Test.cpp | 2 +- velox/functions/prestosql/BinaryFunctions.h | 2 +- 4 files changed, 30 insertions(+), 95 deletions(-) diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index f683e0450693..9473ef7008aa 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -18,7 +18,6 @@ #include #include #include -#include namespace facebook::velox::encoding { @@ -28,20 +27,23 @@ constexpr static int kBinaryBlockSize = 3; // Size of the encoded block after encoding. constexpr static int kEncodedBlockSize = 4; -constexpr const Base64::Charset kBase64Charset = { +// Encoding base to be used. +constexpr static int kBase = 64; + +constexpr const Charset kBase64Charset = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}; -constexpr const Base64::Charset kBase64UrlCharset = { +constexpr const Charset kBase64UrlCharset = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-', '_'}; -constexpr const Base64::ReverseIndex kBase64ReverseIndexTable = { +constexpr const ReverseIndex kBase64ReverseIndexTable = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, @@ -60,7 +62,7 @@ constexpr const Base64::ReverseIndex kBase64ReverseIndexTable = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}; -constexpr const Base64::ReverseIndex kBase64UrlReverseIndexTable = { +constexpr const ReverseIndex kBase64UrlReverseIndexTable = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, @@ -80,13 +82,6 @@ constexpr const Base64::ReverseIndex kBase64UrlReverseIndexTable = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}; -constexpr bool checkForwardIndex( - uint8_t idx, - const Base64::Charset& charset, - const Base64::ReverseIndex& table) { - return (table[static_cast(charset[idx])] == idx) && - (idx > 0 ? checkForwardIndex(idx - 1, charset, table) : true); -} // Verify that for every entry in kBase64Charset, the corresponding entry // in kBase64ReverseIndexTable is correct. static_assert( @@ -103,34 +98,17 @@ static_assert( kBase64UrlCharset, kBase64UrlReverseIndexTable), "kBase64UrlCharset has incorrect entries"); -// Similar to strchr(), but for null-terminated const strings. -// Another difference is that we do not consider "\0" to be present in the -// string. -// Returns true if "str" contains the character c. -constexpr bool constCharsetContains( - const Base64::Charset& charset, - uint8_t idx, - const char c) { - return idx < charset.size() && - ((charset[idx] == c) || constCharsetContains(charset, idx + 1, c)); -} -constexpr bool checkReverseIndex( - uint8_t idx, - const Base64::Charset& charset, - const Base64::ReverseIndex& table) { - return (table[idx] == 255 - ? !constCharsetContains(charset, 0, static_cast(idx)) - : (charset[table[idx]] == idx)) && - (idx > 0 ? checkReverseIndex(idx - 1, charset, table) : true); -} + // Verify that for every entry in kBase64ReverseIndexTable, the corresponding // entry in kBase64Charset is correct. static_assert( checkReverseIndex( sizeof(kBase64ReverseIndexTable) - 1, kBase64Charset, + kBase, kBase64ReverseIndexTable), "kBase64ReverseIndexTable has incorrect entries."); + // Verify that for every entry in kBase64ReverseIndexTable, the corresponding // entry in kBase64Charset is correct. // We can't run this check as the URL version has two duplicate entries so that @@ -217,13 +195,13 @@ template *wp++ = charset[(curr >> 12) & 0x3f]; *wp++ = charset[(curr >> 6) & 0x3f]; if (include_pad) { - *wp = kBase64Pad; + *wp = kPadding; } } else { *wp++ = charset[(curr >> 12) & 0x3f]; if (include_pad) { - *wp++ = kBase64Pad; - *wp = kBase64Pad; + *wp++ = kPadding; + *wp = kPadding; } } } @@ -315,18 +293,6 @@ void Base64::decode(const char* data, size_t size, char* output) { Base64::decode(data, size, output, out_len); } -uint8_t Base64::Base64ReverseLookup( - char p, - const Base64::ReverseIndex& reverse_lookup) { - auto curr = reverse_lookup[(uint8_t)p]; - if (curr >= 0x40) { - throw Base64Exception( - "Base64::decode() - invalid input string: invalid characters"); - } - - return curr; -} - size_t Base64::decode(const char* src, size_t src_len, char* dst, size_t dst_len) { return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable); @@ -338,12 +304,12 @@ size_t Base64::calculateDecodedSize(const char* data, size_t& size) { return 0; } - // Check if the input data is padded + // Check if the input data is padded. if (isPadded(data, size)) { // If padded, ensure that the string length is a multiple of the encoded // block size if (size % kEncodedBlockSize != 0) { - throw Base64Exception( + throw EncoderException( "Base64::decode() - invalid input string: " "string length is not a multiple of the encoded block size."); } @@ -352,7 +318,7 @@ size_t Base64::calculateDecodedSize(const char* data, size_t& size) { auto padding = countPadding(data, size); size -= padding; - // Adjust the needed size for padding + // Adjust the needed size for padding. return needed - ceil((padding * kBinaryBlockSize) / static_cast(kEncodedBlockSize)); @@ -364,7 +330,7 @@ size_t Base64::calculateDecodedSize(const char* data, size_t& size) { // Adjust the needed size for extra bytes, if present if (extra) { if (extra == 1) { - throw Base64Exception( + throw EncoderException( "Base64::decode() - invalid input string: " "string length cannot be 1 more than a multiple of 4."); } @@ -386,7 +352,7 @@ size_t Base64::decodeImpl( auto needed = calculateDecodedSize(src, src_len); if (dst_len < needed) { - throw Base64Exception( + throw EncoderException( "Base64::decode() - invalid output string: " "output string is too small."); } @@ -396,10 +362,10 @@ size_t Base64::decodeImpl( // Each character of the 4 encode 6 bits of the original, grab each with // the appropriate shifts to rebuild the original and then split that back // into the original 8 bit bytes. - uint32_t last = (Base64ReverseLookup(src[0], reverse_lookup) << 18) | - (Base64ReverseLookup(src[1], reverse_lookup) << 12) | - (Base64ReverseLookup(src[2], reverse_lookup) << 6) | - Base64ReverseLookup(src[3], reverse_lookup); + uint32_t last = (baseReverseLookup(kBase,src[0], reverse_lookup) << 18) | + (baseReverseLookup(kBase,src[1], reverse_lookup) << 12) | + (baseReverseLookup(kBase,src[2], reverse_lookup) << 6) | + baseReverseLookup(kBase,src[3], reverse_lookup); dst[0] = (last >> 16) & 0xff; dst[1] = (last >> 8) & 0xff; dst[2] = last & 0xff; @@ -408,14 +374,14 @@ size_t Base64::decodeImpl( // Handle the last 2-4 characters. This is similar to the above, but the // last 2 characters may or may not exist. DCHECK(src_len >= 2); - uint32_t last = (Base64ReverseLookup(src[0], reverse_lookup) << 18) | - (Base64ReverseLookup(src[1], reverse_lookup) << 12); + uint32_t last = (baseReverseLookup(kBase,src[0], reverse_lookup) << 18) | + (baseReverseLookup(kBase,src[1], reverse_lookup) << 12); dst[0] = (last >> 16) & 0xff; if (src_len > 2) { - last |= Base64ReverseLookup(src[2], reverse_lookup) << 6; + last |= baseReverseLookup(kBase,src[2], reverse_lookup) << 6; dst[1] = (last >> 8) & 0xff; if (src_len > 3) { - last |= Base64ReverseLookup(src[3], reverse_lookup); + last |= baseReverseLookup(kBase,src[3], reverse_lookup); dst[2] = last & 0xff; } } diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index b954766e682a..15b2d8ae8ca5 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -21,30 +21,17 @@ #include #include +#include "velox/common/encode/EncoderUtils.h" namespace facebook::velox::encoding { -class Base64Exception : public std::exception { - public: - explicit Base64Exception(const char* msg) : msg_(msg) {} - const char* what() const noexcept override { - return msg_; - } - - protected: - const char* msg_; -}; - class Base64 { public: - using Charset = std::array; - using ReverseIndex = std::array; - static std::string encode(const char* data, size_t len); static std::string encode(folly::StringPiece text); static std::string encode(const folly::IOBuf* text); - /// Returns encoded size for the input of the specified size. + // Returns encoded size for the input of the specified size. static size_t calculateEncodedSize(size_t size, bool withPadding = true); /// Encodes the specified number of characters from the 'data' and writes the @@ -91,25 +78,7 @@ class Base64 { static void decodeUrl(const char* src, size_t src_len, char* dst, size_t dst_len); - constexpr static char kBase64Pad = '='; - private: - static inline bool isPadded(const char* data, size_t len) { - return (len > 0 && data[len - 1] == kBase64Pad) ? true : false; - } - - static inline size_t countPadding(const char* src, size_t len) { - size_t numPadding{0}; - while (len > 0 && src[len - 1] == kBase64Pad) { - numPadding++; - len--; - } - - return numPadding; - } - - static uint8_t Base64ReverseLookup(char p, const ReverseIndex& table); - template static std::string encodeImpl(const T& data, const Charset& charset, bool include_pad); diff --git a/velox/common/encode/tests/Base64Test.cpp b/velox/common/encode/tests/Base64Test.cpp index caa5c73db252..eb7575834aab 100644 --- a/velox/common/encode/tests/Base64Test.cpp +++ b/velox/common/encode/tests/Base64Test.cpp @@ -61,7 +61,7 @@ TEST_F(Base64Test, calculateDecodedSizeProperSize) { encoded_size = 21; EXPECT_THROW( Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size), - facebook::velox::encoding::Base64Exception); + facebook::velox::encoding::EncoderException); encoded_size = 32; EXPECT_EQ( diff --git a/velox/functions/prestosql/BinaryFunctions.h b/velox/functions/prestosql/BinaryFunctions.h index 94180fe2ab3b..9d7c7009ae6a 100644 --- a/velox/functions/prestosql/BinaryFunctions.h +++ b/velox/functions/prestosql/BinaryFunctions.h @@ -293,7 +293,7 @@ struct FromBase64Function { encoding::Base64::calculateDecodedSize(input.data(), inputSize)); encoding::Base64::decode( input.data(), inputSize, result.data(), result.size()); - } catch (const encoding::Base64Exception& e) { + } catch (const encoding::EncoderException& e) { VELOX_USER_FAIL(e.what()); } }